@@ -80,12 +80,12 @@ pub type PacketBuffer<'a> = crate::storage::PacketBuffer<'a, ()>;
80
80
81
81
/// A raw IP socket.
82
82
///
83
- /// A raw socket is bound to a specific IP protocol, and owns
83
+ /// A raw socket may be bound to a specific IP protocol, and owns
84
84
/// transmit and receive packet buffers.
85
85
#[ derive( Debug ) ]
86
86
pub struct Socket < ' a > {
87
- ip_version : IpVersion ,
88
- ip_protocol : IpProtocol ,
87
+ ip_version : Option < IpVersion > ,
88
+ ip_protocol : Option < IpProtocol > ,
89
89
rx_buffer : PacketBuffer < ' a > ,
90
90
tx_buffer : PacketBuffer < ' a > ,
91
91
#[ cfg( feature = "async" ) ]
@@ -98,8 +98,8 @@ impl<'a> Socket<'a> {
98
98
/// Create a raw IP socket bound to the given IP version and datagram protocol,
99
99
/// with the given buffers.
100
100
pub fn new (
101
- ip_version : IpVersion ,
102
- ip_protocol : IpProtocol ,
101
+ ip_version : Option < IpVersion > ,
102
+ ip_protocol : Option < IpProtocol > ,
103
103
rx_buffer : PacketBuffer < ' a > ,
104
104
tx_buffer : PacketBuffer < ' a > ,
105
105
) -> Socket < ' a > {
@@ -152,13 +152,13 @@ impl<'a> Socket<'a> {
152
152
153
153
/// Return the IP version the socket is bound to.
154
154
#[ inline]
155
- pub fn ip_version ( & self ) -> IpVersion {
155
+ pub fn ip_version ( & self ) -> Option < IpVersion > {
156
156
self . ip_version
157
157
}
158
158
159
159
/// Return the IP protocol the socket is bound to.
160
160
#[ inline]
161
- pub fn ip_protocol ( & self ) -> IpProtocol {
161
+ pub fn ip_protocol ( & self ) -> Option < IpProtocol > {
162
162
self . ip_protocol
163
163
}
164
164
@@ -216,7 +216,7 @@ impl<'a> Socket<'a> {
216
216
. map_err ( |_| SendError :: BufferFull ) ?;
217
217
218
218
net_trace ! (
219
- "raw:{}:{}: buffer to send {} octets" ,
219
+ "raw:{:? }:{:? }: buffer to send {} octets" ,
220
220
self . ip_version,
221
221
self . ip_protocol,
222
222
packet_buf. len( )
@@ -238,7 +238,7 @@ impl<'a> Socket<'a> {
238
238
. map_err ( |_| SendError :: BufferFull ) ?;
239
239
240
240
net_trace ! (
241
- "raw:{}:{}: buffer to send {} octets" ,
241
+ "raw:{:? }:{:? }: buffer to send {} octets" ,
242
242
self . ip_version,
243
243
self . ip_protocol,
244
244
size
@@ -265,7 +265,7 @@ impl<'a> Socket<'a> {
265
265
let ( ( ) , packet_buf) = self . rx_buffer . dequeue ( ) . map_err ( |_| RecvError :: Exhausted ) ?;
266
266
267
267
net_trace ! (
268
- "raw:{}:{}: receive {} buffered octets" ,
268
+ "raw:{:? }:{:? }: receive {} buffered octets" ,
269
269
self . ip_version,
270
270
self . ip_protocol,
271
271
packet_buf. len( )
@@ -299,7 +299,7 @@ impl<'a> Socket<'a> {
299
299
let ( ( ) , packet_buf) = self . rx_buffer . peek ( ) . map_err ( |_| RecvError :: Exhausted ) ?;
300
300
301
301
net_trace ! (
302
- "raw:{}:{}: receive {} buffered octets" ,
302
+ "raw:{:? }:{:? }: receive {} buffered octets" ,
303
303
self . ip_version,
304
304
self . ip_protocol,
305
305
packet_buf. len( )
@@ -338,10 +338,17 @@ impl<'a> Socket<'a> {
338
338
}
339
339
340
340
pub ( crate ) fn accepts ( & self , ip_repr : & IpRepr ) -> bool {
341
- if ip_repr. version ( ) != self . ip_version {
341
+ if self
342
+ . ip_version
343
+ . is_some_and ( |version| version != ip_repr. version ( ) )
344
+ {
342
345
return false ;
343
346
}
344
- if ip_repr. next_header ( ) != self . ip_protocol {
347
+
348
+ if self
349
+ . ip_protocol
350
+ . is_some_and ( |next_header| next_header != ip_repr. next_header ( ) )
351
+ {
345
352
return false ;
346
353
}
347
354
@@ -355,7 +362,7 @@ impl<'a> Socket<'a> {
355
362
let total_len = header_len + payload. len ( ) ;
356
363
357
364
net_trace ! (
358
- "raw:{}:{}: receiving {} octets" ,
365
+ "raw:{:? }:{:? }: receiving {} octets" ,
359
366
self . ip_version,
360
367
self . ip_protocol,
361
368
total_len
@@ -367,7 +374,7 @@ impl<'a> Socket<'a> {
367
374
buf[ header_len..] . copy_from_slice ( payload) ;
368
375
}
369
376
Err ( _) => net_trace ! (
370
- "raw:{}:{}: buffer full, dropped incoming packet" ,
377
+ "raw:{:? }:{:? }: buffer full, dropped incoming packet" ,
371
378
self . ip_version,
372
379
self . ip_protocol
373
380
) ,
@@ -395,7 +402,7 @@ impl<'a> Socket<'a> {
395
402
return Ok ( ( ) ) ;
396
403
}
397
404
} ;
398
- if packet . next_header ( ) != ip_protocol {
405
+ if ip_protocol . is_some_and ( |next_header| next_header != packet . next_header ( ) ) {
399
406
net_trace ! ( "raw: sent packet with wrong ip protocol, dropping." ) ;
400
407
return Ok ( ( ) ) ;
401
408
}
@@ -415,7 +422,7 @@ impl<'a> Socket<'a> {
415
422
return Ok ( ( ) ) ;
416
423
}
417
424
} ;
418
- net_trace ! ( "raw:{}:{}: sending" , ip_version, ip_protocol) ;
425
+ net_trace ! ( "raw:{:? }:{:? }: sending" , ip_version, ip_protocol) ;
419
426
emit ( cx, ( IpRepr :: Ipv4 ( ipv4_repr) , packet. payload ( ) ) )
420
427
}
421
428
#[ cfg( feature = "proto-ipv6" ) ]
@@ -427,7 +434,7 @@ impl<'a> Socket<'a> {
427
434
return Ok ( ( ) ) ;
428
435
}
429
436
} ;
430
- if packet . next_header ( ) != ip_protocol {
437
+ if ip_protocol . is_some_and ( |next_header| next_header != packet . next_header ( ) ) {
431
438
net_trace ! ( "raw: sent ipv6 packet with wrong ip protocol, dropping." ) ;
432
439
return Ok ( ( ) ) ;
433
440
}
@@ -440,7 +447,7 @@ impl<'a> Socket<'a> {
440
447
}
441
448
} ;
442
449
443
- net_trace ! ( "raw:{}:{}: sending" , ip_version, ip_protocol) ;
450
+ net_trace ! ( "raw:{:? }:{:? }: sending" , ip_version, ip_protocol) ;
444
451
emit ( cx, ( IpRepr :: Ipv6 ( ipv6_repr) , packet. payload ( ) ) )
445
452
}
446
453
Err ( _) => {
@@ -495,8 +502,8 @@ mod test {
495
502
tx_buffer : PacketBuffer < ' static > ,
496
503
) -> Socket < ' static > {
497
504
Socket :: new (
498
- IpVersion :: Ipv4 ,
499
- IpProtocol :: Unknown ( IP_PROTO ) ,
505
+ Some ( IpVersion :: Ipv4 ) ,
506
+ Some ( IpProtocol :: Unknown ( IP_PROTO ) ) ,
500
507
rx_buffer,
501
508
tx_buffer,
502
509
)
@@ -526,8 +533,8 @@ mod test {
526
533
tx_buffer : PacketBuffer < ' static > ,
527
534
) -> Socket < ' static > {
528
535
Socket :: new (
529
- IpVersion :: Ipv6 ,
530
- IpProtocol :: Unknown ( IP_PROTO ) ,
536
+ Some ( IpVersion :: Ipv6 ) ,
537
+ Some ( IpProtocol :: Unknown ( IP_PROTO ) ) ,
531
538
rx_buffer,
532
539
tx_buffer,
533
540
)
@@ -827,8 +834,8 @@ mod test {
827
834
#[ cfg( feature = "proto-ipv4" ) ]
828
835
{
829
836
let socket = Socket :: new (
830
- IpVersion :: Ipv4 ,
831
- IpProtocol :: Unknown ( ipv4_locals:: IP_PROTO + 1 ) ,
837
+ Some ( IpVersion :: Ipv4 ) ,
838
+ Some ( IpProtocol :: Unknown ( ipv4_locals:: IP_PROTO + 1 ) ) ,
832
839
buffer ( 1 ) ,
833
840
buffer ( 1 ) ,
834
841
) ;
@@ -839,8 +846,8 @@ mod test {
839
846
#[ cfg( feature = "proto-ipv6" ) ]
840
847
{
841
848
let socket = Socket :: new (
842
- IpVersion :: Ipv6 ,
843
- IpProtocol :: Unknown ( ipv6_locals:: IP_PROTO + 1 ) ,
849
+ Some ( IpVersion :: Ipv6 ) ,
850
+ Some ( IpProtocol :: Unknown ( ipv6_locals:: IP_PROTO + 1 ) ) ,
844
851
buffer ( 1 ) ,
845
852
buffer ( 1 ) ,
846
853
) ;
@@ -849,4 +856,94 @@ mod test {
849
856
assert ! ( !socket. accepts( & ipv4_locals:: HEADER_REPR ) ) ;
850
857
}
851
858
}
859
+
860
+ fn check_dispatch ( socket : & mut Socket < ' _ > , cx : & mut Context ) {
861
+ // Check dispatch returns Ok(()) and calls the emit closure
862
+ let mut emitted = false ;
863
+ assert_eq ! (
864
+ socket. dispatch( cx, |_, _| {
865
+ emitted = true ;
866
+ Ok ( ( ) )
867
+ } ) ,
868
+ Ok :: <_, ( ) >( ( ) )
869
+ ) ;
870
+ assert ! ( emitted) ;
871
+ }
872
+
873
+ #[ rstest]
874
+ #[ case:: ip( Medium :: Ip ) ]
875
+ #[ case:: ethernet( Medium :: Ethernet ) ]
876
+ #[ cfg( feature = "medium-ethernet" ) ]
877
+ #[ case:: ieee802154( Medium :: Ieee802154 ) ]
878
+ #[ cfg( feature = "medium-ieee802154" ) ]
879
+ fn test_unfiltered_sends_all ( #[ case] medium : Medium ) {
880
+ // Test a single unfiltered socket can send packets with different IP versions and next
881
+ // headers
882
+ let mut socket = Socket :: new ( None , None , buffer ( 0 ) , buffer ( 2 ) ) ;
883
+ #[ cfg( feature = "proto-ipv4" ) ]
884
+ {
885
+ let ( mut iface, _, _) = setup ( medium) ;
886
+ let cx = iface. context ( ) ;
887
+
888
+ let mut udp_packet = ipv4_locals:: PACKET_BYTES ;
889
+ Ipv4Packet :: new_unchecked ( & mut udp_packet) . set_next_header ( IpProtocol :: Udp ) ;
890
+
891
+ assert_eq ! ( socket. send_slice( & udp_packet) , Ok ( ( ) ) ) ;
892
+ check_dispatch ( & mut socket, cx) ;
893
+
894
+ let mut tcp_packet = ipv4_locals:: PACKET_BYTES ;
895
+ Ipv4Packet :: new_unchecked ( & mut tcp_packet) . set_next_header ( IpProtocol :: Tcp ) ;
896
+
897
+ assert_eq ! ( socket. send_slice( & tcp_packet[ ..] ) , Ok ( ( ) ) ) ;
898
+ check_dispatch ( & mut socket, cx) ;
899
+ }
900
+ #[ cfg( feature = "proto-ipv6" ) ]
901
+ {
902
+ let ( mut iface, _, _) = setup ( medium) ;
903
+ let cx = iface. context ( ) ;
904
+
905
+ let mut udp_packet = ipv6_locals:: PACKET_BYTES ;
906
+ Ipv6Packet :: new_unchecked ( & mut udp_packet) . set_next_header ( IpProtocol :: Udp ) ;
907
+
908
+ assert_eq ! ( socket. send_slice( & ipv6_locals:: PACKET_BYTES ) , Ok ( ( ) ) ) ;
909
+ check_dispatch ( & mut socket, cx) ;
910
+
911
+ let mut tcp_packet = ipv6_locals:: PACKET_BYTES ;
912
+ Ipv6Packet :: new_unchecked ( & mut tcp_packet) . set_next_header ( IpProtocol :: Tcp ) ;
913
+
914
+ assert_eq ! ( socket. send_slice( & tcp_packet[ ..] ) , Ok ( ( ) ) ) ;
915
+ check_dispatch ( & mut socket, cx) ;
916
+ }
917
+ }
918
+
919
+ #[ rstest]
920
+ #[ case:: proto( IpProtocol :: Icmp ) ]
921
+ #[ case:: proto( IpProtocol :: Tcp ) ]
922
+ #[ case:: proto( IpProtocol :: Udp ) ]
923
+ fn test_unfiltered_accepts_all ( #[ case] proto : IpProtocol ) {
924
+ // Test an unfiltered socket can accept packets with different IP versions and next headers
925
+ let socket = Socket :: new ( None , None , buffer ( 0 ) , buffer ( 0 ) ) ;
926
+ #[ cfg( feature = "proto-ipv4" ) ]
927
+ {
928
+ let header_repr = IpRepr :: Ipv4 ( Ipv4Repr {
929
+ src_addr : Ipv4Address :: new ( 10 , 0 , 0 , 1 ) ,
930
+ dst_addr : Ipv4Address :: new ( 10 , 0 , 0 , 2 ) ,
931
+ next_header : proto,
932
+ payload_len : 4 ,
933
+ hop_limit : 64 ,
934
+ } ) ;
935
+ assert ! ( socket. accepts( & header_repr) ) ;
936
+ }
937
+ #[ cfg( feature = "proto-ipv6" ) ]
938
+ {
939
+ let header_repr = IpRepr :: Ipv6 ( Ipv6Repr {
940
+ src_addr : Ipv6Address :: new ( 0xfe80 , 0 , 0 , 0 , 0 , 0 , 0 , 1 ) ,
941
+ dst_addr : Ipv6Address :: new ( 0xfe80 , 0 , 0 , 0 , 0 , 0 , 0 , 2 ) ,
942
+ next_header : proto,
943
+ payload_len : 4 ,
944
+ hop_limit : 64 ,
945
+ } ) ;
946
+ assert ! ( socket. accepts( & header_repr) ) ;
947
+ }
948
+ }
852
949
}
0 commit comments