make central connection unable to overflow, decrease buffer size

This commit is contained in:
Daniella / Tove 2023-01-31 22:06:55 +01:00
parent c3fd53bfbf
commit 1792c2b8d3
3 changed files with 48 additions and 28 deletions

View file

@ -31,9 +31,8 @@ pub fn client(ip: &str, port: u16, dest_ip: &str, dest_port: u16, key: &str, sle
println!("READY!"); println!("READY!");
tcp.set_nonblocking(true).unwrap(); let mut tcp = SocketAdapter::new(tcp, true);
tcp.set_nonblocking(true);
let mut tcp = SocketAdapter::new(tcp);
let mut sockets: Vec<SocketAdapter> = Vec::new(); let mut sockets: Vec<SocketAdapter> = Vec::new();
let mut last_keep_alive = SystemTime::now(); let mut last_keep_alive = SystemTime::now();
loop { loop {
@ -83,12 +82,12 @@ pub fn client(ip: &str, port: u16, dest_ip: &str, dest_port: u16, key: &str, sle
let pt = PacketType::from_ordinal(buf1[0] as i8) let pt = PacketType::from_ordinal(buf1[0] as i8)
.expect("server/client version mismatch or broken TCP"); .expect("server/client version mismatch or broken TCP");
tcp.internal.set_nonblocking(false).unwrap(); tcp.set_nonblocking(false);
match pt { match pt {
PacketType::NewClient => { PacketType::NewClient => {
let tcp = TcpStream::connect((dest_ip, dest_port)).unwrap(); let mut tcp = SocketAdapter::new(TcpStream::connect((dest_ip, dest_port)).unwrap(), false);
tcp.set_nonblocking(true).unwrap(); tcp.set_nonblocking(true);
sockets.push(SocketAdapter::new(tcp)); sockets.push(tcp);
} }
PacketType::CloseClient => { PacketType::CloseClient => {
@ -116,6 +115,6 @@ pub fn client(ip: &str, port: u16, dest_ip: &str, dest_port: u16, key: &str, sle
PacketType::ServerData => unreachable!(), PacketType::ServerData => unreachable!(),
} }
tcp.internal.set_nonblocking(true).unwrap(); tcp.set_nonblocking(true);
} }
} }

View file

@ -38,10 +38,10 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) {
tcp.write_all(&mut ['R' as u8, 'P' as u8, 'F' as u8, 30]) tcp.write_all(&mut ['R' as u8, 'P' as u8, 'F' as u8, 30])
.unwrap(); .unwrap();
tcp.set_nonblocking(true).unwrap();
tcpl.set_nonblocking(true).unwrap(); tcpl.set_nonblocking(true).unwrap();
let mut tcp = SocketAdapter::new(tcp); let mut tcp = SocketAdapter::new(tcp, true);
tcp.set_nonblocking(true);
let mut sockets: Vec<SocketAdapter> = Vec::new(); let mut sockets: Vec<SocketAdapter> = Vec::new();
let mut last_keep_alive_sent = SystemTime::now(); let mut last_keep_alive_sent = SystemTime::now();
let mut last_keep_alive = SystemTime::now(); let mut last_keep_alive = SystemTime::now();
@ -57,8 +57,9 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) {
} }
if let Ok(new) = tcpl.accept() { if let Ok(new) = tcpl.accept() {
new.0.set_nonblocking(true).unwrap(); let mut new = SocketAdapter::new(new.0, false);
sockets.push(SocketAdapter::new(new.0)); new.set_nonblocking(true);
sockets.push(new);
tcp.write(&[PacketType::NewClient.ordinal() as u8]).unwrap(); tcp.write(&[PacketType::NewClient.ordinal() as u8]).unwrap();
did_anything = true; did_anything = true;
} }
@ -103,7 +104,7 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) {
let pt = PacketType::from_ordinal(buf1[0] as i8) let pt = PacketType::from_ordinal(buf1[0] as i8)
.expect("server/client version mismatch or broken TCP"); .expect("server/client version mismatch or broken TCP");
tcp.internal.set_nonblocking(false).unwrap(); tcp.set_nonblocking(false);
match pt { match pt {
PacketType::NewClient => unreachable!(), PacketType::NewClient => unreachable!(),
@ -131,6 +132,6 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) {
let _ = sockets[idx].write_later(&buf[..len]); let _ = sockets[idx].write_later(&buf[..len]);
} }
} }
tcp.internal.set_nonblocking(true).unwrap(); tcp.set_nonblocking(true);
} }
} }

