Skip to content

Commit 0efe61e

Browse files
authored
Merge pull request #55 from rust-vsock/jalil/add-sock-dgram-support
feat: implement UdpSocket like VsockSocket structure
2 parents 726f744 + e5722bc commit 0efe61e

File tree

1 file changed

+230
-48
lines changed

1 file changed

+230
-48
lines changed

src/lib.rs

Lines changed: 230 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ use libc::{
2424
use nix::{
2525
ioctl_read_bad,
2626
sys::socket::{
27-
self, bind, connect, getpeername, getsockname, listen, recv, send, shutdown, socket,
27+
self, bind, connect, getpeername, getsockname, listen, recv, recvfrom, send, sendto,
28+
shutdown, socket,
2829
sockopt::{ReceiveTimeout, SendTimeout, SocketError},
2930
AddressFamily, Backlog, GetSockOpt, MsgFlags, SetSockOpt, SockFlag, SockType,
3031
},
@@ -45,12 +46,54 @@ pub use libc::VMADDR_CID_LOCAL;
4546
pub use libc::{VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR};
4647
pub use nix::sys::socket::{SockaddrLike, VsockAddr};
4748

48-
fn new_socket() -> Result<OwnedFd> {
49+
fn new_socket(ty: SockType) -> Result<OwnedFd> {
4950
#[cfg(not(target_os = "macos"))]
5051
let flags = SockFlag::SOCK_CLOEXEC;
5152
#[cfg(target_os = "macos")]
5253
let flags = SockFlag::empty();
53-
Ok(socket(AddressFamily::Vsock, SockType::Stream, flags, None)?)
54+
Ok(socket(AddressFamily::Vsock, ty, flags, None)?)
55+
}
56+
57+
fn default_send_msg_flags() -> MsgFlags {
58+
#[cfg(not(target_os = "macos"))]
59+
let flags = MsgFlags::MSG_NOSIGNAL;
60+
#[cfg(target_os = "macos")]
61+
let flags = MsgFlags::empty();
62+
flags
63+
}
64+
65+
/// Internal helper to turn a [`Duration`] into a [`timeval`]
66+
fn timeval_from_duration(dur: Option<Duration>) -> Result<timeval> {
67+
match dur {
68+
Some(dur) => {
69+
if dur.as_secs() == 0 && dur.subsec_nanos() == 0 {
70+
return Err(Error::new(
71+
ErrorKind::InvalidInput,
72+
"cannot set a zero duration timeout",
73+
));
74+
}
75+
76+
// https://github.com/rust-lang/libc/issues/1848
77+
#[cfg_attr(target_env = "musl", allow(deprecated))]
78+
let secs = if dur.as_secs() > libc::time_t::MAX as u64 {
79+
libc::time_t::MAX
80+
} else {
81+
dur.as_secs() as libc::time_t
82+
};
83+
let mut timeout = timeval {
84+
tv_sec: secs,
85+
tv_usec: i64::from(dur.subsec_micros()) as suseconds_t,
86+
};
87+
if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
88+
timeout.tv_usec = 1;
89+
}
90+
Ok(timeout)
91+
}
92+
None => Ok(timeval {
93+
tv_sec: 0,
94+
tv_usec: 0,
95+
}),
96+
}
5497
}
5598

5699
/// An iterator that infinitely accepts connections on a VsockListener.
@@ -80,7 +123,7 @@ impl VsockListener {
80123
return Err(Error::other("requires a virtio socket address"));
81124
}
82125

83-
let socket = new_socket()?;
126+
let socket = new_socket(SockType::Stream)?;
84127

85128
bind(socket.as_raw_fd(), addr)?;
86129

@@ -142,7 +185,7 @@ impl VsockListener {
142185
}
143186

144187
/// An iterator over the connections being received on this listener.
145-
pub fn incoming(&self) -> Incoming {
188+
pub fn incoming(&self) -> Incoming<'_> {
146189
Incoming { listener: self }
147190
}
148191

@@ -174,7 +217,7 @@ impl AsRawFd for VsockListener {
174217
}
175218

176219
impl AsFd for VsockListener {
177-
fn as_fd(&self) -> BorrowedFd {
220+
fn as_fd(&self) -> BorrowedFd<'_> {
178221
self.socket.as_fd()
179222
}
180223
}
@@ -193,7 +236,179 @@ impl IntoRawFd for VsockListener {
193236
}
194237
}
195238

239+
/// A virtio sequential packet socket between a local and a remote host.
240+
///
241+
/// This is the vsock equivalent of [`std::net::UdpSocket`].
242+
#[derive(Debug)]
243+
pub struct VsockSocket {
244+
socket: OwnedFd,
245+
}
246+
247+
impl VsockSocket {
248+
/// Bind to an address and listen for connections.
249+
///
250+
/// Analogous to [`std::net::UdpSocket::bind`]
251+
pub fn bind<A: SockaddrLike>(addr: &A) -> Result<Self> {
252+
if addr.family() != Some(AddressFamily::Vsock) {
253+
return Err(Error::other("requires a virtio socket address"));
254+
}
255+
256+
let socket = new_socket(SockType::Datagram)?;
257+
258+
bind(socket.as_raw_fd(), addr)?;
259+
260+
Ok(Self { socket })
261+
}
262+
263+
/// Bind to a specified cid and port and listen for connections.
264+
pub fn bind_with_cid_port(cid: u32, port: u32) -> Result<Self> {
265+
Self::bind(&VsockAddr::new(cid, port))
266+
}
267+
268+
/// Receive a message from a remote host.
269+
///
270+
/// Analogous to [`std::net::UdpSocket::recv_from`]
271+
///
272+
/// # Returns
273+
///
274+
/// The number of bytes read and the address of the remote host.
275+
pub fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, VsockAddr)> {
276+
recvfrom(self.socket.as_raw_fd(), buf)
277+
// UNWRAP SAFETY: recvfrom should always return peer address when SockType == SockType::Datagram
278+
.map(|(size, addr)| (size, addr.expect("recv_from didn't return peer address")))
279+
.map_err(nix::Error::into)
280+
}
281+
282+
/// Send a message to a remote host.
283+
///
284+
/// Analogous to [`std::net::UdpSocket::send_to`]
285+
pub fn send_to<A: SockaddrLike>(&self, buf: &[u8], addr: &A) -> Result<usize> {
286+
sendto(self.socket.as_raw_fd(), buf, addr, default_send_msg_flags())
287+
.map_err(nix::Error::into)
288+
}
289+
290+
/// Send a message to a remote host with specified cid and port.
291+
pub fn send_to_with_cid_port(&self, buf: &[u8], cid: u32, port: u32) -> Result<usize> {
292+
self.send_to(buf, &VsockAddr::new(cid, port))
293+
}
294+
295+
/// Virtio socket address of the remote peer associated with this connection.
296+
pub fn peer_addr(&self) -> Result<VsockAddr> {
297+
Ok(getpeername(self.socket.as_raw_fd())?)
298+
}
299+
300+
/// Virtio socket address of the local address associated with this connection.
301+
pub fn local_addr(&self) -> Result<VsockAddr> {
302+
Ok(getsockname(self.socket.as_raw_fd())?)
303+
}
304+
305+
/// Create a new independently owned handle to the underlying socket.
306+
pub fn try_clone(&self) -> Result<Self> {
307+
Ok(Self {
308+
socket: self.socket.try_clone()?,
309+
})
310+
}
311+
312+
/// Set the timeout on read operations.
313+
pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
314+
let timeout = timeval_from_duration(dur)?.into();
315+
Ok(ReceiveTimeout.set(&self.socket, &timeout)?)
316+
}
317+
318+
/// Set the timeout on write operations.
319+
pub fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> {
320+
let timeout = timeval_from_duration(dur)?.into();
321+
Ok(SendTimeout.set(&self.socket, &timeout)?)
322+
}
323+
324+
/// Retrieve the latest error associated with the underlying socket.
325+
pub fn take_error(&self) -> Result<Option<Error>> {
326+
let error = SocketError.get(&self.socket)?;
327+
Ok(if error == 0 {
328+
None
329+
} else {
330+
Some(Error::from_raw_os_error(error))
331+
})
332+
}
333+
334+
/// Open a connection to a remote host (you need to bind to an address with [`Self::bind`]
335+
/// first).
336+
///
337+
/// Allows you to send and receive messages from this host directly through [`Self::send`] and
338+
/// [`Self::recv`].
339+
///
340+
/// Analogous to [`std::net::UdpSocket::connect`]
341+
pub fn connect<A: SockaddrLike>(&self, addr: &A) -> Result<()> {
342+
if addr.family() != Some(AddressFamily::Vsock) {
343+
return Err(Error::other("requires a virtio socket address"));
344+
}
345+
346+
connect(self.socket.as_raw_fd(), addr).map_err(nix::Error::into)
347+
}
348+
349+
/// Open a connection to a remote host with specified cid and port (you need to bind to an
350+
/// address with [`Self::bind`] first).
351+
///
352+
/// Allows you to send and receive messages from this host directly through [`Self::send`] and
353+
/// [`Self::recv`].
354+
pub fn connect_with_cid_port(&self, cid: u32, port: u32) -> Result<()> {
355+
self.connect(&VsockAddr::new(cid, port))
356+
}
357+
358+
/// Send data to the connected remote host.
359+
///
360+
/// Analogous to [`std::net::UdpSocket::send`]
361+
pub fn send(&self, buf: &[u8]) -> Result<usize> {
362+
send(self.socket.as_raw_fd(), buf, default_send_msg_flags()).map_err(nix::Error::into)
363+
}
364+
365+
/// Receive data from the connected remote host.
366+
///
367+
/// Analogous to [`std::net::UdpSocket::recv`]
368+
pub fn recv(&self, buf: &mut [u8]) -> Result<usize> {
369+
recv(self.socket.as_raw_fd(), buf, MsgFlags::empty()).map_err(nix::Error::into)
370+
}
371+
372+
/// Move this stream in and out of nonblocking mode.
373+
pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
374+
let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
375+
if unsafe { ioctl(self.socket.as_raw_fd(), FIONBIO, &mut nonblocking) } < 0 {
376+
Err(Error::last_os_error())
377+
} else {
378+
Ok(())
379+
}
380+
}
381+
}
382+
383+
impl AsFd for VsockSocket {
384+
fn as_fd(&self) -> BorrowedFd<'_> {
385+
self.socket.as_fd()
386+
}
387+
}
388+
389+
impl AsRawFd for VsockSocket {
390+
fn as_raw_fd(&self) -> RawFd {
391+
self.socket.as_raw_fd()
392+
}
393+
}
394+
395+
impl FromRawFd for VsockSocket {
396+
unsafe fn from_raw_fd(fd: RawFd) -> Self {
397+
Self {
398+
socket: OwnedFd::from_raw_fd(fd),
399+
}
400+
}
401+
}
402+
403+
impl IntoRawFd for VsockSocket {
404+
fn into_raw_fd(self) -> RawFd {
405+
self.socket.into_raw_fd()
406+
}
407+
}
408+
196409
/// A virtio stream between a local and a remote socket.
410+
///
411+
/// This is the vsock equivalent of [`std::net::TcpStream`].
197412
#[derive(Debug)]
198413
pub struct VsockStream {
199414
socket: OwnedFd,
@@ -206,7 +421,7 @@ impl VsockStream {
206421
return Err(Error::other("requires a virtio socket address"));
207422
}
208423

209-
let socket = new_socket()?;
424+
let socket = new_socket(SockType::Stream)?;
210425
connect(socket.as_raw_fd(), addr)?;
211426
Ok(Self { socket })
212427
}
@@ -245,13 +460,13 @@ impl VsockStream {
245460

246461
/// Set the timeout on read operations.
247462
pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
248-
let timeout = Self::timeval_from_duration(dur)?.into();
463+
let timeout = timeval_from_duration(dur)?.into();
249464
Ok(ReceiveTimeout.set(&self.socket, &timeout)?)
250465
}
251466

