From bfef1da81d5536063cfc9810e0d6984eb8f20f00 Mon Sep 17 00:00:00 2001 From: Albert Esteve Date: Thu, 11 Dec 2025 10:40:50 +0100 Subject: [PATCH] vhost-user: Add support for GET_SHMEM_CONFIG message Add support for GET_SHMEM_CONFIG message to retrieve VirtIO Shared Memory Regions configuration. This is useful when the frontend is unaware of specific backend type and configuration of the memory layout. Based on the patch [1] which is just waiting for being merged. [1] - https://lore.kernel.org/all/20251111091058.879669-1-aesteve@redhat.com/ Signed-off-by: Albert Esteve --- vhost-user-backend/CHANGELOG.md | 2 + vhost-user-backend/src/backend.rs | 28 ++++- vhost-user-backend/src/handler.rs | 8 +- vhost/CHANGELOG.md | 1 + vhost/src/vhost_user/backend_req_handler.rs | 115 ++++++++++++++++++++ vhost/src/vhost_user/dummy_backend.rs | 16 +++ vhost/src/vhost_user/frontend.rs | 54 +++++++++ vhost/src/vhost_user/message.rs | 40 +++++++ 8 files changed, 262 insertions(+), 2 deletions(-) diff --git a/vhost-user-backend/CHANGELOG.md b/vhost-user-backend/CHANGELOG.md index d1ede25f..9e977abd 100644 --- a/vhost-user-backend/CHANGELOG.md +++ b/vhost-user-backend/CHANGELOG.md @@ -3,6 +3,8 @@ ## [Unreleased] ### Added +- [[#339]](https://github.com/rust-vmm/vhost/pull/339) Add support for `GET_SHMEM_CONFIG` message + ### Changed ### Deprecated ### Fixed diff --git a/vhost-user-backend/src/backend.rs b/vhost-user-backend/src/backend.rs index 20e7daf3..71974463 100644 --- a/vhost-user-backend/src/backend.rs +++ b/vhost-user-backend/src/backend.rs @@ -25,7 +25,7 @@ use std::sync::{Arc, Mutex, RwLock}; use vhost::vhost_user::message::{ VhostTransferStateDirection, VhostTransferStatePhase, VhostUserProtocolFeatures, - VhostUserSharedMsg, + VhostUserShMemConfig, VhostUserSharedMsg, }; use vhost::vhost_user::Backend; use vm_memory::bitmap::Bitmap; @@ -180,6 +180,13 @@ pub trait VhostUserBackend: Send + Sync { "back end does not support state transfer", )) } + + fn get_shmem_config(&self) -> Result { + Err(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "back end does not support shared memory regions", + )) + } } /// Trait without interior mutability for vhost user backend servers to implement concrete services. @@ -322,6 +329,13 @@ pub trait VhostUserBackendMut: Send + Sync { "back end does not support state transfer", )) } + + fn get_shmem_config(&self) -> Result { + Err(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "back end does not support shared memory regions", + )) + } } impl VhostUserBackend for Arc { @@ -411,6 +425,10 @@ impl VhostUserBackend for Arc { fn check_device_state(&self) -> Result<()> { self.deref().check_device_state() } + + fn get_shmem_config(&self) -> Result { + self.deref().get_shmem_config() + } } impl VhostUserBackend for Mutex { @@ -503,6 +521,10 @@ impl VhostUserBackend for Mutex { fn check_device_state(&self) -> Result<()> { self.lock().unwrap().check_device_state() } + + fn get_shmem_config(&self) -> Result { + self.lock().unwrap().get_shmem_config() + } } impl VhostUserBackend for RwLock { @@ -595,6 +617,10 @@ impl VhostUserBackend for RwLock { fn check_device_state(&self) -> Result<()> { self.read().unwrap().check_device_state() } + + fn get_shmem_config(&self) -> Result { + self.read().unwrap().get_shmem_config() + } } #[cfg(test)] diff --git a/vhost-user-backend/src/handler.rs b/vhost-user-backend/src/handler.rs index 86b8c86f..80764cb8 100644 --- a/vhost-user-backend/src/handler.rs +++ b/vhost-user-backend/src/handler.rs @@ -18,7 +18,7 @@ use crate::bitmap::{BitmapReplace, MemRegionBitmap, MmapLogReg}; use userfaultfd::{Uffd, UffdBuilder}; use vhost::vhost_user::message::{ VhostTransferStateDirection, VhostTransferStatePhase, VhostUserConfigFlags, VhostUserLog, - VhostUserMemoryRegion, VhostUserProtocolFeatures, VhostUserSharedMsg, + VhostUserMemoryRegion, VhostUserProtocolFeatures, VhostUserShMemConfig, VhostUserSharedMsg, VhostUserSingleMemoryRegion, VhostUserVirtioFeatures, VhostUserVringAddrFlags, VhostUserVringState, }; @@ -677,6 +677,12 @@ where .map_err(VhostUserError::ReqHandlerError) } + fn get_shmem_config(&self) -> VhostUserResult { + self.backend + .get_shmem_config() + .map_err(VhostUserError::ReqHandlerError) + } + #[cfg(feature = "postcopy")] fn postcopy_advice(&mut self) -> VhostUserResult { let mut uffd_builder = UffdBuilder::new(); diff --git a/vhost/CHANGELOG.md b/vhost/CHANGELOG.md index f72ae24f..00af71cd 100644 --- a/vhost/CHANGELOG.md +++ b/vhost/CHANGELOG.md @@ -4,6 +4,7 @@ ### Added - [[#251]](https://github.com/rust-vmm/vhost/pull/251) Add `SHMEM_MAP` and `SHMEM_UNMAP` support +- [[#339]](https://github.com/rust-vmm/vhost/pull/339) Add support for `GET_SHMEM_CONFIG` message ### Changed ### Deprecated diff --git a/vhost/src/vhost_user/backend_req_handler.rs b/vhost/src/vhost_user/backend_req_handler.rs index d74b0455..2f5defcc 100644 --- a/vhost/src/vhost_user/backend_req_handler.rs +++ b/vhost/src/vhost_user/backend_req_handler.rs @@ -81,6 +81,7 @@ pub trait VhostUserBackendReqHandler { fd: File, ) -> Result>; fn check_device_state(&self) -> Result<()>; + fn get_shmem_config(&self) -> Result; #[cfg(feature = "postcopy")] fn postcopy_advice(&self) -> Result; #[cfg(feature = "postcopy")] @@ -146,6 +147,7 @@ pub trait VhostUserBackendReqHandlerMut { fd: File, ) -> Result>; fn check_device_state(&mut self) -> Result<()>; + fn get_shmem_config(&self) -> Result; #[cfg(feature = "postcopy")] fn postcopy_advice(&mut self) -> Result; #[cfg(feature = "postcopy")] @@ -289,6 +291,10 @@ impl VhostUserBackendReqHandler for Mutex { self.lock().unwrap().check_device_state() } + fn get_shmem_config(&self) -> Result { + self.lock().unwrap().get_shmem_config() + } + #[cfg(feature = "postcopy")] fn postcopy_advice(&self) -> Result { self.lock().unwrap().postcopy_advice() @@ -679,6 +685,11 @@ impl BackendReqHandler { }; self.send_reply_message(&hdr, &msg)?; } + Ok(FrontendReq::GET_SHMEM_CONFIG) => { + self.check_proto_feature(VhostUserProtocolFeatures::SHMEM)?; + let msg = self.backend.get_shmem_config().unwrap_or_default(); + self.send_reply_message(&hdr, &msg)?; + } #[cfg(feature = "postcopy")] Ok(FrontendReq::POSTCOPY_ADVISE) => { self.check_proto_feature(VhostUserProtocolFeatures::PAGEFAULT)?; @@ -1038,4 +1049,108 @@ mod tests { handler.check_state().unwrap_err(); assert!(handler.as_raw_fd() >= 0); } + + // Helper to send GET_SHMEM_CONFIG request and receive response + fn send_get_shmem_config_request( + mut endpoint: Endpoint>, + ) -> VhostUserShMemConfig { + let hdr = VhostUserMsgHeader::new(FrontendReq::GET_SHMEM_CONFIG, 0, 0); + endpoint.send_message(&hdr, &VhostUserEmpty, None).unwrap(); + + let (reply_hdr, reply_config, rfds) = endpoint.recv_body::().unwrap(); + assert_eq!(reply_hdr.get_code().unwrap(), FrontendReq::GET_SHMEM_CONFIG); + assert!(reply_hdr.is_reply()); + assert!(rfds.is_none()); + reply_config + } + + // Helper to create handler with SHMEM protocol feature enabled + fn create_handler_with_shmem( + backend: Arc>, + p1: UnixStream, + ) -> BackendReqHandler> { + let mut handler = BackendReqHandler::new( + Endpoint::>::from_stream(p1), + backend, + ); + handler.acked_protocol_features = VhostUserProtocolFeatures::SHMEM.bits(); + handler + } + + #[test] + fn test_get_shmem_config_multiple_regions() { + let memory_sizes = [ + 0x1000, 0x2000, 0x3000, 0x4000, 0x5000, 0x6000, 0x7000, 0x8000, + ]; + let config = VhostUserShMemConfig::new(8, &memory_sizes); + + let (p1, p2) = UnixStream::pair().unwrap(); + let mut dummy_backend = DummyBackendReqHandler::new(); + dummy_backend.set_shmem_config(config); + let mut handler = create_handler_with_shmem(Arc::new(Mutex::new(dummy_backend)), p1); + + let handle = std::thread::spawn(move || { + send_get_shmem_config_request(Endpoint::>::from_stream( + p2, + )) + }); + + handler.handle_request().unwrap(); + + let reply_config = handle.join().unwrap(); + assert_eq!(reply_config.nregions, 8); + for i in 0..8 { + assert_eq!(reply_config.memory_sizes[i], (i as u64 + 1) * 0x1000); + } + for i in 8..256 { + assert_eq!(reply_config.memory_sizes[i], 0); + } + } + + #[test] + fn test_get_shmem_config_non_continuous_regions() { + // Create a configuration with non-continuous regions + let memory_sizes = [0x10000, 0, 0x20000, 0, 0, 0, 0, 0]; + let config = VhostUserShMemConfig::new(2, &memory_sizes); + + let (p1, p2) = UnixStream::pair().unwrap(); + let mut dummy_backend = DummyBackendReqHandler::new(); + dummy_backend.set_shmem_config(config); + let mut handler = create_handler_with_shmem(Arc::new(Mutex::new(dummy_backend)), p1); + + let handle = std::thread::spawn(move || { + send_get_shmem_config_request(Endpoint::>::from_stream( + p2, + )) + }); + + handler.handle_request().unwrap(); + + let reply_config = handle.join().unwrap(); + assert_eq!(reply_config.nregions, 2); + assert_eq!(reply_config.memory_sizes[0], 0x10000); + assert_eq!(reply_config.memory_sizes[1], 0); + assert_eq!(reply_config.memory_sizes[2], 0x20000); + for i in 3..256 { + assert_eq!(reply_config.memory_sizes[i], 0); + } + } + + #[test] + fn test_get_shmem_config_feature_not_negotiated() { + // Test that the request fails when SHMEM protocol feature is not negotiated + let (p1, p2) = UnixStream::pair().unwrap(); + let backend = Arc::new(Mutex::new(DummyBackendReqHandler::new())); + let mut handler = BackendReqHandler::new( + Endpoint::>::from_stream(p1), + backend, + ); + let mut frontend_endpoint = Endpoint::>::from_stream(p2); + + std::thread::spawn(move || { + let hdr = VhostUserMsgHeader::new(FrontendReq::GET_SHMEM_CONFIG, 0, 0); + let _ = frontend_endpoint.send_message(&hdr, &VhostUserEmpty, None); + }); + assert!(handler.handle_request().is_err()); + } } diff --git a/vhost/src/vhost_user/dummy_backend.rs b/vhost/src/vhost_user/dummy_backend.rs index a45d3b47..a4dd5c01 100644 --- a/vhost/src/vhost_user/dummy_backend.rs +++ b/vhost/src/vhost_user/dummy_backend.rs @@ -27,6 +27,7 @@ pub struct DummyBackendReqHandler { pub vring_enabled: [bool; MAX_QUEUE_NUM], pub inflight_file: Option, pub shared_file: Option, + pub shmem_config: Option, } impl DummyBackendReqHandler { @@ -37,6 +38,12 @@ impl DummyBackendReqHandler { } } + /// Set the shared memory configuration to be returned by `get_shmem_config` + pub fn set_shmem_config(&mut self, config: VhostUserShMemConfig) { + self.acked_protocol_features |= VhostUserProtocolFeatures::SHMEM.bits(); + self.shmem_config = Some(config); + } + /// Helper to check if VirtioFeature enabled fn check_feature(&self, feat: VhostUserVirtioFeatures) -> Result<()> { if self.acked_features & feat.bits() != 0 { @@ -329,6 +336,15 @@ impl VhostUserBackendReqHandlerMut for DummyBackendReqHandler { ))) } + fn get_shmem_config(&self) -> Result { + self.shmem_config.ok_or_else(|| { + Error::ReqHandlerError(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "dummy back end does not support shared memory regions", + )) + }) + } + #[cfg(feature = "postcopy")] fn postcopy_advice(&mut self) -> Result { let file = tempfile::tempfile().unwrap(); diff --git a/vhost/src/vhost_user/frontend.rs b/vhost/src/vhost_user/frontend.rs index bda291f2..2d95f9ad 100644 --- a/vhost/src/vhost_user/frontend.rs +++ b/vhost/src/vhost_user/frontend.rs @@ -79,6 +79,9 @@ pub trait VhostUserFrontend: VhostBackend { /// Remove a guest memory mapping from vhost. fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()>; + /// Get the shared memory region configuration from the backend. + fn get_shmem_config(&mut self) -> Result; + /// Sends VHOST_USER_POSTCOPY_ADVISE msg to the backend /// initiating the beginning of the postcopy process. /// Backend will return a userfaultfd. @@ -567,6 +570,16 @@ impl VhostUserFrontend for Frontend { node.wait_for_ack(&hdr).map_err(|e| e.into()) } + fn get_shmem_config(&mut self) -> Result { + let mut node = self.node(); + node.check_proto_feature(VhostUserProtocolFeatures::SHMEM)?; + + let hdr = node.send_request_header(FrontendReq::GET_SHMEM_CONFIG, None)?; + let config = node.recv_reply::(&hdr)?; + + Ok(config) + } + #[cfg(feature = "postcopy")] fn postcopy_advise(&mut self) -> Result { let mut node = self.node(); @@ -1201,4 +1214,45 @@ mod tests { let tables = vec![VhostUserMemoryRegionInfo::default(); MAX_ATTACHED_FD_ENTRIES + 1]; frontend.set_mem_table(&tables).unwrap_err(); } + + #[test] + fn test_frontend_get_shmem_config() { + let (mut frontend, mut peer) = create_pair2(); + + let expected_config = VhostUserShMemConfig::new(2, &[0x1000, 0x2000]); + let hdr = VhostUserMsgHeader::new( + FrontendReq::GET_SHMEM_CONFIG, + 0x4, + std::mem::size_of::() as u32, + ); + peer.send_message(&hdr, &expected_config, None).unwrap(); + + let config = frontend.get_shmem_config().unwrap(); + assert_eq!(config.nregions, 2); + assert_eq!(config.memory_sizes[0], 0x1000); + assert_eq!(config.memory_sizes[1], 0x2000); + + let (recv_hdr, rfds) = peer.recv_header().unwrap(); + assert_eq!(recv_hdr.get_code().unwrap(), FrontendReq::GET_SHMEM_CONFIG); + assert!(rfds.is_none()); + } + + #[test] + fn test_frontend_get_shmem_config_no_regions() { + let (mut frontend, mut peer) = create_pair2(); + + let expected_config = VhostUserShMemConfig::default(); + let hdr = VhostUserMsgHeader::new( + FrontendReq::GET_SHMEM_CONFIG, + 0x4, + std::mem::size_of::() as u32, + ); + peer.send_message(&hdr, &expected_config, None).unwrap(); + + let config = frontend.get_shmem_config().unwrap(); + assert_eq!(config.nregions, 0); + for i in 0..256 { + assert_eq!(config.memory_sizes[i], 0); + } + } } diff --git a/vhost/src/vhost_user/message.rs b/vhost/src/vhost_user/message.rs index 360b6ddc..3ff3908f 100644 --- a/vhost/src/vhost_user/message.rs +++ b/vhost/src/vhost_user/message.rs @@ -169,6 +169,8 @@ enum_value! { /// After transferring state, check the backend for any errors that may have /// occurred during the transfer CHECK_DEVICE_STATE = 43, + /// Get shared memory regions configuration from the backend. + GET_SHMEM_CONFIG = 44, } } @@ -688,6 +690,44 @@ impl VhostUserSingleMemoryRegion { unsafe impl ByteValued for VhostUserSingleMemoryRegion {} impl VhostUserMsgValidator for VhostUserSingleMemoryRegion {} +/// Get shared memory regions configuration. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct VhostUserShMemConfig { + /// Total number of shared memory regions + pub nregions: u32, + /// Padding for correct alignment + padding: u32, + /// Size of each memory region + pub memory_sizes: [u64; 256], +} + +impl Default for VhostUserShMemConfig { + fn default() -> Self { + Self { + nregions: 0, + padding: 0, + memory_sizes: [0; 256], + } + } +} + +impl VhostUserShMemConfig { + /// Create a new instance + pub fn new(nregions: u32, memory: &[u64]) -> Self { + let memory_sizes: [u64; 256] = std::array::from_fn(|i| *memory.get(i).unwrap_or(&0)); + Self { + nregions, + padding: 0, + memory_sizes, + } + } +} + +// SAFETY: Safe because all fields of VhostUserSingleMemoryRegion are POD. +unsafe impl ByteValued for VhostUserShMemConfig {} +impl VhostUserMsgValidator for VhostUserShMemConfig {} + /// Vring state descriptor. #[repr(C, packed)] #[derive(Copy, Clone, Default)]