Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vhost-user-backend/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
## [Unreleased]

### Added
- [[#339]](https://github.com/rust-vmm/vhost/pull/339) Add support for `GET_SHMEM_CONFIG` message

### Changed
### Deprecated
### Fixed
Expand Down
28 changes: 27 additions & 1 deletion vhost-user-backend/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use std::sync::{Arc, Mutex, RwLock};

use vhost::vhost_user::message::{
VhostTransferStateDirection, VhostTransferStatePhase, VhostUserProtocolFeatures,
VhostUserSharedMsg,
VhostUserShMemConfig, VhostUserSharedMsg,
};
use vhost::vhost_user::Backend;
use vm_memory::bitmap::Bitmap;
Expand Down Expand Up @@ -180,6 +180,13 @@ pub trait VhostUserBackend: Send + Sync {
"back end does not support state transfer",
))
}

fn get_shmem_config(&self) -> Result<VhostUserShMemConfig> {
Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"back end does not support shared memory regions",
))
}
}

/// Trait without interior mutability for vhost user backend servers to implement concrete services.
Expand Down Expand Up @@ -322,6 +329,13 @@ pub trait VhostUserBackendMut: Send + Sync {
"back end does not support state transfer",
))
}

fn get_shmem_config(&self) -> Result<VhostUserShMemConfig> {
Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"back end does not support shared memory regions",
))
}
}

