Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 230 additions & 48 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ use libc::{
use nix::{
ioctl_read_bad,
sys::socket::{
self, bind, connect, getpeername, getsockname, listen, recv, send, shutdown, socket,
self, bind, connect, getpeername, getsockname, listen, recv, recvfrom, send, sendto,
shutdown, socket,
sockopt::{ReceiveTimeout, SendTimeout, SocketError},
AddressFamily, Backlog, GetSockOpt, MsgFlags, SetSockOpt, SockFlag, SockType,
},
Expand All @@ -45,12 +46,54 @@ pub use libc::VMADDR_CID_LOCAL;
pub use libc::{VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR};
pub use nix::sys::socket::{SockaddrLike, VsockAddr};

fn new_socket() -> Result<OwnedFd> {
fn new_socket(ty: SockType) -> Result<OwnedFd> {
#[cfg(not(target_os = "macos"))]
let flags = SockFlag::SOCK_CLOEXEC;
#[cfg(target_os = "macos")]
let flags = SockFlag::empty();
Ok(socket(AddressFamily::Vsock, SockType::Stream, flags, None)?)
Ok(socket(AddressFamily::Vsock, ty, flags, None)?)
}

fn default_send_msg_flags() -> MsgFlags {
#[cfg(not(target_os = "macos"))]
let flags = MsgFlags::MSG_NOSIGNAL;
#[cfg(target_os = "macos")]
let flags = MsgFlags::empty();
flags
}

/// Internal helper to turn a [`Duration`] into a [`timeval`]
fn timeval_from_duration(dur: Option<Duration>) -> Result<timeval> {
match dur {
Some(dur) => {
if dur.as_secs() == 0 && dur.subsec_nanos() == 0 {
return Err(Error::new(
ErrorKind::InvalidInput,
"cannot set a zero duration timeout",
));
}

// https://github.com/rust-lang/libc/issues/1848
#[cfg_attr(target_env = "musl", allow(deprecated))]
let secs = if dur.as_secs() > libc::time_t::MAX as u64 {
libc::time_t::MAX
} else {
dur.as_secs() as libc::time_t
};
let mut timeout = timeval {
tv_sec: secs,
tv_usec: i64::from(dur.subsec_micros()) as suseconds_t,
};
if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
timeout.tv_usec = 1;
}
Ok(timeout)
}
None => Ok(timeval {
tv_sec: 0,
tv_usec: 0,
}),
}
}

/// An iterator that infinitely accepts connections on a VsockListener.
Expand Down Expand Up @@ -80,7 +123,7 @@ impl VsockListener {
return Err(Error::other("requires a virtio socket address"));
}

let socket = new_socket()?;
let socket = new_socket(SockType::Stream)?;

bind(socket.as_raw_fd(), addr)?;

Expand Down Expand Up @@ -142,7 +185,7 @@ impl VsockListener {
}

/// An iterator over the connections being received on this listener.
pub fn incoming(&self) -> Incoming {
pub fn incoming(&self) -> Incoming<'_> {
Incoming { listener: self }
}

Expand Down Expand Up @@ -174,7 +217,7 @@ impl AsRawFd for VsockListener {
}

impl AsFd for VsockListener {
fn as_fd(&self) -> BorrowedFd {
fn as_fd(&self) -> BorrowedFd<'_> {
self.socket.as_fd()
}
}
Expand All @@ -193,7 +236,179 @@ impl IntoRawFd for VsockListener {
}
}

/// A virtio sequential packet socket between a local and a remote host.
///
/// This is the vsock equivalent of [`std::net::UdpSocket`].
#[derive(Debug)]
pub struct VsockSocket {
socket: OwnedFd,
}

impl VsockSocket {
/// Bind to an address and listen for connections.
///
/// Analogous to [`std::net::UdpSocket::bind`]
pub fn bind<A: SockaddrLike>(addr: &A) -> Result<Self> {
if addr.family() != Some(AddressFamily::Vsock) {
return Err(Error::other("requires a virtio socket address"));
}

let socket = new_socket(SockType::Datagram)?;

bind(socket.as_raw_fd(), addr)?;

Ok(Self { socket })
}

/// Bind to a specified cid and port and listen for connections.
pub fn bind_with_cid_port(cid: u32, port: u32) -> Result<Self> {
Self::bind(&VsockAddr::new(cid, port))
}

/// Receive a message from a remote host.
///
/// Analogous to [`std::net::UdpSocket::recv_from`]
///
/// # Returns
///
/// The number of bytes read and the address of the remote host.
pub fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, VsockAddr)> {
recvfrom(self.socket.as_raw_fd(), buf)
// UNWRAP SAFETY: recvfrom should always return peer address when SockType == SockType::Datagram
.map(|(size, addr)| (size, addr.expect("recv_from didn't return peer address")))
.map_err(nix::Error::into)
}

/// Send a message to a remote host.
///
/// Analogous to [`std::net::UdpSocket::send_to`]
pub fn send_to<A: SockaddrLike>(&self, buf: &[u8], addr: &A) -> Result<usize> {
sendto(self.socket.as_raw_fd(), buf, addr, default_send_msg_flags())
.map_err(nix::Error::into)
}

/// Send a message to a remote host with specified cid and port.
pub fn send_to_with_cid_port(&self, buf: &[u8], cid: u32, port: u32) -> Result<usize> {
self.send_to(buf, &VsockAddr::new(cid, port))
}

/// Virtio socket address of the remote peer associated with this connection.
pub fn peer_addr(&self) -> Result<VsockAddr> {
Ok(getpeername(self.socket.as_raw_fd())?)
}

/// Virtio socket address of the local address associated with this connection.
pub fn local_addr(&self) -> Result<VsockAddr> {
Ok(getsockname(self.socket.as_raw_fd())?)
}

/// Create a new independently owned handle to the underlying socket.
pub fn try_clone(&self) -> Result<Self> {
Ok(Self {
socket: self.socket.try_clone()?,
})
}

/// Set the timeout on read operations.
pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
let timeout = timeval_from_duration(dur)?.into();
Ok(ReceiveTimeout.set(&self.socket, &timeout)?)
}

