Skip to content

Commit 6b49347

Browse files
committed
vhost_user: support variable shmem_config regions
Add serialize/deserialize methods to VhostUserShMemConfig to be able to send and receive variable payloads in the replies from the backend to the SHMEM_CONFIG messages. Signed-off-by: Albert Esteve <aesteve@redhat.com>
1 parent eccb86a commit 6b49347

File tree

3 files changed

+64
-10
lines changed

3 files changed

+64
-10
lines changed

vhost/src/vhost_user/backend_req_handler.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -687,8 +687,10 @@ impl<S: VhostUserBackendReqHandler> BackendReqHandler<S> {
687687
}
688688
Ok(FrontendReq::GET_SHMEM_CONFIG) => {
689689
self.check_proto_feature(VhostUserProtocolFeatures::SHMEM)?;
690-
let msg = self.backend.get_shmem_config().unwrap_or_default();
691-
self.send_reply_message(&hdr, &msg)?;
690+
let config = self.backend.get_shmem_config().unwrap_or_default();
691+
let msg = VhostUserU64::new(config.nregions as u64);
692+
let payload = config.payload();
693+
self.send_reply_with_payload(&hdr, &msg, &payload)?;
692694
}
693695
#[cfg(feature = "postcopy")]
694696
Ok(FrontendReq::POSTCOPY_ADVISE) => {
@@ -1057,11 +1059,13 @@ mod tests {
10571059
let hdr = VhostUserMsgHeader::new(FrontendReq::GET_SHMEM_CONFIG, 0, 0);
10581060
endpoint.send_message(&hdr, &VhostUserEmpty, None).unwrap();
10591061

1060-
let (reply_hdr, reply_config, rfds) = endpoint.recv_body::<VhostUserShMemConfig>().unwrap();
1062+
let (reply_hdr, _count, payload, rfds) =
1063+
endpoint.recv_payload_into_buf::<VhostUserU64>().unwrap();
10611064
assert_eq!(reply_hdr.get_code().unwrap(), FrontendReq::GET_SHMEM_CONFIG);
10621065
assert!(reply_hdr.is_reply());
10631066
assert!(rfds.is_none());
1064-
reply_config
1067+
1068+
VhostUserShMemConfig::from_payload(&payload).unwrap()
10651069
}
10661070

10671071
// Helper to create handler with SHMEM protocol feature enabled

vhost/src/vhost_user/frontend.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,9 @@ impl VhostUserFrontend for Frontend {
575575
node.check_proto_feature(VhostUserProtocolFeatures::SHMEM)?;
576576

577577
let hdr = node.send_request_header(FrontendReq::GET_SHMEM_CONFIG, None)?;
578-
let config = node.recv_reply::<VhostUserShMemConfig>(&hdr)?;
578+
let (_count, payload, _) = node.recv_reply_with_payload::<VhostUserU64>(&hdr)?;
579+
let config = VhostUserShMemConfig::from_payload(&payload)
580+
.map_err(|_| VhostUserError::InvalidMessage)?;
579581

580582
Ok(config)
581583
}
@@ -1210,12 +1212,15 @@ mod tests {
12101212
let (mut frontend, mut peer) = create_pair2();
12111213

12121214
let expected_config = VhostUserShMemConfig::new(2, &[0x1000, 0x2000]);
1215+
let count = VhostUserU64::new(expected_config.nregions as u64);
1216+
let payload = expected_config.payload();
12131217
let hdr = VhostUserMsgHeader::new(
12141218
FrontendReq::GET_SHMEM_CONFIG,
12151219
0x4,
1216-
std::mem::size_of::<VhostUserShMemConfig>() as u32,
1220+
(std::mem::size_of::<VhostUserU64>() + payload.len()) as u32,
12171221
);
1218-
peer.send_message(&hdr, &expected_config, None).unwrap();
1222+
peer.send_message_with_payload(&hdr, &count, &payload, None)
1223+
.unwrap();
12191224

12201225
let config = frontend.get_shmem_config().unwrap();
12211226
assert_eq!(config.nregions, 2);
@@ -1232,12 +1237,15 @@ mod tests {
12321237
let (mut frontend, mut peer) = create_pair2();
12331238

12341239
let expected_config = VhostUserShMemConfig::default();
1240+
let count = VhostUserU64::new(expected_config.nregions as u64);
1241+
let payload = expected_config.payload();
12351242
let hdr = VhostUserMsgHeader::new(
12361243
FrontendReq::GET_SHMEM_CONFIG,
12371244
0x4,
1238-
std::mem::size_of::<VhostUserShMemConfig>() as u32,
1245+
(std::mem::size_of::<VhostUserU64>() + payload.len()) as u32,
12391246
);
1240-
peer.send_message(&hdr, &expected_config, None).unwrap();
1247+
peer.send_message_with_payload(&hdr, &count, &payload, None)
1248+
.unwrap();
12411249

12421250
let config = frontend.get_shmem_config().unwrap();
12431251
assert_eq!(config.nregions, 0);

vhost/src/vhost_user/message.rs

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ impl VhostUserMsgValidator for VhostUserSingleMemoryRegion {}
697697
#[repr(C)]
698698
#[derive(Debug, Clone, Copy)]
699699
pub struct VhostUserShMemConfig {
700-
/// Total number of shared memory regions
700+
/// Total number of shared memory regions sent
701701
pub nregions: u32,
702702
/// Padding for correct alignment
703703
padding: u32,
@@ -725,6 +725,48 @@ impl VhostUserShMemConfig {
725725
memory_sizes,
726726
}
727727
}
728+
729+
/// Serialize memory_sizes to bytes for the wire protocol payload
730+
pub fn payload(&self) -> Vec<u8> {
731+
let num_elements = self
732+
.memory_sizes
733+
.iter()
734+
.rposition(|&x| x != 0)
735+
.map(|i| i + 1)
736+
.unwrap_or(0);
737+
738+
let mut payload = Vec::with_capacity(num_elements * 8);
739+
for i in 0..num_elements {
740+
payload.extend_from_slice(&self.memory_sizes[i].to_ne_bytes());
741+
}
742+
payload
743+
}
744+
745+
/// Deserialize from payload bytes
746+
pub fn from_payload(payload: &[u8]) -> Result<Self> {
747+
if !payload.len().is_multiple_of(8) {
748+
return Err(Error::InvalidMessage);
749+
}
750+
751+
let num_elements = payload.len() / 8;
752+
if num_elements > 256 {
753+
return Err(Error::InvalidMessage);
754+
}
755+
756+
let mut memory_sizes = [0u64; 256];
757+
for (i, chunk) in payload.chunks_exact(8).enumerate() {
758+
memory_sizes[i] = u64::from_ne_bytes(chunk.try_into().unwrap());
759+
}
760+
761+
// Count non-zero elements to determine nregions
762+
let nregions = memory_sizes.iter().filter(|&&x| x != 0).count() as u32;
763+
764+
Ok(Self {
765+
nregions,
766+
padding: 0,
767+
memory_sizes,
768+
})
769+
}
728770
}
729771

730772
// SAFETY: Safe because all fields of VhostUserSingleMemoryRegion are POD.

0 commit comments

Comments
 (0)