make central connection unable to overflow, decrease buffer size

This commit is contained in:
Daniella / Tove 2023-01-31 22:06:55 +01:00
parent c3fd53bfbf
commit 1792c2b8d3
3 changed files with 48 additions and 28 deletions

View file

@ -31,9 +31,8 @@ pub fn client(ip: &str, port: u16, dest_ip: &str, dest_port: u16, key: &str, sle
println!("READY!");
tcp.set_nonblocking(true).unwrap();
let mut tcp = SocketAdapter::new(tcp);
let mut tcp = SocketAdapter::new(tcp, true);
tcp.set_nonblocking(true);
let mut sockets: Vec<SocketAdapter> = Vec::new();
let mut last_keep_alive = SystemTime::now();
loop {
@ -83,12 +82,12 @@ pub fn client(ip: &str, port: u16, dest_ip: &str, dest_port: u16, key: &str, sle
let pt = PacketType::from_ordinal(buf1[0] as i8)
.expect("server/client version mismatch or broken TCP");
tcp.internal.set_nonblocking(false).unwrap();
tcp.set_nonblocking(false);
match pt {
PacketType::NewClient => {
let tcp = TcpStream::connect((dest_ip, dest_port)).unwrap();
tcp.set_nonblocking(true).unwrap();
sockets.push(SocketAdapter::new(tcp));
let mut tcp = SocketAdapter::new(TcpStream::connect((dest_ip, dest_port)).unwrap(), false);
tcp.set_nonblocking(true);
sockets.push(tcp);
}
PacketType::CloseClient => {
@ -116,6 +115,6 @@ pub fn client(ip: &str, port: u16, dest_ip: &str, dest_port: u16, key: &str, sle
PacketType::ServerData => unreachable!(),
}
tcp.internal.set_nonblocking(true).unwrap();
tcp.set_nonblocking(true);
}
}

View file

@ -38,10 +38,10 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) {
tcp.write_all(&mut ['R' as u8, 'P' as u8, 'F' as u8, 30])
.unwrap();
tcp.set_nonblocking(true).unwrap();
tcpl.set_nonblocking(true).unwrap();
let mut tcp = SocketAdapter::new(tcp);
let mut tcp = SocketAdapter::new(tcp, true);
tcp.set_nonblocking(true);
let mut sockets: Vec<SocketAdapter> = Vec::new();
let mut last_keep_alive_sent = SystemTime::now();
let mut last_keep_alive = SystemTime::now();
@ -57,8 +57,9 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) {
}
if let Ok(new) = tcpl.accept() {
new.0.set_nonblocking(true).unwrap();
sockets.push(SocketAdapter::new(new.0));
let mut new = SocketAdapter::new(new.0, false);
new.set_nonblocking(true);
sockets.push(new);
tcp.write(&[PacketType::NewClient.ordinal() as u8]).unwrap();
did_anything = true;
}
@ -103,7 +104,7 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) {
let pt = PacketType::from_ordinal(buf1[0] as i8)
.expect("server/client version mismatch or broken TCP");
tcp.internal.set_nonblocking(false).unwrap();
tcp.set_nonblocking(false);
match pt {
PacketType::NewClient => unreachable!(),
@ -131,6 +132,6 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) {
let _ = sockets[idx].write_later(&buf[..len]);
}
}
tcp.internal.set_nonblocking(true).unwrap();
tcp.set_nonblocking(true);
}
}

View file

@ -25,32 +25,53 @@ pub(crate) struct SocketAdapter {
pub(crate) internal: TcpStream,
written: usize,
to_write: usize,
write: [u8; 1_048_576], // 1MiB
write: [u8; 65536],
broken: Option<Broken>,
wait_if_full: bool,
is_nonblocking: bool,
}
impl SocketAdapter {
pub fn new(tcp: TcpStream) -> SocketAdapter {
pub fn new(tcp: TcpStream, wait_if_full: bool) -> SocketAdapter {
Self {
internal: tcp,
written: 0,
to_write: 0,
write: [0u8; 1_048_576],
write: [0u8; 65536],
broken: None,
wait_if_full,
is_nonblocking: false,
}
}
pub fn set_nonblocking(&mut self, nonblocking: bool) {
if let Err(x) = self.internal.set_nonblocking(nonblocking) {
self.broken = Some(Broken::OsErr(x.raw_os_error().unwrap()));
return;
}
self.is_nonblocking = nonblocking;
}
pub fn write_later(&mut self, buf: &[u8]) -> Result<(), Error> {
if let Some(ref x) = self.broken {
return Err(Error::from(*x));
}
let lidx = self.to_write + self.written + buf.len();
if lidx > self.write.len() && lidx - self.to_write < self.write.len() {
if lidx > self.write.len() && self.to_write + buf.len() < self.write.len() {
self.write
.copy_within(self.written..self.written + self.to_write, 0);
self.written = 0;
}
let Some(x) = self.write.get_mut(self.to_write + self.written..self.to_write + self.written + buf.len()) else {
if self.wait_if_full {
self.internal.set_nonblocking(false)?;
self.internal.write_all(&self.write[self.written..self.written + self.to_write])?;
self.internal.set_nonblocking(self.is_nonblocking)?;
self.written = 0;
self.to_write = buf.len();
self.write[..buf.len()].copy_from_slice(buf);
return Ok(());
}
self.broken = Some(Broken::DirectErr(ErrorKind::TimedOut));
return Err(ErrorKind::TimedOut.into());
};
@ -61,12 +82,7 @@ impl SocketAdapter {
pub fn write(&mut self, buf: &[u8]) -> Result<(), Error> {
self.write_later(buf)?;
if let Err(x) = self.update() {
self.broken = Some(Broken::OsErr(x.raw_os_error().unwrap()));
Err(x)
} else {
Ok(())
}
self.update()
}
pub fn update(&mut self) -> Result<(), Error> {
@ -76,10 +92,14 @@ impl SocketAdapter {
if self.to_write == 0 {
return Ok(());
}
match self
match {
self.internal.set_nonblocking(true)?;
let r = self
.internal
.write(&self.write[self.written..self.written + self.to_write])
{
.write(&self.write[self.written..self.written + self.to_write]);
self.internal.set_nonblocking(self.is_nonblocking)?;
r
} {
Ok(x) => {
self.to_write -= x;
self.written += x;