/// Set the timeout on write operations.
pub fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> {
let timeout = timeval_from_duration(dur)?.into();
Ok(SendTimeout.set(&self.socket, &timeout)?)
}

/// Retrieve the latest error associated with the underlying socket.
pub fn take_error(&self) -> Result<Option<Error>> {
let error = SocketError.get(&self.socket)?;
Ok(if error == 0 {
None
} else {
Some(Error::from_raw_os_error(error))
})
}

/// Open a connection to a remote host (you need to bind to an address with [`Self::bind`]
/// first).
///
/// Allows you to send and receive messages from this host directly through [`Self::send`] and
/// [`Self::recv`].
///
/// Analogous to [`std::net::UdpSocket::connect`]
pub fn connect<A: SockaddrLike>(&self, addr: &A) -> Result<()> {
if addr.family() != Some(AddressFamily::Vsock) {
return Err(Error::other("requires a virtio socket address"));
}

connect(self.socket.as_raw_fd(), addr).map_err(nix::Error::into)
}

/// Open a connection to a remote host with specified cid and port (you need to bind to an
/// address with [`Self::bind`] first).
///
/// Allows you to send and receive messages from this host directly through [`Self::send`] and
/// [`Self::recv`].
pub fn connect_with_cid_port(&self, cid: u32, port: u32) -> Result<()> {
self.connect(&VsockAddr::new(cid, port))
}

/// Send data to the connected remote host.
///
/// Analogous to [`std::net::UdpSocket::send`]
pub fn send(&self, buf: &[u8]) -> Result<usize> {
send(self.socket.as_raw_fd(), buf, default_send_msg_flags()).map_err(nix::Error::into)
}