View file

@ -25,32 +25,53 @@ pub(crate) struct SocketAdapter {
pub(crate) internal: TcpStream, pub(crate) internal: TcpStream,
written: usize, written: usize,
to_write: usize, to_write: usize,
write: [u8; 1_048_576], // 1MiB write: [u8; 65536],
broken: Option<Broken>, broken: Option<Broken>,
wait_if_full: bool,
is_nonblocking: bool,
} }
impl SocketAdapter { impl SocketAdapter {
pub fn new(tcp: TcpStream) -> SocketAdapter { pub fn new(tcp: TcpStream, wait_if_full: bool) -> SocketAdapter {
Self { Self {
internal: tcp, internal: tcp,
written: 0, written: 0,
to_write: 0, to_write: 0,
write: [0u8; 1_048_576], write: [0u8; 65536],
broken: None, broken: None,
wait_if_full,
is_nonblocking: false,
} }
} }
pub fn set_nonblocking(&mut self, nonblocking: bool) {
if let Err(x) = self.internal.set_nonblocking(nonblocking) {
self.broken = Some(Broken::OsErr(x.raw_os_error().unwrap()));
return;
}
self.is_nonblocking = nonblocking;
}
pub fn write_later(&mut self, buf: &[u8]) -> Result<(), Error> { pub fn write_later(&mut self, buf: &[u8]) -> Result<(), Error> {
if let Some(ref x) = self.broken { if let Some(ref x) = self.broken {
return Err(Error::from(*x)); return Err(Error::from(*x));
} }
let lidx = self.to_write + self.written + buf.len(); let lidx = self.to_write + self.written + buf.len();
if lidx > self.write.len() && lidx - self.to_write < self.write.len() { if lidx > self.write.len() && self.to_write + buf.len() < self.write.len() {
self.write self.write
.copy_within(self.written..self.written + self.to_write, 0); .copy_within(self.written..self.written + self.to_write, 0);
self.written = 0; self.written = 0;
} }
let Some(x) = self.write.get_mut(self.to_write + self.written..self.to_write + self.written + buf.len()) else { let Some(x) = self.write.get_mut(self.to_write + self.written..self.to_write + self.written + buf.len()) else {
if self.wait_if_full {
self.internal.set_nonblocking(false)?;
self.internal.write_all(&self.write[self.written..self.written + self.to_write])?;
self.internal.set_nonblocking(self.is_nonblocking)?;
self.written = 0;
self.to_write = buf.len();
self.write[..buf.len()].copy_from_slice(buf);
return Ok(());
}
self.broken = Some(Broken::DirectErr(ErrorKind::TimedOut)); self.broken = Some(Broken::DirectErr(ErrorKind::TimedOut));
return Err(ErrorKind::TimedOut.into()); return Err(ErrorKind::TimedOut.into());
}; };
@ -61,12 +82,7 @@ impl SocketAdapter {
pub fn write(&mut self, buf: &[u8]) -> Result<(), Error> { pub fn write(&mut self, buf: &[u8]) -> Result<(), Error> {
self.write_later(buf)?; self.write_later(buf)?;
if let Err(x) = self.update() { self.update()
self.broken = Some(Broken::OsErr(x.raw_os_error().unwrap()));
Err(x)
} else {
Ok(())
}
} }
pub fn update(&mut self) -> Result<(), Error> { pub fn update(&mut self) -> Result<(), Error> {
@ -76,10 +92,14 @@ impl SocketAdapter {
if self.to_write == 0 { if self.to_write == 0 {
return Ok(()); return Ok(());
} }
match self match {
.internal self.internal.set_nonblocking(true)?;
.write(&self.write[self.written..self.written + self.to_write]) let r = self
{ .internal
.write(&self.write[self.written..self.written + self.to_write]);
self.internal.set_nonblocking(self.is_nonblocking)?;
r
} {
Ok(x) => { Ok(x) => {
self.to_write -= x; self.to_write -= x;
self.written += x; self.written += x;