diff --git a/Cargo.lock b/Cargo.lock index 6bebb6f..25ecf98 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -72,7 +72,7 @@ dependencies = [ [[package]] name = "revpfw3" -version = "0.2.0" +version = "0.2.1" dependencies = [ "enum-ordinalize", ] diff --git a/Cargo.toml b/Cargo.toml index fe8c2b1..63060d1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ name = "revpfw3" repository = "https://github.com/tudbut/revpfw3" description = "A tool to bypass portforwarding restrictions using some cheap VServer" license = "MIT" -version = "0.2.0" +version = "0.2.1" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/src/client.rs b/src/client.rs index c75e9be..3abc4fd 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,7 +4,7 @@ use std::{ net::{Shutdown, TcpStream}, thread, time::{Duration, SystemTime}, - vec, + vec, collections::HashMap, }; use crate::{io_sync, PacketType, SocketAdapter}; @@ -12,6 +12,7 @@ use crate::{io_sync, PacketType, SocketAdapter}; pub fn client(ip: &str, port: u16, dest_ip: &str, dest_port: u16, key: &str, sleep_delay_ms: u64) { let mut buf1 = [0u8; 1]; let mut buf4 = [0u8; 4]; + let mut buf8 = [0u8; 8]; let mut buf16 = [0u8; 16]; let mut buf = [0; 1024]; let mut tcp = TcpStream::connect((ip, port)).unwrap(); @@ -34,7 +35,8 @@ pub fn client(ip: &str, port: u16, dest_ip: &str, dest_port: u16, key: &str, sle let mut tcp = SocketAdapter::new(tcp); tcp.set_nonblocking(true); - let mut sockets: Vec = Vec::new(); + let mut sockets: HashMap = HashMap::new(); + let mut id = 0; let mut last_keep_alive = SystemTime::now(); loop { let mut did_anything = false; @@ -44,7 +46,7 @@ pub fn client(ip: &str, port: u16, dest_ip: &str, dest_port: u16, key: &str, sle } let mut to_remove = vec![]; - for (i, socket) in sockets.iter_mut().enumerate() { + for (&i, socket) in sockets.iter_mut() { if let Ok(x) = socket.poll(&mut buf) { if let Some(len) = x { if len == 0 { @@ -52,7 +54,7 @@ pub fn client(ip: &str, port: u16, dest_ip: &str, dest_port: u16, key: &str, sle } else { tcp.write(&[PacketType::ServerData.ordinal() as u8]) .unwrap(); - tcp.write(&(i as u32).to_be_bytes()).unwrap(); + tcp.write(&i.to_be_bytes()).unwrap(); tcp.write(&(len as u32).to_be_bytes()).unwrap(); tcp.write(&buf[..len]).unwrap(); } @@ -65,15 +67,17 @@ pub fn client(ip: &str, port: u16, dest_ip: &str, dest_port: u16, key: &str, sle if let x @ 1.. = socket.clear_delay() { tcp.write(&[PacketType::ClientExceededBuffer.ordinal() as u8]) .unwrap(); - tcp.write(&(i as u32).to_be_bytes()).unwrap(); + tcp.write(&i.to_be_bytes()).unwrap(); tcp.write(&x.to_be_bytes()).unwrap(); } } for i in to_remove.into_iter().rev() { tcp.write(&[PacketType::CloseClient.ordinal() as u8]) .unwrap(); - tcp.write(&(i as u32).to_be_bytes()).unwrap(); - let _ = sockets.remove(i).internal.shutdown(Shutdown::Both); + tcp.write(&i.to_be_bytes()).unwrap(); + if let Some(x) = sockets.remove(&i) { + let _ = x.internal.shutdown(Shutdown::Both); + } } tcp.update().unwrap(); @@ -94,15 +98,14 @@ pub fn client(ip: &str, port: u16, dest_ip: &str, dest_port: u16, key: &str, sle PacketType::NewClient => { let mut tcp = SocketAdapter::new(TcpStream::connect((dest_ip, dest_port)).unwrap()); tcp.set_nonblocking(true); - sockets.push(tcp); + sockets.insert((id, id += 1).0, tcp); } PacketType::CloseClient => { - tcp.internal.read_exact(&mut buf4).unwrap(); - let _ = sockets - .remove(u32::from_be_bytes(buf4) as usize) - .internal - .shutdown(Shutdown::Both); + tcp.internal.read_exact(&mut buf8).unwrap(); + if let Some(x) = sockets.remove(&u64::from_be_bytes(buf8)) { + let _ = x.internal.shutdown(Shutdown::Both); + } } PacketType::KeepAlive => { @@ -111,25 +114,28 @@ pub fn client(ip: &str, port: u16, dest_ip: &str, dest_port: u16, key: &str, sle } PacketType::ClientData => { - tcp.internal.read_exact(&mut buf4).unwrap(); - let idx = u32::from_be_bytes(buf4) as usize; + tcp.internal.read_exact(&mut buf8).unwrap(); + let idx = u64::from_be_bytes(buf8); tcp.internal.read_exact(&mut buf4).unwrap(); let len = u32::from_be_bytes(buf4) as usize; tcp.internal.read_exact(&mut buf[..len]).unwrap(); - let _ = sockets[idx].write_later(&buf[..len]); + if let Some(socket) = sockets.get_mut(&idx) { + let _ = socket.write_later(&buf[..len]); + } } PacketType::ServerData => unreachable!(), PacketType::ClientExceededBuffer => { - tcp.internal.read_exact(&mut buf4).unwrap(); - let idx = u32::from_be_bytes(buf4) as usize; + tcp.internal.read_exact(&mut buf8).unwrap(); + let idx = u64::from_be_bytes(buf8); tcp.internal.read_exact(&mut buf16).unwrap(); let amount = u128::from_be_bytes(buf16); - if sockets.len() != 1 { // a single connection doesn't need overuse-penalties - sockets[idx].punish(amount); + // a single connection doesn't need overuse-penalties + if let (true, Some(socket)) = (sockets.len() > 1, sockets.get_mut(&idx)) { + socket.punish(amount); } } } diff --git a/src/server.rs b/src/server.rs index 273338a..97d34fa 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,4 +1,5 @@ use std::{ + collections::HashMap, io::Read, io::Write, net::{Shutdown, TcpListener}, @@ -12,6 +13,7 @@ use crate::{io_sync, PacketType, SocketAdapter}; pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { let mut buf1 = [0u8; 1]; let mut buf4 = [0u8; 4]; + let mut buf8 = [0u8; 8]; let mut buf16 = [0u8; 16]; let mut buf = [0; 1024]; let tcpl = TcpListener::bind(("0.0.0.0", port)).unwrap(); @@ -43,7 +45,8 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { let mut tcp = SocketAdapter::new(tcp); tcp.set_nonblocking(true); - let mut sockets: Vec = Vec::new(); + let mut sockets: HashMap = HashMap::new(); + let mut id = 0; let mut last_keep_alive_sent = SystemTime::now(); let mut last_keep_alive = SystemTime::now(); loop { @@ -60,13 +63,13 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { if let Ok(new) = tcpl.accept() { let mut new = SocketAdapter::new(new.0); new.set_nonblocking(true); - sockets.push(new); + sockets.insert((id, id += 1).0, new); tcp.write(&[PacketType::NewClient.ordinal() as u8]).unwrap(); did_anything = true; } let mut to_remove = vec![]; - for (i, socket) in sockets.iter_mut().enumerate() { + for (&i, socket) in sockets.iter_mut() { if let Ok(x) = socket.poll(&mut buf) { if let Some(len) = x { if len == 0 { @@ -74,7 +77,7 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { } else { tcp.write(&[PacketType::ClientData.ordinal() as u8]) .unwrap(); - tcp.write(&(i as u32).to_be_bytes()).unwrap(); + tcp.write(&i.to_be_bytes()).unwrap(); tcp.write(&(len as u32).to_be_bytes()).unwrap(); tcp.write(&buf[..len]).unwrap(); } @@ -87,15 +90,17 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { if let x @ 1.. = socket.clear_delay() { tcp.write(&[PacketType::ClientExceededBuffer.ordinal() as u8]) .unwrap(); - tcp.write(&(i as u32).to_be_bytes()).unwrap(); + tcp.write(&i.to_be_bytes()).unwrap(); tcp.write(&x.to_be_bytes()).unwrap(); } } for i in to_remove.into_iter().rev() { tcp.write(&[PacketType::CloseClient.ordinal() as u8]) .unwrap(); - tcp.write(&(i as u32).to_be_bytes()).unwrap(); - let _ = sockets.remove(i).internal.shutdown(Shutdown::Both); + tcp.write(&i.to_be_bytes()).unwrap(); + if let Some(x) = sockets.remove(&i) { + let _ = x.internal.shutdown(Shutdown::Both); + } } tcp.update().unwrap(); @@ -116,11 +121,10 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { PacketType::NewClient => unreachable!(), PacketType::CloseClient => { - tcp.internal.read_exact(&mut buf4).unwrap(); - let _ = sockets - .remove(u32::from_be_bytes(buf4) as usize) - .internal - .shutdown(Shutdown::Both); + tcp.internal.read_exact(&mut buf8).unwrap(); + if let Some(x) = sockets.remove(&u64::from_be_bytes(buf8)) { + let _ = x.internal.shutdown(Shutdown::Both); + } } PacketType::KeepAlive => { @@ -130,23 +134,26 @@ pub fn server(port: u16, key: &str, sleep_delay_ms: u64) { PacketType::ClientData => unreachable!(), PacketType::ServerData => { - tcp.internal.read_exact(&mut buf4).unwrap(); - let idx = u32::from_be_bytes(buf4) as usize; + tcp.internal.read_exact(&mut buf8).unwrap(); + let idx = u64::from_be_bytes(buf8); tcp.internal.read_exact(&mut buf4).unwrap(); let len = u32::from_be_bytes(buf4) as usize; tcp.internal.read_exact(&mut buf[..len]).unwrap(); - let _ = sockets[idx].write_later(&buf[..len]); + if let Some(socket) = sockets.get_mut(&idx) { + let _ = socket.write_later(&buf[..len]); + } } PacketType::ClientExceededBuffer => { - tcp.internal.read_exact(&mut buf4).unwrap(); - let idx = u32::from_be_bytes(buf4) as usize; + tcp.internal.read_exact(&mut buf8).unwrap(); + let idx = u64::from_be_bytes(buf8); tcp.internal.read_exact(&mut buf16).unwrap(); let amount = u128::from_be_bytes(buf16); - if sockets.len() != 1 { // a single connection doesn't need overuse-penalties - sockets[idx].punish(amount); + // a single connection doesn't need overuse-penalties + if let (true, Some(socket)) = (sockets.len() > 1, sockets.get_mut(&idx)) { + socket.punish(amount); } } }