diff --git a/src/client.rs b/src/client.rs index dd8ab90..ff7a62b 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,3 +1,4 @@ +use core::panic; use std::{ collections::HashMap, io::{Read, Write}, @@ -75,6 +76,41 @@ fn connect(params: &ClientParams) -> Connection { ) } +fn resync(tcp: &mut SocketAdapter) { + println!(); + eprintln!("Server version mismatch or broken connection. Re-syncing in case of the latter..."); + tcp.internal.set_print(false); + tcp.write_now().unwrap(); + tcp.write(&[PacketType::Resync.ordinal() as u8]).unwrap(); + tcp.write_now().unwrap(); + eprintln!( + "Sent resync packet. Server should now wait 8 seconds and then send a resync-echo packet." + ); + tcp.set_nonblocking(true); + let mut buf = [0; 4096]; + // read all packets that are still pending. + while Some(Some(4096)) == tcp.poll(&mut buf).ok() {} + // wait 5 seconds + thread::sleep(Duration::from_secs(5)); + // read all packets that are still pending. + while Some(Some(4096)) == tcp.poll(&mut buf).ok() {} + // server should now have stopped sending packets. waiting 5 more seconds so the server has time to + // send the resync packet. + thread::sleep(Duration::from_secs(5)); + let mut buf = [0]; + eprintln!("Trying to receive the resync echo..."); + tcp.set_nonblocking(false); + tcp.poll(&mut buf).unwrap(); + if buf[0] as i8 == PacketType::ResyncEcho.ordinal() { + eprintln!("Successfully resynced. RevPFW3 can continue."); + } else { + eprintln!("Resync was not successful. Stopping."); + panic!("broken connection or server version mismatch."); + } + tcp.set_nonblocking(true); + tcp.internal.set_print(true); +} + pub fn client(params: ClientParams) { let mut buf1 = [0u8; 1]; let mut buf4 = [0u8; 4]; @@ -82,6 +118,7 @@ pub fn client(params: ClientParams) { let mut buf16 = [0u8; 16]; let mut buf = [0; 1024]; let mut tcp = connect(¶ms); + tcp.set_print(false); println!("Syncing..."); tcp.write_all(&[b'R', b'P', b'F', 30]).unwrap(); println!("Authenticating..."); @@ -96,8 +133,8 @@ pub fn client(params: ClientParams) { } tcp.write_all(&[PacketType::KeepAlive.ordinal() as u8]) .unwrap(); + tcp.set_print(true); - println!(); println!("READY!"); let mut tcp = SocketAdapter::new(tcp); @@ -160,8 +197,10 @@ pub fn client(params: ClientParams) { continue; } - let pt = PacketType::from_ordinal(buf1[0] as i8) - .expect("server/client version mismatch or broken TCP"); + let Some(pt) = PacketType::from_ordinal(buf1[0] as i8) else { + resync(&mut tcp); + continue; + }; tcp.set_nonblocking(false); match pt { PacketType::NewClient => { @@ -197,7 +236,7 @@ pub fn client(params: ClientParams) { } } - PacketType::ServerData => unreachable!(), + PacketType::ServerData => resync(&mut tcp), PacketType::ClientExceededBuffer => { tcp.internal.read_exact(&mut buf8).unwrap(); @@ -210,6 +249,17 @@ pub fn client(params: ClientParams) { socket.punish(amount); } } + + PacketType::Resync => { + println!(); + tcp.internal.set_print(false); + eprintln!("Server asked for re-sync. Waiting 8 seconds, then initiating resync."); + thread::sleep(Duration::from_secs(8)); + resync(&mut tcp); + } + + // this one shouldnt happen. + PacketType::ResyncEcho => resync(&mut tcp), } tcp.set_nonblocking(true); } diff --git a/src/server_connection.rs b/src/connection.rs similarity index 93% rename from src/server_connection.rs rename to src/connection.rs index f8662b0..84b007e 100644 --- a/src/server_connection.rs +++ b/src/connection.rs @@ -26,6 +26,7 @@ pub struct Connection { close_thunk: fn(NonNull<()>) -> io::Result<()>, is_nb: bool, is_serial: bool, + print: bool, print_status: PrintStatus, } @@ -86,6 +87,7 @@ impl Connection { }, is_nb: false, is_serial: false, + print: true, print_status: if print { PrintStatus::Yes { last_print: SystemTime::now(), @@ -114,6 +116,7 @@ impl Connection { close_thunk: |_data| Ok(()), is_nb: false, is_serial: true, + print: true, print_status: if print { PrintStatus::Yes { last_print: SystemTime::now(), @@ -147,6 +150,10 @@ impl Connection { self.is_serial } + pub fn set_print(&mut self, print: bool) { + self.print = print; + } + fn print_status(&mut self, add: usize) { if let &mut PrintStatus::Yes { ref mut last_print, @@ -159,8 +166,12 @@ impl Connection { let diff = *bytes - *last_bytes; let bps = to_units(diff); let total = to_units(*bytes); - print!("\r\x1b[KCurrent transfer speed: {bps}B/s, transferred {total}B so far."); - stdout().flush().unwrap(); + if self.print { + print!( + "\r\x1b[KCurrent transfer speed: {bps}B/s, transferred {total}B so far." + ); + stdout().flush().unwrap(); + } *last_bytes = *bytes; *last_print = SystemTime::now(); } diff --git a/src/lib.rs b/src/lib.rs index c26ba90..818f7a2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,15 +1,15 @@ mod client; +mod connection; mod packet; mod server; -mod server_connection; mod socket_adapter; use std::io::{Error, ErrorKind}; pub use client::*; +pub(crate) use connection::*; pub(crate) use packet::*; pub use server::*; -pub(crate) use server_connection::*; pub(crate) use socket_adapter::*; pub(crate) fn io_sync(result: Result) -> Result, Error> { diff --git a/src/packet.rs b/src/packet.rs index 921e688..3c938cf 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -8,4 +8,6 @@ pub(crate) enum PacketType { ClientData, ServerData, ClientExceededBuffer, + Resync, + ResyncEcho, } diff --git a/src/server.rs b/src/server.rs index f1de154..ad069d7 100644 --- a/src/server.rs +++ b/src/server.rs @@ -10,6 +10,31 @@ use std::{ use crate::{io_sync, Connection, PacketType, SocketAdapter}; +fn resync(tcp: &mut SocketAdapter) { + println!(); + eprintln!("Client version mismatch or broken connection. Re-syncing in case of the latter..."); + tcp.internal.set_print(false); + tcp.write_now().unwrap(); + tcp.write(&[PacketType::Resync.ordinal() as u8]).unwrap(); + tcp.write_now().unwrap(); + eprintln!( + "Sent resync packet. Client should now wait 8 seconds and then send a resync packet back, initiating a normal re-sync." + ); + tcp.set_nonblocking(true); + let mut buf = [0; 4096]; + // read all packets that are still pending. + while Some(Some(4096)) == tcp.poll(&mut buf).ok() {} + // wait 5 seconds + thread::sleep(Duration::from_secs(5)); + // read all packets that are still pending. + while Some(Some(4096)) == tcp.poll(&mut buf).ok() {} + // server should now have stopped sending packets. waiting 5 more seconds so the client has time to + // send the resync packet. + thread::sleep(Duration::from_secs(5)); + tcp.set_nonblocking(true); + tcp.internal.set_print(true); +} + pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { let mut buf1 = [0u8; 1]; let mut buf4 = [0u8; 4]; @@ -114,8 +139,10 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { continue; } - let pt = PacketType::from_ordinal(buf1[0] as i8) - .expect("server/client version mismatch or broken TCP"); + let Some(pt) = PacketType::from_ordinal(buf1[0] as i8) else { + resync(&mut tcp); + continue; + }; tcp.set_nonblocking(false); match pt { PacketType::NewClient => unreachable!(), @@ -131,7 +158,7 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { last_keep_alive = SystemTime::now(); } - PacketType::ClientData => unreachable!(), + PacketType::ClientData => resync(&mut tcp), PacketType::ServerData => { tcp.internal.read_exact(&mut buf8).unwrap(); @@ -156,6 +183,24 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { socket.punish(amount); } } + + PacketType::Resync => { + println!(); + tcp.internal.set_print(false); + eprintln!( + "Client asked for a re-sync. Waiting 8 seconds, then sending resync-echo." + ); + tcp.write_now().unwrap(); + thread::sleep(Duration::from_secs(8)); + tcp.write(&[PacketType::ResyncEcho.ordinal() as u8]) + .unwrap(); + tcp.write_now().unwrap(); + eprintln!("Resync-Echo sent. Going back to normal operation."); + tcp.internal.set_print(true); + } + + // this one can't happen, it should only come from the server + PacketType::ResyncEcho => resync(&mut tcp), } tcp.set_nonblocking(true); } diff --git a/src/socket_adapter.rs b/src/socket_adapter.rs index 5534760..feaac32 100644 --- a/src/socket_adapter.rs +++ b/src/socket_adapter.rs @@ -89,6 +89,29 @@ impl SocketAdapter { self.update() } + pub fn write_now(&mut self) -> Result<(), Error> { + if let Some(ref x) = self.broken { + return Err(Error::from(*x)); + } + if self.to_write == 0 { + return Ok(()); + } + match { + self.internal.set_nonblocking(false)?; + let r = self + .internal + .write_all(&self.write[self.written..self.written + self.to_write]); + self.internal.set_nonblocking(self.is_nonblocking)?; + r + } { + Ok(()) => Ok(()), + Err(x) => { + self.broken = Some(Broken::DirectErr(x.kind(), "io error")); + Err(x) + } + } + } + pub fn update(&mut self) -> Result<(), Error> { if Some(SystemTime::UNIX_EPOCH.elapsed().unwrap().as_micros()) < self.ignore_until { return Ok(());