diff --git a/src/gateway/nat.rs b/src/gateway/nat.rs index 0b81250..9872a70 100644 --- a/src/gateway/nat.rs +++ b/src/gateway/nat.rs @@ -1,6 +1,6 @@ use bimap::BiMap; use std::{ - io::Error, + io::{Error, ErrorKind}, net::{self, Ipv4Addr}, sync::Arc, time::Duration, @@ -9,10 +9,18 @@ use tokio::sync::{RwLock, mpsc::UnboundedSender}; use moka::future::Cache; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct SessionKey { + pub src_addr: Ipv4Addr, + pub dst_addr: Ipv4Addr, + pub src_port: u16, + pub dst_port: u16, +} + pub struct Nat { nat_type: Type, - cache: Cache>, - mapping: Arc>>, + cache: Cache>, + mapping: Arc>>, } pub enum Type { @@ -32,8 +40,8 @@ pub struct Session { impl Nat { pub fn new(nat_type: Type, tx: Option>) -> Self { let ttl = match nat_type { - Type::Tcp => Duration::from_secs(60), - Type::Udp => Duration::from_secs(20), + Type::Tcp => Duration::from_secs(600), + Type::Udp => Duration::from_secs(60), }; let mapping = Arc::new(RwLock::new(BiMap::new())); @@ -53,7 +61,12 @@ impl Nat { dst_addr: Ipv4Addr, dst_port: u16, ) -> Result { - let addr_key = u32::from_be_bytes(src_addr.octets()) + src_port as u32; + let addr_key = SessionKey { + src_addr, + dst_addr, + src_port, + dst_port, + }; if let Some(session) = self.cache.get(&addr_key).await { return Ok(*session); @@ -72,7 +85,23 @@ impl Nat { }); } - self.get_available_port()? + let mut assigned_port = 0; + for _ in 0..10 { + let port = self.get_available_port()?; + if !mapping.contains_right(&port) { + assigned_port = port; + break; + } + } + + if assigned_port == 0 { + return Err(Error::new( + ErrorKind::AddrInUse, + "No available NAT port via OS allocation", + )); + } + + assigned_port }; let session = Arc::new(Session { @@ -94,10 +123,18 @@ impl Nat { } pub async fn find(&self, nat_port: u16) -> Option { - if let Some(addr_key) = self.get_addr_key_by_port_fast(&nat_port).await - && let Some(session) = self.cache.get(&addr_key).await - { - return Some(*session); + if let Some(addr_key) = self.get_addr_key_by_port_fast(&nat_port).await { + if let Some(session) = self.cache.get(&addr_key).await { + return Some(*session); + } + + return Some(Session { + src_addr: addr_key.src_addr, + dst_addr: addr_key.dst_addr, + src_port: addr_key.src_port, + dst_port: addr_key.dst_port, + nat_port, + }); } None @@ -121,27 +158,29 @@ impl Nat { fn new_cache( ttl: Duration, - mapping: Arc>>, + mapping: Arc>>, tx: Option>, - ) -> Cache> { + ) -> Cache> { Cache::builder() .max_capacity(5000) .time_to_idle(ttl) - .eviction_listener(move |addr_key: Arc, session: Arc, _cause| { - let mapping = mapping.clone(); - let tx = tx.clone(); - tokio::task::spawn(async move { - let mut mapping_guard = mapping.write().await; - let _ = mapping_guard.remove_by_left(&*addr_key); - if let Some(ref tx) = tx { - let _ = tx.send(session.nat_port); - } - }); - }) + .eviction_listener( + move |addr_key: Arc, session: Arc, _cause| { + let mapping = mapping.clone(); + let tx = tx.clone(); + tokio::task::spawn(async move { + let mut mapping_guard = mapping.write().await; + let _ = mapping_guard.remove_by_left(&*addr_key); + if let Some(ref tx) = tx { + let _ = tx.send(session.nat_port); + } + }); + }, + ) .build() } - async fn get_addr_key_by_port_fast(&self, nat_port: &u16) -> Option { + async fn get_addr_key_by_port_fast(&self, nat_port: &u16) -> Option { let mapping = self.mapping.read().await; mapping.get_by_right(nat_port).copied() } diff --git a/src/gateway/relay_tcp.rs b/src/gateway/relay_tcp.rs index cbcead8..71952ea 100644 --- a/src/gateway/relay_tcp.rs +++ b/src/gateway/relay_tcp.rs @@ -7,7 +7,7 @@ use std::{ atomic::{AtomicU64, Ordering}, }, task::{Context, Poll}, - time::{Duration, SystemTime, UNIX_EPOCH}, + time::Duration, }; use anyhow::{Context as _, Result}; @@ -27,7 +27,7 @@ use super::{ }; const PROXY_CONNECT_TIMEOUT: Duration = Duration::from_secs(5); -const IDLE_TIMEOUT: Duration = Duration::from_secs(60); +const IDLE_TIMEOUT: Duration = Duration::from_secs(600); pub(crate) struct TcpRelay { runtime: ArcRuntime, @@ -109,20 +109,44 @@ async fn copy_with_idle_timeout( target_addr: &str, ) -> Result<()> { let tracker = Arc::new(SharedIdleTracker::new()); - - let mut timeout_client = IdleTimeoutStream::new(client, tracker.clone(), IDLE_TIMEOUT); - let mut timeout_proxy = IdleTimeoutStream::new(proxy, tracker, IDLE_TIMEOUT); - - match copy_bidirectional(&mut timeout_client, &mut timeout_proxy).await { - Ok((up, down)) => { - stats::update_metrics(runtime, Protocol::Tcp, proxy_name, target_addr, up, down); - Ok(()) + let bytes_to_client = Arc::new(AtomicU64::new(0)); + let bytes_to_proxy = Arc::new(AtomicU64::new(0)); + + let mut active_client = ActiveStream::new(client, tracker.clone(), bytes_to_client.clone()); + let mut active_proxy = ActiveStream::new(proxy, tracker.clone(), bytes_to_proxy.clone()); + + let copy_task = copy_bidirectional(&mut active_client, &mut active_proxy); + + let monitor_task = async { + loop { + let last = tracker.last_activity_instant(); + let deadline = last + IDLE_TIMEOUT; + tokio::time::sleep_until(deadline).await; + if tracker.last_activity_instant() <= last { + // No activity since we woke up (or last activity is older) -> idle timeout + break; + } } - Err(e) => { - debug!("TCP relay error: {}", e); - Ok(()) + }; + + let (up, down) = tokio::select! { + res = copy_task => { + match res { + Ok((up, down)) => (up, down), + Err(e) => { + debug!("TCP relay error or closed: {}", e); + (bytes_to_proxy.load(Ordering::Relaxed), bytes_to_client.load(Ordering::Relaxed)) + } + } } - } + _ = monitor_task => { + debug!("TCP relay idle timeout for {}", target_addr); + (bytes_to_proxy.load(Ordering::Relaxed), bytes_to_client.load(Ordering::Relaxed)) + } + }; + + stats::update_metrics(runtime, Protocol::Tcp, proxy_name, target_addr, up, down); + Ok(()) } async fn find_session_target( @@ -141,131 +165,86 @@ async fn find_session_target( } struct SharedIdleTracker { - last_activity: Arc, + base_instant: tokio::time::Instant, + last_activity_micros: Arc, } impl SharedIdleTracker { fn new() -> Self { - let now_millis = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_millis() as u64; - Self { - last_activity: Arc::new(AtomicU64::new(now_millis)), + base_instant: tokio::time::Instant::now(), + last_activity_micros: Arc::new(AtomicU64::new(0)), } } fn update_activity(&self) { - let now_millis = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_millis() as u64; - self.last_activity.store(now_millis, Ordering::Relaxed); - } - - fn elapsed(&self) -> Duration { - let last_millis = self.last_activity.load(Ordering::Relaxed); - let now_millis = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_millis() as u64; - Duration::from_millis(now_millis.saturating_sub(last_millis)) + let elapsed = self.base_instant.elapsed().as_micros() as u64; + self.last_activity_micros + .fetch_max(elapsed, Ordering::Relaxed); } - fn is_idle(&self, timeout: Duration) -> bool { - self.elapsed() > timeout + fn last_activity_instant(&self) -> tokio::time::Instant { + let micros = self.last_activity_micros.load(Ordering::Relaxed); + self.base_instant + Duration::from_micros(micros) } } -struct IdleTimeoutStream { +struct ActiveStream { inner: T, tracker: Arc, - timeout: Duration, + written_bytes: Arc, } -impl IdleTimeoutStream { - fn new(inner: T, tracker: Arc, timeout: Duration) -> Self { +impl ActiveStream { + fn new(inner: T, tracker: Arc, written_bytes: Arc) -> Self { Self { inner, tracker, - timeout, + written_bytes, } } fn update_activity(&self) { self.tracker.update_activity(); } - - fn check_idle(&self) -> tokio::io::Result<()> { - if self.tracker.is_idle(self.timeout) { - return Err(tokio::io::Error::new( - tokio::io::ErrorKind::TimedOut, - "idle timeout - no activity on either side", - )); - } - Ok(()) - } - - fn is_normal_close(e: &std::io::Error) -> bool { - matches!( - e.kind(), - std::io::ErrorKind::BrokenPipe - | std::io::ErrorKind::ConnectionReset - | std::io::ErrorKind::ConnectionAborted - | std::io::ErrorKind::UnexpectedEof - ) - } } -impl AsyncRead for IdleTimeoutStream { +impl AsyncRead for ActiveStream { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { - self.check_idle()?; - let initial_filled = buf.filled().len(); let poll = Pin::new(&mut self.inner).poll_read(cx, buf); - match &poll { - Poll::Ready(Ok(())) if buf.filled().len() > initial_filled => { + if let Poll::Ready(Ok(())) = &poll { + let n = buf.filled().len() - initial_filled; + if n > 0 { self.update_activity(); } - Poll::Ready(Err(e)) if Self::is_normal_close(e) => { - return Poll::Ready(Ok(())); - } - _ => {} } poll } } -impl AsyncWrite for IdleTimeoutStream { +impl AsyncWrite for ActiveStream { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - self.check_idle()?; - let poll = Pin::new(&mut self.inner).poll_write(cx, buf); - match poll { - Poll::Ready(Ok(n)) => { - if n > 0 { - self.update_activity(); - } - Poll::Ready(Ok(n)) - } - Poll::Ready(Err(e)) if Self::is_normal_close(&e) => { - // Treat normal close as successful write of all bytes - Poll::Ready(Ok(buf.len())) - } - _ => poll, + if let Poll::Ready(Ok(n)) = &poll + && *n > 0 + { + self.update_activity(); + self.written_bytes.fetch_add(*n as u64, Ordering::Relaxed); } + + poll } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { diff --git a/src/main.rs b/src/main.rs index e651278..988df9c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -#![feature(test)] +#![cfg_attr(test, feature(test))] #[macro_use] extern crate lazy_static;