Skip to content

Commit dd3ed58

Browse files
committed
poc
1 parent fa3155a commit dd3ed58

File tree

4 files changed

+223
-6
lines changed

4 files changed

+223
-6
lines changed

library/std/src/os/unix/net/addr.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use crate::os::unix::ffi::OsStrExt;
66
use crate::path::Path;
77
use crate::sealed::Sealed;
88
use crate::sys::cvt;
9+
use crate::sys::net::SockaddrLike;
910
use crate::{fmt, io, mem, ptr};
1011

1112
// FIXME(#43348): Make libc adapt #[doc(cfg(...))] so we don't need these fake definitions here?
@@ -253,6 +254,27 @@ impl SocketAddr {
253254
}
254255
}
255256

257+
impl SockaddrLike for SocketAddr {
258+
unsafe fn from_storage(
259+
storage: &libc::sockaddr_storage,
260+
len: libc::socklen_t,
261+
) -> io::Result<Self> {
262+
let p = (storage as *const libc::sockaddr_storage).cast();
263+
SocketAddr::from_parts(*p, len)
264+
}
265+
266+
fn to_storage(&self, storage_ret: &mut libc::sockaddr_storage) -> libc::socklen_t {
267+
unsafe {
268+
crate::ptr::copy_nonoverlapping(
269+
&raw const self.addr,
270+
(storage_ret as *mut libc::sockaddr_storage).cast(),
271+
self.len as _,
272+
);
273+
self.len
274+
}
275+
}
276+
}
277+
256278
#[stable(feature = "unix_socket_abstract", since = "1.70.0")]
257279
impl Sealed for SocketAddr {}
258280

library/std/src/os/unix/net/ancillary.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ pub(super) fn recv_vectored_with_ancillary_from(
4848
msg.msg_control = ancillary.buffer.as_mut_ptr().cast();
4949
}
5050

51-
let count = socket.recv_msg(&mut msg)?;
51+
let count = socket.recv_msg_(&mut msg)?;
5252

5353
ancillary.length = msg.msg_controllen as usize;
5454
ancillary.truncated = msg.msg_flags & libc::MSG_CTRUNC == libc::MSG_CTRUNC;
@@ -83,7 +83,7 @@ pub(super) fn send_vectored_with_ancillary_to(
8383

8484
ancillary.truncated = false;
8585

86-
socket.send_msg(&mut msg)
86+
socket.send_msg_(&mut msg)
8787
}
8888
}
8989

library/std/src/sys/net/connection/socket/mod.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,46 @@ unsafe fn socket_addr_from_c(
198198
}
199199
}
200200

201+
// Structs that have sockaddr header and can be marshalled to and from sockaddr_storage
202+
pub(crate) trait SockaddrLike: Sized {
203+
// used in recvmsg to parse the received addr
204+
unsafe fn from_storage(storage: &c::sockaddr_storage, len: c::socklen_t) -> io::Result<Self>;
205+
206+
// used in sendmsg to write to a suckaddr_storage buffer
207+
fn to_storage(&self, storage_ret: &mut c::sockaddr_storage) -> c::socklen_t;
208+
}
209+
210+
impl SockaddrLike for SocketAddr {
211+
unsafe fn from_storage(storage: &c::sockaddr_storage, len: c::socklen_t) -> io::Result<Self> {
212+
socket_addr_from_c(storage as *const _, len as _)
213+
}
214+
215+
fn to_storage(&self, storage_ret: &mut c::sockaddr_storage) -> c::socklen_t {
216+
let (crep, len) = socket_addr_to_c(self);
217+
unsafe {
218+
crate::ptr::copy_nonoverlapping(
219+
&raw const crep,
220+
(storage_ret as *mut c::sockaddr_storage).cast(),
221+
len as _,
222+
);
223+
}
224+
len as _
225+
}
226+
}
227+
228+
impl SockaddrLike for () {
229+
unsafe fn from_storage(
230+
_storage: &libc::sockaddr_storage,
231+
_len: libc::socklen_t,
232+
) -> io::Result<Self> {
233+
Ok(())
234+
}
235+
236+
fn to_storage(&self, _storage_ret: &mut libc::sockaddr_storage) -> libc::socklen_t {
237+
0
238+
}
239+
}
240+
201241
////////////////////////////////////////////////////////////////////////////////
202242
// sockaddr and misc bindings
203243
////////////////////////////////////////////////////////////////////////////////

library/std/src/sys/net/connection/socket/unix.rs

Lines changed: 159 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1-
use libc::{MSG_PEEK, c_int, c_void, size_t, sockaddr, socklen_t};
1+
use libc::{
2+
CMSG_DATA, CMSG_FIRSTHDR, CMSG_LEN, CMSG_NXTHDR, MSG_PEEK, c_int, c_uint, c_void, cmsghdr,
3+
iovec, msghdr, size_t, sockaddr, sockaddr_storage, socklen_t,
4+
};
25

