Skip to content

Commit 4a0a9d8

Browse files
committed
poc
1 parent fa3155a commit 4a0a9d8

File tree

4 files changed

+224
-6
lines changed

4 files changed

+224
-6
lines changed

library/std/src/os/unix/net/addr.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use crate::os::unix::ffi::OsStrExt;
66
use crate::path::Path;
77
use crate::sealed::Sealed;
88
use crate::sys::cvt;
9+
use crate::sys::net::SockaddrLike;
910
use crate::{fmt, io, mem, ptr};
1011

1112
// FIXME(#43348): Make libc adapt #[doc(cfg(...))] so we don't need these fake definitions here?
@@ -253,6 +254,27 @@ impl SocketAddr {
253254
}
254255
}
255256

257+
impl SockaddrLike for SocketAddr {
258+
unsafe fn from_storage(
259+
storage: &libc::sockaddr_storage,
260+
len: libc::socklen_t,
261+
) -> io::Result<Self> {
262+
let p = (storage as *const libc::sockaddr_storage).cast();
263+
SocketAddr::from_parts(*p, len)
264+
}
265+
266+
fn to_storage(&self, storage_ret: &mut libc::sockaddr_storage) -> libc::socklen_t {
267+
unsafe {
268+
crate::ptr::copy_nonoverlapping(
269+
&raw const self.addr,
270+
(storage_ret as *mut libc::sockaddr_storage).cast(),
271+
self.len as _,
272+
);
273+
self.len
274+
}
275+
}
276+
}
277+
256278
#[stable(feature = "unix_socket_abstract", since = "1.70.0")]
257279
impl Sealed for SocketAddr {}
258280

library/std/src/os/unix/net/ancillary.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ pub(super) fn recv_vectored_with_ancillary_from(
4848
msg.msg_control = ancillary.buffer.as_mut_ptr().cast();
4949
}
5050

51-
let count = socket.recv_msg(&mut msg)?;
51+
let count = socket.recv_msg_(&mut msg)?;
5252

5353
ancillary.length = msg.msg_controllen as usize;
5454
ancillary.truncated = msg.msg_flags & libc::MSG_CTRUNC == libc::MSG_CTRUNC;
@@ -83,7 +83,7 @@ pub(super) fn send_vectored_with_ancillary_to(
8383

8484
ancillary.truncated = false;
8585

86-
socket.send_msg(&mut msg)
86+
socket.send_msg_(&mut msg)
8787
}
8888
}
8989

library/std/src/sys/net/connection/socket/mod.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,46 @@ unsafe fn socket_addr_from_c(
198198
}
199199
}
200200

201+
// Structs that have sockaddr header and can be marshalled to and from sockaddr_storage
202+
pub(crate) trait SockaddrLike: Sized {
203+
// used in recvmsg to parse the received addr
204+
unsafe fn from_storage(storage: &c::sockaddr_storage, len: c::socklen_t) -> io::Result<Self>;
205+
206+
// used in sendmsg to write to a suckaddr_storage buffer
207+
fn to_storage(&self, storage_ret: &mut c::sockaddr_storage) -> c::socklen_t;
208+
}
209+
210+
impl SockaddrLike for SocketAddr {
211+
unsafe fn from_storage(storage: &c::sockaddr_storage, len: c::socklen_t) -> io::Result<Self> {
212+
socket_addr_from_c(storage as *const _, len as _)
213+
}
214+
215+
fn to_storage(&self, storage_ret: &mut c::sockaddr_storage) -> c::socklen_t {
216+
let (crep, len) = socket_addr_to_c(self);
217+
unsafe {
218+
crate::ptr::copy_nonoverlapping(
219+
&raw const crep,
220+
(storage_ret as *mut c::sockaddr_storage).cast(),
221+
len as _,
222+
);
223+
}
224+
len as _
225+
}
226+
}
227+
228+
impl SockaddrLike for () {
229+
unsafe fn from_storage(
230+
_storage: &libc::sockaddr_storage,
231+
_len: libc::socklen_t,
232+
) -> io::Result<Self> {
233+
Ok(())
234+
}
235+
236+
fn to_storage(&self, _storage_ret: &mut libc::sockaddr_storage) -> libc::socklen_t {
237+
0
238+
}
239+
}
240+
201241
////////////////////////////////////////////////////////////////////////////////
202242
// sockaddr and misc bindings
203243
////////////////////////////////////////////////////////////////////////////////

library/std/src/sys/net/connection/socket/unix.rs

Lines changed: 160 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1-
use libc::{MSG_PEEK, c_int, c_void, size_t, sockaddr, socklen_t};
1+
use libc::{
2+
CMSG_DATA, CMSG_FIRSTHDR, CMSG_LEN, CMSG_NXTHDR, MSG_PEEK, c_int, c_uint, c_void, cmsghdr,
3+
iovec, msghdr, size_t, sockaddr, sockaddr_storage, socklen_t,
4+
};
25

