Skip to content

Commit e000f15

Browse files
authored
Merge pull request #1067 from KingCol13/raw-socket-optional-ip-version-protocol
raw: optional IP version and next_header
2 parents a68620f + 9779059 commit e000f15

File tree

3 files changed

+134
-32
lines changed

3 files changed

+134
-32
lines changed

examples/multicast.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ fn main() {
6666
// Will not send IGMP
6767
let raw_tx_buffer = raw::PacketBuffer::new(vec![], vec![]);
6868
let raw_socket = raw::Socket::new(
69-
IpVersion::Ipv4,
70-
IpProtocol::Igmp,
69+
Some(IpVersion::Ipv4),
70+
Some(IpProtocol::Igmp),
7171
raw_rx_buffer,
7272
raw_tx_buffer,
7373
);

src/iface/interface/tests/ipv4.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,12 @@ fn test_raw_socket_no_reply(#[case] medium: Medium) {
852852
vec![raw::PacketMetadata::EMPTY; packets],
853853
vec![0; 48 * packets],
854854
);
855-
let raw_socket = raw::Socket::new(IpVersion::Ipv4, IpProtocol::Udp, rx_buffer, tx_buffer);
855+
let raw_socket = raw::Socket::new(
856+
Some(IpVersion::Ipv4),
857+
Some(IpProtocol::Udp),
858+
rx_buffer,
859+
tx_buffer,
860+
);
856861
sockets.add(raw_socket);
857862

858863
let src_addr = Ipv4Address::new(127, 0, 0, 2);
@@ -948,8 +953,8 @@ fn test_raw_socket_with_udp_socket(#[case] medium: Medium) {
948953
vec![0; 48 * packets],
949954
);
950955
let raw_socket = raw::Socket::new(
951-
IpVersion::Ipv4,
952-
IpProtocol::Udp,
956+
Some(IpVersion::Ipv4),
957+
Some(IpProtocol::Udp),
953958
raw_rx_buffer,
954959
raw_tx_buffer,
955960
);

src/socket/raw.rs

Lines changed: 124 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ pub type PacketBuffer<'a> = crate::storage::PacketBuffer<'a, ()>;
8080

8181
/// A raw IP socket.
8282
///
83-
/// A raw socket is bound to a specific IP protocol, and owns
83+
/// A raw socket may be bound to a specific IP protocol, and owns
8484
/// transmit and receive packet buffers.
8585
#[derive(Debug)]
8686
pub struct Socket<'a> {
87-
ip_version: IpVersion,
88-
ip_protocol: IpProtocol,
87+
ip_version: Option<IpVersion>,
88+
ip_protocol: Option<IpProtocol>,
8989
rx_buffer: PacketBuffer<'a>,
9090
tx_buffer: PacketBuffer<'a>,
9191
#[cfg(feature = "async")]
@@ -98,8 +98,8 @@ impl<'a> Socket<'a> {
9898
/// Create a raw IP socket bound to the given IP version and datagram protocol,
9999
/// with the given buffers.
100100
pub fn new(
101-
ip_version: IpVersion,
102-
ip_protocol: IpProtocol,
101+
ip_version: Option<IpVersion>,
102+
ip_protocol: Option<IpProtocol>,
103103
rx_buffer: PacketBuffer<'a>,
104104
tx_buffer: PacketBuffer<'a>,
105105
) -> Socket<'a> {
@@ -152,13 +152,13 @@ impl<'a> Socket<'a> {
152152

153153
/// Return the IP version the socket is bound to.
154154
#[inline]
155-
pub fn ip_version(&self) -> IpVersion {
155+
pub fn ip_version(&self) -> Option<IpVersion> {
156156
self.ip_version
157157
}
158158

159159
/// Return the IP protocol the socket is bound to.
160160
#[inline]
161-
pub fn ip_protocol(&self) -> IpProtocol {
161+
pub fn ip_protocol(&self) -> Option<IpProtocol> {
162162
self.ip_protocol
163163
}
164164

@@ -216,7 +216,7 @@ impl<'a> Socket<'a> {
216216
.map_err(|_| SendError::BufferFull)?;
217217

218218
net_trace!(
219-
"raw:{}:{}: buffer to send {} octets",
219+
"raw:{:?}:{:?}: buffer to send {} octets",
220220
self.ip_version,
221221
self.ip_protocol,
222222
packet_buf.len()
@@ -238,7 +238,7 @@ impl<'a> Socket<'a> {
238238
.map_err(|_| SendError::BufferFull)?;
239239

240240
net_trace!(
241-
"raw:{}:{}: buffer to send {} octets",
241+
"raw:{:?}:{:?}: buffer to send {} octets",
242242
self.ip_version,
243243
self.ip_protocol,
244244
size
@@ -265,7 +265,7 @@ impl<'a> Socket<'a> {
265265
let ((), packet_buf) = self.rx_buffer.dequeue().map_err(|_| RecvError::Exhausted)?;
266266

267267
net_trace!(
268-
"raw:{}:{}: receive {} buffered octets",
268+
"raw:{:?}:{:?}: receive {} buffered octets",
269269
self.ip_version,
270270
self.ip_protocol,
271271
packet_buf.len()
@@ -299,7 +299,7 @@ impl<'a> Socket<'a> {
299299
let ((), packet_buf) = self.rx_buffer.peek().map_err(|_| RecvError::Exhausted)?;
300300

301301
net_trace!(
302-
"raw:{}:{}: receive {} buffered octets",
302+
"raw:{:?}:{:?}: receive {} buffered octets",
303303
self.ip_version,
304304
self.ip_protocol,
305305
packet_buf.len()
@@ -338,10 +338,17 @@ impl<'a> Socket<'a> {
338338
}
339339

340340
pub(crate) fn accepts(&self, ip_repr: &IpRepr) -> bool {
341-
if ip_repr.version() != self.ip_version {
341+
if self
342+
.ip_version
343+
.is_some_and(|version| version != ip_repr.version())
344+
{
342345
return false;
343346
}
344-
if ip_repr.next_header() != self.ip_protocol {
347+
348+
if self
349+
.ip_protocol
350+
.is_some_and(|next_header| next_header != ip_repr.next_header())
351+
{
345352
return false;
346353
}
347354

@@ -355,7 +362,7 @@ impl<'a> Socket<'a> {
355362
let total_len = header_len + payload.len();
356363

357364
net_trace!(
358-
"raw:{}:{}: receiving {} octets",
365+
"raw:{:?}:{:?}: receiving {} octets",
359366
self.ip_version,
360367
self.ip_protocol,
361368
total_len
@@ -367,7 +374,7 @@ impl<'a> Socket<'a> {
367374
buf[header_len..].copy_from_slice(payload);
368375
}
369376
Err(_) => net_trace!(
370-
"raw:{}:{}: buffer full, dropped incoming packet",
377+
"raw:{:?}:{:?}: buffer full, dropped incoming packet",
371378
self.ip_version,
372379
self.ip_protocol
373380
),
@@ -395,7 +402,7 @@ impl<'a> Socket<'a> {
395402
return Ok(());
396403
}
397404
};
398-
if packet.next_header() != ip_protocol {
405+
if ip_protocol.is_some_and(|next_header| next_header != packet.next_header()) {
399406
net_trace!("raw: sent packet with wrong ip protocol, dropping.");
400407
return Ok(());
401408
}
@@ -415,7 +422,7 @@ impl<'a> Socket<'a> {
415422
return Ok(());
416423
}
417424
};
418-
net_trace!("raw:{}:{}: sending", ip_version, ip_protocol);
425+
net_trace!("raw:{:?}:{:?}: sending", ip_version, ip_protocol);
419426
emit(cx, (IpRepr::Ipv4(ipv4_repr), packet.payload()))
420427
}
421428
#[cfg(feature = "proto-ipv6")]
@@ -427,7 +434,7 @@ impl<'a> Socket<'a> {
427434
return Ok(());
428435
}
429436
};
430-
if packet.next_header() != ip_protocol {
437+
if ip_protocol.is_some_and(|next_header| next_header != packet.next_header()) {
431438
net_trace!("raw: sent ipv6 packet with wrong ip protocol, dropping.");
432439
return Ok(());
433440
}
@@ -440,7 +447,7 @@ impl<'a> Socket<'a> {
440447
}
441448
};
442449

443-
net_trace!("raw:{}:{}: sending", ip_version, ip_protocol);
450+
net_trace!("raw:{:?}:{:?}: sending", ip_version, ip_protocol);
444451
emit(cx, (IpRepr::Ipv6(ipv6_repr), packet.payload()))
445452
}
446453
Err(_) => {
@@ -495,8 +502,8 @@ mod test {
495502
tx_buffer: PacketBuffer<'static>,
496503
) -> Socket<'static> {
497504
Socket::new(
498-
IpVersion::Ipv4,
499-
IpProtocol::Unknown(IP_PROTO),
505+
Some(IpVersion::Ipv4),
506+
Some(IpProtocol::Unknown(IP_PROTO)),
500507
rx_buffer,
501508
tx_buffer,
502509
)
@@ -526,8 +533,8 @@ mod test {
526533
tx_buffer: PacketBuffer<'static>,
527534
) -> Socket<'static> {
528535
Socket::new(
529-
IpVersion::Ipv6,
530-
IpProtocol::Unknown(IP_PROTO),
536+
Some(IpVersion::Ipv6),
537+
Some(IpProtocol::Unknown(IP_PROTO)),
531538
rx_buffer,
532539
tx_buffer,
533540
)
@@ -827,8 +834,8 @@ mod test {
827834
#[cfg(feature = "proto-ipv4")]
828835
{
829836
let socket = Socket::new(
830-
IpVersion::Ipv4,
831-
IpProtocol::Unknown(ipv4_locals::IP_PROTO + 1),
837+
Some(IpVersion::Ipv4),
838+
Some(IpProtocol::Unknown(ipv4_locals::IP_PROTO + 1)),
832839
buffer(1),
833840
buffer(1),
834841
);
@@ -839,8 +846,8 @@ mod test {
839846
#[cfg(feature = "proto-ipv6")]
840847
{
841848
let socket = Socket::new(
842-
IpVersion::Ipv6,
843-
IpProtocol::Unknown(ipv6_locals::IP_PROTO + 1),
849+
Some(IpVersion::Ipv6),
850+
Some(IpProtocol::Unknown(ipv6_locals::IP_PROTO + 1)),
844851
buffer(1),
845852
buffer(1),
846853
);
@@ -849,4 +856,94 @@ mod test {
849856
assert!(!socket.accepts(&ipv4_locals::HEADER_REPR));
850857
}
851858
}
859+
860+
fn check_dispatch(socket: &mut Socket<'_>, cx: &mut Context) {
861+
// Check dispatch returns Ok(()) and calls the emit closure
862+
let mut emitted = false;
863+
assert_eq!(
864+
socket.dispatch(cx, |_, _| {
865+
emitted = true;
866+
Ok(())
867+
}),
868+
Ok::<_, ()>(())
869+
);
870+
assert!(emitted);
871+
}
872+
873+
#[rstest]
874+
#[case::ip(Medium::Ip)]
875+
#[case::ethernet(Medium::Ethernet)]
876+
#[cfg(feature = "medium-ethernet")]
877+
#[case::ieee802154(Medium::Ieee802154)]
878+
#[cfg(feature = "medium-ieee802154")]
879+
fn test_unfiltered_sends_all(#[case] medium: Medium) {
880+
// Test a single unfiltered socket can send packets with different IP versions and next
881+
// headers
882+
let mut socket = Socket::new(None, None, buffer(0), buffer(2));
883+
#[cfg(feature = "proto-ipv4")]
884+
{
885+
let (mut iface, _, _) = setup(medium);
886+
let cx = iface.context();
887+
888+
let mut udp_packet = ipv4_locals::PACKET_BYTES;
889+
Ipv4Packet::new_unchecked(&mut udp_packet).set_next_header(IpProtocol::Udp);
890+
891+
assert_eq!(socket.send_slice(&udp_packet), Ok(()));
892+
check_dispatch(&mut socket, cx);
893+
894+
let mut tcp_packet = ipv4_locals::PACKET_BYTES;
895+
Ipv4Packet::new_unchecked(&mut tcp_packet).set_next_header(IpProtocol::Tcp);
896+
897+
assert_eq!(socket.send_slice(&tcp_packet[..]), Ok(()));
898+
check_dispatch(&mut socket, cx);
899+
}
900+
#[cfg(feature = "proto-ipv6")]
901+
{
902+
let (mut iface, _, _) = setup(medium);
903+
let cx = iface.context();
904+
905+
let mut udp_packet = ipv6_locals::PACKET_BYTES;
906+
Ipv6Packet::new_unchecked(&mut udp_packet).set_next_header(IpProtocol::Udp);
907+
908+
assert_eq!(socket.send_slice(&ipv6_locals::PACKET_BYTES), Ok(()));
909+
check_dispatch(&mut socket, cx);
910+
911+
let mut tcp_packet = ipv6_locals::PACKET_BYTES;
912+
Ipv6Packet::new_unchecked(&mut tcp_packet).set_next_header(IpProtocol::Tcp);
913+
914+
assert_eq!(socket.send_slice(&tcp_packet[..]), Ok(()));
915+
check_dispatch(&mut socket, cx);
916+
}
917+
}
918+
919+
#[rstest]
920+
#[case::proto(IpProtocol::Icmp)]
921+
#[case::proto(IpProtocol::Tcp)]
922+
#[case::proto(IpProtocol::Udp)]
923+
fn test_unfiltered_accepts_all(#[case] proto: IpProtocol) {
924+
// Test an unfiltered socket can accept packets with different IP versions and next headers
925+
let socket = Socket::new(None, None, buffer(0), buffer(0));
926+
#[cfg(feature = "proto-ipv4")]
927+
{
928+
let header_repr = IpRepr::Ipv4(Ipv4Repr {
929+
src_addr: Ipv4Address::new(10, 0, 0, 1),
930+
dst_addr: Ipv4Address::new(10, 0, 0, 2),
931+
next_header: proto,
932+
payload_len: 4,
933+
hop_limit: 64,
934+
});
935+
assert!(socket.accepts(&header_repr));
936+
}
937+
#[cfg(feature = "proto-ipv6")]
938+
{
939+
let header_repr = IpRepr::Ipv6(Ipv6Repr {
940+
src_addr: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1),
941+
dst_addr: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 2),
942+
next_header: proto,
943+
payload_len: 4,
944+
hop_limit: 64,
945+
});
946+
assert!(socket.accepts(&header_repr));
947+
}
948+
}
852949
}

0 commit comments

Comments
 (0)