/// Receive data from the connected remote host.
///
/// Analogous to [`std::net::UdpSocket::recv`]
pub fn recv(&self, buf: &mut [u8]) -> Result<usize> {
recv(self.socket.as_raw_fd(), buf, MsgFlags::empty()).map_err(nix::Error::into)
}

/// Move this stream in and out of nonblocking mode.
pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
if unsafe { ioctl(self.socket.as_raw_fd(), FIONBIO, &mut nonblocking) } < 0 {
Err(Error::last_os_error())
} else {
Ok(())
}
}
}

impl AsFd for VsockSocket {
fn as_fd(&self) -> BorrowedFd<'_> {
self.socket.as_fd()
}
}

impl AsRawFd for VsockSocket {
fn as_raw_fd(&self) -> RawFd {
self.socket.as_raw_fd()
}
}

impl FromRawFd for VsockSocket {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
Self {
socket: OwnedFd::from_raw_fd(fd),
}
}
}

impl IntoRawFd for VsockSocket {
fn into_raw_fd(self) -> RawFd {
self.socket.into_raw_fd()
}
}

/// A virtio stream between a local and a remote socket.
///
/// This is the vsock equivalent of [`std::net::TcpStream`].
#[derive(Debug)]
pub struct VsockStream {
socket: OwnedFd,
Expand All @@ -206,7 +421,7 @@ impl VsockStream {
return Err(Error::other("requires a virtio socket address"));
}

let socket = new_socket()?;
let socket = new_socket(SockType::Stream)?;
connect(socket.as_raw_fd(), addr)?;
Ok(Self { socket })
}
Expand Down Expand Up @@ -245,13 +460,13 @@ impl VsockStream {

/// Set the timeout on read operations.
pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
let timeout = Self::timeval_from_duration(dur)?.into();
let timeout = timeval_from_duration(dur)?.into();
Ok(ReceiveTimeout.set(&self.socket, &timeout)?)
}

/// Set the timeout on write operations.
pub fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> {
let timeout = Self::timeval_from_duration(dur)?.into();
let timeout = timeval_from_duration(dur)?.into();
Ok(SendTimeout.set(&self.socket, &timeout)?)
}

Expand All @@ -274,39 +489,6 @@ impl VsockStream {
Ok(())
}
}

fn timeval_from_duration(dur: Option<Duration>) -> Result<timeval> {
match dur {
Some(dur) => {
if dur.as_secs() == 0 && dur.subsec_nanos() == 0 {
return Err(Error::new(
ErrorKind::InvalidInput,
"cannot set a zero duration timeout",
));
}

// https://github.com/rust-lang/libc/issues/1848
#[cfg_attr(target_env = "musl", allow(deprecated))]
let secs = if dur.as_secs() > libc::time_t::MAX as u64 {
libc::time_t::MAX
} else {
dur.as_secs() as libc::time_t
};
let mut timeout = timeval {
tv_sec: secs,
tv_usec: i64::from(dur.subsec_micros()) as suseconds_t,
};
if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
timeout.tv_usec = 1;
}
Ok(timeout)
}
None => Ok(timeval {
tv_sec: 0,
tv_usec: 0,
}),
}
}
}

impl Read for VsockStream {
Expand All @@ -333,11 +515,11 @@ impl Read for &VsockStream {

impl Write for &VsockStream {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
#[cfg(not(target_os = "macos"))]
let flags = MsgFlags::MSG_NOSIGNAL;
#[cfg(target_os = "macos")]
let flags = MsgFlags::empty();
Ok(send(self.socket.as_raw_fd(), buf, flags)?)
Ok(send(
self.socket.as_raw_fd(),
buf,
default_send_msg_flags(),
)?)
}

fn flush(&mut self) -> Result<()> {
Expand All @@ -352,7 +534,7 @@ impl AsRawFd for VsockStream {
}

impl AsFd for VsockStream {
fn as_fd(&self) -> BorrowedFd {
fn as_fd(&self) -> BorrowedFd<'_> {
self.socket.as_fd()
}
}
Expand Down