diff --git a/vhost-user-backend/CHANGELOG.md b/vhost-user-backend/CHANGELOG.md index d1ede25f..9e977abd 100644 --- a/vhost-user-backend/CHANGELOG.md +++ b/vhost-user-backend/CHANGELOG.md @@ -3,6 +3,8 @@ ## [Unreleased] ### Added +- [[#339]](https://github.com/rust-vmm/vhost/pull/339) Add support for `GET_SHMEM_CONFIG` message + ### Changed ### Deprecated ### Fixed diff --git a/vhost-user-backend/src/backend.rs b/vhost-user-backend/src/backend.rs index 20e7daf3..71974463 100644 --- a/vhost-user-backend/src/backend.rs +++ b/vhost-user-backend/src/backend.rs @@ -25,7 +25,7 @@ use std::sync::{Arc, Mutex, RwLock}; use vhost::vhost_user::message::{ VhostTransferStateDirection, VhostTransferStatePhase, VhostUserProtocolFeatures, - VhostUserSharedMsg, + VhostUserShMemConfig, VhostUserSharedMsg, }; use vhost::vhost_user::Backend; use vm_memory::bitmap::Bitmap; @@ -180,6 +180,13 @@ pub trait VhostUserBackend: Send + Sync { "back end does not support state transfer", )) } + + fn get_shmem_config(&self) -> Result { + Err(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "back end does not support shared memory regions", + )) + } } /// Trait without interior mutability for vhost user backend servers to implement concrete services. @@ -322,6 +329,13 @@ pub trait VhostUserBackendMut: Send + Sync { "back end does not support state transfer", )) } + + fn get_shmem_config(&self) -> Result { + Err(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "back end does not support shared memory regions", + )) + } } impl VhostUserBackend for Arc { @@ -411,6 +425,10 @@ impl VhostUserBackend for Arc { fn check_device_state(&self) -> Result<()> { self.deref().check_device_state() } + + fn get_shmem_config(&self) -> Result { + self.deref().get_shmem_config() + } } impl VhostUserBackend for Mutex { @@ -503,6 +521,10 @@ impl VhostUserBackend for Mutex { fn check_device_state(&self) -> Result<()> { self.lock().unwrap().check_device_state() } + + fn get_shmem_config(&self) -> Result { + self.lock().unwrap().get_shmem_config() + } } impl VhostUserBackend for RwLock { @@ -595,6 +617,10 @@ impl VhostUserBackend for RwLock { fn check_device_state(&self) -> Result<()> { self.read().unwrap().check_device_state() } + + fn get_shmem_config(&self) -> Result { + self.read().unwrap().get_shmem_config() + } } #[cfg(test)] diff --git a/vhost-user-backend/src/handler.rs b/vhost-user-backend/src/handler.rs index 86b8c86f..80764cb8 100644 --- a/vhost-user-backend/src/handler.rs +++ b/vhost-user-backend/src/handler.rs @@ -18,7 +18,7 @@ use crate::bitmap::{BitmapReplace, MemRegionBitmap, MmapLogReg}; use userfaultfd::{Uffd, UffdBuilder}; use vhost::vhost_user::message::{ VhostTransferStateDirection, VhostTransferStatePhase, VhostUserConfigFlags, VhostUserLog, - VhostUserMemoryRegion, VhostUserProtocolFeatures, VhostUserSharedMsg, + VhostUserMemoryRegion, VhostUserProtocolFeatures, VhostUserShMemConfig, VhostUserSharedMsg, VhostUserSingleMemoryRegion, VhostUserVirtioFeatures, VhostUserVringAddrFlags, VhostUserVringState, }; @@ -677,6 +677,12 @@ where .map_err(VhostUserError::ReqHandlerError) } + fn get_shmem_config(&self) -> VhostUserResult { + self.backend + .get_shmem_config() + .map_err(VhostUserError::ReqHandlerError) + } + #[cfg(feature = "postcopy")] fn postcopy_advice(&mut self) -> VhostUserResult { let mut uffd_builder = UffdBuilder::new(); diff --git a/vhost/CHANGELOG.md b/vhost/CHANGELOG.md index f72ae24f..00af71cd 100644 --- a/vhost/CHANGELOG.md +++ b/vhost/CHANGELOG.md @@ -4,6 +4,7 @@ ### Added - [[#251]](https://github.com/rust-vmm/vhost/pull/251) Add `SHMEM_MAP` and `SHMEM_UNMAP` support +- [[#339]](https://github.com/rust-vmm/vhost/pull/339) Add support for `GET_SHMEM_CONFIG` message ### Changed ### Deprecated diff --git a/vhost/src/vhost_user/backend_req_handler.rs b/vhost/src/vhost_user/backend_req_handler.rs index d74b0455..2f5defcc 100644 --- a/vhost/src/vhost_user/backend_req_handler.rs +++ b/vhost/src/vhost_user/backend_req_handler.rs @@ -81,6 +81,7 @@ pub trait VhostUserBackendReqHandler { fd: File, ) -> Result>; fn check_device_state(&self) -> Result<()>; + fn get_shmem_config(&self) -> Result; #[cfg(feature = "postcopy")] fn postcopy_advice(&self) -> Result; #[cfg(feature = "postcopy")] @@ -146,6 +147,7 @@ pub trait VhostUserBackendReqHandlerMut { fd: File, ) -> Result>; fn check_device_state(&mut self) -> Result<()>; + fn get_shmem_config(&self) -> Result; #[cfg(feature = "postcopy")] fn postcopy_advice(&mut self) -> Result; #[cfg(feature = "postcopy")] @@ -289,6 +291,10 @@ impl VhostUserBackendReqHandler for Mutex { self.lock().unwrap().check_device_state() } + fn get_shmem_config(&self) -> Result { + self.lock().unwrap().get_shmem_config() + } + #[cfg(feature = "postcopy")] fn postcopy_advice(&self) -> Result { self.lock().unwrap().postcopy_advice() @@ -679,6 +685,11 @@ impl BackendReqHandler { }; self.send_reply_message(&hdr, &msg)?; } + Ok(FrontendReq::GET_SHMEM_CONFIG) => { + self.check_proto_feature(VhostUserProtocolFeatures::SHMEM)?; + let msg = self.backend.get_shmem_config().unwrap_or_default(); + self.send_reply_message(&hdr, &msg)?; + } #[cfg(feature = "postcopy")] Ok(FrontendReq::POSTCOPY_ADVISE) => { self.check_proto_feature(VhostUserProtocolFeatures::PAGEFAULT)?; @@ -1038,4 +1049,108 @@ mod tests { handler.check_state().unwrap_err(); assert!(handler.as_raw_fd() >= 0); } + + // Helper to send GET_SHMEM_CONFIG request and receive response + fn send_get_shmem_config_request( + mut endpoint: Endpoint>, + ) -> VhostUserShMemConfig { + let hdr = VhostUserMsgHeader::new(FrontendReq::GET_SHMEM_CONFIG, 0, 0); + endpoint.send_message(&hdr, &VhostUserEmpty, None).unwrap(); + + let (reply_hdr, reply_config, rfds) = endpoint.recv_body::().unwrap(); + assert_eq!(reply_hdr.get_code().unwrap(), FrontendReq::GET_SHMEM_CONFIG); + assert!(reply_hdr.is_reply()); + assert!(rfds.is_none()); + reply_config + } + + // Helper to create handler with SHMEM protocol feature enabled + fn create_handler_with_shmem( + backend: Arc>, + p1: UnixStream, + ) -> BackendReqHandler> { + let mut handler = BackendReqHandler::new( + Endpoint::>::from_stream(p1), + backend, + ); + handler.acked_protocol_features = VhostUserProtocolFeatures::SHMEM.bits(); + handler + } + + #[test] + fn test_get_shmem_config_multiple_regions() { + let memory_sizes = [ + 0x1000, 0x2000, 0x3000, 0x4000, 0x5000, 0x6000, 0x7000, 0x8000, + ]; + let config = VhostUserShMemConfig::new(8, &memory_sizes); + + let (p1, p2) = UnixStream::pair().unwrap(); + let mut dummy_backend = DummyBackendReqHandler::new(); + dummy_backend.set_shmem_config(config); + let mut handler = create_handler_with_shmem(Arc::new(Mutex::new(dummy_backend)), p1); + + let handle = std::thread::spawn(move || { + send_get_shmem_config_request(Endpoint::>::from_stream( + p2, + )) + }); + + handler.handle_request().unwrap(); + + let reply_config = handle.join().unwrap(); + assert_eq!(reply_config.nregions, 8); + for i in 0..8 { + assert_eq!(reply_config.memory_sizes[i], (i as u64 + 1) * 0x1000); + } + for i in 8..256 { + assert_eq!(reply_config.memory_sizes[i], 0); + } + } + + #[test] + fn test_get_shmem_config_non_continuous_regions() { + // Create a configuration with non-continuous regions + let memory_sizes = [0x10000, 0, 0x20000, 0, 0, 0, 0, 0]; + let config = VhostUserShMemConfig::new(2, &memory_sizes); + + let (p1, p2) = UnixStream::pair().unwrap(); + let mut dummy_backend = DummyBackendReqHandler::new(); + dummy_backend.set_shmem_config(config); + let mut handler = create_handler_with_shmem(Arc::new(Mutex::new(dummy_backend)), p1); + + let handle = std::thread::spawn(move || { + send_get_shmem_config_request(Endpoint::>::from_stream( + p2, + )) + }); + + handler.handle_request().unwrap(); + + let reply_config = handle.join().unwrap(); + assert_eq!(reply_config.nregions, 2); + assert_eq!(reply_config.memory_sizes[0], 0x10000); + assert_eq!(reply_config.memory_sizes[1], 0); + assert_eq!(reply_config.memory_sizes[2], 0x20000); + for i in 3..256 { + assert_eq!(reply_config.memory_sizes[i], 0); + } + } + + #[test] + fn test_get_shmem_config_feature_not_negotiated() { + // Test that the request fails when SHMEM protocol feature is not negotiated + let (p1, p2) = UnixStream::pair().unwrap(); + let backend = Arc::new(Mutex::new(DummyBackendReqHandler::new())); + let mut handler = BackendReqHandler::new( + Endpoint::>::from_stream(p1), + backend, + ); + let mut frontend_endpoint = Endpoint::>::from_stream(p2); + + std::thread::spawn(move || { + let hdr = VhostUserMsgHeader::new(FrontendReq::GET_SHMEM_CONFIG, 0, 0); + let _ = frontend_endpoint.send_message(&hdr, &VhostUserEmpty, None); + }); + assert!(handler.handle_request().is_err()); + } } diff --git a/vhost/src/vhost_user/dummy_backend.rs b/vhost/src/vhost_user/dummy_backend.rs index a45d3b47..a4dd5c01 100644 --- a/vhost/src/vhost_user/dummy_backend.rs +++ b/vhost/src/vhost_user/dummy_backend.rs @@ -27,6 +27,7 @@ pub struct DummyBackendReqHandler { pub vring_enabled: [bool; MAX_QUEUE_NUM], pub inflight_file: Option, pub shared_file: Option, + pub shmem_config: Option, } impl DummyBackendReqHandler { @@ -37,6 +38,12 @@ impl DummyBackendReqHandler { } } + /// Set the shared memory configuration to be returned by `get_shmem_config` + pub fn set_shmem_config(&mut self, config: VhostUserShMemConfig) { + self.acked_protocol_features |= VhostUserProtocolFeatures::SHMEM.bits(); + self.shmem_config = Some(config); + } + /// Helper to check if VirtioFeature enabled fn check_feature(&self, feat: VhostUserVirtioFeatures) -> Result<()> { if self.acked_features & feat.bits() != 0 { @@ -329,6 +336,15 @@ impl VhostUserBackendReqHandlerMut for DummyBackendReqHandler { ))) } + fn get_shmem_config(&self) -> Result { + self.shmem_config.ok_or_else(|| { + Error::ReqHandlerError(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "dummy back end does not support shared memory regions", + )) + }) + } + #[cfg(feature = "postcopy")] fn postcopy_advice(&mut self) -> Result { let file = tempfile::tempfile().unwrap(); diff --git a/vhost/src/vhost_user/frontend.rs b/vhost/src/vhost_user/frontend.rs index bda291f2..2d95f9ad 100644 --- a/vhost/src/vhost_user/frontend.rs +++ b/vhost/src/vhost_user/frontend.rs @@ -79,6 +79,9 @@ pub trait VhostUserFrontend: VhostBackend { /// Remove a guest memory mapping from vhost. fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()>; + /// Get the shared memory region configuration from the backend. + fn get_shmem_config(&mut self) -> Result; + /// Sends VHOST_USER_POSTCOPY_ADVISE msg to the backend /// initiating the beginning of the postcopy process. /// Backend will return a userfaultfd. @@ -567,6 +570,16 @@ impl VhostUserFrontend for Frontend { node.wait_for_ack(&hdr).map_err(|e| e.into()) } + fn get_shmem_config(&mut self) -> Result { + let mut node = self.node(); + node.check_proto_feature(VhostUserProtocolFeatures::SHMEM)?; + + let hdr = node.send_request_header(FrontendReq::GET_SHMEM_CONFIG, None)?; + let config = node.recv_reply::(&hdr)?; + + Ok(config) + } + #[cfg(feature = "postcopy")] fn postcopy_advise(&mut self) -> Result { let mut node = self.node(); @@ -1201,4 +1214,45 @@ mod tests { let tables = vec![VhostUserMemoryRegionInfo::default(); MAX_ATTACHED_FD_ENTRIES + 1]; frontend.set_mem_table(&tables).unwrap_err(); } + + #[test] + fn test_frontend_get_shmem_config() { + let (mut frontend, mut peer) = create_pair2(); + + let expected_config = VhostUserShMemConfig::new(2, &[0x1000, 0x2000]); + let hdr = VhostUserMsgHeader::new( + FrontendReq::GET_SHMEM_CONFIG, + 0x4, + std::mem::size_of::() as u32, + ); + peer.send_message(&hdr, &expected_config, None).unwrap(); + + let config = frontend.get_shmem_config().unwrap(); + assert_eq!(config.nregions, 2); + assert_eq!(config.memory_sizes[0], 0x1000); + assert_eq!(config.memory_sizes[1], 0x2000); + + let (recv_hdr, rfds) = peer.recv_header().unwrap(); + assert_eq!(recv_hdr.get_code().unwrap(), FrontendReq::GET_SHMEM_CONFIG); + assert!(rfds.is_none()); + } + + #[test] + fn test_frontend_get_shmem_config_no_regions() { + let (mut frontend, mut peer) = create_pair2(); + + let expected_config = VhostUserShMemConfig::default(); + let hdr = VhostUserMsgHeader::new( + FrontendReq::GET_SHMEM_CONFIG, + 0x4, + std::mem::size_of::() as u32, + ); + peer.send_message(&hdr, &expected_config, None).unwrap(); + + let config = frontend.get_shmem_config().unwrap(); + assert_eq!(config.nregions, 0); + for i in 0..256 { + assert_eq!(config.memory_sizes[i], 0); + } + } } diff --git a/vhost/src/vhost_user/message.rs b/vhost/src/vhost_user/message.rs index 360b6ddc..3ff3908f 100644 --- a/vhost/src/vhost_user/message.rs +++ b/vhost/src/vhost_user/message.rs @@ -169,6 +169,8 @@ enum_value! { /// After transferring state, check the backend for any errors that may have /// occurred during the transfer CHECK_DEVICE_STATE = 43, + /// Get shared memory regions configuration from the backend. + GET_SHMEM_CONFIG = 44, } } @@ -688,6 +690,44 @@ impl VhostUserSingleMemoryRegion { unsafe impl ByteValued for VhostUserSingleMemoryRegion {} impl VhostUserMsgValidator for VhostUserSingleMemoryRegion {} +/// Get shared memory regions configuration. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct VhostUserShMemConfig { + /// Total number of shared memory regions + pub nregions: u32, + /// Padding for correct alignment + padding: u32, + /// Size of each memory region + pub memory_sizes: [u64; 256], +} + +impl Default for VhostUserShMemConfig { + fn default() -> Self { + Self { + nregions: 0, + padding: 0, + memory_sizes: [0; 256], + } + } +} + +impl VhostUserShMemConfig { + /// Create a new instance + pub fn new(nregions: u32, memory: &[u64]) -> Self { + let memory_sizes: [u64; 256] = std::array::from_fn(|i| *memory.get(i).unwrap_or(&0)); + Self { + nregions, + padding: 0, + memory_sizes, + } + } +} + +// SAFETY: Safe because all fields of VhostUserSingleMemoryRegion are POD. +unsafe impl ByteValued for VhostUserShMemConfig {} +impl VhostUserMsgValidator for VhostUserShMemConfig {} + /// Vring state descriptor. #[repr(C, packed)] #[derive(Copy, Clone, Default)]