Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 64 additions & 25 deletions src/gateway/nat.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use bimap::BiMap;
use std::{
io::Error,
io::{Error, ErrorKind},
net::{self, Ipv4Addr},
sync::Arc,
time::Duration,
Expand All @@ -9,10 +9,18 @@ use tokio::sync::{RwLock, mpsc::UnboundedSender};

use moka::future::Cache;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct SessionKey {
pub src_addr: Ipv4Addr,
pub dst_addr: Ipv4Addr,
pub src_port: u16,
pub dst_port: u16,
}

pub struct Nat {
nat_type: Type,
cache: Cache<u32, Arc<Session>>,
mapping: Arc<RwLock<BiMap<u32, u16>>>,
cache: Cache<SessionKey, Arc<Session>>,
mapping: Arc<RwLock<BiMap<SessionKey, u16>>>,
}

pub enum Type {
Expand All @@ -32,8 +40,8 @@ pub struct Session {
impl Nat {
pub fn new(nat_type: Type, tx: Option<UnboundedSender<u16>>) -> Self {
let ttl = match nat_type {
Type::Tcp => Duration::from_secs(60),
Type::Udp => Duration::from_secs(20),
Type::Tcp => Duration::from_secs(600),
Type::Udp => Duration::from_secs(60),
};

let mapping = Arc::new(RwLock::new(BiMap::new()));
Expand All @@ -53,7 +61,12 @@ impl Nat {
dst_addr: Ipv4Addr,
dst_port: u16,
) -> Result<Session, Error> {
let addr_key = u32::from_be_bytes(src_addr.octets()) + src_port as u32;
let addr_key = SessionKey {
src_addr,
dst_addr,
src_port,
dst_port,
};
Comment on lines +64 to +69
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create() now keys sessions by the full 4-tuple (SessionKey), but the unit test only exercises a single tuple and doesn’t assert that different destinations (or ports) produce distinct sessions/ports. Consider extending the tests to create at least two different SessionKey values and assert they don’t alias (and that find() resolves each nat_port correctly).

Copilot uses AI. Check for mistakes.

if let Some(session) = self.cache.get(&addr_key).await {
return Ok(*session);
Expand All @@ -72,7 +85,23 @@ impl Nat {
});
}

self.get_available_port()?
let mut assigned_port = 0;
for _ in 0..10 {
let port = self.get_available_port()?;
if !mapping.contains_right(&port) {
assigned_port = port;
break;
}
}
Comment on lines +88 to +95
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create() chooses a NAT port while holding only a read-lock on mapping, then later inserts under a separate write-lock. Two concurrent create() calls can still pick the same port (OS can reuse ephemeral ports quickly), and BiMap::insert() will typically evict/replace an existing entry when the right-side value already exists, potentially breaking an existing session mapping. Consider selecting+reserving the port under a write-lock (or re-checking and retrying under the write-lock right before insert) so port assignment is race-free.

Copilot uses AI. Check for mistakes.

if assigned_port == 0 {
return Err(Error::new(
ErrorKind::AddrInUse,
"No available NAT port via OS allocation",
));
}

assigned_port
};

let session = Arc::new(Session {
Expand All @@ -94,10 +123,18 @@ impl Nat {
}

pub async fn find(&self, nat_port: u16) -> Option<Session> {
if let Some(addr_key) = self.get_addr_key_by_port_fast(&nat_port).await
&& let Some(session) = self.cache.get(&addr_key).await
{
return Some(*session);
if let Some(addr_key) = self.get_addr_key_by_port_fast(&nat_port).await {
if let Some(session) = self.cache.get(&addr_key).await {
return Some(*session);
}

return Some(Session {
src_addr: addr_key.src_addr,
dst_addr: addr_key.dst_addr,
src_port: addr_key.src_port,
dst_port: addr_key.dst_port,
nat_port,
});
Comment on lines 125 to +137
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

find() now reconstructs a Session from SessionKey when the cache misses but the mapping still contains the NAT port. This new behavior isn’t exercised by the current tests (they only hit the cache-present path). Consider adding a unit test that forces a cache miss (e.g., invalidate the entry) while keeping the mapping present, and asserts find() still returns the expected session.

Copilot uses AI. Check for mistakes.
}

None
Expand All @@ -121,27 +158,29 @@ impl Nat {

fn new_cache(
ttl: Duration,
mapping: Arc<RwLock<BiMap<u32, u16>>>,
mapping: Arc<RwLock<BiMap<SessionKey, u16>>>,
tx: Option<UnboundedSender<u16>>,
) -> Cache<u32, Arc<Session>> {
) -> Cache<SessionKey, Arc<Session>> {
Cache::builder()
.max_capacity(5000)
.time_to_idle(ttl)
.eviction_listener(move |addr_key: Arc<u32>, session: Arc<Session>, _cause| {
let mapping = mapping.clone();
let tx = tx.clone();
tokio::task::spawn(async move {
let mut mapping_guard = mapping.write().await;
let _ = mapping_guard.remove_by_left(&*addr_key);
if let Some(ref tx) = tx {
let _ = tx.send(session.nat_port);
}
});
})
.eviction_listener(
move |addr_key: Arc<SessionKey>, session: Arc<Session>, _cause| {
let mapping = mapping.clone();
let tx = tx.clone();
tokio::task::spawn(async move {
let mut mapping_guard = mapping.write().await;
let _ = mapping_guard.remove_by_left(&*addr_key);
if let Some(ref tx) = tx {
let _ = tx.send(session.nat_port);
}
});
},
)
.build()
}

async fn get_addr_key_by_port_fast(&self, nat_port: &u16) -> Option<u32> {
async fn get_addr_key_by_port_fast(&self, nat_port: &u16) -> Option<SessionKey> {
let mapping = self.mapping.read().await;
mapping.get_by_right(nat_port).copied()
}
Expand Down
151 changes: 65 additions & 86 deletions src/gateway/relay_tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{
atomic::{AtomicU64, Ordering},
},
task::{Context, Poll},
time::{Duration, SystemTime, UNIX_EPOCH},
time::Duration,
};

use anyhow::{Context as _, Result};
Expand All @@ -27,7 +27,7 @@ use super::{
};

const PROXY_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
const IDLE_TIMEOUT: Duration = Duration::from_secs(60);
const IDLE_TIMEOUT: Duration = Duration::from_secs(600);

pub(crate) struct TcpRelay {
runtime: ArcRuntime,
Expand Down Expand Up @@ -109,20 +109,44 @@ async fn copy_with_idle_timeout(
target_addr: &str,
) -> Result<()> {
let tracker = Arc::new(SharedIdleTracker::new());

let mut timeout_client = IdleTimeoutStream::new(client, tracker.clone(), IDLE_TIMEOUT);
let mut timeout_proxy = IdleTimeoutStream::new(proxy, tracker, IDLE_TIMEOUT);

match copy_bidirectional(&mut timeout_client, &mut timeout_proxy).await {
Ok((up, down)) => {
stats::update_metrics(runtime, Protocol::Tcp, proxy_name, target_addr, up, down);
Ok(())
let bytes_to_client = Arc::new(AtomicU64::new(0));
let bytes_to_proxy = Arc::new(AtomicU64::new(0));

let mut active_client = ActiveStream::new(client, tracker.clone(), bytes_to_client.clone());
let mut active_proxy = ActiveStream::new(proxy, tracker.clone(), bytes_to_proxy.clone());

let copy_task = copy_bidirectional(&mut active_client, &mut active_proxy);

let monitor_task = async {
loop {
let last = tracker.last_activity_instant();
let deadline = last + IDLE_TIMEOUT;
tokio::time::sleep_until(deadline).await;
if tracker.last_activity_instant() <= last {
// No activity since we woke up (or last activity is older) -> idle timeout
break;
}
}
Err(e) => {
debug!("TCP relay error: {}", e);
Ok(())
};

let (up, down) = tokio::select! {
res = copy_task => {
match res {
Ok((up, down)) => (up, down),
Err(e) => {
debug!("TCP relay error or closed: {}", e);
(bytes_to_proxy.load(Ordering::Relaxed), bytes_to_client.load(Ordering::Relaxed))
}
}
}
}
_ = monitor_task => {
debug!("TCP relay idle timeout for {}", target_addr);
(bytes_to_proxy.load(Ordering::Relaxed), bytes_to_client.load(Ordering::Relaxed))
}
};

stats::update_metrics(runtime, Protocol::Tcp, proxy_name, target_addr, up, down);
Ok(())
}

async fn find_session_target(
Expand All @@ -141,131 +165,86 @@ async fn find_session_target(
}

struct SharedIdleTracker {
last_activity: Arc<AtomicU64>,
base_instant: tokio::time::Instant,
last_activity_micros: Arc<AtomicU64>,
}

impl SharedIdleTracker {
fn new() -> Self {
let now_millis = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64;

Self {
last_activity: Arc::new(AtomicU64::new(now_millis)),
base_instant: tokio::time::Instant::now(),
last_activity_micros: Arc::new(AtomicU64::new(0)),
}
}

fn update_activity(&self) {
let now_millis = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
self.last_activity.store(now_millis, Ordering::Relaxed);
}

fn elapsed(&self) -> Duration {
let last_millis = self.last_activity.load(Ordering::Relaxed);
let now_millis = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
Duration::from_millis(now_millis.saturating_sub(last_millis))
let elapsed = self.base_instant.elapsed().as_micros() as u64;
self.last_activity_micros
.fetch_max(elapsed, Ordering::Relaxed);
Comment on lines +182 to +183
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SharedIdleTracker::update_activity stores elapsed time in microseconds and monitor_task treats last_activity_instant() not advancing as idle. If activity occurs within the same microsecond tick (e.g., very soon after tracker creation), fetch_max(elapsed) may not advance the value, and the connection can be incorrectly timed out after IDLE_TIMEOUT. Consider storing higher-resolution time (e.g., nanoseconds) and/or ensuring each activity update monotonically increments the stored value even when the timestamp quantizes to the same unit.

Suggested change
self.last_activity_micros
.fetch_max(elapsed, Ordering::Relaxed);
let mut current = self.last_activity_micros.load(Ordering::Relaxed);
loop {
let new_value = if elapsed > current {
elapsed
} else {
current.wrapping_add(1)
};
match self.last_activity_micros.compare_exchange(
current,
new_value,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(actual) => {
current = actual;
}
}
}

Copilot uses AI. Check for mistakes.
}

fn is_idle(&self, timeout: Duration) -> bool {
self.elapsed() > timeout
fn last_activity_instant(&self) -> tokio::time::Instant {
let micros = self.last_activity_micros.load(Ordering::Relaxed);
self.base_instant + Duration::from_micros(micros)
}
}

struct IdleTimeoutStream<T> {
struct ActiveStream<T> {
inner: T,
tracker: Arc<SharedIdleTracker>,
timeout: Duration,
written_bytes: Arc<AtomicU64>,
}

impl<T> IdleTimeoutStream<T> {
fn new(inner: T, tracker: Arc<SharedIdleTracker>, timeout: Duration) -> Self {
impl<T> ActiveStream<T> {
fn new(inner: T, tracker: Arc<SharedIdleTracker>, written_bytes: Arc<AtomicU64>) -> Self {
Self {
inner,
tracker,
timeout,
written_bytes,
}
}

fn update_activity(&self) {
self.tracker.update_activity();
}

fn check_idle(&self) -> tokio::io::Result<()> {
if self.tracker.is_idle(self.timeout) {
return Err(tokio::io::Error::new(
tokio::io::ErrorKind::TimedOut,
"idle timeout - no activity on either side",
));
}
Ok(())
}

fn is_normal_close(e: &std::io::Error) -> bool {
matches!(
e.kind(),
std::io::ErrorKind::BrokenPipe
| std::io::ErrorKind::ConnectionReset
| std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::UnexpectedEof
)
}
}

impl<T: AsyncRead + Unpin> AsyncRead for IdleTimeoutStream<T> {
impl<T: AsyncRead + Unpin> AsyncRead for ActiveStream<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
self.check_idle()?;

let initial_filled = buf.filled().len();
let poll = Pin::new(&mut self.inner).poll_read(cx, buf);

match &poll {
Poll::Ready(Ok(())) if buf.filled().len() > initial_filled => {
if let Poll::Ready(Ok(())) = &poll {
let n = buf.filled().len() - initial_filled;
if n > 0 {
self.update_activity();
}
Poll::Ready(Err(e)) if Self::is_normal_close(e) => {
return Poll::Ready(Ok(()));
}
_ => {}
}

poll
}
}

impl<T: AsyncWrite + Unpin> AsyncWrite for IdleTimeoutStream<T> {
impl<T: AsyncWrite + Unpin> AsyncWrite for ActiveStream<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
self.check_idle()?;

let poll = Pin::new(&mut self.inner).poll_write(cx, buf);

match poll {
Poll::Ready(Ok(n)) => {
if n > 0 {
self.update_activity();
}
Poll::Ready(Ok(n))
}
Poll::Ready(Err(e)) if Self::is_normal_close(&e) => {
// Treat normal close as successful write of all bytes
Poll::Ready(Ok(buf.len()))
}
_ => poll,
if let Poll::Ready(Ok(n)) = &poll
&& *n > 0
{
self.update_activity();
self.written_bytes.fetch_add(*n as u64, Ordering::Relaxed);
}

poll
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#![feature(test)]
#![cfg_attr(test, feature(test))]

#[macro_use]
extern crate lazy_static;
Expand Down
Loading