diff --git a/src/client.rs b/src/client.rs index 82a765c..943a547 100644 --- a/src/client.rs +++ b/src/client.rs @@ -10,7 +10,7 @@ use std::{ use serial::SerialPort; -use crate::{io_sync, Connection, PacketType, SocketAdapter}; +use crate::{Connection, PacketType, SocketAdapter}; pub struct ClientParams<'a> { pub server_ip: &'a str, @@ -79,7 +79,6 @@ fn connect(params: &ClientParams) -> Connection { fn resync(tcp: &mut SocketAdapter) { println!(); eprintln!("Server version mismatch or broken connection. Re-syncing in case of the latter..."); - tcp.set_nonblocking(true); tcp.internal.set_print(false); tcp.write_now().unwrap(); tcp.write(&[PacketType::Resync.ordinal() as u8]).unwrap(); @@ -94,13 +93,10 @@ fn resync(tcp: &mut SocketAdapter) { thread::sleep(Duration::from_secs(5)); // read all packets that are still pending. while let Some(Some(_x @ 1..)) = tcp.poll(&mut buf).ok() {} - // server should now have stopped sending packets. waiting 5 more seconds so the server has time to - // send the resync packet. - thread::sleep(Duration::from_secs(5)); + // server should now have stopped sending packets. let mut buf = [0]; eprintln!("Trying to receive the resync echo..."); - tcp.set_nonblocking(false); - tcp.poll(&mut buf).unwrap(); + tcp.read_now(&mut buf).unwrap(); if buf[0] as i8 == PacketType::ResyncEcho.ordinal() { eprintln!("Successfully resynced. RevPFW3 can continue."); } else { @@ -141,7 +137,6 @@ pub fn client(params: ClientParams) { let mut id = 0; let mut last_keep_alive = SystemTime::now(); loop { - tcp.set_nonblocking(true); thread::sleep(Duration::from_millis(params.rate_limit_sleep)); let mut did_anything = false; @@ -186,10 +181,7 @@ pub fn client(params: ClientParams) { } tcp.update().unwrap(); - if io_sync(tcp.internal.read_exact(&mut buf1)) - .unwrap() - .is_none() - { + if tcp.poll_exact(&mut buf1).unwrap().is_none() { if !did_anything { thread::sleep(Duration::from_millis(params.sleep_delay_ms)); } @@ -200,19 +192,17 @@ pub fn client(params: ClientParams) { resync(&mut tcp); continue; }; - tcp.set_nonblocking(false); match pt { PacketType::NewClient => { - let mut tcp = SocketAdapter::new(Connection::new_tcp( + let tcp = SocketAdapter::new(Connection::new_tcp( TcpStream::connect((params.dest_ip, params.dest_port)).unwrap(), false, )); - tcp.set_nonblocking(true); sockets.insert((id, id += 1).0, tcp); } PacketType::CloseClient => { - tcp.internal.read_exact(&mut buf8).unwrap(); + tcp.read_now(&mut buf8).unwrap(); if let Some(x) = sockets.remove(&u64::from_be_bytes(buf8)) { let _ = x.internal.close(); } @@ -224,11 +214,11 @@ pub fn client(params: ClientParams) { } PacketType::ClientData => { - tcp.internal.read_exact(&mut buf8).unwrap(); + tcp.read_now(&mut buf8).unwrap(); let idx = u64::from_be_bytes(buf8); - tcp.internal.read_exact(&mut buf4).unwrap(); + tcp.read_now(&mut buf4).unwrap(); let len = u32::from_be_bytes(buf4) as usize; - tcp.internal.read_exact(&mut buf[..len]).unwrap(); + tcp.read_now(&mut buf[..len]).unwrap(); if let Some(socket) = sockets.get_mut(&idx) { let _ = socket.write_later(&buf[..len]); @@ -238,9 +228,9 @@ pub fn client(params: ClientParams) { PacketType::ServerData => resync(&mut tcp), PacketType::ClientExceededBuffer => { - tcp.internal.read_exact(&mut buf8).unwrap(); + tcp.read_now(&mut buf8).unwrap(); let idx = u64::from_be_bytes(buf8); - tcp.internal.read_exact(&mut buf16).unwrap(); + tcp.read_now(&mut buf16).unwrap(); let amount = u128::from_be_bytes(buf16); // a single connection doesn't need overuse-penalties @@ -260,6 +250,5 @@ pub fn client(params: ClientParams) { // this one shouldnt happen. PacketType::ResyncEcho => resync(&mut tcp), } - tcp.set_nonblocking(true); } } diff --git a/src/connection.rs b/src/connection.rs index 775a21c..5a5327a 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -48,17 +48,18 @@ impl Read for Connection { } fn read_exact(&mut self, mut buf: &mut [u8]) -> io::Result<()> { + let len = buf.len(); while !buf.is_empty() { match self.read(buf) { - Ok(0) if self.is_nb => { + Ok(0) if self.is_nb && buf.len() == len => { return Err(io::Error::new(ErrorKind::WouldBlock, "would block")) } - Ok(0) => (), + Ok(0) => break, Ok(n) => { let tmp = buf; buf = &mut tmp[n..]; } - Err(ref e) if e.kind() == ErrorKind::Interrupted => (), + Err(ref e) if e.kind() == ErrorKind::Interrupted => {} Err(e) => return Err(e), } } diff --git a/src/server.rs b/src/server.rs index 9634c9e..bc35958 100644 --- a/src/server.rs +++ b/src/server.rs @@ -8,12 +8,11 @@ use std::{ vec, }; -use crate::{io_sync, Connection, PacketType, SocketAdapter}; +use crate::{Connection, PacketType, SocketAdapter}; fn resync(tcp: &mut SocketAdapter) { println!(); eprintln!("Client version mismatch or broken connection. Re-syncing in case of the latter..."); - tcp.set_nonblocking(true); tcp.internal.set_print(false); tcp.write_now().unwrap(); tcp.write(&[PacketType::Resync.ordinal() as u8]).unwrap(); @@ -28,9 +27,7 @@ fn resync(tcp: &mut SocketAdapter) { thread::sleep(Duration::from_secs(5)); // read all packets that are still pending. while let Some(Some(_x @ 1..)) = tcp.poll(&mut buf).ok() {} - // server should now have stopped sending packets. waiting 5 more seconds so the client has time to - // send the resync packet. - thread::sleep(Duration::from_secs(5)); + // client should now have stopped sending packets. tcp.internal.set_print(true); } @@ -72,7 +69,6 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { let mut last_keep_alive_sent = SystemTime::now(); let mut last_keep_alive = SystemTime::now(); loop { - tcp.set_nonblocking(true); let mut did_anything = false; if last_keep_alive_sent.elapsed().unwrap().as_secs() >= 10 { @@ -84,8 +80,7 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { } if let Ok(new) = tcpl.accept() { - let mut new = SocketAdapter::new(Connection::new_tcp(new.0, false)); - new.set_nonblocking(true); + let new = SocketAdapter::new(Connection::new_tcp(new.0, false)); sockets.insert((id, id += 1).0, new); tcp.write(&[PacketType::NewClient.ordinal() as u8]).unwrap(); did_anything = true; @@ -128,10 +123,7 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { } tcp.update().unwrap(); - if io_sync(tcp.internal.read_exact(&mut buf1)) - .unwrap() - .is_none() - { + if tcp.poll_exact(&mut buf1).unwrap().is_none() { if !did_anything { thread::sleep(Duration::from_millis(sleep_delay_ms)); } @@ -142,12 +134,11 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { resync(&mut tcp); continue; }; - tcp.set_nonblocking(false); match pt { PacketType::NewClient => resync(&mut tcp), PacketType::CloseClient => { - tcp.internal.read_exact(&mut buf8).unwrap(); + tcp.read_now(&mut buf8).unwrap(); if let Some(x) = sockets.remove(&u64::from_be_bytes(buf8)) { let _ = x.internal.close(); } @@ -160,11 +151,11 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { PacketType::ClientData => resync(&mut tcp), PacketType::ServerData => { - tcp.internal.read_exact(&mut buf8).unwrap(); + tcp.read_now(&mut buf8).unwrap(); let idx = u64::from_be_bytes(buf8); - tcp.internal.read_exact(&mut buf4).unwrap(); + tcp.read_now(&mut buf4).unwrap(); let len = u32::from_be_bytes(buf4) as usize; - tcp.internal.read_exact(&mut buf[..len]).unwrap(); + tcp.read_now(&mut buf[..len]).unwrap(); if let Some(socket) = sockets.get_mut(&idx) { let _ = socket.write_later(&buf[..len]); @@ -172,9 +163,9 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { } PacketType::ClientExceededBuffer => { - tcp.internal.read_exact(&mut buf8).unwrap(); + tcp.read_now(&mut buf8).unwrap(); let idx = u64::from_be_bytes(buf8); - tcp.internal.read_exact(&mut buf16).unwrap(); + tcp.read_now(&mut buf16).unwrap(); let amount = u128::from_be_bytes(buf16); // a single connection doesn't need overuse-penalties @@ -201,6 +192,5 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { // this one can't happen, it should only come from the server PacketType::ResyncEcho => resync(&mut tcp), } - tcp.set_nonblocking(true); } } diff --git a/src/socket_adapter.rs b/src/socket_adapter.rs index 1f97946..2606570 100644 --- a/src/socket_adapter.rs +++ b/src/socket_adapter.rs @@ -8,14 +8,12 @@ use crate::{io_sync, Connection}; #[derive(Clone, Copy)] enum Broken { - OsErr(i32), DirectErr(ErrorKind, &'static str), } impl From for Error { fn from(value: Broken) -> Self { match value { - Broken::OsErr(x) => Error::from_raw_os_error(x), Broken::DirectErr(x, s) => Error::new(x, s), } } @@ -28,7 +26,6 @@ pub(crate) struct SocketAdapter { write: [u8; 65536], broken: Option, accumulated_delay: u128, - is_nonblocking: bool, ignore_until: Option, } @@ -41,19 +38,10 @@ impl SocketAdapter { write: [0u8; 65536], broken: None, accumulated_delay: 0, - is_nonblocking: false, ignore_until: None, } } - pub fn set_nonblocking(&mut self, nonblocking: bool) { - if let Err(x) = self.internal.set_nonblocking(nonblocking) { - self.broken = Some(Broken::OsErr(x.raw_os_error().unwrap())); - return; - } - self.is_nonblocking = nonblocking; - } - pub fn write_later(&mut self, buf: &[u8]) -> Result<(), Error> { if let Some(ref x) = self.broken { return Err(Error::from(*x)); @@ -72,7 +60,6 @@ impl SocketAdapter { self.internal.set_nonblocking(false)?; self.internal .write_all(&self.write[self.written..self.written + self.to_write])?; - self.internal.set_nonblocking(self.is_nonblocking)?; self.written = 0; self.to_write = buf.len(); self.write[..buf.len()].copy_from_slice(buf); @@ -101,7 +88,6 @@ impl SocketAdapter { let r = self .internal .write_all(&self.write[self.written..self.written + self.to_write]); - self.internal.set_nonblocking(self.is_nonblocking)?; r } { Ok(()) => { @@ -131,7 +117,6 @@ impl SocketAdapter { let r = self .internal .write(&self.write[self.written..self.written + self.to_write]); - self.internal.set_nonblocking(self.is_nonblocking)?; r } { Ok(x) => { @@ -150,11 +135,30 @@ impl SocketAdapter { } } + pub fn read_now(&mut self, buf: &mut [u8]) -> Result, Error> { + if Some(SystemTime::UNIX_EPOCH.elapsed().unwrap().as_micros()) < self.ignore_until { + return Ok(None); + } + self.update()?; + self.internal.set_nonblocking(false)?; + io_sync(self.internal.read_exact(buf)) + } + + pub fn poll_exact(&mut self, buf: &mut [u8]) -> Result, Error> { + if Some(SystemTime::UNIX_EPOCH.elapsed().unwrap().as_micros()) < self.ignore_until { + return Ok(None); + } + self.update()?; + self.internal.set_nonblocking(true)?; + io_sync(self.internal.read_exact(buf)) + } + pub fn poll(&mut self, buf: &mut [u8]) -> Result, Error> { if Some(SystemTime::UNIX_EPOCH.elapsed().unwrap().as_micros()) < self.ignore_until { return Ok(None); } self.update()?; + self.internal.set_nonblocking(true)?; io_sync(self.internal.read(buf)) }