diff --git a/library/std/src/os/windows/mod.rs b/library/std/src/os/windows/mod.rs index f452403ee8426..9e078e3364093 100644 --- a/library/std/src/os/windows/mod.rs +++ b/library/std/src/os/windows/mod.rs @@ -29,6 +29,8 @@ pub mod ffi; pub mod fs; pub mod io; +#[cfg(windows)] +pub mod net; pub mod process; pub mod raw; pub mod thread; diff --git a/library/std/src/os/windows/net/addr.rs b/library/std/src/os/windows/net/addr.rs new file mode 100644 index 0000000000000..1289aedcc547d --- /dev/null +++ b/library/std/src/os/windows/net/addr.rs @@ -0,0 +1,83 @@ +#![unstable(feature = "windows_unix_domain_sockets", issue = "56533")] + +use crate::os::raw::{c_char, c_int}; +use crate::path::Path; +use crate::sys::c::{self, SOCKADDR}; +use crate::sys::cvt; +use crate::{io, mem}; +pub fn sockaddr_un(path: &Path) -> io::Result<(c::sockaddr_un, c_int)> { + let mut addr: c::sockaddr_un = unsafe { mem::zeroed() }; + addr.sun_family = c::AF_UNIX; + // Winsock2 expects 'sun_path' to be a Win32 UTF-8 file system path + let bytes = path + .to_str() + .map(|s| s.as_bytes()) + .ok_or(io::Error::new(io::ErrorKind::InvalidInput, "path contains invalid characters"))?; + + if bytes.contains(&0) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "paths may not contain interior null bytes", + )); + } + + if bytes.len() >= addr.sun_path.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "path must be shorter than SUN_LEN", + )); + } + for (dst, src) in addr.sun_path.iter_mut().zip(bytes.iter()) { + *dst = *src as c_char; + } + // null byte for pathname addresses is already there because we zeroed the + // struct + + let mut len = sun_path_offset(&addr) + bytes.len(); + match bytes.first() { + Some(&0) | None => {} + Some(_) => len += 1, + } + Ok((addr, len as _)) +} +fn sun_path_offset(addr: &c::sockaddr_un) -> usize { + // Work with an actual instance of the type since using a null pointer is UB + let base = addr as *const _ as usize; + let path = &addr.sun_path as *const _ as usize; + path - base +} +#[allow(dead_code)] +pub struct SocketAddr { + addr: c::sockaddr_un, + len: c_int, +} +impl SocketAddr { + pub fn new(f: F) -> io::Result + where + F: FnOnce(*mut SOCKADDR, *mut c_int) -> c_int, + { + unsafe { + let mut addr: c::sockaddr_un = mem::zeroed(); + let mut len = mem::size_of::() as c_int; + cvt(f(&mut addr as *mut _ as *mut _, &mut len))?; + SocketAddr::from_parts(addr, len) + } + } + fn from_parts(addr: c::sockaddr_un, mut len: c_int) -> io::Result { + if len == 0 { + // When there is a datagram from unnamed unix socket + // linux returns zero bytes of address + len = sun_path_offset(&addr) as c_int; // i.e. zero-length address + } else if addr.sun_family != c::AF_UNIX { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "file descriptor did not correspond to a Unix socket", + )); + } + + Ok(SocketAddr { addr, len }) + } +} +pub fn from_sockaddr_un(addr: c::sockaddr_un, len: c_int) -> io::Result { + SocketAddr::from_parts(addr, len) +} diff --git a/library/std/src/os/windows/net/listener.rs b/library/std/src/os/windows/net/listener.rs new file mode 100644 index 0000000000000..815c798bddfa9 --- /dev/null +++ b/library/std/src/os/windows/net/listener.rs @@ -0,0 +1,94 @@ +#![unstable(feature = "windows_unix_domain_sockets", issue = "56533")] + +use core::mem; + +use super::sockaddr_un; +use crate::io; +use crate::os::raw::c_int; +use crate::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; +use crate::os::windows::net::{SocketAddr, UnixStream, from_sockaddr_un}; +use crate::path::Path; +use crate::sys::c::{self, bind, getsockname, listen}; +use crate::sys::cvt; +use crate::sys::net::Socket; + +pub struct UnixListener(Socket); + +impl UnixListener { + pub fn bind>(path: &Path) -> io::Result { + unsafe { + let inner = Socket::new_unix()?; + let (addr, len) = sockaddr_un(path)?; + cvt(bind(inner.as_raw(), &addr as *const _ as *const _, len))?; + cvt(listen(inner.as_raw(), 128))?; + Ok(UnixListener(inner)) + } + } + pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { + let mut storage: c::sockaddr_un = unsafe { mem::zeroed() }; + let mut len = mem::size_of_val(&storage) as c_int; + let sock = self.0.accept(&mut storage as *mut _ as *mut _, &mut len)?; + let addr = from_sockaddr_un(storage, len)?; + Ok((UnixStream(sock), addr)) + } + pub fn incoming(&self) -> Incoming<'_> { + Incoming { listener: self } + } + pub fn take_error(&self) -> io::Result> { + self.0.take_error() + } + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } + pub fn local_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| unsafe { getsockname(self.0.as_raw() as _, addr, len) }) + } + pub fn try_clone(&self) -> io::Result { + self.0.duplicate().map(UnixListener) + } +} + +pub struct Incoming<'a> { + listener: &'a UnixListener, +} + +impl<'a> Iterator for Incoming<'a> { + type Item = io::Result; + + fn next(&mut self) -> Option> { + Some(self.listener.accept().map(|s| s.0)) + } + + fn size_hint(&self) -> (usize, Option) { + (usize::MAX, None) + } +} + +impl AsRawSocket for UnixListener { + fn as_raw_socket(&self) -> RawSocket { + self.0.as_raw_socket() + } +} + +impl FromRawSocket for UnixListener { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixListener(unsafe { Socket::from_raw_socket(sock) }) + } +} + +impl IntoRawSocket for UnixListener { + fn into_raw_socket(self) -> RawSocket { + let ret = self.0.as_raw_socket(); + mem::forget(self); + ret + } +} + +impl<'a> IntoIterator for &'a UnixListener { + type Item = io::Result; + type IntoIter = Incoming<'a>; + + fn into_iter(self) -> Incoming<'a> { + self.incoming() + } +} diff --git a/library/std/src/os/windows/net/mod.rs b/library/std/src/os/windows/net/mod.rs new file mode 100644 index 0000000000000..fe7ec8885907c --- /dev/null +++ b/library/std/src/os/windows/net/mod.rs @@ -0,0 +1,8 @@ +#![unstable(feature = "windows_unix_domain_sockets", issue = "56533")] + +mod addr; +mod listener; +mod stream; +pub use addr::*; +pub use listener::*; +pub use stream::*; diff --git a/library/std/src/os/windows/net/stream.rs b/library/std/src/os/windows/net/stream.rs new file mode 100644 index 0000000000000..b618123d4f633 --- /dev/null +++ b/library/std/src/os/windows/net/stream.rs @@ -0,0 +1,114 @@ +#![unstable(feature = "windows_unix_domain_sockets", issue = "56533")] + +use core::mem; +use core::time::Duration; + +use crate::io::{self, IoSlice}; +use crate::net::Shutdown; +use crate::os::windows::io::{ + AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, IntoRawSocket, RawSocket, +}; +use crate::os::windows::net::{SocketAddr, sockaddr_un}; +use crate::path::Path; +use crate::sys::c::{SO_RCVTIMEO, SO_SNDTIMEO, connect, getpeername, getsockname}; +use crate::sys::cvt; +use crate::sys::net::Socket; + +pub struct UnixStream(pub Socket); +impl UnixStream { + pub fn connect>(path: P) -> io::Result { + unsafe { + let inner = Socket::new_unix()?; + let (addr, len) = sockaddr_un(path.as_ref())?; + cvt(connect(inner.as_raw() as _, &addr as *const _ as *const _, len))?; + Ok(UnixStream(inner)) + } + } + pub fn local_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| unsafe { getsockname(self.0.as_raw() as _, addr, len) }) + } + pub fn peer_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| unsafe { getpeername(self.0.as_raw() as _, addr, len) }) + } + pub fn read_timeout(&self) -> io::Result> { + self.0.timeout(SO_RCVTIMEO) + } + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } + pub fn set_read_timeout(&self, dur: Option) -> io::Result<()> { + self.0.set_timeout(dur, SO_RCVTIMEO) + } + pub fn set_write_timeout(&self, dur: Option) -> io::Result<()> { + self.0.set_timeout(dur, SO_SNDTIMEO) + } + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + self.0.shutdown(how) + } + pub fn take_error(&self) -> io::Result> { + self.0.take_error() + } + pub fn try_clone(&self) -> io::Result { + self.0.duplicate().map(UnixStream) + } + pub fn write_timeout(&self) -> io::Result> { + self.0.timeout(SO_SNDTIMEO) + } +} + +impl io::Read for UnixStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + io::Read::read(&mut &*self, buf) + } +} + +impl<'a> io::Read for &'a UnixStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } +} + +impl io::Write for UnixStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + io::Write::write(&mut &*self, buf) + } + + fn flush(&mut self) -> io::Result<()> { + io::Write::flush(&mut &*self) + } +} +impl<'a> io::Write for &'a UnixStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write_vectored(&[IoSlice::new(buf)]) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl AsSocket for UnixStream { + fn as_socket(&self) -> BorrowedSocket<'_> { + self.0.as_socket() + } +} + +impl AsRawSocket for UnixStream { + fn as_raw_socket(&self) -> RawSocket { + self.0.as_raw_socket() + } +} + +impl FromRawSocket for UnixStream { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + unsafe { UnixStream(Socket::from_raw_socket(sock)) } + } +} + +impl IntoRawSocket for UnixStream { + fn into_raw_socket(self) -> RawSocket { + let ret = self.0.as_raw_socket(); + mem::forget(self); + ret + } +} diff --git a/library/std/src/sys/net/connection/socket/windows.rs b/library/std/src/sys/net/connection/socket/windows.rs index 6dbebc5e276ec..3ec4b16ec620c 100644 --- a/library/std/src/sys/net/connection/socket/windows.rs +++ b/library/std/src/sys/net/connection/socket/windows.rs @@ -9,6 +9,7 @@ use crate::os::windows::io::{ AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, IntoRawSocket, OwnedSocket, RawSocket, }; use crate::sys::c; +use crate::sys::c::{AF_UNIX, INVALID_SOCKET, SOCK_STREAM, WSA_FLAG_OVERLAPPED, WSASocketW}; use crate::sys::pal::winsock::last_error; use crate::sys_common::{AsInner, FromInner, IntoInner}; use crate::time::Duration; @@ -117,6 +118,23 @@ pub use crate::sys::pal::winsock::{cvt, cvt_gai, cvt_r, startup as init}; pub struct Socket(OwnedSocket); impl Socket { + pub fn new_unix() -> io::Result { + let socket = unsafe { + match WSASocketW( + AF_UNIX as i32, + SOCK_STREAM, + 0, + ptr::null_mut(), + 0, + WSA_FLAG_OVERLAPPED, + ) { + INVALID_SOCKET => Err(last_error()), + n => Ok(Socket::from_raw(n)), + } + }?; + socket.0.set_no_inherit()?; + Ok(socket) + } pub fn new(family: c_int, ty: c_int) -> io::Result { let socket = unsafe { c::WSASocketW( diff --git a/library/std/src/sys/pal/windows/c.rs b/library/std/src/sys/pal/windows/c.rs index 25c1a82cc426a..9dec28a88886c 100644 --- a/library/std/src/sys/pal/windows/c.rs +++ b/library/std/src/sys/pal/windows/c.rs @@ -11,8 +11,16 @@ use core::ptr; mod windows_sys; pub use windows_sys::*; -pub type WCHAR = u16; +use crate::os::raw::c_char; +pub type WCHAR = u16; +pub const AF_UNIX: ADDRESS_FAMILY = 1; +#[derive(Clone, Copy)] +#[repr(C)] +pub struct sockaddr_un { + pub sun_family: ADDRESS_FAMILY, + pub sun_path: [c_char; 108], +} pub const INVALID_HANDLE_VALUE: HANDLE = ::core::ptr::without_provenance_mut(-1i32 as _); // https://learn.microsoft.com/en-us/cpp/c-runtime-library/exit-success-exit-failure?view=msvc-170 diff --git a/library/std/tests/net/windows_unix_socket.rs b/library/std/tests/net/windows_unix_socket.rs new file mode 100644 index 0000000000000..51bddf0b18ba3 --- /dev/null +++ b/library/std/tests/net/windows_unix_socket.rs @@ -0,0 +1,67 @@ +#![cfg(all(windows, feature = "windows_unix_domain_sockets"))] +#![unstable(feature = "windows_unix_domain_sockets", issue = "none")] + +use std::io::{Read, Write}; +use std::os::windows::net::{UnixListener, UnixStream}; +use std::path::Path; +use std::thread; + +#[test] +fn smoke_bind_connect() { + let tmp = std::env::temp_dir(); + let sock_path = tmp.join("rust-test-uds.sock"); + + let listener = UnixListener::bind(&sock_path).expect("bind failed"); + + let tx = thread::spawn(move || { + let mut stream = UnixStream::connect(&sock_path).expect("connect failed"); + stream.write_all(b"hello").expect("write failed"); + }); + + let (mut stream, _) = listener.accept().expect("accept failed"); + let mut buf = [0; 5]; + stream.read_exact(&mut buf).expect("read failed"); + assert_eq!(&buf, b"hello"); + + tx.join().unwrap; + + drop(listener); + let _ = std::fs::remove_file(&sock_path); +} + +#[test] +fn echo() { + let tmp = std::env::temp_dir(); + let sock_path = tmp.join("rust-test-uds-echo.sock"); + + let listener = UnixListener::bind(&sock_path).unwrap(); + + let tx = thread::spawn(move || { + let (mut stream, _) = listener.accept().unwrap; + let mut buf = [0; 1024]; + loop { + let n = match stream.read(&mut buf) { + Ok(0) => return, + Ok(n) => n, + Err(e) => panic!("read error: {}", e), + }; + stream.write_all(&buf[..n]).unwrap; + } + }); + + let mut client = UnixStream::connect(&sock_path).unwrap; + client.write_all(b"echo").unwrap; + let mut buf = [0; 4]; + client.read_exact(&mut buf).unwrap; + assert_eq!(&buf, b"echo"); + + drop(client); + tx.join().unwrap; + let _ = std::fs::remove_file(&sock_path); +} + +#[test] +fn path_too_long() { + let long = "\\\\?\\".to_string() + &"a".repeat(300); + assert!(UnixListener::bind(long).is_err()); +}