diff --git a/src/client.rs b/src/client.rs index a97c824..d67826e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -31,9 +31,8 @@ pub fn client(ip: &str, port: u16, dest_ip: &str, dest_port: u16, key: &str, sle println!("READY!"); - tcp.set_nonblocking(true).unwrap(); - - let mut tcp = SocketAdapter::new(tcp); + let mut tcp = SocketAdapter::new(tcp, true); + tcp.set_nonblocking(true); let mut sockets: Vec = Vec::new(); let mut last_keep_alive = SystemTime::now(); loop { @@ -83,12 +82,12 @@ pub fn client(ip: &str, port: u16, dest_ip: &str, dest_port: u16, key: &str, sle let pt = PacketType::from_ordinal(buf1[0] as i8) .expect("server/client version mismatch or broken TCP"); - tcp.internal.set_nonblocking(false).unwrap(); + tcp.set_nonblocking(false); match pt { PacketType::NewClient => { - let tcp = TcpStream::connect((dest_ip, dest_port)).unwrap(); - tcp.set_nonblocking(true).unwrap(); - sockets.push(SocketAdapter::new(tcp)); + let mut tcp = SocketAdapter::new(TcpStream::connect((dest_ip, dest_port)).unwrap(), false); + tcp.set_nonblocking(true); + sockets.push(tcp); } PacketType::CloseClient => { @@ -116,6 +115,6 @@ pub fn client(ip: &str, port: u16, dest_ip: &str, dest_port: u16, key: &str, sle PacketType::ServerData => unreachable!(), } - tcp.internal.set_nonblocking(true).unwrap(); + tcp.set_nonblocking(true); } } diff --git a/src/server.rs b/src/server.rs index 0717e90..61b9529 100644 --- a/src/server.rs +++ b/src/server.rs @@ -38,10 +38,10 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { tcp.write_all(&mut ['R' as u8, 'P' as u8, 'F' as u8, 30]) .unwrap(); - tcp.set_nonblocking(true).unwrap(); tcpl.set_nonblocking(true).unwrap(); - let mut tcp = SocketAdapter::new(tcp); + let mut tcp = SocketAdapter::new(tcp, true); + tcp.set_nonblocking(true); let mut sockets: Vec = Vec::new(); let mut last_keep_alive_sent = SystemTime::now(); let mut last_keep_alive = SystemTime::now(); @@ -57,8 +57,9 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { } if let Ok(new) = tcpl.accept() { - new.0.set_nonblocking(true).unwrap(); - sockets.push(SocketAdapter::new(new.0)); + let mut new = SocketAdapter::new(new.0, false); + new.set_nonblocking(true); + sockets.push(new); tcp.write(&[PacketType::NewClient.ordinal() as u8]).unwrap(); did_anything = true; } @@ -103,7 +104,7 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { let pt = PacketType::from_ordinal(buf1[0] as i8) .expect("server/client version mismatch or broken TCP"); - tcp.internal.set_nonblocking(false).unwrap(); + tcp.set_nonblocking(false); match pt { PacketType::NewClient => unreachable!(), @@ -131,6 +132,6 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { let _ = sockets[idx].write_later(&buf[..len]); } } - tcp.internal.set_nonblocking(true).unwrap(); + tcp.set_nonblocking(true); } } diff --git a/src/socket_adapter.rs b/src/socket_adapter.rs index 1933311..adc300b 100644 --- a/src/socket_adapter.rs +++ b/src/socket_adapter.rs @@ -25,32 +25,53 @@ pub(crate) struct SocketAdapter { pub(crate) internal: TcpStream, written: usize, to_write: usize, - write: [u8; 1_048_576], // 1MiB + write: [u8; 65536], broken: Option, + wait_if_full: bool, + is_nonblocking: bool, } impl SocketAdapter { - pub fn new(tcp: TcpStream) -> SocketAdapter { + pub fn new(tcp: TcpStream, wait_if_full: bool) -> SocketAdapter { Self { internal: tcp, written: 0, to_write: 0, - write: [0u8; 1_048_576], + write: [0u8; 65536], broken: None, + wait_if_full, + is_nonblocking: false, } } + pub fn set_nonblocking(&mut self, nonblocking: bool) { + if let Err(x) = self.internal.set_nonblocking(nonblocking) { + self.broken = Some(Broken::OsErr(x.raw_os_error().unwrap())); + return; + } + self.is_nonblocking = nonblocking; + } + pub fn write_later(&mut self, buf: &[u8]) -> Result<(), Error> { if let Some(ref x) = self.broken { return Err(Error::from(*x)); } let lidx = self.to_write + self.written + buf.len(); - if lidx > self.write.len() && lidx - self.to_write < self.write.len() { + if lidx > self.write.len() && self.to_write + buf.len() < self.write.len() { self.write .copy_within(self.written..self.written + self.to_write, 0); self.written = 0; } let Some(x) = self.write.get_mut(self.to_write + self.written..self.to_write + self.written + buf.len()) else { + if self.wait_if_full { + self.internal.set_nonblocking(false)?; + self.internal.write_all(&self.write[self.written..self.written + self.to_write])?; + self.internal.set_nonblocking(self.is_nonblocking)?; + self.written = 0; + self.to_write = buf.len(); + self.write[..buf.len()].copy_from_slice(buf); + return Ok(()); + } self.broken = Some(Broken::DirectErr(ErrorKind::TimedOut)); return Err(ErrorKind::TimedOut.into()); }; @@ -61,12 +82,7 @@ impl SocketAdapter { pub fn write(&mut self, buf: &[u8]) -> Result<(), Error> { self.write_later(buf)?; - if let Err(x) = self.update() { - self.broken = Some(Broken::OsErr(x.raw_os_error().unwrap())); - Err(x) - } else { - Ok(()) - } + self.update() } pub fn update(&mut self) -> Result<(), Error> { @@ -76,10 +92,14 @@ impl SocketAdapter { if self.to_write == 0 { return Ok(()); } - match self - .internal - .write(&self.write[self.written..self.written + self.to_write]) - { + match { + self.internal.set_nonblocking(true)?; + let r = self + .internal + .write(&self.write[self.written..self.written + self.to_write]); + self.internal.set_nonblocking(self.is_nonblocking)?; + r + } { Ok(x) => { self.to_write -= x; self.written += x;