migrate away from directly using internal connections, fix unreliable read_exact

This commit is contained in:
Daniella / Tove 2023-10-04 17:02:49 +02:00
parent 63287f17fd
commit 0d197f885b
Signed by: TudbuT
GPG key ID: 7D63D5634B7C417F
4 changed files with 44 additions and 60 deletions

View file

@ -10,7 +10,7 @@ use std::{
use serial::SerialPort;
use crate::{io_sync, Connection, PacketType, SocketAdapter};
use crate::{Connection, PacketType, SocketAdapter};
pub struct ClientParams<'a> {
pub server_ip: &'a str,
@ -79,7 +79,6 @@ fn connect(params: &ClientParams) -> Connection {
fn resync(tcp: &mut SocketAdapter) {
println!();
eprintln!("Server version mismatch or broken connection. Re-syncing in case of the latter...");
tcp.set_nonblocking(true);
tcp.internal.set_print(false);
tcp.write_now().unwrap();
tcp.write(&[PacketType::Resync.ordinal() as u8]).unwrap();
@ -94,13 +93,10 @@ fn resync(tcp: &mut SocketAdapter) {
thread::sleep(Duration::from_secs(5));
// read all packets that are still pending.
while let Some(Some(_x @ 1..)) = tcp.poll(&mut buf).ok() {}
// server should now have stopped sending packets. waiting 5 more seconds so the server has time to
// send the resync packet.
thread::sleep(Duration::from_secs(5));
// server should now have stopped sending packets.
let mut buf = [0];
eprintln!("Trying to receive the resync echo...");
tcp.set_nonblocking(false);
tcp.poll(&mut buf).unwrap();
tcp.read_now(&mut buf).unwrap();
if buf[0] as i8 == PacketType::ResyncEcho.ordinal() {
eprintln!("Successfully resynced. RevPFW3 can continue.");
} else {
@ -141,7 +137,6 @@ pub fn client(params: ClientParams) {
let mut id = 0;
let mut last_keep_alive = SystemTime::now();
loop {
tcp.set_nonblocking(true);
thread::sleep(Duration::from_millis(params.rate_limit_sleep));
let mut did_anything = false;
@ -186,10 +181,7 @@ pub fn client(params: ClientParams) {
}
tcp.update().unwrap();
if io_sync(tcp.internal.read_exact(&mut buf1))
.unwrap()
.is_none()
{
if tcp.poll_exact(&mut buf1).unwrap().is_none() {
if !did_anything {
thread::sleep(Duration::from_millis(params.sleep_delay_ms));
}
@ -200,19 +192,17 @@ pub fn client(params: ClientParams) {
resync(&mut tcp);
continue;
};
tcp.set_nonblocking(false);
match pt {
PacketType::NewClient => {
let mut tcp = SocketAdapter::new(Connection::new_tcp(
let tcp = SocketAdapter::new(Connection::new_tcp(
TcpStream::connect((params.dest_ip, params.dest_port)).unwrap(),
false,
));
tcp.set_nonblocking(true);
sockets.insert((id, id += 1).0, tcp);
}
PacketType::CloseClient => {
tcp.internal.read_exact(&mut buf8).unwrap();
tcp.read_now(&mut buf8).unwrap();
if let Some(x) = sockets.remove(&u64::from_be_bytes(buf8)) {
let _ = x.internal.close();
}
@ -224,11 +214,11 @@ pub fn client(params: ClientParams) {
}
PacketType::ClientData => {
tcp.internal.read_exact(&mut buf8).unwrap();
tcp.read_now(&mut buf8).unwrap();
let idx = u64::from_be_bytes(buf8);
tcp.internal.read_exact(&mut buf4).unwrap();
tcp.read_now(&mut buf4).unwrap();
let len = u32::from_be_bytes(buf4) as usize;
tcp.internal.read_exact(&mut buf[..len]).unwrap();
tcp.read_now(&mut buf[..len]).unwrap();
if let Some(socket) = sockets.get_mut(&idx) {
let _ = socket.write_later(&buf[..len]);
@ -238,9 +228,9 @@ pub fn client(params: ClientParams) {
PacketType::ServerData => resync(&mut tcp),
PacketType::ClientExceededBuffer => {
tcp.internal.read_exact(&mut buf8).unwrap();
tcp.read_now(&mut buf8).unwrap();
let idx = u64::from_be_bytes(buf8);
tcp.internal.read_exact(&mut buf16).unwrap();
tcp.read_now(&mut buf16).unwrap();
let amount = u128::from_be_bytes(buf16);
// a single connection doesn't need overuse-penalties
@ -260,6 +250,5 @@ pub fn client(params: ClientParams) {
// this one shouldnt happen.
PacketType::ResyncEcho => resync(&mut tcp),
}
tcp.set_nonblocking(true);
}
}

View file

@ -48,17 +48,18 @@ impl Read for Connection {
}
fn read_exact(&mut self, mut buf: &mut [u8]) -> io::Result<()> {
let len = buf.len();
while !buf.is_empty() {
match self.read(buf) {
Ok(0) if self.is_nb => {
Ok(0) if self.is_nb && buf.len() == len => {
return Err(io::Error::new(ErrorKind::WouldBlock, "would block"))
}
Ok(0) => (),
Ok(0) => break,
Ok(n) => {
let tmp = buf;
buf = &mut tmp[n..];
}
Err(ref e) if e.kind() == ErrorKind::Interrupted => (),
Err(ref e) if e.kind() == ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}

View file

@ -8,12 +8,11 @@ use std::{
vec,
};
use crate::{io_sync, Connection, PacketType, SocketAdapter};
use crate::{Connection, PacketType, SocketAdapter};
fn resync(tcp: &mut SocketAdapter) {
println!();
eprintln!("Client version mismatch or broken connection. Re-syncing in case of the latter...");
tcp.set_nonblocking(true);
tcp.internal.set_print(false);
tcp.write_now().unwrap();
tcp.write(&[PacketType::Resync.ordinal() as u8]).unwrap();
@ -28,9 +27,7 @@ fn resync(tcp: &mut SocketAdapter) {
thread::sleep(Duration::from_secs(5));
// read all packets that are still pending.
while let Some(Some(_x @ 1..)) = tcp.poll(&mut buf).ok() {}
// server should now have stopped sending packets. waiting 5 more seconds so the client has time to
// send the resync packet.
thread::sleep(Duration::from_secs(5));
// client should now have stopped sending packets.
tcp.internal.set_print(true);
}
@ -72,7 +69,6 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) {
let mut last_keep_alive_sent = SystemTime::now();
let mut last_keep_alive = SystemTime::now();
loop {
tcp.set_nonblocking(true);
let mut did_anything = false;
if last_keep_alive_sent.elapsed().unwrap().as_secs() >= 10 {
@ -84,8 +80,7 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) {
}
if let Ok(new) = tcpl.accept() {
let mut new = SocketAdapter::new(Connection::new_tcp(new.0, false));
new.set_nonblocking(true);
let new = SocketAdapter::new(Connection::new_tcp(new.0, false));
sockets.insert((id, id += 1).0, new);
tcp.write(&[PacketType::NewClient.ordinal() as u8]).unwrap();
did_anything = true;
@ -128,10 +123,7 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) {
}
tcp.update().unwrap();
if io_sync(tcp.internal.read_exact(&mut buf1))
.unwrap()
.is_none()
{
if tcp.poll_exact(&mut buf1).unwrap().is_none() {
if !did_anything {
thread::sleep(Duration::from_millis(sleep_delay_ms));
}
@ -142,12 +134,11 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) {
resync(&mut tcp);
continue;
};
tcp.set_nonblocking(false);
match pt {
PacketType::NewClient => resync(&mut tcp),
PacketType::CloseClient => {
tcp.internal.read_exact(&mut buf8).unwrap();
tcp.read_now(&mut buf8).unwrap();
if let Some(x) = sockets.remove(&u64::from_be_bytes(buf8)) {
let _ = x.internal.close();
}
@ -160,11 +151,11 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) {
PacketType::ClientData => resync(&mut tcp),
PacketType::ServerData => {
tcp.internal.read_exact(&mut buf8).unwrap();
tcp.read_now(&mut buf8).unwrap();
let idx = u64::from_be_bytes(buf8);
tcp.internal.read_exact(&mut buf4).unwrap();
tcp.read_now(&mut buf4).unwrap();
let len = u32::from_be_bytes(buf4) as usize;
tcp.internal.read_exact(&mut buf[..len]).unwrap();
tcp.read_now(&mut buf[..len]).unwrap();
if let Some(socket) = sockets.get_mut(&idx) {
let _ = socket.write_later(&buf[..len]);
@ -172,9 +163,9 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) {
}
PacketType::ClientExceededBuffer => {
tcp.internal.read_exact(&mut buf8).unwrap();
tcp.read_now(&mut buf8).unwrap();
let idx = u64::from_be_bytes(buf8);
tcp.internal.read_exact(&mut buf16).unwrap();
tcp.read_now(&mut buf16).unwrap();
let amount = u128::from_be_bytes(buf16);
// a single connection doesn't need overuse-penalties
@ -201,6 +192,5 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) {
// this one can't happen, it should only come from the server
PacketType::ResyncEcho => resync(&mut tcp),
}
tcp.set_nonblocking(true);
}
}

View file

@ -8,14 +8,12 @@ use crate::{io_sync, Connection};
#[derive(Clone, Copy)]
enum Broken {
OsErr(i32),
DirectErr(ErrorKind, &'static str),
}
impl From<Broken> for Error {
fn from(value: Broken) -> Self {
match value {
Broken::OsErr(x) => Error::from_raw_os_error(x),
Broken::DirectErr(x, s) => Error::new(x, s),
}
}
@ -28,7 +26,6 @@ pub(crate) struct SocketAdapter {
write: [u8; 65536],
broken: Option<Broken>,
accumulated_delay: u128,
is_nonblocking: bool,
ignore_until: Option<u128>,
}
@ -41,19 +38,10 @@ impl SocketAdapter {
write: [0u8; 65536],
broken: None,
accumulated_delay: 0,
is_nonblocking: false,
ignore_until: None,
}
}
pub fn set_nonblocking(&mut self, nonblocking: bool) {
if let Err(x) = self.internal.set_nonblocking(nonblocking) {
self.broken = Some(Broken::OsErr(x.raw_os_error().unwrap()));
return;
}
self.is_nonblocking = nonblocking;
}
pub fn write_later(&mut self, buf: &[u8]) -> Result<(), Error> {
if let Some(ref x) = self.broken {
return Err(Error::from(*x));
@ -72,7 +60,6 @@ impl SocketAdapter {
self.internal.set_nonblocking(false)?;
self.internal
.write_all(&self.write[self.written..self.written + self.to_write])?;
self.internal.set_nonblocking(self.is_nonblocking)?;
self.written = 0;
self.to_write = buf.len();
self.write[..buf.len()].copy_from_slice(buf);
@ -101,7 +88,6 @@ impl SocketAdapter {
let r = self
.internal
.write_all(&self.write[self.written..self.written + self.to_write]);
self.internal.set_nonblocking(self.is_nonblocking)?;
r
} {
Ok(()) => {
@ -131,7 +117,6 @@ impl SocketAdapter {
let r = self
.internal
.write(&self.write[self.written..self.written + self.to_write]);
self.internal.set_nonblocking(self.is_nonblocking)?;
r
} {
Ok(x) => {
@ -150,11 +135,30 @@ impl SocketAdapter {
}
}
pub fn read_now(&mut self, buf: &mut [u8]) -> Result<Option<()>, Error> {
if Some(SystemTime::UNIX_EPOCH.elapsed().unwrap().as_micros()) < self.ignore_until {
return Ok(None);
}
self.update()?;
self.internal.set_nonblocking(false)?;
io_sync(self.internal.read_exact(buf))
}
pub fn poll_exact(&mut self, buf: &mut [u8]) -> Result<Option<()>, Error> {
if Some(SystemTime::UNIX_EPOCH.elapsed().unwrap().as_micros()) < self.ignore_until {
return Ok(None);
}
self.update()?;
self.internal.set_nonblocking(true)?;
io_sync(self.internal.read_exact(buf))
}
pub fn poll(&mut self, buf: &mut [u8]) -> Result<Option<usize>, Error> {
if Some(SystemTime::UNIX_EPOCH.elapsed().unwrap().as_micros()) < self.ignore_until {
return Ok(None);
}
self.update()?;
self.internal.set_nonblocking(true)?;
io_sync(self.internal.read(buf))
}