36
#[cfg(not(any(target_os = "espidf", target_os = "nuttx")))]
47
use crate::ffi::CStr;
58
use crate::io::{self, BorrowedBuf, BorrowedCursor, IoSlice, IoSliceMut};
9+
use crate::mem::zeroed;
610
use crate::net::{Shutdown, SocketAddr};
711
use crate::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, RawFd};
12+
use crate::ptr::copy_nonoverlapping;
813
use crate::sys::fd::FileDesc;
9-
use crate::sys::net::{getsockopt, setsockopt};
14+
use crate::sys::net::{SockaddrLike, getsockopt, setsockopt};
1015
use crate::sys::pal::IsMinusOne;
1116
use crate::sys_common::{AsInner, FromInner, IntoInner};
1217
use crate::time::{Duration, Instant};
@@ -62,6 +67,64 @@ pub fn cvt_gai(err: c_int) -> io::Result<()> {
6267
))
6368
}
6469

70+
#[repr(C)]
71+
pub union CmsgIter<'buf> {
72+
_align: msghdr,
73+
inner: CmsgIterInner<'buf>,
74+
}
75+
76+
#[derive(Clone, Copy)]
77+
#[repr(C)]
78+
struct CmsgIterInner<'buf> {
79+
_padding: [u8; size_of::<usize>() + size_of::<socklen_t>() + size_of::<size_t>()],
80+
curr_cmsg: *mut cmsghdr,
81+
cmsg_buf: &'buf [u8],
82+
cmsg_buf_len: usize,
83+
}
84+
85+
#[repr(transparent)]
86+
pub struct CmsgBuf<'buf>(&'buf mut [u8]);
87+
88+
impl<'buf> CmsgBuf<'buf> {
89+
// fails if buf isn't aligned to alignof(cmsghdr)
90+
pub fn new(buf: &'buf mut [u8]) -> io::Result<Self> {
91+
if buf.as_ptr().align_offset(align_of::<cmsghdr>()) == 0 {
92+
Ok(CmsgBuf(buf))
93+
} else {
94+
Err(io::Error::new(io::ErrorKind::InvalidInput, "unaligned buffer"))
95+
}
96+
}
97+
98+
pub unsafe fn new_unchecked(buf: &'buf mut [u8]) -> Self {
99+
CmsgBuf(buf)
100+
}
101+
}
102+
103+
impl<'buf> Iterator for CmsgIter<'buf> {
104+
type Item = (size_t, c_int, c_int, &'buf [u8]);
105+
106+
fn next(&mut self) -> Option<Self::Item> {
107+
unsafe {
108+
if self.inner.curr_cmsg.is_null() {
109+
None
110+
} else {
111+
let curr = *self.inner.curr_cmsg;
112+
let data_ptr = CMSG_DATA(self.inner.curr_cmsg);
113+
let ptrdiff = data_ptr.offset_from_unsigned(self.inner.curr_cmsg as *const u8);
114+
let r = (
115+
curr.cmsg_len,
116+
curr.cmsg_level,
117+
curr.cmsg_type,
118+
crate::slice::from_raw_parts(data_ptr, curr.cmsg_len - ptrdiff),
119+
);
120+
self.inner.curr_cmsg =
121+
CMSG_NXTHDR(self as *mut _ as *mut msghdr, self.inner.curr_cmsg);
122+
Some(r)
123+
}
124+
}
125+
}
126+
}
127+
65128
impl Socket {
66129
pub fn new(addr: &SocketAddr, ty: c_int) -> io::Result<Socket> {
67130
let fam = match *addr {
@@ -362,11 +425,51 @@ impl Socket {
362425
}
363426

364427
#[cfg(any(target_os = "android", target_os = "linux", target_os = "cygwin"))]
365-
pub fn recv_msg(&self, msg: &mut libc::msghdr) -> io::Result<usize> {
428+
pub fn recv_msg_(&self, msg: &mut libc::msghdr) -> io::Result<usize> {
366429
let n = cvt(unsafe { libc::recvmsg(self.as_raw_fd(), msg, libc::MSG_CMSG_CLOEXEC) })?;
367430
Ok(n as usize)
368431
}
369432

433+
pub fn recv_msg<'a, 'b, T>(
434+
&self,
435+
iov_buf: &mut [IoSliceMut<'_>],
436+
cmsg_buf: CmsgBuf<'a>,
437+
flags: c_int,
438+
) -> io::Result<(usize, T, c_int, CmsgIter<'b>)>
439+
where
440+
T: SockaddrLike,
441+
'a: 'b,
442+
{
443+
unsafe {
444+
let mut msg: msghdr = zeroed();
445+
let mut addr: sockaddr_storage = zeroed();
446+
msg.msg_name = (&raw mut addr).cast();
447+
msg.msg_namelen = mem::size_of_val(&addr) as _;
448+
449+
msg.msg_iovlen = iov_buf.len();
450+
msg.msg_iov = iov_buf.as_mut_ptr().cast();
451+
452+
msg.msg_controllen = cmsg_buf.0.len();
453+
if msg.msg_controllen != 0 {
454+
msg.msg_control = cmsg_buf.0.as_mut_ptr().cast();
455+
}
456+
457+
msg.msg_flags = 0;
458+
459+
let bytes = cvt(libc::recvmsg(self.as_raw_fd(), &raw mut msg, flags))? as usize;
460+
461+
let addr = SockaddrLike::from_storage(&addr, msg.msg_namelen)?;
462+
463+
let mut iter: CmsgIter<'_> = zeroed();
464+
iter.inner.cmsg_buf = cmsg_buf.0;
465+
iter.inner.cmsg_buf_len = msg.msg_controllen;
466+
let fst_cmsg = CMSG_FIRSTHDR((&raw const iter).cast());
467+
iter.inner.curr_cmsg = fst_cmsg;
468+
469+
Ok((bytes, addr, msg.msg_flags, iter))
470+
}
471+
}
472+
370473
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
371474
self.recv_from_with_flags(buf, MSG_PEEK)
372475
}
@@ -385,11 +488,63 @@ impl Socket {
385488
}
386489

387490
#[cfg(any(target_os = "android", target_os = "linux", target_os = "cygwin"))]
388-
pub fn send_msg(&self, msg: &mut libc::msghdr) -> io::Result<usize> {
491+
pub fn send_msg_(&self, msg: &mut libc::msghdr) -> io::Result<usize> {
389492
let n = cvt(unsafe { libc::sendmsg(self.as_raw_fd(), msg, 0) })?;
390493
Ok(n as usize)
391494
}
392495

496+
pub fn send_msg<T>(
497+
&self,
498+
addr: Option<&T>,
499+
iov: &[IoSlice<'_>],
500+
cmsgs: &[(c_int, c_int, &[u8])],
501+
cmsg_buf: CmsgBuf<'_>,
502+
flags: c_int,
503+
) -> io::Result<usize>
504+
where
505+
T: SockaddrLike,
506+
{
507+
unsafe {
508+
let mut msg: msghdr = zeroed();
509+
let mut addr_s: sockaddr_storage = zeroed();
510+
511+
if let Some(addr_) = addr {
512+
let len = addr_.to_storage(&mut addr_s);
513+
msg.msg_namelen = len;
514+
msg.msg_name = (&raw mut addr_s).cast();
515+
}
516+
517+
msg.msg_iovlen = iov.len();
518+
msg.msg_iov = iov.as_ptr().cast::<IoSlice<'_>>() as *mut iovec;
519+
520+
// cmsg
521+
msg.msg_controllen = cmsg_buf.0.len();
522+
msg.msg_control = cmsg_buf.0.as_mut_ptr().cast();
523+
let mut curr_cmsg_hdr = CMSG_FIRSTHDR(&raw const msg);
524+
for (cmsg_level, cmsg_type, cmsg_data) in cmsgs {
525+
if curr_cmsg_hdr.is_null() {
526+
return Err(io::Error::new(
527+
io::ErrorKind::InvalidInput,
528+
"cmsg_buf supplied is too small to hold all control messages",
529+
));
530+
}
531+
532+
(*curr_cmsg_hdr).cmsg_level = *cmsg_level;
533+
(*curr_cmsg_hdr).cmsg_type = *cmsg_type;
534+
(*curr_cmsg_hdr).cmsg_len = CMSG_LEN(cmsg_data.len() as c_uint) as usize;
535+
536+
let cmsg_data_ptr = CMSG_DATA(curr_cmsg_hdr);
537+
copy_nonoverlapping((*cmsg_data).as_ptr(), cmsg_data_ptr, cmsg_data.len());
538+
539+
curr_cmsg_hdr = CMSG_NXTHDR(&raw const msg, curr_cmsg_hdr as *const _);
540+
}
541+
542+
let bytes = cvt(libc::sendmsg(self.as_raw_fd(), &raw mut msg, flags))? as usize;
543+
544+
Ok(bytes)
545+
}
546+
}
547+
393548
pub fn set_timeout(&self, dur: Option<Duration>, kind: libc::c_int) -> io::Result<()> {
394549
let timeout = match dur {
395550
Some(dur) => {

0 commit comments

Comments
 (0)