Skip to content
189 changes: 188 additions & 1 deletion src/socket/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,30 @@ impl Display for ConnectError {
#[cfg(feature = "std")]
impl std::error::Error for ConnectError {}

/// Error returned by set_*
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum ArgumentError {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the creation of a new error type that is used only in this API follows our API design principles.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't expect a new error type for this API, too. However, I didn't find an appropriate existent error type to add enum variants. Do you have any suggestion for which of the existent {Send, Recv, Listen, Connect}Error should I add variants?

InvalidArgs,
InvalidState,
InsufficientResource,
}

impl Display for crate::socket::tcp::ArgumentError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
crate::socket::tcp::ArgumentError::InvalidArgs => write!(f, "invalid arguments by RFC"),
crate::socket::tcp::ArgumentError::InvalidState => write!(f, "invalid state"),
crate::socket::tcp::ArgumentError::InsufficientResource => {
write!(f, "insufficient runtime resource")
}
}
}
}

#[cfg(feature = "std")]
impl std::error::Error for crate::socket::tcp::ArgumentError {}

/// Error returned by [`Socket::send`]
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
Expand Down Expand Up @@ -774,6 +798,41 @@ impl<'a> Socket<'a> {
}
}

/// Return the local receive window scaling factor defined in [RFC 1323].
///
/// The value will become constant after the connection is established.
/// It may be reset to 0 during the handshake if remote side does not support window scaling.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like that this API has a bidirectional data flow, where a value set by the consumer is read by the TCP/IP stack, but also then is updated by the TCP/IP stack. Is there any precedent (whether in smoltcp or in BSD TCP/IP) to have an option like this? I would there to be two API entry points, one to request an option, one to see if it was applied.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In google's gVisor userspace TCP/IP stack, they have similar options. I believe the issue is that, smoltcp has its default, and we want to override this default. That's why we have such a bidirectional data flow. Our TCP/IP stack reads the value (after connect/accept), and sets the value if no appropriate default value is set (before connect/accept).

As for the "two API entry points", I don't quite get that. Could you give some examples?

pub fn local_recv_win_scale(&self) -> u8 {
self.remote_win_shift
}

/// Set the local receive window scaling factor defined in [RFC 1323].
///
/// The value will become constant after the connection is established.
/// It may be reset to 0 during the handshake if remote side does not support window scaling.
///
/// # Errors
/// `Err(ArgumentError::InvalidArgs)` if the scale is greater than 14.
/// `Err(ArgumentError::InvalidState)` if the socket is not in the `Closed` or `Listen` state.
/// `Err(ArgumentError::InsufficientResource)` if the receive buffer is smaller than (1<<scale) bytes.
pub fn set_local_recv_win_scale(&mut self, scale: u8) -> Result<(), ArgumentError> {
if scale > 14 {
return Err(ArgumentError::InvalidArgs);
}

if self.rx_buffer.capacity() < (1 << scale) as usize {
return Err(ArgumentError::InsufficientResource);
}

match self.state {
State::Closed | State::Listen => {
self.remote_win_shift = scale;
Ok(())
}
_ => Err(ArgumentError::InvalidState),
}
}

/// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
///
/// See also the [set_hop_limit](#method.set_hop_limit) method
Expand Down Expand Up @@ -828,6 +887,7 @@ impl<'a> Socket<'a> {
fn reset(&mut self) {
let rx_cap_log2 =
mem::size_of::<usize>() * 8 - self.rx_buffer.capacity().leading_zeros() as usize;
let new_rx_win_shift = rx_cap_log2.saturating_sub(16) as u8;

self.state = State::Closed;
self.timer = Timer::new();
Expand All @@ -845,7 +905,10 @@ impl<'a> Socket<'a> {
self.remote_last_win = 0;
self.remote_win_len = 0;
self.remote_win_scale = None;
self.remote_win_shift = rx_cap_log2.saturating_sub(16) as u8;
// keep user-specified window scaling across connect()/listen()
if self.remote_win_shift < new_rx_win_shift {
self.remote_win_shift = new_rx_win_shift;
}
self.remote_mss = DEFAULT_MSS;
self.remote_last_ts = None;
self.ack_delay_timer = AckDelayTimer::Idle;
Expand Down Expand Up @@ -2329,6 +2392,7 @@ impl<'a> Socket<'a> {
} else if self.timer.should_close(cx.now()) {
// If we have spent enough time in the TIME-WAIT state, close the socket.
tcp_trace!("TIME-WAIT timer expired");
self.remote_win_shift = 0;
self.reset();
return Ok(());
} else {
Expand Down Expand Up @@ -2601,6 +2665,63 @@ impl<'a> Socket<'a> {
.unwrap_or(&PollAt::Ingress)
}
}

/// Replace the receive buffer with a new one.
///
/// The requirements for the new buffer are:
/// 1. The new buffer must be larger than the length of remaining data in the current buffer
/// 2. The new buffer must be multiple of (1 << self.remote_win_shift)
///
/// If the new buffer does not meet the requirements, the new buffer is returned as an error;
/// otherwise, the old buffer is returned as an Ok value.
///
/// See also the [local_recv_win_scale](struct.Socket.html#method.local_recv_win_scale) methods.
pub fn replace_recv_buffer<T: Into<SocketBuffer<'a>>>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a generic method that has no direct correspondence to TCP receive window resizing and as such I don't see why it should be a part of this PR at all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method does enable the TCP receive window resizing. Smoltcp uses rx_buffer.capacity()-rx_buffer.len() to compute the current receive window. Without this method, it's impossible to adjust rx_buffer.capacity(). By replaced with a larger receive buffer, smoltcp knows it can advertise a larger receive window size safely without correctness issue.

&mut self,
new_buffer: T,
) -> Result<SocketBuffer<'a>, SocketBuffer<'a>> {
let mut replaced_buf = new_buffer.into();
/* Check if the new buffer is valid
* Requirements:
* 1. The new buffer must be larger than the length of remaining data in the current buffer
* 2. The new buffer must be multiple of (1 << self.remote_win_shift)
*/
if replaced_buf.capacity() < self.rx_buffer.len()
|| replaced_buf.capacity() % (1 << self.remote_win_shift) != 0
{
return Err(replaced_buf);
}
replaced_buf.clear();

// We should copy both allocated data and unallocated data (for assembler)
let allocated1 = self.rx_buffer.get_allocated(0, self.rx_buffer.len());
let l = replaced_buf.enqueue_slice(allocated1);
assert_eq!(l, allocated1.len());
if allocated1.len() < self.rx_buffer.len() {
let allocated2 = self
.rx_buffer
.get_allocated(allocated1.len(), self.rx_buffer.len() - allocated1.len());
let l = replaced_buf.enqueue_slice(allocated2);
assert_eq!(l, allocated2.len());
}

// make sure assembler can work properly
let unallocated1 = self.rx_buffer.get_unallocated(0, self.rx_buffer.window());
let unallocated1_len = unallocated1.len();
let l = replaced_buf.write_unallocated(0, unallocated1);
assert_eq!(l, unallocated1.len());
if unallocated1_len < self.rx_buffer.window() {
let unallocated2 = self
.rx_buffer
.get_unallocated(unallocated1_len, self.rx_buffer.window() - unallocated1_len);
let l = replaced_buf.write_unallocated(unallocated1_len, unallocated2);
assert_eq!(l, unallocated2.len());
}
assert_eq!(replaced_buf.len(), self.rx_buffer.len());

mem::swap(&mut self.rx_buffer, &mut replaced_buf);
Ok(replaced_buf)
}
}

impl<'a> fmt::Write for Socket<'a> {
Expand Down Expand Up @@ -8151,4 +8272,70 @@ mod test {
}]
);
}

// =========================================================================================//
// Tests for window scaling
// =========================================================================================//

fn socket_established_with_window_scaling() -> TestSocket {
let mut s = socket_established();
s.remote_win_shift = 10;
const BASE: usize = 1 << 10;
s.tx_buffer = SocketBuffer::new(vec![0u8; 64 * BASE]);
s.rx_buffer = SocketBuffer::new(vec![0u8; 64 * BASE]);
s
}

#[test]
fn test_too_large_window_scale() {
let mut socket = Socket::new(
SocketBuffer::new(vec![0; 8 * (1 << 15)]),
SocketBuffer::new(vec![0; 8 * (1 << 15)]),
);
assert!(socket.set_local_recv_win_scale(15).is_err())
}

#[test]
fn test_set_window_scale() {
let mut socket = Socket::new(
SocketBuffer::new(vec![0; 128]),
SocketBuffer::new(vec![0; 128]),
);
assert!(matches!(socket.state, State::Closed));
assert_eq!(socket.rx_buffer.capacity(), 128);
assert!(socket.set_local_recv_win_scale(6).is_ok());
assert!(socket.set_local_recv_win_scale(14).is_err());
assert_eq!(socket.local_recv_win_scale(), 6);
}

#[test]
fn test_set_scale_with_tcp_state() {
let mut socket = socket();
assert!(socket.set_local_recv_win_scale(1).is_ok());
let mut socket = socket_established();
assert!(socket.set_local_recv_win_scale(1).is_err());
let mut socket = socket_listen();
assert!(socket.set_local_recv_win_scale(1).is_ok());
let mut socket = socket_syn_received();
assert!(socket.set_local_recv_win_scale(1).is_err());
}

#[test]
fn test_resize_recv_buffer_invalid_size() {
let mut s = socket_established_with_window_scaling();
assert_eq!(s.rx_buffer.enqueue_slice(&[42; 31 * 1024]), 31 * 1024);
assert_eq!(s.rx_buffer.len(), 31 * 1024);
assert!(s
.replace_recv_buffer(SocketBuffer::new(vec![7u8; 32 * 1024 + 512]))
.is_err());
assert!(s
.replace_recv_buffer(SocketBuffer::new(vec![7u8; 16 * 1024]))
.is_err());
let old_buffer = s
.replace_recv_buffer(SocketBuffer::new(vec![7u8; 32 * 1024]))
.unwrap();
assert_eq!(old_buffer.capacity(), 64 * 1024);
assert_eq!(s.rx_buffer.len(), 31 * 1024);
assert_eq!(s.rx_buffer.capacity(), 32 * 1024);
}
}
Loading