Skip to content

Commit efe2983

Browse files
bonziniandreeaflorescu
authored andcommitted
bytes: avoid torn writes with memcpy
Some memcpy implementations are copying bytes one at a time, which is slow and also breaks the virtio specification by splitting writes to the fields of the virtio descriptor. So, reimplement memcpy in Rust and copy in larger pieces, according to the largest possible alignment allowed by the pointer values. Signed-off-by: Serban Iorga <[email protected]> Signed-off-by: Andreea Florescu <[email protected]> Signed-off-by: Alexandru Agache <[email protected]> Signed-off-by: Paolo Bonzini <[email protected]>
1 parent 257d7c7 commit efe2983

File tree

2 files changed

+200
-15
lines changed

2 files changed

+200
-15
lines changed

src/guest_memory.rs

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,9 +864,14 @@ impl<T: GuestMemory> Bytes<GuestAddress> for T {
864864
mod tests {
865865
use super::*;
866866
#[cfg(feature = "backend-mmap")]
867+
use crate::bytes::ByteValued;
868+
#[cfg(feature = "backend-mmap")]
867869
use crate::{GuestAddress, GuestMemoryMmap};
868870
#[cfg(feature = "backend-mmap")]
869871
use std::io::Cursor;
872+
#[cfg(feature = "backend-mmap")]
873+
use std::time::{Duration, Instant};
874+
870875
use vmm_sys_util::tempfile::TempFile;
871876

872877
#[cfg(feature = "backend-mmap")]
@@ -905,4 +910,124 @@ mod tests {
905910
.unwrap()
906911
);
907912
}
913+
914+
// Runs the provided closure in a loop, until at least `duration` time units have elapsed.
915+
#[cfg(feature = "backend-mmap")]
916+
fn loop_timed<F>(duration: Duration, mut f: F)
917+
where
918+
F: FnMut() -> (),
919+
{
920+
// We check the time every `CHECK_PERIOD` iterations.
921+
const CHECK_PERIOD: u64 = 1_000_000;
922+
let start_time = Instant::now();
923+
924+
loop {
925+
for _ in 0..CHECK_PERIOD {
926+
f();
927+
}
928+
if start_time.elapsed() >= duration {
929+
break;
930+
}
931+
}
932+
}
933+
934+
// Helper method for the following test. It spawns a writer and a reader thread, which
935+
// simultaneously try to access an object that is placed at the junction of two memory regions.
936+
// The part of the object that's continuously accessed is a member of type T. The writer
937+
// flips all the bits of the member with every write, while the reader checks that every byte
938+
// has the same value (and thus it did not do a non-atomic access). The test succeeds if
939+
// no mismatch is detected after performing accesses for a pre-determined amount of time.
940+
#[cfg(feature = "backend-mmap")]
941+
fn non_atomic_access_helper<T>()
942+
where
943+
T: ByteValued
944+
+ std::fmt::Debug
945+
+ From<u8>
946+
+ Into<u128>
947+
+ std::ops::Not<Output = T>
948+
+ PartialEq,
949+
{
950+
use std::mem;
951+
use std::thread;
952+
953+
// A dummy type that's always going to have the same alignment as the first member,
954+
// and then adds some bytes at the end.
955+
#[derive(Clone, Copy, Debug, Default, PartialEq)]
956+
struct Data<T> {
957+
val: T,
958+
some_bytes: [u8; 7],
959+
}
960+
961+
// Some sanity checks.
962+
assert_eq!(mem::align_of::<T>(), mem::align_of::<Data<T>>());
963+
assert_eq!(mem::size_of::<T>(), mem::align_of::<T>());
964+
965+
unsafe impl<T: ByteValued> ByteValued for Data<T> {}
966+
967+
// Start of first guest memory region.
968+
let start = GuestAddress(0);
969+
let region_len = 1 << 12;
970+
971+
// The address where we start writing/reading a Data<T> value.
972+
let data_start = GuestAddress((region_len - mem::size_of::<T>()) as u64);
973+
974+
let mem = GuestMemoryMmap::from_ranges(&[
975+
(start, region_len),
976+
(start.unchecked_add(region_len as u64), region_len),
977+
])
978+
.unwrap();
979+
980+
// Need to clone this and move it into the new thread we create.
981+
let mem2 = mem.clone();
982+
// Just some bytes.
983+
let some_bytes = [1u8, 2, 4, 16, 32, 64, 128];
984+
985+
let mut data = Data {
986+
val: T::from(0u8),
987+
some_bytes,
988+
};
989+
990+
// Simple check that cross-region write/read is ok.
991+
mem.write_obj(data, data_start).unwrap();
992+
let read_data = mem.read_obj::<Data<T>>(data_start).unwrap();
993+
assert_eq!(read_data, data);
994+
995+
let t = thread::spawn(move || {
996+
let mut count: u64 = 0;
997+
998+
loop_timed(Duration::from_secs(3), || {
999+
let data = mem2.read_obj::<Data<T>>(data_start).unwrap();
1000+
1001+
// Every time data is written to memory by the other thread, the value of
1002+
// data.val alternates between 0 and T::MAX, so the inner bytes should always
1003+
// have the same value. If they don't match, it means we read a partial value,
1004+
// so the access was not atomic.
1005+
let bytes = data.val.into().to_le_bytes();
1006+
for i in 1..mem::size_of::<T>() {
1007+
if bytes[0] != bytes[i] {
1008+
panic!(
1009+
"val bytes don't match {:?} after {} iterations",
1010+
&bytes[..mem::size_of::<T>()],
1011+
count
1012+
);
1013+
}
1014+
}
1015+
count += 1;
1016+
});
1017+
});
1018+
1019+
// Write the object while flipping the bits of data.val over and over again.
1020+
loop_timed(Duration::from_secs(3), || {
1021+
mem.write_obj(data, data_start).unwrap();
1022+
data.val = !data.val;
1023+
});
1024+
1025+
t.join().unwrap()
1026+
}
1027+
1028+
#[cfg(feature = "backend-mmap")]
1029+
#[test]
1030+
fn test_non_atomic_access() {
1031+
non_atomic_access_helper::<u16>()
1032+
}
9081033
}

