@@ -81,6 +81,7 @@ pub trait VhostUserBackendReqHandler {
8181 fd : File ,
8282 ) -> Result < Option < File > > ;
8383 fn check_device_state ( & self ) -> Result < ( ) > ;
84+ fn get_shmem_config ( & self ) -> Result < VhostUserShMemConfig > ;
8485 #[ cfg( feature = "postcopy" ) ]
8586 fn postcopy_advice ( & self ) -> Result < File > ;
8687 #[ cfg( feature = "postcopy" ) ]
@@ -146,6 +147,7 @@ pub trait VhostUserBackendReqHandlerMut {
146147 fd : File ,
147148 ) -> Result < Option < File > > ;
148149 fn check_device_state ( & mut self ) -> Result < ( ) > ;
150+ fn get_shmem_config ( & self ) -> Result < VhostUserShMemConfig > ;
149151 #[ cfg( feature = "postcopy" ) ]
150152 fn postcopy_advice ( & mut self ) -> Result < File > ;
151153 #[ cfg( feature = "postcopy" ) ]
@@ -289,6 +291,10 @@ impl<T: VhostUserBackendReqHandlerMut> VhostUserBackendReqHandler for Mutex<T> {
289291 self . lock ( ) . unwrap ( ) . check_device_state ( )
290292 }
291293
294+ fn get_shmem_config ( & self ) -> Result < VhostUserShMemConfig > {
295+ self . lock ( ) . unwrap ( ) . get_shmem_config ( )
296+ }
297+
292298 #[ cfg( feature = "postcopy" ) ]
293299 fn postcopy_advice ( & self ) -> Result < File > {
294300 self . lock ( ) . unwrap ( ) . postcopy_advice ( )
@@ -679,6 +685,11 @@ impl<S: VhostUserBackendReqHandler> BackendReqHandler<S> {
679685 } ;
680686 self . send_reply_message ( & hdr, & msg) ?;
681687 }
688+ Ok ( FrontendReq :: GET_SHMEM_CONFIG ) => {
689+ self . check_proto_feature ( VhostUserProtocolFeatures :: SHMEM ) ?;
690+ let msg = self . backend . get_shmem_config ( ) . unwrap_or_default ( ) ;
691+ self . send_reply_message ( & hdr, & msg) ?;
692+ }
682693 #[ cfg( feature = "postcopy" ) ]
683694 Ok ( FrontendReq :: POSTCOPY_ADVISE ) => {
684695 self . check_proto_feature ( VhostUserProtocolFeatures :: PAGEFAULT ) ?;
@@ -1038,4 +1049,108 @@ mod tests {
10381049 handler. check_state ( ) . unwrap_err ( ) ;
10391050 assert ! ( handler. as_raw_fd( ) >= 0 ) ;
10401051 }
1052+
1053+ // Helper to send GET_SHMEM_CONFIG request and receive response
1054+ fn send_get_shmem_config_request (
1055+ mut endpoint : Endpoint < VhostUserMsgHeader < FrontendReq > > ,
1056+ ) -> VhostUserShMemConfig {
1057+ let hdr = VhostUserMsgHeader :: new ( FrontendReq :: GET_SHMEM_CONFIG , 0 , 0 ) ;
1058+ endpoint. send_message ( & hdr, & VhostUserEmpty , None ) . unwrap ( ) ;
1059+
1060+ let ( reply_hdr, reply_config, rfds) = endpoint. recv_body :: < VhostUserShMemConfig > ( ) . unwrap ( ) ;
1061+ assert_eq ! ( reply_hdr. get_code( ) . unwrap( ) , FrontendReq :: GET_SHMEM_CONFIG ) ;
1062+ assert ! ( reply_hdr. is_reply( ) ) ;
1063+ assert ! ( rfds. is_none( ) ) ;
1064+ reply_config
1065+ }
1066+
1067+ // Helper to create handler with SHMEM protocol feature enabled
1068+ fn create_handler_with_shmem (
1069+ backend : Arc < Mutex < DummyBackendReqHandler > > ,
1070+ p1 : UnixStream ,
1071+ ) -> BackendReqHandler < Mutex < DummyBackendReqHandler > > {
1072+ let mut handler = BackendReqHandler :: new (
1073+ Endpoint :: < VhostUserMsgHeader < FrontendReq > > :: from_stream ( p1) ,
1074+ backend,
1075+ ) ;
1076+ handler. acked_protocol_features = VhostUserProtocolFeatures :: SHMEM . bits ( ) ;
1077+ handler
1078+ }
1079+
1080+ #[ test]
1081+ fn test_get_shmem_config_multiple_regions ( ) {
1082+ let memory_sizes = [
1083+ 0x1000 , 0x2000 , 0x3000 , 0x4000 , 0x5000 , 0x6000 , 0x7000 , 0x8000 ,
1084+ ] ;
1085+ let config = VhostUserShMemConfig :: new ( 8 , & memory_sizes) ;
1086+
1087+ let ( p1, p2) = UnixStream :: pair ( ) . unwrap ( ) ;
1088+ let mut dummy_backend = DummyBackendReqHandler :: new ( ) ;
1089+ dummy_backend. set_shmem_config ( config) ;
1090+ let mut handler = create_handler_with_shmem ( Arc :: new ( Mutex :: new ( dummy_backend) ) , p1) ;
1091+
1092+ let handle = std:: thread:: spawn ( move || {
1093+ send_get_shmem_config_request ( Endpoint :: < VhostUserMsgHeader < FrontendReq > > :: from_stream (
1094+ p2,
1095+ ) )
1096+ } ) ;
1097+
1098+ handler. handle_request ( ) . unwrap ( ) ;
1099+
1100+ let reply_config = handle. join ( ) . unwrap ( ) ;
1101+ assert_eq ! ( reply_config. nregions, 8 ) ;
1102+ for i in 0 ..8 {
1103+ assert_eq ! ( reply_config. memory_sizes[ i] , ( i as u64 + 1 ) * 0x1000 ) ;
1104+ }
1105+ for i in 8 ..256 {
1106+ assert_eq ! ( reply_config. memory_sizes[ i] , 0 ) ;
1107+ }
1108+ }
1109+
1110+ #[ test]
1111+ fn test_get_shmem_config_non_continuous_regions ( ) {
1112+ // Create a configuration with non-continuous regions
1113+ let memory_sizes = [ 0x10000 , 0 , 0x20000 , 0 , 0 , 0 , 0 , 0 ] ;
1114+ let config = VhostUserShMemConfig :: new ( 2 , & memory_sizes) ;
1115+
1116+ let ( p1, p2) = UnixStream :: pair ( ) . unwrap ( ) ;
1117+ let mut dummy_backend = DummyBackendReqHandler :: new ( ) ;
1118+ dummy_backend. set_shmem_config ( config) ;
1119+ let mut handler = create_handler_with_shmem ( Arc :: new ( Mutex :: new ( dummy_backend) ) , p1) ;
1120+
1121+ let handle = std:: thread:: spawn ( move || {
1122+ send_get_shmem_config_request ( Endpoint :: < VhostUserMsgHeader < FrontendReq > > :: from_stream (
1123+ p2,
1124+ ) )
1125+ } ) ;
1126+
1127+ handler. handle_request ( ) . unwrap ( ) ;
1128+
1129+ let reply_config = handle. join ( ) . unwrap ( ) ;
1130+ assert_eq ! ( reply_config. nregions, 2 ) ;
1131+ assert_eq ! ( reply_config. memory_sizes[ 0 ] , 0x10000 ) ;
1132+ assert_eq ! ( reply_config. memory_sizes[ 1 ] , 0 ) ;
1133+ assert_eq ! ( reply_config. memory_sizes[ 2 ] , 0x20000 ) ;
1134+ for i in 3 ..256 {
1135+ assert_eq ! ( reply_config. memory_sizes[ i] , 0 ) ;
1136+ }
1137+ }
1138+
1139+ #[ test]
1140+ fn test_get_shmem_config_feature_not_negotiated ( ) {
1141+ // Test that the request fails when SHMEM protocol feature is not negotiated
1142+ let ( p1, p2) = UnixStream :: pair ( ) . unwrap ( ) ;
1143+ let backend = Arc :: new ( Mutex :: new ( DummyBackendReqHandler :: new ( ) ) ) ;
1144+ let mut handler = BackendReqHandler :: new (
1145+ Endpoint :: < VhostUserMsgHeader < FrontendReq > > :: from_stream ( p1) ,
1146+ backend,
1147+ ) ;
1148+ let mut frontend_endpoint = Endpoint :: < VhostUserMsgHeader < FrontendReq > > :: from_stream ( p2) ;
1149+
1150+ std:: thread:: spawn ( move || {
1151+ let hdr = VhostUserMsgHeader :: new ( FrontendReq :: GET_SHMEM_CONFIG , 0 , 0 ) ;
1152+ let _ = frontend_endpoint. send_message ( & hdr, & VhostUserEmpty , None ) ;
1153+ } ) ;
1154+ assert ! ( handler. handle_request( ) . is_err( ) ) ;
1155+ }
10411156}
0 commit comments