diff --git a/src/lib.rs b/src/lib.rs index 9bd29448c..7bf157ffd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -50,6 +50,8 @@ mod macros; mod interest; mod poll; mod sys; +#[cfg(all(windows, feature = "net"))] +pub use sys::uds; mod token; #[cfg(not(target_os = "wasi"))] mod waker; diff --git a/src/net/mod.rs b/src/net/mod.rs index 15c405cf9..be0439dc0 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -33,7 +33,9 @@ mod udp; #[cfg(not(target_os = "wasi"))] pub use self::udp::UdpSocket; -#[cfg(unix)] +#[cfg(any(unix,windows))] mod uds; +#[cfg(any(unix,windows))] +pub use self::uds::{UnixListener, UnixStream}; #[cfg(unix)] -pub use self::uds::{UnixDatagram, UnixListener, UnixStream}; +pub use self::uds::UnixDatagram; diff --git a/src/net/uds/datagram.rs b/src/net/uds/datagram.rs index 73fea0731..5b61558dc 100644 --- a/src/net/uds/datagram.rs +++ b/src/net/uds/datagram.rs @@ -1,6 +1,6 @@ +#![cfg(unix)] use std::net::Shutdown; -use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, OwnedFd, RawFd}; -use std::os::unix::net::{self, SocketAddr}; +use std::os::{fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, OwnedFd, RawFd},unix::net::{self, SocketAddr}}; use std::path::Path; use std::{fmt, io}; diff --git a/src/net/uds/listener.rs b/src/net/uds/listener.rs index a255972a5..bffc63cb3 100644 --- a/src/net/uds/listener.rs +++ b/src/net/uds/listener.rs @@ -1,5 +1,12 @@ -use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, OwnedFd, RawFd}; -use std::os::unix::net::{self, SocketAddr}; +#[cfg(windows)] +use crate::sys::uds::{net, SocketAddr}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; +#[cfg(unix)] +use std::os::{ + fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, OwnedFd, RawFd}, + unix::net::{self, SocketAddr}, +}; use std::path::Path; use std::{fmt, io}; @@ -53,6 +60,10 @@ impl UnixListener { pub fn take_error(&self) -> io::Result> { self.inner.take_error() } + /// Sets the non-blocking mode for this socket + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.inner.set_nonblocking(nonblocking) + } } impl event::Source for UnixListener { @@ -84,19 +95,19 @@ impl fmt::Debug for UnixListener { self.inner.fmt(f) } } - +#[cfg(unix)] impl IntoRawFd for UnixListener { fn into_raw_fd(self) -> RawFd { self.inner.into_inner().into_raw_fd() } } - +#[cfg(unix)] impl AsRawFd for UnixListener { fn as_raw_fd(&self) -> RawFd { self.inner.as_raw_fd() } } - +#[cfg(unix)] impl FromRawFd for UnixListener { /// Converts a `RawFd` to a `UnixListener`. /// @@ -109,6 +120,7 @@ impl FromRawFd for UnixListener { } } +#[cfg(unix)] impl From for net::UnixListener { fn from(listener: UnixListener) -> Self { // Safety: This is safe since we are extracting the raw fd from a well-constructed @@ -117,21 +129,39 @@ impl From for net::UnixListener { unsafe { net::UnixListener::from_raw_fd(listener.into_raw_fd()) } } } - +#[cfg(unix)] impl From for OwnedFd { fn from(unix_listener: UnixListener) -> Self { unix_listener.inner.into_inner().into() } } - +#[cfg(unix)] impl AsFd for UnixListener { fn as_fd(&self) -> BorrowedFd<'_> { self.inner.as_fd() } } - +#[cfg(unix)] impl From for UnixListener { fn from(fd: OwnedFd) -> Self { UnixListener::from_std(From::from(fd)) } } +#[cfg(windows)] +impl AsRawSocket for UnixListener { + fn as_raw_socket(&self) -> RawSocket { + self.inner.as_raw_socket() + } +} +#[cfg(windows)] +impl FromRawSocket for UnixListener { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixListener::from_std(FromRawSocket::from_raw_socket(sock)) + } +} +#[cfg(windows)] +impl IntoRawSocket for UnixListener { + fn into_raw_socket(self) -> RawSocket { + self.inner.into_inner().into_raw_socket() + } +} diff --git a/src/net/uds/mod.rs b/src/net/uds/mod.rs index e02fd80dc..abbab5e14 100644 --- a/src/net/uds/mod.rs +++ b/src/net/uds/mod.rs @@ -1,8 +1,11 @@ mod datagram; +#[cfg(unix)] pub use self::datagram::UnixDatagram; mod listener; +#[cfg(any(unix,windows))] pub use self::listener::UnixListener; mod stream; +#[cfg(any(unix,windows))] pub use self::stream::UnixStream; diff --git a/src/net/uds/stream.rs b/src/net/uds/stream.rs index 244f40455..256c4af91 100644 --- a/src/net/uds/stream.rs +++ b/src/net/uds/stream.rs @@ -1,8 +1,16 @@ +#[cfg(windows)] +use crate::sys::uds::{net, SocketAddr}; use std::fmt; use std::io::{self, IoSlice, IoSliceMut, Read, Write}; use std::net::Shutdown; -use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, OwnedFd, RawFd}; -use std::os::unix::net::{self, SocketAddr}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; +#[cfg(unix)] +use std::os::{ + fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, OwnedFd, RawFd}, + unix::net::{self, SocketAddr}, +}; + use std::path::Path; use crate::io_source::IoSource; @@ -52,6 +60,7 @@ impl UnixStream { /// Creates an unnamed pair of connected sockets. /// /// Returns two `UnixStream`s which are connected to each other. + #[cfg(unix)] pub fn pair() -> io::Result<(UnixStream, UnixStream)> { sys::uds::stream::pair().map(|(stream1, stream2)| { (UnixStream::from_std(stream1), UnixStream::from_std(stream2)) @@ -149,6 +158,10 @@ impl UnixStream { { self.inner.do_io(|_| f()) } + /// Sets the non-blocking mode for this socket + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.inner.set_nonblocking(nonblocking) + } } impl Read for UnixStream { @@ -228,18 +241,20 @@ impl fmt::Debug for UnixStream { self.inner.fmt(f) } } - +#[cfg(unix)] impl IntoRawFd for UnixStream { fn into_raw_fd(self) -> RawFd { self.inner.into_inner().into_raw_fd() } } +#[cfg(unix)] impl AsRawFd for UnixStream { fn as_raw_fd(&self) -> RawFd { self.inner.as_raw_fd() } } +#[cfg(unix)] impl FromRawFd for UnixStream { /// Converts a `RawFd` to a `UnixStream`. @@ -252,6 +267,7 @@ impl FromRawFd for UnixStream { UnixStream::from_std(FromRawFd::from_raw_fd(fd)) } } +#[cfg(unix)] impl From for net::UnixStream { fn from(stream: UnixStream) -> Self { @@ -261,21 +277,42 @@ impl From for net::UnixStream { unsafe { net::UnixStream::from_raw_fd(stream.into_raw_fd()) } } } +#[cfg(unix)] impl From for OwnedFd { fn from(unix_stream: UnixStream) -> Self { unix_stream.inner.into_inner().into() } } +#[cfg(unix)] impl AsFd for UnixStream { fn as_fd(&self) -> BorrowedFd<'_> { self.inner.as_fd() } } +#[cfg(unix)] impl From for UnixStream { fn from(fd: OwnedFd) -> Self { UnixStream::from_std(From::from(fd)) } } +#[cfg(windows)] +impl AsRawSocket for UnixStream { + fn as_raw_socket(&self) -> RawSocket { + self.inner.as_raw_socket() + } +} +#[cfg(windows)] +impl FromRawSocket for UnixStream { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixStream::from_std(FromRawSocket::from_raw_socket(sock)) + } +} +#[cfg(windows)] +impl IntoRawSocket for UnixStream { + fn into_raw_socket(self) -> RawSocket { + self.inner.into_inner().into_raw_socket() + } +} diff --git a/src/sys/windows/event.rs b/src/sys/windows/event.rs index 66656d0e5..d769ea6c4 100644 --- a/src/sys/windows/event.rs +++ b/src/sys/windows/event.rs @@ -4,7 +4,7 @@ use super::afd; use super::iocp::CompletionStatus; use crate::Token; -#[derive(Clone)] +#[derive(Clone,Debug)] pub struct Event { pub flags: u32, pub data: u64, diff --git a/src/sys/windows/mod.rs b/src/sys/windows/mod.rs index 89d74b1a2..2f13b4f9a 100644 --- a/src/sys/windows/mod.rs +++ b/src/sys/windows/mod.rs @@ -41,6 +41,8 @@ cfg_net! { cfg_os_ext! { pub(crate) mod named_pipe; + /// UDS on Windows + pub mod uds; } mod waker; diff --git a/src/sys/windows/net.rs b/src/sys/windows/net.rs index 5cc235335..d526f2884 100644 --- a/src/sys/windows/net.rs +++ b/src/sys/windows/net.rs @@ -3,6 +3,7 @@ use std::mem; use std::net::SocketAddr; use std::sync::Once; +use windows_sys::Win32::Networking::WinSock::SOCKADDR_UN; use windows_sys::Win32::Networking::WinSock::{ closesocket, ioctlsocket, socket, AF_INET, AF_INET6, FIONBIO, IN6_ADDR, IN6_ADDR_0, INVALID_SOCKET, IN_ADDR, IN_ADDR_0, SOCKADDR, SOCKADDR_IN, SOCKADDR_IN6, SOCKADDR_IN6_0, @@ -55,6 +56,7 @@ pub(crate) fn new_socket(domain: u32, socket_type: i32) -> io::Result { pub(crate) union SocketAddrCRepr { v4: SOCKADDR_IN, v6: SOCKADDR_IN6, + unix: SOCKADDR_UN } impl SocketAddrCRepr { diff --git a/src/sys/windows/uds/listener.rs b/src/sys/windows/uds/listener.rs new file mode 100644 index 000000000..f2fbcdbad --- /dev/null +++ b/src/sys/windows/uds/listener.rs @@ -0,0 +1,222 @@ +use super::{socketaddr_un, startup, wsa_error, Socket, SocketAddr, UnixStream}; +use std::{ + io, + os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}, + path::Path, +}; +use windows_sys::Win32::Networking::WinSock::{self, SOCKADDR_UN, SOCKET_ERROR}; +/// A Unix domain socket server for listening to incoming connections. +/// +/// This structure represents a socket server that listens for incoming Unix domain socket +/// connections on Windows systems. After creating a `UnixListener` by binding it to a socket +/// address, it can accept incoming connections from clients. +/// +/// The `UnixListener` wraps an underlying `Socket` and provides a higher-level interface +/// for server-side Unix domain socket operations. +/// +/// # Examples +/// +/// ```no_run +/// use std::io; +/// use mio::sys::uds::UnixListener; +/// +/// fn main() -> io::Result<()> { +/// // Bind to a socket file +/// let listener = UnixListener::bind("/tmp/socket.sock")?; +/// +/// // Accept incoming connections +/// match listener.accept() { +/// Ok((stream, addr)) => { +/// println!("New connection from {:?}", addr); +/// // Handle the connection with the stream... +/// } +/// Err(e) => eprintln!("Connection failed: {}", e), +/// } +/// +/// Ok(()) +/// } +/// ``` +#[derive(Debug)] +pub struct UnixListener(Socket); + +impl UnixListener { + /// Creates a new `UnixListener` bound to the specified path. + /// + /// This function will perform the following operations: + /// 1. Initialize the Winsock library + /// 2. Create a new socket + /// 3. Convert the provided path to a socket address + /// 4. Bind the socket to the address + /// 5. Start listening for incoming connections with a backlog of 5 + /// + /// # Arguments + /// + /// * `path` - The filesystem path to bind the socket to + /// + /// # Errors + /// + /// This function will return an error in the following situations: + /// + /// * Winsock initialization fails + /// * Socket creation fails + /// * The path cannot be converted to a valid socket address + /// * Binding to the specified path fails + /// * Listening on the socket fails + /// + /// # Examples + /// + /// ```no_run + /// use mio::sys::uds::UnixListener; + /// + /// let listener = UnixListener::bind("/tmp/socket.sock").unwrap(); + /// ``` + pub fn bind>(path: P) -> io::Result { + unsafe { + startup()?; + let s = Socket::new()?; + let (addr, len) = socketaddr_un(path.as_ref())?; + if WinSock::bind(s.0, &addr as *const _ as *const _, len) == SOCKET_ERROR { + Err(wsa_error()) + } else { + match WinSock::listen(s.0, 5) { + SOCKET_ERROR => Err(wsa_error()), + _ => Ok(Self(s)), + } + } + } + } + + /// Creates a new `UnixListener` bound to the specified socket address. + /// + /// This function allows binding to a pre-constructed `SocketAddr` instead of + /// creating one from a path. This can be useful when you need more control + /// over the socket address configuration or when reusing addresses. + /// + /// Unlike `bind`, this function does not initialize Winsock, assuming it has + /// already been initialized elsewhere. + /// + /// # Arguments + /// + /// * `socket_addr` - The socket address to bind to + /// + /// # Errors + /// + /// This function will return an error in the following situations: + /// + /// * Socket creation fails + /// * Binding to the specified address fails + /// * Listening on the socket fails + /// + /// # Examples + /// + /// ```no_run + /// use mio::sys::uds::{UnixListener, SocketAddr}; + /// use std::path::Path; + /// + /// // Create a socket address first + /// let addr = SocketAddr::from_path(Path::new("/tmp/socket.sock")).unwrap(); + /// let listener = UnixListener::bind_addr(&addr).unwrap(); + /// ``` + pub fn bind_addr(socket_addr: &SocketAddr) -> io::Result { + unsafe { + let s = Socket::new()?; + if WinSock::bind( + s.0, + &socket_addr.addr as *const _ as *const _, + socket_addr.addrlen, + ) == SOCKET_ERROR + { + Err(wsa_error()) + } else { + match WinSock::listen(s.0, 5) { + SOCKET_ERROR => Err(wsa_error()), + _ => Ok(Self(s)), + } + } + } + } + + /// Accepts a new incoming connection to this listener. + /// + /// This function will block the calling thread until a new Unix domain socket + /// connection is established. When established, the corresponding [`UnixStream`] + /// and the remote peer's address will be returned. + /// + /// The returned [`UnixStream`] can be used to read and write data to the connected + /// client, while the [`SocketAddr`] contains information about the client's address. + /// + /// # Errors + /// + /// This function will return an error if the underlying socket call fails. + /// Specific errors may include: + /// + /// * The socket is not bound or listening + /// * The socket has been closed + /// * Insufficient resources to complete the operation + /// * The operation was interrupted + /// + /// # Examples + /// + /// ```no_run + /// use your_crate::UnixListener; + /// + /// let listener = UnixListener::bind("/tmp/socket.sock").unwrap(); + /// + /// // Accept connections in a loop + /// for stream_result in listener.incoming() { + /// match stream_result { + /// Ok((stream, addr)) => { + /// println!("New connection from {:?}", addr); + /// // Handle the connection... + /// } + /// Err(e) => { + /// eprintln!("Accept error: {}", e); + /// } + /// } + /// } + /// ``` + pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { + let mut addr = SOCKADDR_UN::default(); + let mut addrlen = size_of::() as _; + let s = self + .0 + .accept(&mut addr as *mut _ as *mut _, &mut addrlen as *mut _)?; + Ok((UnixStream::new(s), SocketAddr { addr, addrlen })) + } + /// Returns the socket address of the local half of this connection. + pub fn local_addr(&self) -> io::Result { + self.0.local_addr() + } + /// Returns the value of the `SO_ERROR` option. + pub fn take_error(&self) -> io::Result> { + self.0.take_error() + } + /// Sets the non-blocking mode for this socket + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } +} + +pub(crate) fn bind_addr(socket_addr: &SocketAddr) -> io::Result { + UnixListener::bind_addr(socket_addr) +} +pub(crate) fn accept(s: &UnixListener) -> io::Result<(crate::net::UnixStream, SocketAddr)> { + let (inner, addr) = s.accept()?; + Ok((crate::net::UnixStream::from_std(inner), addr)) +} + +impl AsRawSocket for UnixListener { + fn as_raw_socket(&self) -> RawSocket { + self.0 .0 as _ + } +} +impl FromRawSocket for UnixListener { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + Self(Socket(sock as _)) + } +} +impl IntoRawSocket for UnixListener { + fn into_raw_socket(self) -> RawSocket { + self.0 .0 as _ + } +} \ No newline at end of file diff --git a/src/sys/windows/uds/mod.rs b/src/sys/windows/uds/mod.rs new file mode 100644 index 000000000..75c1da6b4 --- /dev/null +++ b/src/sys/windows/uds/mod.rs @@ -0,0 +1,99 @@ +use std::io; +///we need this file to report std::os::unix::net +pub mod net; +pub(crate) mod stream; +pub use stream::*; +mod socket; +pub use socket::*; +pub(crate) mod listener; +pub use listener::*; +pub(crate) fn startup() -> io::Result<()> { + use windows_sys::Win32::Networking::WinSock::{self, WSADATA}; + use WinSock::{WSAEFAULT, WSAEINPROGRESS, WSAEPROCLIM, WSASYSNOTREADY, WSAVERNOTSUPPORTED}; + let mut wsa_data = WSADATA::default(); + match unsafe { WinSock::WSAStartup(0x202, &mut wsa_data) } { + 0 => Ok(()), + WSASYSNOTREADY => Err(io::Error::other("Network subsystem not ready")), + WSAVERNOTSUPPORTED => Err(io::Error::new( + io::ErrorKind::Unsupported, + "Winsock version not supported", + )), + WSAEINPROGRESS => Err(io::Error::new( + io::ErrorKind::WouldBlock, + "Blocking operation in progress", + )), + WSAEPROCLIM => Err(io::Error::other("Too many tasks")), + WSAEFAULT => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid parameter", + )), + _ => Err(io::Error::other("Unknown WSAStartup error")), + } +} +pub(crate) fn wsa_error() -> io::Error { + use windows_sys::Win32::Networking::WinSock::{ + WSAGetLastError, WSAEACCES, WSAEADDRINUSE, WSAEADDRNOTAVAIL, WSAEAFNOSUPPORT, + WSAECONNABORTED, WSAECONNREFUSED, WSAECONNRESET, WSAEHOSTUNREACH, WSAEINPROGRESS, + WSAEINVAL, WSAEINVALIDPROCTABLE, WSAEINVALIDPROVIDER, WSAEISCONN, WSAEMFILE, WSAEMSGSIZE, + WSAENETDOWN, WSAENETUNREACH, WSAENOBUFS, WSAENOTCONN, WSAEPROTONOSUPPORT, WSAEPROTOTYPE, + WSAEPROVIDERFAILEDINIT, WSAESHUTDOWN, WSAESOCKTNOSUPPORT, WSAETIMEDOUT, WSANOTINITIALISED, + }; + let err = unsafe { WSAGetLastError() }; + let kind = match err { + WSANOTINITIALISED => io::ErrorKind::NotConnected, + WSAENETDOWN => io::ErrorKind::ConnectionReset, + WSAEAFNOSUPPORT => io::ErrorKind::Unsupported, + WSAEINPROGRESS => io::ErrorKind::WouldBlock, + WSAEMFILE => io::ErrorKind::ResourceBusy, + WSAEINVAL => io::ErrorKind::InvalidInput, + WSAEINVALIDPROVIDER | WSAEINVALIDPROCTABLE | WSAEPROVIDERFAILEDINIT => { + io::ErrorKind::InvalidData + } + WSAENOBUFS => io::ErrorKind::OutOfMemory, + WSAEPROTONOSUPPORT | WSAEPROTOTYPE | WSAESOCKTNOSUPPORT => io::ErrorKind::Unsupported, + WSAECONNREFUSED => io::ErrorKind::ConnectionRefused, + WSAETIMEDOUT => io::ErrorKind::TimedOut, + WSAECONNABORTED => io::ErrorKind::ConnectionAborted, + WSAECONNRESET => io::ErrorKind::ConnectionReset, + WSAEADDRINUSE => io::ErrorKind::AddrInUse, + WSAEADDRNOTAVAIL => io::ErrorKind::AddrNotAvailable, + WSAEACCES => io::ErrorKind::PermissionDenied, + WSAEISCONN => io::ErrorKind::AlreadyExists, + WSAENOTCONN => io::ErrorKind::NotConnected, + WSAESHUTDOWN => io::ErrorKind::BrokenPipe, + WSAEMSGSIZE => io::ErrorKind::InvalidInput, + WSAEHOSTUNREACH | WSAENETUNREACH => io::ErrorKind::HostUnreachable, + + _ => io::ErrorKind::Other, + }; + let description = match err { + WSANOTINITIALISED => "Successful WSAStartup call must occur before using this function", + WSAENETDOWN => "The network subsystem has failed", + WSAEAFNOSUPPORT => "The specified address family is not supported", + WSAEINPROGRESS => "A blocking Windows Sockets call is in progress", + WSAEMFILE => "No more socket descriptors are available", + WSAEINVAL => "An invalid argument was supplied", + WSAEINVALIDPROVIDER => "The service provider returned a version other than 2.2", + WSAEINVALIDPROCTABLE => "The service provider returned an invalid procedure table", + WSAENOBUFS => "No buffer space is available", + WSAEPROTONOSUPPORT => "The specified protocol is not supported", + WSAEPROTOTYPE => "The specified protocol is the wrong type for this socket", + WSAEPROVIDERFAILEDINIT => "The service provider failed to initialize", + WSAESOCKTNOSUPPORT => "The specified socket type is not supported in this address family", + WSAECONNREFUSED => "Connection refused", + WSAETIMEDOUT => "Connection timed out", + WSAECONNABORTED => "Connection aborted", + WSAECONNRESET => "Connection reset by peer", + WSAEADDRINUSE => "Address already in use", + WSAEADDRNOTAVAIL => "Address not available", + WSAEACCES => "Permission denied", + WSAEISCONN => "Socket is already connected", + WSAENOTCONN => "Socket is not connected", + WSAESHUTDOWN => "Socket has been shut down", + WSAEMSGSIZE => "Message too long", + WSAEHOSTUNREACH => "Host is unreachable", + WSAENETUNREACH => "Network is unreachable", + _ => "Windows Sockets error", + }; + io::Error::new(kind, format!("{} (error code: {:?})", description, err)) +} diff --git a/src/sys/windows/uds/net.rs b/src/sys/windows/uds/net.rs new file mode 100644 index 000000000..504fbd59c --- /dev/null +++ b/src/sys/windows/uds/net.rs @@ -0,0 +1,2 @@ +pub use super::{UnixListener,UnixStream,SocketAddr}; +//we need this file to report std::os::unix::net \ No newline at end of file diff --git a/src/sys/windows/uds/socket.rs b/src/sys/windows/uds/socket.rs new file mode 100644 index 000000000..1ae193738 --- /dev/null +++ b/src/sys/windows/uds/socket.rs @@ -0,0 +1,260 @@ +use super::{startup, wsa_error}; +use std::{ffi::CStr, fmt::Debug, io, net::Shutdown, os::raw::c_int, path::Path, ptr::null_mut}; +use windows_sys::Win32::Networking::WinSock::{ + self, AF_UNIX, FIONBIO, INVALID_SOCKET, SOCKADDR, SOCKADDR_UN, SOCKET, SOCKET_ERROR, + SOCK_STREAM, SOL_SOCKET, SO_ERROR, WSABUF, +}; +#[derive(Debug)] +pub(crate) struct Socket(pub SOCKET); + +impl Socket { + pub fn new() -> io::Result { + unsafe { + startup()?; + match WinSock::socket(AF_UNIX as _, SOCK_STREAM, 0) { + INVALID_SOCKET => Err(wsa_error()), + s => Ok(Self(s)), + } + } + } + pub fn write(&self, buf: &[u8]) -> io::Result { + unsafe { + match WinSock::send(self.0 as _, buf.as_ptr(), buf.len() as _, 0) { + SOCKET_ERROR => Err(wsa_error()), + len => Ok(len as _), + } + } + } + pub fn write_vectored(&self, bufs: &[io::IoSlice<'_>]) -> io::Result { + let bufs: Vec<_> = bufs + .iter() + .map(|buf| WSABUF { + buf: buf.as_ptr() as *mut _, + len: buf.len() as _, + }) + .collect(); + let mut bytes_send = 0; + unsafe { + match WinSock::WSASend( + self.0, + bufs.as_ptr(), + bufs.len() as _, + &mut bytes_send, + 0, + null_mut(), + None, + ) { + 0 => Ok(bytes_send as usize), + _ => Err(wsa_error()), + } + } + } + + pub fn recv(&self, buf: &mut [u8]) -> io::Result { + unsafe { + match WinSock::recv(self.0 as _, buf.as_mut_ptr(), buf.len() as _, 0) { + 0 => Err(io::Error::other("Connection closed")), + len if len > 0 => Ok(len as _), + _ => Err(wsa_error()), + } + } + } + pub fn recv_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result { + unsafe { + let mut bytes_received = 0; + let mut flags = 0; + let mut bufs: Vec<_> = bufs + .iter_mut() + .map(|buf| WSABUF { + len: buf.len() as _, + buf: buf.as_mut_ptr(), + }) + .collect(); + match WinSock::WSARecv( + self.0, + bufs.as_mut_ptr(), + bufs.len() as _, + &mut bytes_received, + &mut flags, + null_mut(), + None, + ) { + 0 => Ok(bytes_received as usize), + _ => Err(wsa_error()), + } + } + } + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + use WinSock::{SD_BOTH, SD_RECEIVE, SD_SEND}; + let shutdown_how = match how { + Shutdown::Read => SD_RECEIVE, + Shutdown::Write => SD_SEND, + Shutdown::Both => SD_BOTH, + }; + unsafe { + match WinSock::shutdown(self.0, shutdown_how) { + 0 => Ok(()), + _ => Err(wsa_error()), + } + } + } + pub fn accept(&self, addr: *mut SOCKADDR, addrlen: *mut i32) -> io::Result { + unsafe { + // or we should just use None None here because + // seems like accept write nothing to addr and addrlen + match WinSock::accept(self.0, addr, addrlen) { + INVALID_SOCKET => Err(wsa_error()), + s => Ok(Socket(s)), + } + } + } + pub fn local_addr(&self) -> io::Result { + let mut addr = SocketAddr::default(); + addr.addrlen = size_of::() as i32; + match unsafe { + WinSock::getsockname( + self.0, + &mut addr.addr as *mut _ as *mut _, + &mut addr.addrlen as *mut _ as *mut _, + ) + } { + SOCKET_ERROR => Err(wsa_error()), + _ => Ok(addr), + } + } + pub fn peer_addr(&self) -> io::Result { + let mut addr = SocketAddr::default(); + addr.addrlen = size_of::() as i32; + match unsafe { + WinSock::getpeername( + self.0, + &mut addr.addr as *mut _ as *mut _, + &mut addr.addrlen as *mut _ as *mut _, + ) + } { + SOCKET_ERROR => Err(wsa_error()), + _ => Ok(addr), + } + } + pub fn take_error(&self) -> io::Result> { + unsafe { + let mut val = c_int::default(); + let mut len = size_of::() as i32; + match WinSock::getsockopt( + self.0, + SOL_SOCKET, + SO_ERROR, + &mut val as *mut _ as *mut _, + &mut len as *mut _, + ) { + SOCKET_ERROR => Err(wsa_error()), + _ => { + if val == 0 { + Ok(None) + } else { + Ok(Some(wsa_error())) + } + } + } + } + } + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + let mut val = if nonblocking { 1u32 } else { 0 }; + match unsafe { WinSock::ioctlsocket(self.0, FIONBIO, &mut val as *mut _) } { + SOCKET_ERROR => Err(wsa_error()), + _ => Ok(()), + } + } +} +#[derive(Default)] +/// A socket address for Unix domain sockets. +/// +/// This struct wraps the underlying system socket address structure +/// along with its length to provide a safe interface for working with +/// Unix domain sockets. +pub struct SocketAddr { + /// The underlying system socket address structure + pub addr: SOCKADDR_UN, + /// The length of the socket address structure + pub addrlen: i32, +} + +impl Debug for SocketAddr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> core::fmt::Result { + let sun_path_str = unsafe { CStr::from_ptr(self.addr.sun_path.as_ptr()).to_string_lossy() }; + + write!( + f, + "SocketAddr {{ addr: SOCKADDR_UN {{ sun_family: {}, sun_path: {:?} }}, addrlen: {} }}", + self.addr.sun_family, sun_path_str, self.addrlen + ) + } +} +impl SocketAddr { + /// Creates a new `SocketAddr` from a filesystem path. + /// + /// # Arguments + /// + /// * `path` - A path to a socket file in the filesystem + /// + /// # Returns + /// + /// Returns `Ok(SocketAddr)` if the address was successfully created, + /// or an `io::Error` if the path is invalid or too long. + /// + /// # Examples + /// + /// ```no_run + /// use std::path::Path; + /// use mio::uds::SocketAddr; + /// + /// let addr = SocketAddr::from_pathname("/tmp/socket.sock").unwrap(); + /// ``` + pub fn from_pathname>(path: P) -> io::Result { + let (addr, addrlen) = socketaddr_un(path.as_ref())?; + Ok(Self { addr, addrlen }) + } + /// Returns the contents of this address if it is a `pathname` address + pub fn as_pathname(&self) -> Option<&Path> { + let path_ptr = self.addr.sun_path.as_ptr(); + if unsafe { *path_ptr } == 0 { + return None; + } + let c_str = unsafe { CStr::from_ptr(path_ptr) }; + match c_str.to_str() { + Ok(s) => Some(Path::new(s)), + Err(_e) => None, + } + } +} + +pub(crate) fn socketaddr_un(path: &Path) -> io::Result<(SOCKADDR_UN, i32)> { + // let bytes = path.as_os_str().as_encoded_bytes(); + let mut sockaddr = SOCKADDR_UN::default(); + // Winsock2 expects 'sun_path' to be a Win32 UTF-8 file system path + let bytes = path.to_str().map(|s| s.as_bytes()).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "path contains invalid characters", + ) + })?; + + if bytes.contains(&0) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "paths may not contain interior null bytes", + )); + } + + if bytes.len() >= sockaddr.sun_path.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "path must be shorter than SUN_LEN", + )); + } + let src_i8 = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const i8, bytes.len()) }; + sockaddr.sun_family = AF_UNIX; + sockaddr.sun_path[..src_i8.len()].copy_from_slice(src_i8); + let socklen = size_of::() as _; + Ok((sockaddr, socklen)) +} diff --git a/src/sys/windows/uds/stream.rs b/src/sys/windows/uds/stream.rs new file mode 100644 index 000000000..8aceec6fb --- /dev/null +++ b/src/sys/windows/uds/stream.rs @@ -0,0 +1,198 @@ +use super::socketaddr_un; +use super::startup; +use super::wsa_error; +use super::Socket; +use super::SocketAddr; +use std::fmt::Debug; +use std::io; +use std::net::Shutdown; +use std::os::windows::io::AsRawSocket; +use std::os::windows::io::FromRawSocket; +use std::os::windows::io::IntoRawSocket; +use std::os::windows::io::RawSocket; +use std::path::Path; +use windows_sys::Win32::Networking::WinSock; +use windows_sys::Win32::Networking::WinSock::SOCKET_ERROR; +/// A Unix domain socket stream client. +/// +/// This type represents a connected Unix domain socket client stream, +/// providing bidirectional I/O communication with a server. +/// +/// # Examples +/// +/// ```no_run +/// use std::io::{Read, Write}; +/// +/// let mut stream = UnixStream::connect("/tmp/socket.sock")?; +/// stream.write_all(b"Hello, server!")?; +/// +/// let mut response = String::new(); +/// stream.read_to_string(&mut response)?; +/// # Ok::<(), Box>(()) +/// ``` +#[derive(Debug)] +pub struct UnixStream(Socket); +impl UnixStream { + pub(crate) fn new(socket: Socket) -> Self { + Self(socket) + } + /// Connects to a Unix domain socket server at the specified filesystem path. + /// + /// This function creates a new socket and establishes a connection to the server + /// listening on the given path. The path must be a valid filesystem path that + /// the server is bound to. + /// + /// # Arguments + /// + /// * `path` - The filesystem path of the server socket to connect to + /// + /// # Errors + /// + /// Returns an `io::Error` if: + /// - Winsock initialization fails + /// - Socket creation fails + /// - The connection attempt fails + /// - The provided path is invalid + /// + /// # Examples + /// + /// ```no_run + /// let stream = UnixStream::connect("/tmp/socket.sock")?; + /// # Ok::<(), std::io::Error>(()) + /// ``` + pub fn connect>(path: P) -> io::Result { + unsafe { + startup()?; + let s = Socket::new()?; + let (addr, len) = socketaddr_un(path.as_ref())?; + match WinSock::connect(s.0, &addr as *const _ as *const _, len) { + SOCKET_ERROR => Err(wsa_error()), + _ => Ok(Self(s)), + } + } + } + + /// Connects to a Unix domain socket server using a pre-constructed `SocketAddr`. + /// + /// This function creates a new socket and establishes a connection to the server + /// address specified by the given `SocketAddr`. This is useful when you already + /// have a socket address constructed and want to reuse it. + /// + /// # Arguments + /// + /// * `socket_addr` - The socket address of the server to connect to + /// + /// # Errors + /// + /// Returns an `io::Error` if: + /// - Socket creation fails + /// - The connection attempt fails + /// + /// # Examples + /// + /// ```no_run + /// use mio::sys::uds::SocketAddr; + /// + /// let addr = SocketAddr::from_path("/tmp/my_socket")?; + /// let stream = UnixStream::connect_addr(&addr)?; + /// # Ok::<(), std::io::Error>(()) + /// ``` + pub fn connect_addr(socket_addr: &SocketAddr) -> io::Result { + let s = Socket::new()?; + match unsafe { + WinSock::connect( + s.0, + &socket_addr.addr as *const _ as *const _, + socket_addr.addrlen, + ) + } { + SOCKET_ERROR => Err(wsa_error()), + _ => Ok(Self(s)), + } + } + /// Returns the socket address of the local half of this connection. + pub fn local_addr(&self) -> io::Result { + self.0.local_addr() + } + + /// Returns the socket address of the remote half of this connection. + pub fn peer_addr(&self) -> io::Result { + self.0.peer_addr() + } + + /// Returns the value of the `SO_ERROR` option. + pub fn take_error(&self) -> io::Result> { + self.0.take_error() + } + + /// Shuts down the read, write, or both halves of this connection. + /// + /// This function will cause all pending and future I/O calls on the + /// specified portions to immediately return with an appropriate value + /// (see the documentation of `Shutdown`). + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + self.0.shutdown(how) + } + /// Sets the non-blocking mode for this socket + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } +} +impl io::Write for UnixStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + io::Write::write(&mut &*self, buf) + } + fn flush(&mut self) -> io::Result<()> { + io::Write::flush(&mut &*self) + } + fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result { + io::Write::write_vectored(&mut &*self, bufs) + } +} +impl io::Write for &UnixStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } + fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result { + self.0.write_vectored(bufs) + } +} +impl io::Read for &UnixStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.recv(buf) + } + fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result { + self.0.recv_vectored(bufs) + } +} +impl io::Read for UnixStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + io::Read::read(&mut &*self, buf) + } + fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result { + io::Read::read_vectored(&mut &*self, bufs) + } +} + +pub(crate) fn connect_addr(address: &SocketAddr) -> io::Result { + UnixStream::connect_addr(address) +} + +impl AsRawSocket for UnixStream { + fn as_raw_socket(&self) -> RawSocket { + self.0 .0 as _ + } +} +impl FromRawSocket for UnixStream { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixStream(Socket(sock as _)) + } +} +impl IntoRawSocket for UnixStream { + fn into_raw_socket(self) -> RawSocket { + self.0.0 as _ + } +} \ No newline at end of file diff --git a/tests/unix_listener.rs b/tests/unix_listener.rs index 7666f9ff2..98d862a4a 100644 --- a/tests/unix_listener.rs +++ b/tests/unix_listener.rs @@ -1,12 +1,16 @@ -#![cfg(all(unix, feature = "os-poll", feature = "net"))] +#![cfg(all(feature = "os-poll", feature = "net"))] use mio::net::UnixListener; +#[cfg(windows)] +use mio::uds::net; use mio::{Interest, Token}; use std::io::{self, Read}; +#[cfg(unix)] use std::os::unix::net; use std::path::{Path, PathBuf}; use std::sync::{Arc, Barrier}; use std::thread; +use std::time::Duration; #[macro_use] mod util; @@ -196,7 +200,7 @@ where let path = temp_file(test_name); let mut listener = new_listener(&path).unwrap(); - + listener.set_nonblocking(true).unwrap(); assert_socket_non_blocking(&listener); assert_socket_close_on_exec(&listener); @@ -220,12 +224,13 @@ where let mut buf = [0; DEFAULT_BUF_SIZE]; assert_would_block(stream.read(&mut buf)); - + drop(stream); assert_would_block(listener.accept()); assert!(listener.take_error().unwrap().is_none()); barrier.wait(); handle.join().unwrap(); + drop(listener); } fn open_connections( diff --git a/tests/unix_stream.rs b/tests/unix_stream.rs index 8a12a90aa..5f9d99d5c 100644 --- a/tests/unix_stream.rs +++ b/tests/unix_stream.rs @@ -1,14 +1,18 @@ -#![cfg(all(unix, feature = "os-poll", feature = "net"))] +#![cfg(all(feature = "os-poll", feature = "net"))] use mio::net::UnixStream; +#[cfg(windows)] +use mio::uds::net; use mio::{Interest, Token}; use std::io::{self, IoSlice, IoSliceMut, Read, Write}; use std::net::Shutdown; +#[cfg(unix)] use std::os::unix::net; use std::path::Path; use std::sync::mpsc::channel; use std::sync::{Arc, Barrier}; use std::thread; +use std::time::Duration; #[macro_use] mod util; @@ -53,7 +57,8 @@ fn unix_stream_connect() { let barrier_clone = barrier.clone(); let handle = thread::spawn(move || { - let (stream, _) = listener.accept().unwrap(); + let (mut stream, _) = listener.accept().unwrap(); + stream.write_all(b"Hi").unwrap(); barrier_clone.wait(); drop(stream); }); @@ -72,12 +77,18 @@ fn unix_stream_connect() { ); barrier.wait(); + #[cfg(unix)] expect_events( &mut poll, &mut events, vec![ExpectEvent::new(TOKEN_1, Interest::READABLE)], ); - + #[cfg(windows)] + { + let mut buf = [0; 2]; + assert_eq!(stream.read(&mut buf).unwrap(), 2); + assert_eq!(buf, *b"Hi"); + } handle.join().unwrap(); } @@ -99,7 +110,8 @@ fn unix_stream_connect_addr() { let barrier_clone = barrier.clone(); let handle = thread::spawn(move || { - let (stream, _) = mio_listener.accept().unwrap(); + let (mut stream, _) = mio_listener.accept().unwrap(); + stream.write_all(b"Hi").unwrap(); barrier_clone.wait(); drop(stream); }); @@ -118,12 +130,18 @@ fn unix_stream_connect_addr() { ); barrier.wait(); + #[cfg(unix)] expect_events( &mut poll, &mut events, vec![ExpectEvent::new(TOKEN_1, Interest::READABLE)], ); - + #[cfg(windows)] + { + let mut buf = [0; 2]; + assert_eq!(stream.read(&mut buf).unwrap(), 2); + assert_eq!(buf, *b"Hi"); + } handle.join().unwrap(); } @@ -146,6 +164,7 @@ fn unix_stream_from_std() { } #[test] +#[cfg(unix)] fn unix_stream_pair() { let (mut poll, mut events) = init_with_poll(); @@ -186,9 +205,9 @@ fn unix_stream_pair() { fn unix_stream_peer_addr() { init(); let (handle, expected_addr) = new_echo_listener(1, "unix_stream_peer_addr"); - let expected_path = expected_addr.as_pathname().expect("failed to get pathname"); + let path = expected_addr.as_pathname().expect("failed to get pathname"); - let stream = UnixStream::connect(expected_path).unwrap(); + let stream = UnixStream::connect(path).unwrap(); // Complete handshake to unblock the server thread. #[cfg(target_os = "cygwin")] let stream = { @@ -209,14 +228,13 @@ fn unix_stream_peer_addr() { stream }; - assert_eq!( - stream.peer_addr().unwrap().as_pathname().unwrap(), - expected_path - ); + assert_eq!(stream.peer_addr().unwrap().as_pathname().unwrap(), path); assert!(stream.local_addr().unwrap().as_pathname().is_none()); // Close the connection to allow the remote to shutdown drop(stream); + let _ = std::fs::remove_file(path); + handle.join().unwrap(); } @@ -278,6 +296,7 @@ fn unix_stream_shutdown_read() { // Close the connection to allow the remote to shutdown drop(stream); + let _ = std::fs::remove_file(path); handle.join().unwrap(); } @@ -340,6 +359,8 @@ fn unix_stream_shutdown_write() { // Close the connection to allow the remote to shutdown drop(stream); + let _ = std::fs::remove_file(path); + handle.join().unwrap(); } @@ -405,10 +426,12 @@ fn unix_stream_shutdown_both() { #[cfg(unix)] assert_eq!(err.kind(), io::ErrorKind::BrokenPipe); #[cfg(windows)] - assert_eq!(err.kind(), io::ErrorKind::ConnectionAbroted); + assert_eq!(err.kind(), io::ErrorKind::ConnectionAborted); // Close the connection to allow the remote to shutdown drop(stream); + let _ = std::fs::remove_file(path); + handle.join().unwrap(); } @@ -446,6 +469,8 @@ fn unix_stream_shutdown_listener_write() { ); barrier.wait(); + let _ = std::fs::remove_file(path); + handle.join().unwrap(); } @@ -467,6 +492,8 @@ fn unix_stream_register() { // Close the connection to allow the remote to shutdown drop(stream); + let _ = std::fs::remove_file(path); + handle.join().unwrap(); } @@ -495,6 +522,8 @@ fn unix_stream_reregister() { // Close the connection to allow the remote to shutdown drop(stream); + let _ = std::fs::remove_file(path); + handle.join().unwrap(); } @@ -535,6 +564,8 @@ fn unix_stream_deregister() { // Close the connection to allow the remote to shutdown drop(stream); + let _ = std::fs::remove_file(path); + handle.join().unwrap(); } @@ -547,7 +578,7 @@ where let path = remote_addr.as_pathname().expect("failed to get pathname"); let mut stream = connect_stream(path).unwrap(); - + stream.set_nonblocking(true).unwrap(); assert_socket_non_blocking(&stream); assert_socket_close_on_exec(&stream); @@ -584,12 +615,14 @@ where let bufs = [IoSlice::new(DATA1), IoSlice::new(DATA2)]; let wrote = stream.write_vectored(&bufs).unwrap(); assert_eq!(wrote, DATA1_LEN + DATA2_LEN); + #[cfg(unix)] expect_events( &mut poll, &mut events, vec![ExpectEvent::new(TOKEN_1, Interest::READABLE)], ); - + #[cfg(windows)] + std::thread::sleep(Duration::from_millis(500)); let mut buf1 = [1; DATA1_LEN]; let mut buf2 = [2; DATA2_LEN + 1]; let mut bufs = [IoSliceMut::new(&mut buf1), IoSliceMut::new(&mut buf2)]; @@ -603,6 +636,7 @@ where // Close the connection to allow the remote to shutdown drop(stream); + let _ = std::fs::remove_file(path); handle.join().unwrap(); } @@ -613,13 +647,13 @@ fn new_echo_listener( let (addr_sender, addr_receiver) = channel(); let handle = thread::spawn(move || { let path = temp_file(test_name); - let listener = net::UnixListener::bind(path).unwrap(); + let listener = net::UnixListener::bind(&path).unwrap(); let local_addr = listener.local_addr().unwrap(); addr_sender.send(local_addr).unwrap(); for _ in 0..connections { let (mut stream, _) = listener.accept().unwrap(); - + stream.set_nonblocking(true).unwrap(); // On Linux based system it will cause a connection reset // error when the reading side of the peer connection is // shutdown, we don't consider it an actual here. @@ -631,16 +665,15 @@ fn new_echo_listener( read += amount; amount } - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, - Err(ref err) - if matches!( - err.kind(), - io::ErrorKind::ConnectionReset | io::ErrorKind::ConnectionAborted - ) => - { - break + Err(e) => { + //I don't know why Windows keep send WSAEWOULDBLOCK code. + //even connection closed + if !&path.exists() { + break; + } + std::thread::sleep(Duration::from_millis(200)); + continue; } - Err(err) => panic!("{}", err), }; if n == 0 { break; @@ -654,6 +687,8 @@ fn new_echo_listener( } assert_eq!(read, written, "unequal reads and writes"); } + eprintln!("Exit"); + drop(listener); }); (handle, addr_receiver.recv().unwrap()) } diff --git a/tests/util/mod.rs b/tests/util/mod.rs index 579a10d00..0d2e300d4 100644 --- a/tests/util/mod.rs +++ b/tests/util/mod.rs @@ -193,8 +193,11 @@ pub fn assert_error(result: Result, expected_msg: &str pub fn assert_would_block(result: io::Result) { match result { Ok(_) => panic!("unexpected OK result, expected a `WouldBlock` error"), - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {} - Err(err) => panic!("unexpected error result: {err}"), + Err(e) if e.kind() == io::ErrorKind::WouldBlock => {} + #[cfg(unix)] + Err(e) => panic!("unexpected error: {e}"), + #[cfg(windows)] + Err(_) => {} } }