src/volatile_memory.rs

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,57 @@ impl<'a> VolatileSlice<'a> {
463463
}
464464
}
465465

466+
// Return the largest value that `addr` is aligned to. Forcing this function to return 1 will
467+
// cause test_non_atomic_access to fail.
468+
fn alignment(addr: usize) -> usize {
469+
// Rust is silly and does not let me write addr & -addr.
470+
addr & (!addr + 1)
471+
}
472+
473+
// Has the same safety requirements as `read_volatile` + `write_volatile`, namely:
474+
// - `src_addr` and `dst_addr` must be valid for reads/writes.
475+
// - `src_addr` and `dst_addr` must be properly aligned with respect to `align`.
476+
// - `src_addr` must point to a properly initialized value, which is true here because
477+
// we're only using integer primitives.
478+
unsafe fn copy_single(align: usize, src_addr: usize, dst_addr: usize) {
479+
match align {
480+
8 => write_volatile(dst_addr as *mut u64, read_volatile(src_addr as *const u64)),
481+
4 => write_volatile(dst_addr as *mut u32, read_volatile(src_addr as *const u32)),
482+
2 => write_volatile(dst_addr as *mut u16, read_volatile(src_addr as *const u16)),
483+
1 => write_volatile(dst_addr as *mut u8, read_volatile(src_addr as *const u8)),
484+
_ => unreachable!(),
485+
}
486+
}
487+
488+
fn copy_slice(dst: &mut [u8], src: &[u8]) -> usize {
489+
let total = min(src.len(), dst.len());
490+
let mut left = total;
491+
492+
let mut src_addr = src.as_ptr() as usize;
493+
let mut dst_addr = dst.as_ptr() as usize;
494+
let align = min(alignment(src_addr), alignment(dst_addr));
495+
496+
let mut copy_aligned_slice = |min_align| {
497+
while align >= min_align && left >= min_align {
498+
// Safe because we check alignment beforehand, the memory areas are valid for
499+
// reads/writes, and the source always contains a valid value.
500+
unsafe { copy_single(min_align, src_addr, dst_addr) };
501+
src_addr += min_align;
502+
dst_addr += min_align;
503+
left -= min_align;
504+
}
505+
};
506+
507+
if size_of::<usize>() > 4 {
508+
copy_aligned_slice(8);
509+
}
510+
copy_aligned_slice(4);
511+
copy_aligned_slice(2);
512+
copy_aligned_slice(1);
513+
514+
total
515+
}
516+
466517
impl Bytes<usize> for VolatileSlice<'_> {
467518
type E = Error;
468519

