Skip to content

Commit 6d11354

Browse files
committed
feat: implement UdpSocket like VsockSocket structure
This implementation mirrors the stdlib's `UdpSocket` implementation as closely as possible. I skipped some methods to keep the implementation simple, they can be added later if needed. This is ground work for implementing the feature requested in `tokio-vsock`: rust-vsock/tokio-vsock#67
1 parent 726f744 commit 6d11354

File tree

1 file changed

+207
-44
lines changed

1 file changed

+207
-44
lines changed

src/lib.rs

Lines changed: 207 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,54 @@ pub use libc::VMADDR_CID_LOCAL;
4545
pub use libc::{VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR};
4646
pub use nix::sys::socket::{SockaddrLike, VsockAddr};
4747

48-
fn new_socket() -> Result<OwnedFd> {
48+
fn new_socket(ty: SockType) -> Result<OwnedFd> {
4949
#[cfg(not(target_os = "macos"))]
5050
let flags = SockFlag::SOCK_CLOEXEC;
5151
#[cfg(target_os = "macos")]
5252
let flags = SockFlag::empty();
53-
Ok(socket(AddressFamily::Vsock, SockType::Stream, flags, None)?)
53+
Ok(socket(AddressFamily::Vsock, ty, flags, None)?)
54+
}
55+
56+
fn default_send_msg_flags() -> MsgFlags {
57+
#[cfg(not(target_os = "macos"))]
58+
let flags = MsgFlags::MSG_NOSIGNAL;
59+
#[cfg(target_os = "macos")]
60+
let flags = MsgFlags::empty();
61+
flags
62+
}
63+
64+
/// Internal helper to turn a [`Duration`] into a [`timeval`]
65+
fn timeval_from_duration(dur: Option<Duration>) -> Result<timeval> {
66+
match dur {
67+
Some(dur) => {
68+
if dur.as_secs() == 0 && dur.subsec_nanos() == 0 {
69+
return Err(Error::new(
70+
ErrorKind::InvalidInput,
71+
"cannot set a zero duration timeout",
72+
));
73+
}
74+
75+
// https://github.com/rust-lang/libc/issues/1848
76+
#[cfg_attr(target_env = "musl", allow(deprecated))]
77+
let secs = if dur.as_secs() > libc::time_t::MAX as u64 {
78+
libc::time_t::MAX
79+
} else {
80+
dur.as_secs() as libc::time_t
81+
};
82+
let mut timeout = timeval {
83+
tv_sec: secs,
84+
tv_usec: i64::from(dur.subsec_micros()) as suseconds_t,
85+
};
86+
if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
87+
timeout.tv_usec = 1;
88+
}
89+
Ok(timeout)
90+
}
91+
None => Ok(timeval {
92+
tv_sec: 0,
93+
tv_usec: 0,
94+
}),
95+
}
5496
}
5597

5698
/// An iterator that infinitely accepts connections on a VsockListener.
@@ -80,7 +122,7 @@ impl VsockListener {
80122
return Err(Error::other("requires a virtio socket address"));
81123
}
82124

83-
let socket = new_socket()?;
125+
let socket = new_socket(SockType::Stream)?;
84126

85127
bind(socket.as_raw_fd(), addr)?;
86128

@@ -194,6 +236,160 @@ impl IntoRawFd for VsockListener {
194236
}
195237

