@@ -45,12 +45,54 @@ pub use libc::VMADDR_CID_LOCAL;
4545pub use libc:: { VMADDR_CID_ANY , VMADDR_CID_HOST , VMADDR_CID_HYPERVISOR } ;
4646pub use nix:: sys:: socket:: { SockaddrLike , VsockAddr } ;
4747
48- fn new_socket ( ) -> Result < OwnedFd > {
48+ fn new_socket ( ty : SockType ) -> Result < OwnedFd > {
4949 #[ cfg( not( target_os = "macos" ) ) ]
5050 let flags = SockFlag :: SOCK_CLOEXEC ;
5151 #[ cfg( target_os = "macos" ) ]
5252 let flags = SockFlag :: empty ( ) ;
53- Ok ( socket ( AddressFamily :: Vsock , SockType :: Stream , flags, None ) ?)
53+ Ok ( socket ( AddressFamily :: Vsock , ty, flags, None ) ?)
54+ }
55+
56+ fn default_send_msg_flags ( ) -> MsgFlags {
57+ #[ cfg( not( target_os = "macos" ) ) ]
58+ let flags = MsgFlags :: MSG_NOSIGNAL ;
59+ #[ cfg( target_os = "macos" ) ]
60+ let flags = MsgFlags :: empty ( ) ;
61+ flags
62+ }
63+
64+ /// Internal helper to turn a [`Duration`] into a [`timeval`]
65+ fn timeval_from_duration ( dur : Option < Duration > ) -> Result < timeval > {
66+ match dur {
67+ Some ( dur) => {
68+ if dur. as_secs ( ) == 0 && dur. subsec_nanos ( ) == 0 {
69+ return Err ( Error :: new (
70+ ErrorKind :: InvalidInput ,
71+ "cannot set a zero duration timeout" ,
72+ ) ) ;
73+ }
74+
75+ // https://github.com/rust-lang/libc/issues/1848
76+ #[ cfg_attr( target_env = "musl" , allow( deprecated) ) ]
77+ let secs = if dur. as_secs ( ) > libc:: time_t:: MAX as u64 {
78+ libc:: time_t:: MAX
79+ } else {
80+ dur. as_secs ( ) as libc:: time_t
81+ } ;
82+ let mut timeout = timeval {
83+ tv_sec : secs,
84+ tv_usec : i64:: from ( dur. subsec_micros ( ) ) as suseconds_t ,
85+ } ;
86+ if timeout. tv_sec == 0 && timeout. tv_usec == 0 {
87+ timeout. tv_usec = 1 ;
88+ }
89+ Ok ( timeout)
90+ }
91+ None => Ok ( timeval {
92+ tv_sec : 0 ,
93+ tv_usec : 0 ,
94+ } ) ,
95+ }
5496}
5597
5698/// An iterator that infinitely accepts connections on a VsockListener.
@@ -80,7 +122,7 @@ impl VsockListener {
80122 return Err ( Error :: other ( "requires a virtio socket address" ) ) ;
81123 }
82124
83- let socket = new_socket ( ) ?;
125+ let socket = new_socket ( SockType :: Stream ) ?;
84126
85127 bind ( socket. as_raw_fd ( ) , addr) ?;
86128
@@ -194,6 +236,160 @@ impl IntoRawFd for VsockListener {
194236}
195237
196238/// A virtio stream between a local and a remote socket.
239+ ///
240+ /// This is the vsock equivalent of [`std::net::UdpSocket`].
241+ #[ derive( Debug ) ]
242+ pub struct VsockSocket {
243+ socket : OwnedFd ,
244+ }
245+
246+ impl VsockSocket {
247+ /// Bind to an address and listen for connections.
248+ pub fn bind < A : SockaddrLike > ( addr : & A ) -> Result < Self > {
249+ if addr. family ( ) != Some ( AddressFamily :: Vsock ) {
250+ return Err ( Error :: other ( "requires a virtio socket address" ) ) ;
251+ }
252+
253+ let socket = new_socket ( SockType :: Datagram ) ?;
254+
255+ bind ( socket. as_raw_fd ( ) , addr) ?;
256+
257+ Ok ( Self { socket } )
258+ }
259+
260+ /// Bind to a remote host with specified cid and port.
261+ pub fn bind_with_cid_port ( & self , cid : u32 , port : u32 ) -> Result < Self > {
262+ Self :: bind ( & VsockAddr :: new ( cid, port) )
263+ }
264+
265+ /// Receive a message from a remote host.
266+ ///
267+ /// # Returns
268+ ///
269+ /// The number of bytes read and the address of the remote host.
270+ pub fn recv_from ( & self , buf : & mut [ u8 ] ) -> Result < ( usize , VsockAddr ) > {
271+ nix:: sys:: socket:: recvfrom ( self . socket . as_raw_fd ( ) , buf)
272+ // UNWRAP SAFETY: recvfrom should always return peer address when SockType == SockType::Datagram
273+ . map ( |( size, addr) | ( size, addr. expect ( "recv_from didn't return peer address" ) ) )
274+ . map_err ( nix:: Error :: into)
275+ }
276+
277+ /// Send a message to a remote host.
278+ pub fn send_to < A : SockaddrLike > ( & self , buf : & [ u8 ] , addr : & A ) -> Result < usize > {
279+ nix:: sys:: socket:: sendto ( self . socket . as_raw_fd ( ) , buf, addr, default_send_msg_flags ( ) )
280+ . map_err ( nix:: Error :: into)
281+ }
282+
283+ /// Send a message to a remote host with specified cid and port.
284+ pub fn send_to_with_cid_port ( & self , buf : & [ u8 ] , cid : u32 , port : u32 ) -> Result < usize > {
285+ self . send_to ( buf, & VsockAddr :: new ( cid, port) )
286+ }
287+
288+ /// Virtio socket address of the remote peer associated with this connection.
289+ pub fn peer_addr ( & self ) -> Result < VsockAddr > {
290+ Ok ( getpeername ( self . socket . as_raw_fd ( ) ) ?)
291+ }
292+
293+ /// Virtio socket address of the local address associated with this connection.
294+ pub fn local_addr ( & self ) -> Result < VsockAddr > {
295+ Ok ( getsockname ( self . socket . as_raw_fd ( ) ) ?)
296+ }
297+
298+ /// Create a new independently owned handle to the underlying socket.
299+ pub fn try_clone ( & self ) -> Result < Self > {
300+ Ok ( Self {
301+ socket : self . socket . try_clone ( ) ?,
302+ } )
303+ }
304+
305+ /// Set the timeout on read operations.
306+ pub fn set_read_timeout ( & self , dur : Option < Duration > ) -> Result < ( ) > {
307+ let timeout = timeval_from_duration ( dur) ?. into ( ) ;
308+ Ok ( ReceiveTimeout . set ( & self . socket , & timeout) ?)
309+ }
310+
311+ /// Set the timeout on write operations.
312+ pub fn set_write_timeout ( & self , dur : Option < Duration > ) -> Result < ( ) > {
313+ let timeout = timeval_from_duration ( dur) ?. into ( ) ;
314+ Ok ( SendTimeout . set ( & self . socket , & timeout) ?)
315+ }
316+
317+ /// Retrieve the latest error associated with the underlying socket.
318+ pub fn take_error ( & self ) -> Result < Option < Error > > {
319+ let error = SocketError . get ( & self . socket ) ?;
320+ Ok ( if error == 0 {
321+ None
322+ } else {
323+ Some ( Error :: from_raw_os_error ( error) )
324+ } )
325+ }
326+
327+ /// Open a connection to a remote host.
328+ pub fn connect < A : SockaddrLike > ( & self , addr : & A ) -> Result < ( ) > {
329+ if addr. family ( ) != Some ( AddressFamily :: Vsock ) {
330+ return Err ( Error :: other ( "requires a virtio socket address" ) ) ;
331+ }
332+
333+ connect ( self . socket . as_raw_fd ( ) , addr) . map_err ( nix:: Error :: into)
334+ }
335+
336+ /// Open a connection to a remote host with specified cid and port.
337+ pub fn connect_with_cid_port ( & self , cid : u32 , port : u32 ) -> Result < ( ) > {
338+ self . connect ( & VsockAddr :: new ( cid, port) )
339+ }
340+
341+ /// Send data to the connected remote host.
342+ pub fn send ( & self , buf : & [ u8 ] ) -> Result < usize > {
343+ nix:: sys:: socket:: send ( self . socket . as_raw_fd ( ) , buf, default_send_msg_flags ( ) )
344+ . map_err ( nix:: Error :: into)
345+ }
346+
347+ /// Receive data from the connected remote host.
348+ pub fn recv ( & self , buf : & mut [ u8 ] ) -> Result < usize > {
349+ nix:: sys:: socket:: recv ( self . socket . as_raw_fd ( ) , buf, MsgFlags :: empty ( ) )
350+ . map_err ( nix:: Error :: into)
351+ }
352+
353+ /// Move this stream in and out of nonblocking mode.
354+ pub fn set_nonblocking ( & self , nonblocking : bool ) -> Result < ( ) > {
355+ let mut nonblocking: i32 = if nonblocking { 1 } else { 0 } ;
356+ if unsafe { ioctl ( self . socket . as_raw_fd ( ) , FIONBIO , & mut nonblocking) } < 0 {
357+ Err ( Error :: last_os_error ( ) )
358+ } else {
359+ Ok ( ( ) )
360+ }
361+ }
362+ }
363+
364+ impl AsFd for VsockSocket {
365+ fn as_fd ( & self ) -> BorrowedFd < ' _ > {
366+ self . socket . as_fd ( )
367+ }
368+ }
369+
370+ impl AsRawFd for VsockSocket {
371+ fn as_raw_fd ( & self ) -> RawFd {
372+ self . socket . as_raw_fd ( )
373+ }
374+ }
375+
376+ impl FromRawFd for VsockSocket {
377+ unsafe fn from_raw_fd ( fd : RawFd ) -> Self {
378+ Self {
379+ socket : OwnedFd :: from_raw_fd ( fd) ,
380+ }
381+ }
382+ }
383+
384+ impl IntoRawFd for VsockSocket {
385+ fn into_raw_fd ( self ) -> RawFd {
386+ self . socket . into_raw_fd ( )
387+ }
388+ }
389+
390+ /// A virtio stream between a local and a remote socket.
391+ ///
392+ /// This is the vsock equivalent of [`std::net::TcpStream`].
197393#[ derive( Debug ) ]
198394pub struct VsockStream {
199395 socket : OwnedFd ,
@@ -206,7 +402,7 @@ impl VsockStream {
206402 return Err ( Error :: other ( "requires a virtio socket address" ) ) ;
207403 }
208404
209- let socket = new_socket ( ) ?;
405+ let socket = new_socket ( SockType :: Stream ) ?;
210406 connect ( socket. as_raw_fd ( ) , addr) ?;
211407 Ok ( Self { socket } )
212408 }
@@ -245,13 +441,13 @@ impl VsockStream {
245441
246442 /// Set the timeout on read operations.
247443 pub fn set_read_timeout ( & self , dur : Option < Duration > ) -> Result < ( ) > {
248- let timeout = Self :: timeval_from_duration ( dur) ?. into ( ) ;
444+ let timeout = timeval_from_duration ( dur) ?. into ( ) ;
249445 Ok ( ReceiveTimeout . set ( & self . socket , & timeout) ?)
250446 }
251447
252448 /// Set the timeout on write operations.
253449 pub fn set_write_timeout ( & self , dur : Option < Duration > ) -> Result < ( ) > {
254- let timeout = Self :: timeval_from_duration ( dur) ?. into ( ) ;
450+ let timeout = timeval_from_duration ( dur) ?. into ( ) ;
255451 Ok ( SendTimeout . set ( & self . socket , & timeout) ?)
256452 }
257453
@@ -274,39 +470,6 @@ impl VsockStream {
274470 Ok ( ( ) )
275471 }
276472 }
277-
278- fn timeval_from_duration ( dur : Option < Duration > ) -> Result < timeval > {
279- match dur {
280- Some ( dur) => {
281- if dur. as_secs ( ) == 0 && dur. subsec_nanos ( ) == 0 {
282- return Err ( Error :: new (
283- ErrorKind :: InvalidInput ,
284- "cannot set a zero duration timeout" ,
285- ) ) ;
286- }
287-
288- // https://github.com/rust-lang/libc/issues/1848
289- #[ cfg_attr( target_env = "musl" , allow( deprecated) ) ]
290- let secs = if dur. as_secs ( ) > libc:: time_t:: MAX as u64 {
291- libc:: time_t:: MAX
292- } else {
293- dur. as_secs ( ) as libc:: time_t
294- } ;
295- let mut timeout = timeval {
296- tv_sec : secs,
297- tv_usec : i64:: from ( dur. subsec_micros ( ) ) as suseconds_t ,
298- } ;
299- if timeout. tv_sec == 0 && timeout. tv_usec == 0 {
300- timeout. tv_usec = 1 ;
301- }
302- Ok ( timeout)
303- }
304- None => Ok ( timeval {
305- tv_sec : 0 ,
306- tv_usec : 0 ,
307- } ) ,
308- }
309- }
310473}
311474
312475impl Read for VsockStream {
@@ -333,11 +496,11 @@ impl Read for &VsockStream {
333496
334497impl Write for & VsockStream {
335498 fn write ( & mut self , buf : & [ u8 ] ) -> Result < usize > {
336- # [ cfg ( not ( target_os = "macos" ) ) ]
337- let flags = MsgFlags :: MSG_NOSIGNAL ;
338- # [ cfg ( target_os = "macos" ) ]
339- let flags = MsgFlags :: empty ( ) ;
340- Ok ( send ( self . socket . as_raw_fd ( ) , buf , flags ) ?)
499+ Ok ( send (
500+ self . socket . as_raw_fd ( ) ,
501+ buf ,
502+ default_send_msg_flags ( ) ,
503+ ) ?)
341504 }
342505
343506 fn flush ( & mut self ) -> Result < ( ) > {
0 commit comments