@@ -482,13 +533,12 @@ impl Bytes<usize> for VolatileSlice<'_> {
482533
if addr >= self.size {
483534
return Err(Error::OutOfBounds { addr });
484535
}
485-
unsafe {
486-
// Guest memory can't strictly be modeled as a slice because it is
487-
// volatile. Writing to it with what compiles down to a memcpy
488-
// won't hurt anything as long as we get the bounds checks right.
489-
let mut slice: &mut [u8] = &mut self.as_mut_slice()[addr..];
490-
slice.write(buf).map_err(Error::IOError)
491-
}
536+
537+
// Guest memory can't strictly be modeled as a slice because it is
538+
// volatile. Writing to it with what is essentially a fancy memcpy
539+
// won't hurt anything as long as we get the bounds checks right.
540+
let slice = unsafe { self.as_mut_slice() }.split_at_mut(addr).1;
541+
Ok(copy_slice(slice, buf))
492542
}
493543

494544
/// # Examples
@@ -504,17 +554,16 @@ impl Bytes<usize> for VolatileSlice<'_> {
504554
/// assert!(res.is_ok());
505555
/// assert_eq!(res.unwrap(), 14);
506556
/// ```
507-
fn read(&self, mut buf: &mut [u8], addr: usize) -> Result<usize> {
557+
fn read(&self, buf: &mut [u8], addr: usize) -> Result<usize> {
508558
if addr >= self.size {
509559
return Err(Error::OutOfBounds { addr });
510560
}
511-
unsafe {
512-
// Guest memory can't strictly be modeled as a slice because it is
513-
// volatile. Writing to it with what compiles down to a memcpy
514-
// won't hurt anything as long as we get the bounds checks right.
515-
let slice: &[u8] = &self.as_slice()[addr..];
516-
buf.write(slice).map_err(Error::IOError)
517-
}
561+
562+
// Guest memory can't strictly be modeled as a slice because it is
563+
// volatile. Writing to it with what is essentially a fancy memcpy
564+
// won't hurt anything as long as we get the bounds checks right.
565+
let slice = unsafe { self.as_slice() }.split_at(addr).1;
566+
Ok(copy_slice(buf, slice))
518567
}
519568

520569
/// # Examples
@@ -1560,4 +1609,15 @@ mod tests {
15601609
}
15611610
);
15621611
}
1612+
1613+
#[test]
1614+
fn alignment() {
1615+
let a = [0u8; 64];
1616+
let a = &a[a.as_ptr().align_offset(32)] as *const u8 as usize;
1617+
assert!(super::alignment(a) >= 32);
1618+
assert_eq!(super::alignment(a + 9), 1);
1619+
assert_eq!(super::alignment(a + 30), 2);
1620+
assert_eq!(super::alignment(a + 12), 4);
1621+
assert_eq!(super::alignment(a + 8), 8);
1622+
}
15631623
}

0 commit comments

Comments
 (0)