196238
/// A virtio stream between a local and a remote socket.
239+
///
240+
/// This is the vsock equivalent of [`std::net::UdpSocket`].
241+
#[derive(Debug)]
242+
pub struct VsockSocket {
243+
socket: OwnedFd,
244+
}
245+
246+
impl VsockSocket {
247+
/// Bind to an address and listen for connections.
248+
pub fn bind<A: SockaddrLike>(addr: &A) -> Result<Self> {
249+
if addr.family() != Some(AddressFamily::Vsock) {
250+
return Err(Error::other("requires a virtio socket address"));
251+
}
252+
253+
let socket = new_socket(SockType::Datagram)?;
254+
255+
bind(socket.as_raw_fd(), addr)?;
256+
257+
Ok(Self { socket })
258+
}
259+
260+
/// Bind to a remote host with specified cid and port.
261+
pub fn bind_with_cid_port(&self, cid: u32, port: u32) -> Result<Self> {
262+
Self::bind(&VsockAddr::new(cid, port))
263+
}
264+
265+
/// Receive a message from a remote host.
266+
///
267+
/// # Returns
268+
///
269+
/// The number of bytes read and the address of the remote host.
270+
pub fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, VsockAddr)> {
271+
nix::sys::socket::recvfrom(self.socket.as_raw_fd(), buf)
272+
// UNWRAP SAFETY: recvfrom should always return peer address when SockType == SockType::Datagram
273+
.map(|(size, addr)| (size, addr.expect("recv_from didn't return peer address")))
274+
.map_err(nix::Error::into)
275+
}
276+
277+
/// Send a message to a remote host.
278+
pub fn send_to<A: SockaddrLike>(&self, buf: &[u8], addr: &A) -> Result<usize> {
279+
nix::sys::socket::sendto(self.socket.as_raw_fd(), buf, addr, default_send_msg_flags())
280+
.map_err(nix::Error::into)
281+
}
282+
283+
/// Send a message to a remote host with specified cid and port.
284+
pub fn send_to_with_cid_port(&self, buf: &[u8], cid: u32, port: u32) -> Result<usize> {
285+
self.send_to(buf, &VsockAddr::new(cid, port))
286+
}
287+
288+
/// Virtio socket address of the remote peer associated with this connection.
289+
pub fn peer_addr(&self) -> Result<VsockAddr> {
290+
Ok(getpeername(self.socket.as_raw_fd())?)
291+
}
292+
293+
/// Virtio socket address of the local address associated with this connection.
294+
pub fn local_addr(&self) -> Result<VsockAddr> {
295+
Ok(getsockname(self.socket.as_raw_fd())?)
296+
}
297+
298+
/// Create a new independently owned handle to the underlying socket.
299+
pub fn try_clone(&self) -> Result<Self> {
300+
Ok(Self {
301+
socket: self.socket.try_clone()?,
302+
})
303+
}
304+
305+
/// Set the timeout on read operations.
306+
pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
307+
let timeout = timeval_from_duration(dur)?.into();
308+
Ok(ReceiveTimeout.set(&self.socket, &timeout)?)
309+
}
310+
311+
/// Set the timeout on write operations.
312+
pub fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> {
313+
let timeout = timeval_from_duration(dur)?.into();
314+
Ok(SendTimeout.set(&self.socket, &timeout)?)
315+
}
316+
317+
/// Retrieve the latest error associated with the underlying socket.
318+
pub fn take_error(&self) -> Result<Option<Error>> {
319+
let error = SocketError.get(&self.socket)?;
320+
Ok(if error == 0 {
321+
None
322+
} else {
323+
Some(Error::from_raw_os_error(error))
324+
})
325+
}
326+
327+
/// Open a connection to a remote host.
328+
pub fn connect<A: SockaddrLike>(&self, addr: &A) -> Result<()> {
329+
if addr.family() != Some(AddressFamily::Vsock) {
330+
return Err(Error::other("requires a virtio socket address"));
331+
}
332+
333+
connect(self.socket.as_raw_fd(), addr).map_err(nix::Error::into)
334+
}
335+
336+
/// Open a connection to a remote host with specified cid and port.
337+
pub fn connect_with_cid_port(&self, cid: u32, port: u32) -> Result<()> {
338+
self.connect(&VsockAddr::new(cid, port))
339+
}
340+
341+
/// Send data to the connected remote host.
342+
pub fn send(&self, buf: &[u8]) -> Result<usize> {
343+
nix::sys::socket::send(self.socket.as_raw_fd(), buf, default_send_msg_flags())
344+
.map_err(nix::Error::into)
345+
}
346+
347+
/// Receive data from the connected remote host.
348+
pub fn recv(&self, buf: &mut [u8]) -> Result<usize> {
349+
nix::sys::socket::recv(self.socket.as_raw_fd(), buf, MsgFlags::empty())
350+
.map_err(nix::Error::into)
351+
}
352+
353+
/// Move this stream in and out of nonblocking mode.
354+
pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
355+
let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
356+
if unsafe { ioctl(self.socket.as_raw_fd(), FIONBIO, &mut nonblocking) } < 0 {
357+
Err(Error::last_os_error())
358+
} else {
359+
Ok(())
360+
}
361+
}
362+
}
363+
364+
impl AsFd for VsockSocket {
365+
fn as_fd(&self) -> BorrowedFd<'_> {
366+
self.socket.as_fd()
367+
}
368+
}
369+
370+
impl AsRawFd for VsockSocket {
371+
fn as_raw_fd(&self) -> RawFd {
372+
self.socket.as_raw_fd()
373+
}
374+
}
375+
376+
impl FromRawFd for VsockSocket {
377+
unsafe fn from_raw_fd(fd: RawFd) -> Self {
378+
Self {
379+
socket: OwnedFd::from_raw_fd(fd),
380+
}
381+
}
382+
}
383+
384+
impl IntoRawFd for VsockSocket {
385+
fn into_raw_fd(self) -> RawFd {
386+
self.socket.into_raw_fd()
387+
}
388+
}
389+
390+
/// A virtio stream between a local and a remote socket.
391+
///
392+
/// This is the vsock equivalent of [`std::net::TcpStream`].
197393
#[derive(Debug)]
198394
pub struct VsockStream {
199395
socket: OwnedFd,
@@ -206,7 +402,7 @@ impl VsockStream {
206402
return Err(Error::other("requires a virtio socket address"));
207403
}
208404

209-
let socket = new_socket()?;
405+
let socket = new_socket(SockType::Stream)?;
210406
connect(socket.as_raw_fd(), addr)?;
211407
Ok(Self { socket })
212408
}
@@ -245,13 +441,13 @@ impl VsockStream {
245441

246442
/// Set the timeout on read operations.
247443
pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
248-
let timeout = Self::timeval_from_duration(dur)?.into();
444+
let timeout = timeval_from_duration(dur)?.into();
249445
Ok(ReceiveTimeout.set(&self.socket, &timeout)?)
250446
}
251447

