@@ -24,7 +24,8 @@ use libc::{
2424use nix:: {
2525 ioctl_read_bad,
2626 sys:: socket:: {
27- self , bind, connect, getpeername, getsockname, listen, recv, send, shutdown, socket,
27+ self , bind, connect, getpeername, getsockname, listen, recv, recvfrom, send, sendto,
28+ shutdown, socket,
2829 sockopt:: { ReceiveTimeout , SendTimeout , SocketError } ,
2930 AddressFamily , Backlog , GetSockOpt , MsgFlags , SetSockOpt , SockFlag , SockType ,
3031 } ,
@@ -45,12 +46,54 @@ pub use libc::VMADDR_CID_LOCAL;
4546pub use libc:: { VMADDR_CID_ANY , VMADDR_CID_HOST , VMADDR_CID_HYPERVISOR } ;
4647pub use nix:: sys:: socket:: { SockaddrLike , VsockAddr } ;
4748
48- fn new_socket ( ) -> Result < OwnedFd > {
49+ fn new_socket ( ty : SockType ) -> Result < OwnedFd > {
4950 #[ cfg( not( target_os = "macos" ) ) ]
5051 let flags = SockFlag :: SOCK_CLOEXEC ;
5152 #[ cfg( target_os = "macos" ) ]
5253 let flags = SockFlag :: empty ( ) ;
53- Ok ( socket ( AddressFamily :: Vsock , SockType :: Stream , flags, None ) ?)
54+ Ok ( socket ( AddressFamily :: Vsock , ty, flags, None ) ?)
55+ }
56+
57+ fn default_send_msg_flags ( ) -> MsgFlags {
58+ #[ cfg( not( target_os = "macos" ) ) ]
59+ let flags = MsgFlags :: MSG_NOSIGNAL ;
60+ #[ cfg( target_os = "macos" ) ]
61+ let flags = MsgFlags :: empty ( ) ;
62+ flags
63+ }
64+
65+ /// Internal helper to turn a [`Duration`] into a [`timeval`]
66+ fn timeval_from_duration ( dur : Option < Duration > ) -> Result < timeval > {
67+ match dur {
68+ Some ( dur) => {
69+ if dur. as_secs ( ) == 0 && dur. subsec_nanos ( ) == 0 {
70+ return Err ( Error :: new (
71+ ErrorKind :: InvalidInput ,
72+ "cannot set a zero duration timeout" ,
73+ ) ) ;
74+ }
75+
76+ // https://github.com/rust-lang/libc/issues/1848
77+ #[ cfg_attr( target_env = "musl" , allow( deprecated) ) ]
78+ let secs = if dur. as_secs ( ) > libc:: time_t:: MAX as u64 {
79+ libc:: time_t:: MAX
80+ } else {
81+ dur. as_secs ( ) as libc:: time_t
82+ } ;
83+ let mut timeout = timeval {
84+ tv_sec : secs,
85+ tv_usec : i64:: from ( dur. subsec_micros ( ) ) as suseconds_t ,
86+ } ;
87+ if timeout. tv_sec == 0 && timeout. tv_usec == 0 {
88+ timeout. tv_usec = 1 ;
89+ }
90+ Ok ( timeout)
91+ }
92+ None => Ok ( timeval {
93+ tv_sec : 0 ,
94+ tv_usec : 0 ,
95+ } ) ,
96+ }
5497}
5598
5699/// An iterator that infinitely accepts connections on a VsockListener.
@@ -80,7 +123,7 @@ impl VsockListener {
80123 return Err ( Error :: other ( "requires a virtio socket address" ) ) ;
81124 }
82125
83- let socket = new_socket ( ) ?;
126+ let socket = new_socket ( SockType :: Stream ) ?;
84127
85128 bind ( socket. as_raw_fd ( ) , addr) ?;
86129
@@ -142,7 +185,7 @@ impl VsockListener {
142185 }
143186
144187 /// An iterator over the connections being received on this listener.
145- pub fn incoming ( & self ) -> Incoming {
188+ pub fn incoming ( & self ) -> Incoming < ' _ > {
146189 Incoming { listener : self }
147190 }
148191
@@ -174,7 +217,7 @@ impl AsRawFd for VsockListener {
174217}
175218
176219impl AsFd for VsockListener {
177- fn as_fd ( & self ) -> BorrowedFd {
220+ fn as_fd ( & self ) -> BorrowedFd < ' _ > {
178221 self . socket . as_fd ( )
179222 }
180223}
@@ -193,7 +236,179 @@ impl IntoRawFd for VsockListener {
193236 }
194237}
195238
239+ /// A virtio sequential packet socket between a local and a remote host.
240+ ///
241+ /// This is the vsock equivalent of [`std::net::UdpSocket`].
242+ #[ derive( Debug ) ]
243+ pub struct VsockSocket {
244+ socket : OwnedFd ,
245+ }
246+
247+ impl VsockSocket {
248+ /// Bind to an address and listen for connections.
249+ ///
250+ /// Analogous to [`std::net::UdpSocket::bind`]
251+ pub fn bind < A : SockaddrLike > ( addr : & A ) -> Result < Self > {
252+ if addr. family ( ) != Some ( AddressFamily :: Vsock ) {
253+ return Err ( Error :: other ( "requires a virtio socket address" ) ) ;
254+ }
255+
256+ let socket = new_socket ( SockType :: Datagram ) ?;
257+
258+ bind ( socket. as_raw_fd ( ) , addr) ?;
259+
260+ Ok ( Self { socket } )
261+ }
262+
263+ /// Bind to a specified cid and port and listen for connections.
264+ pub fn bind_with_cid_port ( cid : u32 , port : u32 ) -> Result < Self > {
265+ Self :: bind ( & VsockAddr :: new ( cid, port) )
266+ }
267+
268+ /// Receive a message from a remote host.
269+ ///
270+ /// Analogous to [`std::net::UdpSocket::recv_from`]
271+ ///
272+ /// # Returns
273+ ///
274+ /// The number of bytes read and the address of the remote host.
275+ pub fn recv_from ( & self , buf : & mut [ u8 ] ) -> Result < ( usize , VsockAddr ) > {
276+ recvfrom ( self . socket . as_raw_fd ( ) , buf)
277+ // UNWRAP SAFETY: recvfrom should always return peer address when SockType == SockType::Datagram
278+ . map ( |( size, addr) | ( size, addr. expect ( "recv_from didn't return peer address" ) ) )
279+ . map_err ( nix:: Error :: into)
280+ }
281+
282+ /// Send a message to a remote host.
283+ ///
284+ /// Analogous to [`std::net::UdpSocket::send_to`]
285+ pub fn send_to < A : SockaddrLike > ( & self , buf : & [ u8 ] , addr : & A ) -> Result < usize > {
286+ sendto ( self . socket . as_raw_fd ( ) , buf, addr, default_send_msg_flags ( ) )
287+ . map_err ( nix:: Error :: into)
288+ }
289+
290+ /// Send a message to a remote host with specified cid and port.
291+ pub fn send_to_with_cid_port ( & self , buf : & [ u8 ] , cid : u32 , port : u32 ) -> Result < usize > {
292+ self . send_to ( buf, & VsockAddr :: new ( cid, port) )
293+ }
294+
295+ /// Virtio socket address of the remote peer associated with this connection.
296+ pub fn peer_addr ( & self ) -> Result < VsockAddr > {
297+ Ok ( getpeername ( self . socket . as_raw_fd ( ) ) ?)
298+ }
299+
300+ /// Virtio socket address of the local address associated with this connection.
301+ pub fn local_addr ( & self ) -> Result < VsockAddr > {
302+ Ok ( getsockname ( self . socket . as_raw_fd ( ) ) ?)
303+ }
304+
305+ /// Create a new independently owned handle to the underlying socket.
306+ pub fn try_clone ( & self ) -> Result < Self > {
307+ Ok ( Self {
308+ socket : self . socket . try_clone ( ) ?,
309+ } )
310+ }
311+
312+ /// Set the timeout on read operations.
313+ pub fn set_read_timeout ( & self , dur : Option < Duration > ) -> Result < ( ) > {
314+ let timeout = timeval_from_duration ( dur) ?. into ( ) ;
315+ Ok ( ReceiveTimeout . set ( & self . socket , & timeout) ?)
316+ }
317+
318+ /// Set the timeout on write operations.
319+ pub fn set_write_timeout ( & self , dur : Option < Duration > ) -> Result < ( ) > {
320+ let timeout = timeval_from_duration ( dur) ?. into ( ) ;
321+ Ok ( SendTimeout . set ( & self . socket , & timeout) ?)
322+ }
323+
324+ /// Retrieve the latest error associated with the underlying socket.
325+ pub fn take_error ( & self ) -> Result < Option < Error > > {
326+ let error = SocketError . get ( & self . socket ) ?;
327+ Ok ( if error == 0 {
328+ None
329+ } else {
330+ Some ( Error :: from_raw_os_error ( error) )
331+ } )
332+ }
333+
334+ /// Open a connection to a remote host (you need to bind to an address with [`Self::bind`]
335+ /// first).
336+ ///
337+ /// Allows you to send and receive messages from this host directly through [`Self::send`] and
338+ /// [`Self::recv`].
339+ ///
340+ /// Analogous to [`std::net::UdpSocket::connect`]
341+ pub fn connect < A : SockaddrLike > ( & self , addr : & A ) -> Result < ( ) > {
342+ if addr. family ( ) != Some ( AddressFamily :: Vsock ) {
343+ return Err ( Error :: other ( "requires a virtio socket address" ) ) ;
344+ }
345+
346+ connect ( self . socket . as_raw_fd ( ) , addr) . map_err ( nix:: Error :: into)
347+ }
348+
349+ /// Open a connection to a remote host with specified cid and port (you need to bind to an
350+ /// address with [`Self::bind`] first).
351+ ///
352+ /// Allows you to send and receive messages from this host directly through [`Self::send`] and
353+ /// [`Self::recv`].
354+ pub fn connect_with_cid_port ( & self , cid : u32 , port : u32 ) -> Result < ( ) > {
355+ self . connect ( & VsockAddr :: new ( cid, port) )
356+ }
357+
358+ /// Send data to the connected remote host.
359+ ///
360+ /// Analogous to [`std::net::UdpSocket::send`]
361+ pub fn send ( & self , buf : & [ u8 ] ) -> Result < usize > {
362+ send ( self . socket . as_raw_fd ( ) , buf, default_send_msg_flags ( ) ) . map_err ( nix:: Error :: into)
363+ }
364+
365+ /// Receive data from the connected remote host.
366+ ///
367+ /// Analogous to [`std::net::UdpSocket::recv`]
368+ pub fn recv ( & self , buf : & mut [ u8 ] ) -> Result < usize > {
369+ recv ( self . socket . as_raw_fd ( ) , buf, MsgFlags :: empty ( ) ) . map_err ( nix:: Error :: into)
370+ }
371+
372+ /// Move this stream in and out of nonblocking mode.
373+ pub fn set_nonblocking ( & self , nonblocking : bool ) -> Result < ( ) > {
374+ let mut nonblocking: i32 = if nonblocking { 1 } else { 0 } ;
375+ if unsafe { ioctl ( self . socket . as_raw_fd ( ) , FIONBIO , & mut nonblocking) } < 0 {
376+ Err ( Error :: last_os_error ( ) )
377+ } else {
378+ Ok ( ( ) )
379+ }
380+ }
381+ }
382+
383+ impl AsFd for VsockSocket {
384+ fn as_fd ( & self ) -> BorrowedFd < ' _ > {
385+ self . socket . as_fd ( )
386+ }
387+ }
388+
389+ impl AsRawFd for VsockSocket {
390+ fn as_raw_fd ( & self ) -> RawFd {
391+ self . socket . as_raw_fd ( )
392+ }
393+ }
394+
395+ impl FromRawFd for VsockSocket {
396+ unsafe fn from_raw_fd ( fd : RawFd ) -> Self {
397+ Self {
398+ socket : OwnedFd :: from_raw_fd ( fd) ,
399+ }
400+ }
401+ }
402+
403+ impl IntoRawFd for VsockSocket {
404+ fn into_raw_fd ( self ) -> RawFd {
405+ self . socket . into_raw_fd ( )
406+ }
407+ }
408+
196409/// A virtio stream between a local and a remote socket.
410+ ///
411+ /// This is the vsock equivalent of [`std::net::TcpStream`].
197412#[ derive( Debug ) ]
198413pub struct VsockStream {
199414 socket : OwnedFd ,
@@ -206,7 +421,7 @@ impl VsockStream {
206421 return Err ( Error :: other ( "requires a virtio socket address" ) ) ;
207422 }
208423
209- let socket = new_socket ( ) ?;
424+ let socket = new_socket ( SockType :: Stream ) ?;
210425 connect ( socket. as_raw_fd ( ) , addr) ?;
211426 Ok ( Self { socket } )
212427 }
@@ -245,13 +460,13 @@ impl VsockStream {
245460
246461 /// Set the timeout on read operations.
247462 pub fn set_read_timeout ( & self , dur : Option < Duration > ) -> Result < ( ) > {
248- let timeout = Self :: timeval_from_duration ( dur) ?. into ( ) ;
463+ let timeout = timeval_from_duration ( dur) ?. into ( ) ;
249464 Ok ( ReceiveTimeout . set ( & self . socket , & timeout) ?)
250465 }
251466
252467 /// Set the timeout on write operations.
253468 pub fn set_write_timeout ( & self , dur : Option < Duration > ) -> Result < ( ) > {
254- let timeout = Self :: timeval_from_duration ( dur) ?. into ( ) ;
469+ let timeout = timeval_from_duration ( dur) ?. into ( ) ;
255470 Ok ( SendTimeout . set ( & self . socket , & timeout) ?)
256471 }
257472
@@ -274,39 +489,6 @@ impl VsockStream {
274489 Ok ( ( ) )
275490 }
276491 }
277-
278- fn timeval_from_duration ( dur : Option < Duration > ) -> Result < timeval > {
279- match dur {
280- Some ( dur) => {
281- if dur. as_secs ( ) == 0 && dur. subsec_nanos ( ) == 0 {
282- return Err ( Error :: new (
283- ErrorKind :: InvalidInput ,
284- "cannot set a zero duration timeout" ,
285- ) ) ;
286- }
287-
288- // https://github.com/rust-lang/libc/issues/1848
289- #[ cfg_attr( target_env = "musl" , allow( deprecated) ) ]
290- let secs = if dur. as_secs ( ) > libc:: time_t:: MAX as u64 {
291- libc:: time_t:: MAX
292- } else {
293- dur. as_secs ( ) as libc:: time_t
294- } ;
295- let mut timeout = timeval {
296- tv_sec : secs,
297- tv_usec : i64:: from ( dur. subsec_micros ( ) ) as suseconds_t ,
298- } ;
299- if timeout. tv_sec == 0 && timeout. tv_usec == 0 {
300- timeout. tv_usec = 1 ;
301- }
302- Ok ( timeout)
303- }
304- None => Ok ( timeval {
305- tv_sec : 0 ,
306- tv_usec : 0 ,
307- } ) ,
308- }
309- }
310492}
311493
312494impl Read for VsockStream {
@@ -333,11 +515,11 @@ impl Read for &VsockStream {
333515
334516impl Write for & VsockStream {
335517 fn write ( & mut self , buf : & [ u8 ] ) -> Result < usize > {
336- # [ cfg ( not ( target_os = "macos" ) ) ]
337- let flags = MsgFlags :: MSG_NOSIGNAL ;
338- # [ cfg ( target_os = "macos" ) ]
339- let flags = MsgFlags :: empty ( ) ;
340- Ok ( send ( self . socket . as_raw_fd ( ) , buf , flags ) ?)
518+ Ok ( send (
519+ self . socket . as_raw_fd ( ) ,
520+ buf ,
521+ default_send_msg_flags ( ) ,
522+ ) ?)
341523 }
342524
343525 fn flush ( & mut self ) -> Result < ( ) > {
@@ -352,7 +534,7 @@ impl AsRawFd for VsockStream {
352534}
353535
354536impl AsFd for VsockStream {
355- fn as_fd ( & self ) -> BorrowedFd {
537+ fn as_fd ( & self ) -> BorrowedFd < ' _ > {
356538 self . socket . as_fd ( )
357539 }
358540}
0 commit comments