impl<T: VhostUserBackend> VhostUserBackend for Arc<T> {
Expand Down Expand Up @@ -411,6 +425,10 @@ impl<T: VhostUserBackend> VhostUserBackend for Arc<T> {
fn check_device_state(&self) -> Result<()> {
self.deref().check_device_state()
}

fn get_shmem_config(&self) -> Result<VhostUserShMemConfig> {
self.deref().get_shmem_config()
}
}

impl<T: VhostUserBackendMut> VhostUserBackend for Mutex<T> {
Expand Down Expand Up @@ -503,6 +521,10 @@ impl<T: VhostUserBackendMut> VhostUserBackend for Mutex<T> {
fn check_device_state(&self) -> Result<()> {
self.lock().unwrap().check_device_state()
}

fn get_shmem_config(&self) -> Result<VhostUserShMemConfig> {
self.lock().unwrap().get_shmem_config()
}
}

impl<T: VhostUserBackendMut> VhostUserBackend for RwLock<T> {
Expand Down Expand Up @@ -595,6 +617,10 @@ impl<T: VhostUserBackendMut> VhostUserBackend for RwLock<T> {
fn check_device_state(&self) -> Result<()> {
self.read().unwrap().check_device_state()
}

fn get_shmem_config(&self) -> Result<VhostUserShMemConfig> {
self.read().unwrap().get_shmem_config()
}
}

#[cfg(test)]
Expand Down
8 changes: 7 additions & 1 deletion vhost-user-backend/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::bitmap::{BitmapReplace, MemRegionBitmap, MmapLogReg};
use userfaultfd::{Uffd, UffdBuilder};
use vhost::vhost_user::message::{
VhostTransferStateDirection, VhostTransferStatePhase, VhostUserConfigFlags, VhostUserLog,
VhostUserMemoryRegion, VhostUserProtocolFeatures, VhostUserSharedMsg,
VhostUserMemoryRegion, VhostUserProtocolFeatures, VhostUserShMemConfig, VhostUserSharedMsg,
VhostUserSingleMemoryRegion, VhostUserVirtioFeatures, VhostUserVringAddrFlags,
VhostUserVringState,
};
Expand Down Expand Up @@ -677,6 +677,12 @@ where
.map_err(VhostUserError::ReqHandlerError)
}

fn get_shmem_config(&self) -> VhostUserResult<VhostUserShMemConfig> {
self.backend
.get_shmem_config()
.map_err(VhostUserError::ReqHandlerError)
}

#[cfg(feature = "postcopy")]
fn postcopy_advice(&mut self) -> VhostUserResult<File> {
let mut uffd_builder = UffdBuilder::new();
Expand Down
1 change: 1 addition & 0 deletions vhost/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### Added
- [[#251]](https://github.com/rust-vmm/vhost/pull/251) Add `SHMEM_MAP` and `SHMEM_UNMAP` support
- [[#339]](https://github.com/rust-vmm/vhost/pull/339) Add support for `GET_SHMEM_CONFIG` message

### Changed
### Deprecated
Expand Down
115 changes: 115 additions & 0 deletions vhost/src/vhost_user/backend_req_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ pub trait VhostUserBackendReqHandler {
fd: File,
) -> Result<Option<File>>;
fn check_device_state(&self) -> Result<()>;
fn get_shmem_config(&self) -> Result<VhostUserShMemConfig>;
#[cfg(feature = "postcopy")]
fn postcopy_advice(&self) -> Result<File>;
#[cfg(feature = "postcopy")]
Expand Down Expand Up @@ -146,6 +147,7 @@ pub trait VhostUserBackendReqHandlerMut {
fd: File,
) -> Result<Option<File>>;
fn check_device_state(&mut self) -> Result<()>;
fn get_shmem_config(&self) -> Result<VhostUserShMemConfig>;
#[cfg(feature = "postcopy")]
fn postcopy_advice(&mut self) -> Result<File>;
#[cfg(feature = "postcopy")]
Expand Down Expand Up @@ -289,6 +291,10 @@ impl<T: VhostUserBackendReqHandlerMut> VhostUserBackendReqHandler for Mutex<T> {
self.lock().unwrap().check_device_state()
}

fn get_shmem_config(&self) -> Result<VhostUserShMemConfig> {
self.lock().unwrap().get_shmem_config()
}

#[cfg(feature = "postcopy")]
fn postcopy_advice(&self) -> Result<File> {
self.lock().unwrap().postcopy_advice()
Expand Down Expand Up @@ -679,6 +685,11 @@ impl<S: VhostUserBackendReqHandler> BackendReqHandler<S> {
};
self.send_reply_message(&hdr, &msg)?;
}
Ok(FrontendReq::GET_SHMEM_CONFIG) => {
self.check_proto_feature(VhostUserProtocolFeatures::SHMEM)?;
let msg = self.backend.get_shmem_config().unwrap_or_default();
self.send_reply_message(&hdr, &msg)?;
}
#[cfg(feature = "postcopy")]
Ok(FrontendReq::POSTCOPY_ADVISE) => {
self.check_proto_feature(VhostUserProtocolFeatures::PAGEFAULT)?;
Expand Down Expand Up @@ -1038,4 +1049,108 @@ mod tests {
handler.check_state().unwrap_err();
assert!(handler.as_raw_fd() >= 0);
}

// Helper to send GET_SHMEM_CONFIG request and receive response
fn send_get_shmem_config_request(
mut endpoint: Endpoint<VhostUserMsgHeader<FrontendReq>>,
) -> VhostUserShMemConfig {
let hdr = VhostUserMsgHeader::new(FrontendReq::GET_SHMEM_CONFIG, 0, 0);
endpoint.send_message(&hdr, &VhostUserEmpty, None).unwrap();

let (reply_hdr, reply_config, rfds) = endpoint.recv_body::<VhostUserShMemConfig>().unwrap();
assert_eq!(reply_hdr.get_code().unwrap(), FrontendReq::GET_SHMEM_CONFIG);
assert!(reply_hdr.is_reply());
assert!(rfds.is_none());
reply_config
}

// Helper to create handler with SHMEM protocol feature enabled
fn create_handler_with_shmem(
backend: Arc<Mutex<DummyBackendReqHandler>>,
p1: UnixStream,
) -> BackendReqHandler<Mutex<DummyBackendReqHandler>> {
let mut handler = BackendReqHandler::new(
Endpoint::<VhostUserMsgHeader<FrontendReq>>::from_stream(p1),
backend,
);
handler.acked_protocol_features = VhostUserProtocolFeatures::SHMEM.bits();
handler
}

#[test]
fn test_get_shmem_config_multiple_regions() {
let memory_sizes = [
0x1000, 0x2000, 0x3000, 0x4000, 0x5000, 0x6000, 0x7000, 0x8000,
];
let config = VhostUserShMemConfig::new(8, &memory_sizes);

let (p1, p2) = UnixStream::pair().unwrap();
let mut dummy_backend = DummyBackendReqHandler::new();
dummy_backend.set_shmem_config(config);
let mut handler = create_handler_with_shmem(Arc::new(Mutex::new(dummy_backend)), p1);

let handle = std::thread::spawn(move || {
send_get_shmem_config_request(Endpoint::<VhostUserMsgHeader<FrontendReq>>::from_stream(
p2,
))
});

handler.handle_request().unwrap();

let reply_config = handle.join().unwrap();
assert_eq!(reply_config.nregions, 8);
for i in 0..8 {
assert_eq!(reply_config.memory_sizes[i], (i as u64 + 1) * 0x1000);
}
for i in 8..256 {
assert_eq!(reply_config.memory_sizes[i], 0);
}
}

#[test]
fn test_get_shmem_config_non_continuous_regions() {
// Create a configuration with non-continuous regions
let memory_sizes = [0x10000, 0, 0x20000, 0, 0, 0, 0, 0];
let config = VhostUserShMemConfig::new(2, &memory_sizes);

let (p1, p2) = UnixStream::pair().unwrap();
let mut dummy_backend = DummyBackendReqHandler::new();
dummy_backend.set_shmem_config(config);
let mut handler = create_handler_with_shmem(Arc::new(Mutex::new(dummy_backend)), p1);

let handle = std::thread::spawn(move || {
send_get_shmem_config_request(Endpoint::<VhostUserMsgHeader<FrontendReq>>::from_stream(
p2,
))
});

handler.handle_request().unwrap();

let reply_config = handle.join().unwrap();
assert_eq!(reply_config.nregions, 2);
assert_eq!(reply_config.memory_sizes[0], 0x10000);
assert_eq!(reply_config.memory_sizes[1], 0);
assert_eq!(reply_config.memory_sizes[2], 0x20000);
for i in 3..256 {
assert_eq!(reply_config.memory_sizes[i], 0);
}
}

#[test]
fn test_get_shmem_config_feature_not_negotiated() {
// Test that the request fails when SHMEM protocol feature is not negotiated
let (p1, p2) = UnixStream::pair().unwrap();
let backend = Arc::new(Mutex::new(DummyBackendReqHandler::new()));
let mut handler = BackendReqHandler::new(
Endpoint::<VhostUserMsgHeader<FrontendReq>>::from_stream(p1),
backend,
);
let mut frontend_endpoint = Endpoint::<VhostUserMsgHeader<FrontendReq>>::from_stream(p2);

std::thread::spawn(move || {
let hdr = VhostUserMsgHeader::new(FrontendReq::GET_SHMEM_CONFIG, 0, 0);
let _ = frontend_endpoint.send_message(&hdr, &VhostUserEmpty, None);
});
assert!(handler.handle_request().is_err());
}
}
16 changes: 16 additions & 0 deletions vhost/src/vhost_user/dummy_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub struct DummyBackendReqHandler {
pub vring_enabled: [bool; MAX_QUEUE_NUM],
pub inflight_file: Option<File>,
pub shared_file: Option<File>,
pub shmem_config: Option<VhostUserShMemConfig>,
}

impl DummyBackendReqHandler {
Expand All @@ -37,6 +38,12 @@ impl DummyBackendReqHandler {
}
}

/// Set the shared memory configuration to be returned by `get_shmem_config`
pub fn set_shmem_config(&mut self, config: VhostUserShMemConfig) {
self.acked_protocol_features |= VhostUserProtocolFeatures::SHMEM.bits();
self.shmem_config = Some(config);
}

/// Helper to check if VirtioFeature enabled
fn check_feature(&self, feat: VhostUserVirtioFeatures) -> Result<()> {
if self.acked_features & feat.bits() != 0 {
Expand Down Expand Up @@ -329,6 +336,15 @@ impl VhostUserBackendReqHandlerMut for DummyBackendReqHandler {
)))
}

fn get_shmem_config(&self) -> Result<VhostUserShMemConfig> {
self.shmem_config.ok_or_else(|| {
Error::ReqHandlerError(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"dummy back end does not support shared memory regions",
))
})
}

#[cfg(feature = "postcopy")]
fn postcopy_advice(&mut self) -> Result<File> {
let file = tempfile::tempfile().unwrap();
Expand Down
54 changes: 54 additions & 0 deletions vhost/src/vhost_user/frontend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ pub trait VhostUserFrontend: VhostBackend {
/// Remove a guest memory mapping from vhost.
fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()>;

/// Get the shared memory region configuration from the backend.
fn get_shmem_config(&mut self) -> Result<VhostUserShMemConfig>;

/// Sends VHOST_USER_POSTCOPY_ADVISE msg to the backend
/// initiating the beginning of the postcopy process.
/// Backend will return a userfaultfd.
Expand Down Expand Up @@ -567,6 +570,16 @@ impl VhostUserFrontend for Frontend {
node.wait_for_ack(&hdr).map_err(|e| e.into())
}

fn get_shmem_config(&mut self) -> Result<VhostUserShMemConfig> {
let mut node = self.node();
node.check_proto_feature(VhostUserProtocolFeatures::SHMEM)?;

let hdr = node.send_request_header(FrontendReq::GET_SHMEM_CONFIG, None)?;
let config = node.recv_reply::<VhostUserShMemConfig>(&hdr)?;

Ok(config)
}

#[cfg(feature = "postcopy")]
fn postcopy_advise(&mut self) -> Result<File> {
let mut node = self.node();
Expand Down Expand Up @@ -1201,4 +1214,45 @@ mod tests {
let tables = vec![VhostUserMemoryRegionInfo::default(); MAX_ATTACHED_FD_ENTRIES + 1];
frontend.set_mem_table(&tables).unwrap_err();
}

#[test]
fn test_frontend_get_shmem_config() {
let (mut frontend, mut peer) = create_pair2();

let expected_config = VhostUserShMemConfig::new(2, &[0x1000, 0x2000]);
let hdr = VhostUserMsgHeader::new(
FrontendReq::GET_SHMEM_CONFIG,
0x4,
std::mem::size_of::<VhostUserShMemConfig>() as u32,
);
peer.send_message(&hdr, &expected_config, None).unwrap();

let config = frontend.get_shmem_config().unwrap();
assert_eq!(config.nregions, 2);
assert_eq!(config.memory_sizes[0], 0x1000);
assert_eq!(config.memory_sizes[1], 0x2000);

let (recv_hdr, rfds) = peer.recv_header().unwrap();
assert_eq!(recv_hdr.get_code().unwrap(), FrontendReq::GET_SHMEM_CONFIG);
assert!(rfds.is_none());
}

#[test]
fn test_frontend_get_shmem_config_no_regions() {
let (mut frontend, mut peer) = create_pair2();

let expected_config = VhostUserShMemConfig::default();
let hdr = VhostUserMsgHeader::new(
FrontendReq::GET_SHMEM_CONFIG,
0x4,
std::mem::size_of::<VhostUserShMemConfig>() as u32,
);
peer.send_message(&hdr, &expected_config, None).unwrap();

let config = frontend.get_shmem_config().unwrap();
assert_eq!(config.nregions, 0);
for i in 0..256 {
assert_eq!(config.memory_sizes[i], 0);
}
}
}
Loading