252467
/// Set the timeout on write operations.
253468
pub fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> {
254-
let timeout = Self::timeval_from_duration(dur)?.into();
469+
let timeout = timeval_from_duration(dur)?.into();
255470
Ok(SendTimeout.set(&self.socket, &timeout)?)
256471
}
257472

@@ -274,39 +489,6 @@ impl VsockStream {
274489
Ok(())
275490
}
276491
}
277-
278-
fn timeval_from_duration(dur: Option<Duration>) -> Result<timeval> {
279-
match dur {
280-
Some(dur) => {
281-
if dur.as_secs() == 0 && dur.subsec_nanos() == 0 {
282-
return Err(Error::new(
283-
ErrorKind::InvalidInput,
284-
"cannot set a zero duration timeout",
285-
));
286-
}
287-
288-
// https://github.com/rust-lang/libc/issues/1848
289-
#[cfg_attr(target_env = "musl", allow(deprecated))]
290-
let secs = if dur.as_secs() > libc::time_t::MAX as u64 {
291-
libc::time_t::MAX
292-
} else {
293-
dur.as_secs() as libc::time_t
294-
};
295-
let mut timeout = timeval {
296-
tv_sec: secs,
297-
tv_usec: i64::from(dur.subsec_micros()) as suseconds_t,
298-
};
299-
if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
300-
timeout.tv_usec = 1;
301-
}
302-
Ok(timeout)
303-
}
304-
None => Ok(timeval {
305-
tv_sec: 0,
306-
tv_usec: 0,
307-
}),
308-
}
309-
}
310492
}
311493

312494
impl Read for VsockStream {
@@ -333,11 +515,11 @@ impl Read for &VsockStream {
333515

334516
impl Write for &VsockStream {
335517
fn write(&mut self, buf: &[u8]) -> Result<usize> {
336-
#[cfg(not(target_os = "macos"))]
337-
let flags = MsgFlags::MSG_NOSIGNAL;
338-
#[cfg(target_os = "macos")]
339-
let flags = MsgFlags::empty();
340-
Ok(send(self.socket.as_raw_fd(), buf, flags)?)
518+
Ok(send(
519+
self.socket.as_raw_fd(),
520+
buf,
521+
default_send_msg_flags(),
522+
)?)
341523
}
342524

343525
fn flush(&mut self) -> Result<()> {
@@ -352,7 +534,7 @@ impl AsRawFd for VsockStream {
352534
}
353535

354536
impl AsFd for VsockStream {
355-
fn as_fd(&self) -> BorrowedFd {
537+
fn as_fd(&self) -> BorrowedFd<'_> {
356538
self.socket.as_fd()
357539
}
358540
}

0 commit comments

Comments
 (0)