252448
/// Set the timeout on write operations.
253449
pub fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> {
254-
let timeout = Self::timeval_from_duration(dur)?.into();
450+
let timeout = timeval_from_duration(dur)?.into();
255451
Ok(SendTimeout.set(&self.socket, &timeout)?)
256452
}
257453

@@ -274,39 +470,6 @@ impl VsockStream {
274470
Ok(())
275471
}
276472
}
277-
278-
fn timeval_from_duration(dur: Option<Duration>) -> Result<timeval> {
279-
match dur {
280-
Some(dur) => {
281-
if dur.as_secs() == 0 && dur.subsec_nanos() == 0 {
282-
return Err(Error::new(
283-
ErrorKind::InvalidInput,
284-
"cannot set a zero duration timeout",
285-
));
286-
}
287-
288-
// https://github.com/rust-lang/libc/issues/1848
289-
#[cfg_attr(target_env = "musl", allow(deprecated))]
290-
let secs = if dur.as_secs() > libc::time_t::MAX as u64 {
291-
libc::time_t::MAX
292-
} else {
293-
dur.as_secs() as libc::time_t
294-
};
295-
let mut timeout = timeval {
296-
tv_sec: secs,
297-
tv_usec: i64::from(dur.subsec_micros()) as suseconds_t,
298-
};
299-
if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
300-
timeout.tv_usec = 1;
301-
}
302-
Ok(timeout)
303-
}
304-
None => Ok(timeval {
305-
tv_sec: 0,
306-
tv_usec: 0,
307-
}),
308-
}
309-
}
310473
}
311474

312475
impl Read for VsockStream {
@@ -333,11 +496,11 @@ impl Read for &VsockStream {
333496

334497
impl Write for &VsockStream {
335498
fn write(&mut self, buf: &[u8]) -> Result<usize> {
336-
#[cfg(not(target_os = "macos"))]
337-
let flags = MsgFlags::MSG_NOSIGNAL;
338-
#[cfg(target_os = "macos")]
339-
let flags = MsgFlags::empty();
340-
Ok(send(self.socket.as_raw_fd(), buf, flags)?)
499+
Ok(send(
500+
self.socket.as_raw_fd(),
501+
buf,
502+
default_send_msg_flags(),
503+
)?)
341504
}
342505

343506
fn flush(&mut self) -> Result<()> {

0 commit comments

Comments
 (0)