Skip to content

Commit 297f02b

Browse files
authored
Fix broken RecvMsgOut parsing (#257)
* Fix broken RecvMsgOut parsing Signed-off-by: Alex Saveau <[email protected]> * Add payload truncation test Signed-off-by: Alex Saveau <[email protected]> --------- Signed-off-by: Alex Saveau <[email protected]>
1 parent cc8060a commit 297f02b

File tree

6 files changed

+208
-36
lines changed

6 files changed

+208
-36
lines changed

Cargo.lock

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

io-uring-test/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ anyhow = "1"
1313
tempfile = "3"
1414
once_cell = "1"
1515
socket2 = "0.5"
16+
semver = "1.0.21"
1617

1718
[features]
1819
direct-syscall = [ "io-uring/direct-syscall" ]

io-uring-test/src/main.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,20 @@ mod tests;
44

55
use io_uring::{cqueue, squeue, IoUring, Probe};
66
use std::cell::Cell;
7+
use std::ffi::CStr;
8+
use std::mem;
79

810
pub struct Test {
911
probe: Probe,
1012
target: Option<String>,
1113
count: Cell<usize>,
14+
kernel_version: semver::Version,
15+
}
16+
17+
impl Test {
18+
fn check_kernel_version(&self, min_version: &str) -> bool {
19+
self.kernel_version >= semver::Version::parse(min_version).unwrap()
20+
}
1221
}
1322

1423
fn main() -> anyhow::Result<()> {
@@ -63,6 +72,16 @@ fn test<S: squeue::EntryMarker, C: cqueue::EntryMarker>(
6372
probe,
6473
target: std::env::args().nth(1),
6574
count: Cell::new(0),
75+
kernel_version: {
76+
let mut uname: libc::utsname = unsafe { mem::zeroed() };
77+
unsafe {
78+
assert!(libc::uname(&mut uname) >= 0);
79+
}
80+
81+
let version = unsafe { CStr::from_ptr(uname.release.as_ptr()) };
82+
let version = version.to_str().unwrap();
83+
semver::Version::parse(version).unwrap()
84+
},
6685
};
6786

6887
tests::queue::test_nop(&mut ring, &test)?;
@@ -132,6 +151,7 @@ fn test<S: squeue::EntryMarker, C: cqueue::EntryMarker>(
132151
tests::net::test_shutdown(&mut ring, &test)?;
133152
tests::net::test_socket(&mut ring, &test)?;
134153
tests::net::test_udp_recvmsg_multishot(&mut ring, &test)?;
154+
tests::net::test_udp_recvmsg_multishot_trunc(&mut ring, &test)?;
135155
tests::net::test_udp_sendzc_with_dest(&mut ring, &test)?;
136156

137157
// queue

io-uring-test/src/tests/net.rs

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,6 +1436,142 @@ pub fn test_udp_recvmsg_multishot<S: squeue::EntryMarker, C: cqueue::EntryMarker
14361436

14371437
Ok(())
14381438
}
1439+
pub fn test_udp_recvmsg_multishot_trunc<S: squeue::EntryMarker, C: cqueue::EntryMarker>(
1440+
ring: &mut IoUring<S, C>,
1441+
test: &Test,
1442+
) -> anyhow::Result<()> {
1443+
require!(
1444+
test;
1445+
test.probe.is_supported(opcode::RecvMsgMulti::CODE);
1446+
test.probe.is_supported(opcode::ProvideBuffers::CODE);
1447+
test.probe.is_supported(opcode::SendMsg::CODE);
1448+
test.check_kernel_version("6.6.0" /* 6.2 is totally broken and returns nonsense upon truncation */);
1449+
);
1450+
1451+
println!("test udp_recvmsg_multishot_trunc");
1452+
1453+
let server_socket: socket2::Socket = std::net::UdpSocket::bind("127.0.0.1:0").unwrap().into();
1454+
let server_addr = server_socket.local_addr()?;
1455+
1456+
const BUF_GROUP: u16 = 33;
1457+
const DATA: &[u8] = b"testfooo for me";
1458+
let mut buf1 = [0u8; 20]; // 20 = size_of::<io_uring_recvmsg_out>() + msghdr.msg_namelen
1459+
let mut buf2 = [0u8; 20 + DATA.len()];
1460+
let mut buf3 = [0u8; 20 + DATA.len()];
1461+
let mut buffers = [
1462+
buf1.as_mut_slice(),
1463+
buf2.as_mut_slice(),
1464+
buf3.as_mut_slice(),
1465+
];
1466+
1467+
for (index, buf) in buffers.iter_mut().enumerate() {
1468+
let provide_bufs_e = io_uring::opcode::ProvideBuffers::new(
1469+
(**buf).as_mut_ptr(),
1470+
buf.len() as i32,
1471+
1,
1472+
BUF_GROUP,
1473+
index as u16,
1474+
)
1475+
.build()
1476+
.user_data(11)
1477+
.into();
1478+
unsafe { ring.submission().push(&provide_bufs_e)? };
1479+
ring.submitter().submit_and_wait(1)?;
1480+
let cqes: Vec<io_uring::cqueue::Entry> = ring.completion().map(Into::into).collect();
1481+
assert_eq!(cqes.len(), 1);
1482+
assert_eq!(cqes[0].user_data(), 11);
1483+
assert_eq!(cqes[0].result(), 0);
1484+
assert_eq!(cqes[0].flags(), 0);
1485+
}
1486+
1487+
// This structure is actually only used for input arguments to the kernel
1488+
// (and only name length and control length are actually relevant).
1489+
let mut msghdr: libc::msghdr = unsafe { std::mem::zeroed() };
1490+
msghdr.msg_namelen = 4;
1491+
1492+
let recvmsg_e = opcode::RecvMsgMulti::new(
1493+
Fd(server_socket.as_raw_fd()),
1494+
&msghdr as *const _,
1495+
BUF_GROUP,
1496+
)
1497+
.flags(libc::MSG_TRUNC as u32)
1498+
.build()
1499+
.user_data(77)
1500+
.into();
1501+
unsafe { ring.submission().push(&recvmsg_e)? };
1502+
ring.submitter().submit().unwrap();
1503+
1504+
let client_socket: socket2::Socket = std::net::UdpSocket::bind("127.0.0.1:0").unwrap().into();
1505+
1506+
let data = [io::IoSlice::new(DATA)];
1507+
let mut msghdr1: libc::msghdr = unsafe { mem::zeroed() };
1508+
msghdr1.msg_name = server_addr.as_ptr() as *const _ as *mut _;
1509+
msghdr1.msg_namelen = server_addr.len();
1510+
msghdr1.msg_iov = data.as_ptr() as *const _ as *mut _;
1511+
msghdr1.msg_iovlen = 1;
1512+
1513+
let send_msgs = (0..2)
1514+
.map(|_| {
1515+
opcode::SendMsg::new(Fd(client_socket.as_raw_fd()), &msghdr1 as *const _)
1516+
.build()
1517+
.user_data(55)
1518+
.into()
1519+
})
1520+
.collect::<Vec<_>>();
1521+
unsafe { ring.submission().push_multiple(&send_msgs)? };
1522+
ring.submitter().submit().unwrap();
1523+
1524+
ring.submitter().submit_and_wait(4).unwrap();
1525+
let cqes: Vec<io_uring::cqueue::Entry> = ring.completion().map(Into::into).collect();
1526+
assert_eq!(cqes.len(), 4);
1527+
let mut i = 0;
1528+
for cqe in cqes {
1529+
let is_more = io_uring::cqueue::more(cqe.flags());
1530+
match cqe.user_data() {
1531+
// send notifications
1532+
55 => {
1533+
assert!(cqe.result() > 0);
1534+
assert!(!is_more);
1535+
}
1536+
// RecvMsgMulti
1537+
77 => {
1538+
assert!(cqe.result() > 0);
1539+
assert!(is_more);
1540+
let buf_id = io_uring::cqueue::buffer_select(cqe.flags()).unwrap();
1541+
let tmp_buf = &buffers[buf_id as usize];
1542+
let msg = types::RecvMsgOut::parse(tmp_buf, &msghdr);
1543+
1544+
match i {
1545+
0 => {
1546+
let msg = msg.unwrap();
1547+
assert!(msg.is_payload_truncated());
1548+
assert!(msg.is_name_data_truncated());
1549+
assert_eq!(DATA.len(), msg.incoming_payload_len() as usize);
1550+
assert!(msg.payload_data().is_empty());
1551+
assert!(4 < msg.incoming_name_len());
1552+
assert_eq!(4, msg.name_data().len());
1553+
}
1554+
1 => {
1555+
let msg = msg.unwrap();
1556+
assert!(!msg.is_payload_truncated());
1557+
assert!(msg.is_name_data_truncated());
1558+
assert_eq!(DATA.len(), msg.incoming_payload_len() as usize);
1559+
assert_eq!(DATA, msg.payload_data());
1560+
assert!(4 < msg.incoming_name_len());
1561+
assert_eq!(4, msg.name_data().len());
1562+
}
1563+
_ => unreachable!(),
1564+
}
1565+
i += 1;
1566+
}
1567+
_ => {
1568+
unreachable!()
1569+
}
1570+
}
1571+
}
1572+
1573+
Ok(())
1574+
}
14391575
pub fn test_udp_sendzc_with_dest<S: squeue::EntryMarker, C: cqueue::EntryMarker>(
14401576
ring: &mut IoUring<S, C>,
14411577
test: &Test,

io-uring-test/src/utils.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ macro_rules! require {
66
$test:expr;
77
$( $cond:expr ; )*
88
) => {
9+
let test = $test;
910
let mut cond = true;
1011

11-
if let Some(target) = $test.target.as_ref() {
12+
if let Some(target) = test.target.as_ref() {
1213
cond &= function_name!().contains(target);
1314
}
1415

@@ -20,7 +21,7 @@ macro_rules! require {
2021
return Ok(());
2122
}
2223

23-
$test.count.set($test.count.get() + 1);
24+
test.count.set(test.count.get() + 1);
2425
}
2526
}
2627

src/types.rs

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ pub(crate) mod sealed {
4343
use crate::sys;
4444
use crate::util::{cast_ptr, unwrap_nonzero, unwrap_u32};
4545
use bitflags::bitflags;
46+
use std::convert::TryFrom;
4647
use std::marker::PhantomData;
4748
use std::num::NonZeroU32;
4849
use std::os::unix::io::RawFd;
@@ -377,10 +378,7 @@ pub struct RecvMsgOut<'buf> {
377378
/// If it is smaller, it gets 0-padded to fill the whole field. In either case,
378379
/// this fixed amount of space is reserved in the result buffer.
379380
msghdr_name_len: usize,
380-
/// The fixed length of the control field, in bytes.
381-
///
382-
/// This follows the same semantics as the field above, but for control data.
383-
msghdr_control_len: usize,
381+
384382
name_data: &'buf [u8],
385383
control_data: &'buf [u8],
386384
payload_data: &'buf [u8],
@@ -396,7 +394,15 @@ impl<'buf> RecvMsgOut<'buf> {
396394
/// (only `msg_namelen` and `msg_controllen` fields are relevant).
397395
#[allow(clippy::result_unit_err)]
398396
pub fn parse(buffer: &'buf [u8], msghdr: &libc::msghdr) -> Result<Self, ()> {
399-
if buffer.len() < std::mem::size_of::<sys::io_uring_recvmsg_out>() {
397+
let msghdr_name_len = usize::try_from(msghdr.msg_namelen).unwrap();
398+
let msghdr_control_len = usize::try_from(msghdr.msg_controllen).unwrap();
399+
400+
if Self::DATA_START
401+
.checked_add(msghdr_name_len)
402+
.and_then(|acc| acc.checked_add(msghdr_control_len))
403+
.map(|header_len| buffer.len() < header_len)
404+
.unwrap_or(true)
405+
{
400406
return Err(());
401407
}
402408
// SAFETY: buffer (minimum) length is checked here above.
@@ -407,45 +413,36 @@ impl<'buf> RecvMsgOut<'buf> {
407413
.read_unaligned()
408414
};
409415

410-
let msghdr_name_len = msghdr.msg_namelen as _;
411-
let msghdr_control_len = msghdr.msg_controllen as _;
412-
413-
// Check total length upfront, so that further logic here
414-
// below can safely use unchecked/saturating math.
415-
let length_overflow = Some(Self::DATA_START)
416-
.and_then(|acc| acc.checked_add(msghdr_name_len))
417-
.and_then(|acc| acc.checked_add(msghdr_control_len))
418-
.and_then(|acc| acc.checked_add(header.payloadlen as usize))
419-
.map(|total_len| total_len > buffer.len())
420-
.unwrap_or(true);
421-
if length_overflow {
422-
return Err(());
423-
}
424-
416+
// min is used because the header may indicate the true size of the data
417+
// while what we received was truncated.
425418
let (name_data, control_start) = {
426419
let name_start = Self::DATA_START;
427-
let name_size = usize::min(header.namelen as usize, msghdr_name_len);
428-
let name_data_end = name_start.saturating_add(name_size);
429-
let name_data = &buffer[name_start..name_data_end];
430-
let name_field_end = name_start.saturating_add(msghdr_name_len);
431-
(name_data, name_field_end)
420+
let name_data_end =
421+
name_start + usize::min(usize::try_from(header.namelen).unwrap(), msghdr_name_len);
422+
let name_field_end = name_start + msghdr_name_len;
423+
(&buffer[name_start..name_data_end], name_field_end)
432424
};
433425
let (control_data, payload_start) = {
434-
let control_size = usize::min(header.controllen as usize, msghdr_control_len);
435-
let control_data_end = control_start.saturating_add(control_size);
436-
let control_data = &buffer[control_start..control_data_end];
437-
let control_field_end = control_start.saturating_add(msghdr_control_len);
438-
(control_data, control_field_end)
426+
let control_data_end = control_start
427+
+ usize::min(
428+
usize::try_from(header.controllen).unwrap(),
429+
msghdr_control_len,
430+
);
431+
let control_field_end = control_start + msghdr_control_len;
432+
(&buffer[control_start..control_data_end], control_field_end)
439433
};
440434
let payload_data = {
441-
let payload_data_end = payload_start.saturating_add(header.payloadlen as usize);
435+
let payload_data_end = payload_start
436+
+ usize::min(
437+
usize::try_from(header.payloadlen).unwrap(),
438+
buffer.len() - payload_start,
439+
);
442440
&buffer[payload_start..payload_data_end]
443441
};
444442

445443
Ok(Self {
446444
header,
447445
msghdr_name_len,
448-
msghdr_control_len,
449446
name_data,
450447
control_data,
451448
payload_data,
@@ -490,7 +487,7 @@ impl<'buf> RecvMsgOut<'buf> {
490487
/// When `true`, data returned by `control_data()` is truncated and
491488
/// incomplete.
492489
pub fn is_control_data_truncated(&self) -> bool {
493-
self.header.controllen as usize > self.msghdr_control_len
490+
(self.header.flags & u32::try_from(libc::MSG_CTRUNC).unwrap()) != 0
494491
}
495492

496493
/// Message control data, with the same semantics as `msghdr.msg_control`.
@@ -503,14 +500,24 @@ impl<'buf> RecvMsgOut<'buf> {
503500
/// When `true`, data returned by `payload_data()` is truncated and
504501
/// incomplete.
505502
pub fn is_payload_truncated(&self) -> bool {
506-
self.header.flags & (libc::MSG_TRUNC as u32) != 0
503+
(self.header.flags & u32::try_from(libc::MSG_TRUNC).unwrap()) != 0
507504
}
508505

509506
/// Message payload, as buffered by the kernel.
510507
pub fn payload_data(&self) -> &[u8] {
511508
self.payload_data
512509
}
513510

511+
/// Return the length of the incoming `payload` data.
512+
///
513+
/// This may be larger than the size of the content returned by
514+
/// `payload_data()`, if the kernel could not fit all the incoming
515+
/// data in the provided buffer size. In that case, payload data in
516+
/// the result buffer gets truncated.
517+
pub fn incoming_payload_len(&self) -> u32 {
518+
self.header.payloadlen
519+
}
520+
514521
/// Message flags, with the same semantics as `msghdr.msg_flags`.
515522
pub fn flags(&self) -> u32 {
516523
self.header.flags

0 commit comments

Comments
 (0)