diff --git a/CHANGELOG.md b/CHANGELOG.md index 91468004..a62778cc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ is now a type alias for `GuestRegionContainer`). - \[[#338](https://github.com/rust-vmm/vm-memory/pull/338)\] Make `GuestMemoryAtomic` always implement `Clone`. - \[[#338](https://github.com/rust-vmm/vm-memory/pull/338)\] Make `GuestAddressSpace` a subtrait of `Clone`. +- \[[#339](https://github.com/rust-vmm/vm-memory/pull/339)\] Add `GuestMemory::get_slices()` ### Changed @@ -22,6 +23,8 @@ and `GuestRegionMmap::from_range` to be separate from the error type returned by `GuestRegionCollection` functions. Change return type of `GuestRegionMmap::new` from `Result` to `Option`. - \[#324](https:////github.com/rust-vmm/vm-memory/pull/324)\] `GuestMemoryRegion::bitmap()` now returns a `BitmapSlice`. Accessing the full bitmap is now possible only if the type of the memory region is know, for example with `MmapRegion::bitmap()`. +- \[[#339](https://github.com/rust-vmm/vm-memory/pull/339)\] Fix `Bytes::read()` and `Bytes::write()` not to ignore `try_access()`'s `count` parameter +- \[[#339](https://github.com/rust-vmm/vm-memory/pull/339)\] Implement `Bytes::load()` and `Bytes::store()` with `try_access()` instead of `to_region_addr()` ### Removed diff --git a/src/bitmap/mod.rs b/src/bitmap/mod.rs index cf1555b2..802ebef1 100644 --- a/src/bitmap/mod.rs +++ b/src/bitmap/mod.rs @@ -266,7 +266,9 @@ pub(crate) mod tests { let dirty_len = size_of_val(&val); let (region, region_addr) = m.to_region_addr(dirty_addr).unwrap(); - let slice = m.get_slice(dirty_addr, dirty_len).unwrap(); + let mut slices = m.get_slices(dirty_addr, dirty_len); + let slice = slices.next().unwrap().unwrap(); + assert!(slices.next().is_none()); assert!(range_is_clean(®ion.bitmap(), 0, region.len() as usize)); assert!(range_is_clean(slice.bitmap(), 0, dirty_len)); diff --git a/src/guest_memory.rs b/src/guest_memory.rs index 62d0c531..b9d3b62e 100644 --- a/src/guest_memory.rs +++ b/src/guest_memory.rs @@ -44,6 +44,7 @@ use std::convert::From; use std::fs::File; use std::io; +use std::mem::size_of; use std::ops::{BitAnd, BitOr, Deref}; use std::rc::Rc; use std::sync::atomic::Ordering; @@ -455,6 +456,93 @@ pub trait GuestMemory { .ok_or(Error::InvalidGuestAddress(addr)) .and_then(|(r, addr)| r.get_slice(addr, count)) } + + /// Returns an iterator over [`VolatileSlice`](struct.VolatileSlice.html)s, together covering + /// `count` bytes starting at `addr`. + /// + /// Iterating in this way is necessary because the given address range may be fragmented across + /// multiple [`GuestMemoryRegion`]s. + /// + /// The iterator’s items are wrapped in [`Result`], i.e. errors are reported on individual + /// items. If there is no such error, the cumulative length of all items will be equal to + /// `count`. If `count` is 0, an empty iterator will be returned. + fn get_slices<'a>( + &'a self, + addr: GuestAddress, + count: usize, + ) -> GuestMemorySliceIterator<'a, Self> { + GuestMemorySliceIterator { + mem: self, + addr, + count, + } + } +} + +/// Iterates over [`VolatileSlice`]s that together form a guest memory area. +/// +/// Returned by [`GuestMemory::get_slices()`]. +#[derive(Debug)] +pub struct GuestMemorySliceIterator<'a, M: GuestMemory + ?Sized> { + /// Underlying memory + mem: &'a M, + /// Next address in the guest memory area + addr: GuestAddress, + /// Remaining bytes in the guest memory area + count: usize, +} + +impl<'a, M: GuestMemory + ?Sized> GuestMemorySliceIterator<'a, M> { + /// Helper function for [`::next()`]. + /// + /// Get the next slice (i.e. the one starting from `self.addr` with a length up to + /// `self.count`) and update the internal state. + /// + /// # Safety + /// + /// This function does not reset to `self.count` to 0 in case of error, i.e. will not stop + /// iterating. Actual behavior after an error is ill-defined, so the caller must check the + /// return value, and in case of an error, reset `self.count` to 0. + /// + /// (This is why this function exists, so this resetting can be done in a single central + /// location.) + unsafe fn do_next(&mut self) -> Option>>> { + if self.count == 0 { + return None; + } + + let Some((region, start)) = self.mem.to_region_addr(self.addr) else { + return Some(Err(Error::InvalidGuestAddress(self.addr))); + }; + + let cap = region.len() - start.raw_value(); + let len = std::cmp::min(cap, self.count as GuestUsize); + + self.count -= len as usize; + self.addr = match self.addr.overflowing_add(len as GuestUsize) { + (x @ GuestAddress(0), _) | (x, false) => x, + (_, true) => return Some(Err(Error::GuestAddressOverflow)), + }; + + Some(region.get_slice(start, len as usize)) + } +} + +impl<'a, M: GuestMemory + ?Sized> Iterator for GuestMemorySliceIterator<'a, M> { + type Item = Result>>; + + fn next(&mut self) -> Option { + // SAFETY: + // We reset `self.count` to 0 on error + match unsafe { self.do_next() } { + Some(Ok(slice)) => Some(Ok(slice)), + other => { + // On error (or end), reset to 0 so iteration remains stopped + self.count = 0; + other + } + } + } } impl Bytes for T { @@ -464,8 +552,8 @@ impl Bytes for T { self.try_access( buf.len(), addr, - |offset, _count, caddr, region| -> Result { - region.write(&buf[offset..], caddr) + |offset, count, caddr, region| -> Result { + region.write(&buf[offset..(offset + count)], caddr) }, ) } @@ -474,8 +562,8 @@ impl Bytes for T { self.try_access( buf.len(), addr, - |offset, _count, caddr, region| -> Result { - region.read(&mut buf[offset..], caddr) + |offset, count, caddr, region| -> Result { + region.read(&mut buf[offset..(offset + count)], caddr) }, ) } @@ -591,17 +679,62 @@ impl Bytes for T { } fn store(&self, val: O, addr: GuestAddress, order: Ordering) -> Result<()> { - // `find_region` should really do what `to_region_addr` is doing right now, except - // it should keep returning a `Result`. - self.to_region_addr(addr) - .ok_or(Error::InvalidGuestAddress(addr)) - .and_then(|(region, region_addr)| region.store(val, region_addr, order)) + let expected = size_of::(); + + let completed = self.try_access( + expected, + addr, + |offset, len, region_addr, region| -> Result { + assert_eq!(offset, 0); + if len < expected { + return Err(Error::PartialBuffer { + expected, + completed: 0, + }); + } + region.store(val, region_addr, order).map(|()| expected) + }, + )?; + + if completed < expected { + Err(Error::PartialBuffer { + expected, + completed, + }) + } else { + Ok(()) + } } fn load(&self, addr: GuestAddress, order: Ordering) -> Result { - self.to_region_addr(addr) - .ok_or(Error::InvalidGuestAddress(addr)) - .and_then(|(region, region_addr)| region.load(region_addr, order)) + let expected = size_of::(); + let mut result = None::; + + let completed = self.try_access( + expected, + addr, + |offset, len, region_addr, region| -> Result { + assert_eq!(offset, 0); + if len < expected { + return Err(Error::PartialBuffer { + expected, + completed: 0, + }); + } + result = Some(region.load(region_addr, order)?); + Ok(expected) + }, + )?; + + if completed < expected { + Err(Error::PartialBuffer { + expected, + completed, + }) + } else { + // Must be set because `completed == expected` + Ok(result.unwrap()) + } } } diff --git a/src/mmap/mod.rs b/src/mmap/mod.rs index 2468d662..1ba59f54 100644 --- a/src/mmap/mod.rs +++ b/src/mmap/mod.rs @@ -624,6 +624,56 @@ mod tests { assert!(guest_mem.get_slice(GuestAddress(0xc00), 0x100).is_err()); } + #[test] + fn test_guest_memory_get_slices() { + let start_addr1 = GuestAddress(0); + let start_addr2 = GuestAddress(0x800); + let start_addr3 = GuestAddress(0xc00); + let guest_mem = GuestMemoryMmap::from_ranges(&[ + (start_addr1, 0x400), + (start_addr2, 0x400), + (start_addr3, 0x400), + ]) + .unwrap(); + + // Same cases as `test_guest_memory_get_slice()`, just with `get_slices()`. + let slice_size = 0x200; + let mut slices = guest_mem.get_slices(GuestAddress(0x100), slice_size); + let slice = slices.next().unwrap().unwrap(); + assert!(slices.next().is_none()); + assert_eq!(slice.len(), slice_size); + + let slice_size = 0x400; + let mut slices = guest_mem.get_slices(GuestAddress(0x800), slice_size); + let slice = slices.next().unwrap().unwrap(); + assert!(slices.next().is_none()); + assert_eq!(slice.len(), slice_size); + + // Empty iterator. + assert!(guest_mem + .get_slices(GuestAddress(0x900), 0) + .next() + .is_none()); + + // Error cases, wrong size or base address. + let mut slices = guest_mem.get_slices(GuestAddress(0), 0x500); + assert_eq!(slices.next().unwrap().unwrap().len(), 0x400); + assert!(slices.next().unwrap().is_err()); + assert!(slices.next().is_none()); + let mut slices = guest_mem.get_slices(GuestAddress(0x600), 0x100); + assert!(slices.next().unwrap().is_err()); + assert!(slices.next().is_none()); + let mut slices = guest_mem.get_slices(GuestAddress(0x1000), 0x100); + assert!(slices.next().unwrap().is_err()); + assert!(slices.next().is_none()); + + // Test fragmented case + let mut slices = guest_mem.get_slices(GuestAddress(0xa00), 0x400); + assert_eq!(slices.next().unwrap().unwrap().len(), 0x200); + assert_eq!(slices.next().unwrap().unwrap().len(), 0x200); + assert!(slices.next().is_none()); + } + #[test] fn test_atomic_accesses() { let region = GuestRegionMmap::from_range(GuestAddress(0), 0x1000, None).unwrap();