36
#[cfg(not(any(target_os = "espidf", target_os = "nuttx")))]
47
use crate::ffi::CStr;
58
use crate::io::{self, BorrowedBuf, BorrowedCursor, IoSlice, IoSliceMut};
9+
use crate::mem::zeroed;
610
use crate::net::{Shutdown, SocketAddr};
711
use crate::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, RawFd};
12+
use crate::ptr::copy_nonoverlapping;
813
use crate::sys::fd::FileDesc;
9-
use crate::sys::net::{getsockopt, setsockopt};
14+
use crate::sys::net::{SockaddrLike, getsockopt, setsockopt};
1015
use crate::sys::pal::IsMinusOne;
1116
use crate::sys_common::{AsInner, FromInner, IntoInner};
1217
use crate::time::{Duration, Instant};
@@ -62,6 +67,63 @@ pub fn cvt_gai(err: c_int) -> io::Result<()> {
6267
))
6368
}
6469

70+
#[repr(C)]
71+
pub union CmsgIter<'buf> {
72+
_align: msghdr,
73+
inner: CmsgIterInner<'buf>,
74+
}
75+
76+
#[derive(Clone, Copy)]
77+
#[repr(C)]
78+
struct CmsgIterInner<'buf> {
79+
_padding: [u8; size_of::<usize>() + size_of::<socklen_t>() + size_of::<size_t>()],
80+
curr_cmsg: *mut cmsghdr,
81+
cmsg_buf: &'buf [u8],
82+
cmsg_buf_len: usize,
83+
}
84+
85+
#[repr(transparent)]
86+
pub struct CmsgBuf<'buf>(&'buf mut [u8]);
87+
88+
impl<'buf> CmsgBuf<'buf> {
89+
pub fn new(buf: &'buf mut [u8]) -> io::Result<Self> {
90+
if buf.as_ptr().align_offset(align_of::<cmsghdr>()) == 0 {
91+
Ok(CmsgBuf(buf))
92+
} else {
93+
Err(io::Error::new(io::ErrorKind::InvalidInput, "unaligned buffer"))
94+
}
95+
}
96+
97+
pub unsafe fn new_unchecked(buf: &'buf mut [u8]) -> Self {
98+
CmsgBuf(buf)
99+
}
100+
}
101+
102+
impl<'buf> Iterator for CmsgIter<'buf> {
103+
type Item = (size_t, c_int, c_int, &'buf [u8]);
104+
105+
fn next(&mut self) -> Option<Self::Item> {
106+
unsafe {
107+
if self.inner.curr_cmsg.is_null() {
108+
None
109+
} else {
110+
let curr = *self.inner.curr_cmsg;
111+
let data_ptr = CMSG_DATA(self.inner.curr_cmsg);
112+
let ptrdiff = data_ptr.offset_from_unsigned(self.inner.curr_cmsg as *const u8);
113+
let r = (
114+
curr.cmsg_len,
115+
curr.cmsg_level,
116+
curr.cmsg_type,
117+
crate::slice::from_raw_parts(data_ptr, curr.cmsg_len - ptrdiff),
118+
);
119+
self.inner.curr_cmsg =
120+
CMSG_NXTHDR(self as *mut _ as *mut msghdr, self.inner.curr_cmsg);
121+
Some(r)
122+
}
123+
}
124+
}
125+
}
126+
65127
impl Socket {
66128
pub fn new(addr: &SocketAddr, ty: c_int) -> io::Result<Socket> {
67129
let fam = match *addr {
@@ -362,11 +424,52 @@ impl Socket {
362424
}
363425

364426
#[cfg(any(target_os = "android", target_os = "linux", target_os = "cygwin"))]
365-
pub fn recv_msg(&self, msg: &mut libc::msghdr) -> io::Result<usize> {
427+
pub fn recv_msg_(&self, msg: &mut libc::msghdr) -> io::Result<usize> {
366428
let n = cvt(unsafe { libc::recvmsg(self.as_raw_fd(), msg, libc::MSG_CMSG_CLOEXEC) })?;
367429
Ok(n as usize)
368430
}
369431

432+
// user is responsible of aligning cmsg_buf to cmsghdr's alignment (use align_of)
433+
pub fn recv_msg<'a, 'b, T>(
434+
&self,
435+
iov_buf: &mut [IoSliceMut<'_>],
436+
cmsg_buf: CmsgBuf<'a>,
437+
flags: c_int,
438+
) -> io::Result<(usize, T, c_int, CmsgIter<'b>)>
439+
where
440+
T: SockaddrLike,
441+
'a: 'b,
442+
{
443+
unsafe {
444+
let mut msg: msghdr = zeroed();
445+
let mut addr: sockaddr_storage = zeroed();
446+
msg.msg_name = (&raw mut addr).cast();
447+
msg.msg_namelen = mem::size_of_val(&addr) as _;
448+
449+
msg.msg_iovlen = iov_buf.len();
450+
msg.msg_iov = iov_buf.as_mut_ptr().cast();
451+
452+
msg.msg_controllen = cmsg_buf.0.len();
453+
if msg.msg_controllen != 0 {
454+
msg.msg_control = cmsg_buf.0.as_mut_ptr().cast();
455+
}
456+
457+
msg.msg_flags = 0;
458+
459+
let bytes = cvt(libc::recvmsg(self.as_raw_fd(), &raw mut msg, flags))? as usize;
460+
461+
let addr = SockaddrLike::from_storage(&addr, msg.msg_namelen)?;
462+
463+
let mut iter: CmsgIter<'_> = zeroed();
464+
iter.inner.cmsg_buf = cmsg_buf.0;
465+
iter.inner.cmsg_buf_len = msg.msg_controllen;
466+
let fst_cmsg = CMSG_FIRSTHDR((&raw const iter).cast());
467+
iter.inner.curr_cmsg = fst_cmsg;
468+
469+
Ok((bytes, addr, msg.msg_flags, iter))
470+
}
471+
}
472+
370473
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
371474
self.recv_from_with_flags(buf, MSG_PEEK)
372475
}
@@ -385,11 +488,64 @@ impl Socket {
385488
}
386489

387490
#[cfg(any(target_os = "android", target_os = "linux", target_os = "cygwin"))]
388-
pub fn send_msg(&self, msg: &mut libc::msghdr) -> io::Result<usize> {
491+
pub fn send_msg_(&self, msg: &mut libc::msghdr) -> io::Result<usize> {
389492
let n = cvt(unsafe { libc::sendmsg(self.as_raw_fd(), msg, 0) })?;
390493
Ok(n as usize)
391494
}
392495

496+
// user is responsible of aligning cmsg_buf to cmsghdr's alignment (use align_of)
497+
pub fn send_msg<T>(
498+
&self,
499+
addr: Option<&T>,
500+
iov: &[IoSlice<'_>],
501+
cmsgs: &[(c_int, c_int, &[u8])],
502+
cmsg_buf: CmsgBuf<'_>,
503+
flags: c_int,
504+
) -> io::Result<usize>
505+
where
506+
T: SockaddrLike,
507+
{
508+
unsafe {
509+
let mut msg: msghdr = zeroed();
510+
let mut addr_s: sockaddr_storage = zeroed();
511+
512+
if let Some(addr_) = addr {
513+
let len = addr_.to_storage(&mut addr_s);
514+
msg.msg_namelen = len;
515+
msg.msg_name = (&raw mut addr_s).cast();
516+
}
517+
518+
msg.msg_iovlen = iov.len();
519+
msg.msg_iov = iov.as_ptr().cast::<IoSlice<'_>>() as *mut iovec;
520+
521+
// cmsg
522+
msg.msg_controllen = cmsg_buf.0.len();
523+
msg.msg_control = cmsg_buf.0.as_mut_ptr().cast();
524+
let mut curr_cmsg_hdr = CMSG_FIRSTHDR(&raw const msg);
525+
for (cmsg_level, cmsg_type, cmsg_data) in cmsgs {
526+
if curr_cmsg_hdr.is_null() {
527+
return Err(io::Error::new(
528+
io::ErrorKind::InvalidInput,
529+
"cmsg_buf supplied is too small to hold all control messages",
530+
));
531+
}
532+
533+
(*curr_cmsg_hdr).cmsg_level = *cmsg_level;
534+
(*curr_cmsg_hdr).cmsg_type = *cmsg_type;
535+
(*curr_cmsg_hdr).cmsg_len = CMSG_LEN(cmsg_data.len() as c_uint) as usize;
536+
537+
let cmsg_data_ptr = CMSG_DATA(curr_cmsg_hdr);
538+
copy_nonoverlapping((*cmsg_data).as_ptr(), cmsg_data_ptr, cmsg_data.len());
539+
540+
curr_cmsg_hdr = CMSG_NXTHDR(&raw const msg, curr_cmsg_hdr as *const _);
541+
}
542+
543+
let bytes = cvt(libc::sendmsg(self.as_raw_fd(), &raw mut msg, flags))? as usize;
544+
545+
Ok(bytes)
546+
}
547+
}
548+
393549
pub fn set_timeout(&self, dur: Option<Duration>, kind: libc::c_int) -> io::Result<()> {
394550
let timeout = match dur {
395551
Some(dur) => {

0 commit comments

Comments
 (0)