@@ -847,9 +847,14 @@ impl<T: GuestMemory> Bytes<GuestAddress> for T {
847
847
mod tests {
848
848
use super :: * ;
849
849
#[ cfg( feature = "backend-mmap" ) ]
850
+ use crate :: bytes:: ByteValued ;
851
+ #[ cfg( feature = "backend-mmap" ) ]
850
852
use crate :: { GuestAddress , GuestMemoryMmap } ;
851
853
#[ cfg( feature = "backend-mmap" ) ]
852
854
use std:: io:: Cursor ;
855
+ #[ cfg( feature = "backend-mmap" ) ]
856
+ use std:: time:: { Duration , Instant } ;
857
+
853
858
use vmm_sys_util:: tempfile:: TempFile ;
854
859
855
860
#[ cfg( feature = "backend-mmap" ) ]
@@ -888,4 +893,124 @@ mod tests {
888
893
. unwrap( )
889
894
) ;
890
895
}
896
+
897
+ // Runs the provided closure in a loop, until at least `duration` time units have elapsed.
898
+ #[ cfg( feature = "backend-mmap" ) ]
899
+ fn loop_timed < F > ( duration : Duration , mut f : F )
900
+ where
901
+ F : FnMut ( ) -> ( ) ,
902
+ {
903
+ // We check the time every `CHECK_PERIOD` iterations.
904
+ const CHECK_PERIOD : u64 = 1_000_000 ;
905
+ let start_time = Instant :: now ( ) ;
906
+
907
+ loop {
908
+ for _ in 0 ..CHECK_PERIOD {
909
+ f ( ) ;
910
+ }
911
+ if start_time. elapsed ( ) >= duration {
912
+ break ;
913
+ }
914
+ }
915
+ }
916
+
917
+ // Helper method for the following test. It spawns a writer and a reader thread, which
918
+ // simultaneously try to access an object that is placed at the junction of two memory regions.
919
+ // The part of the object that's continuously accessed is a member of type T. The writer
920
+ // flips all the bits of the member with every write, while the reader checks that every byte
921
+ // has the same value (and thus it did not do a non-atomic access). The test succeeds if
922
+ // no mismatch is detected after performing accesses for a pre-determined amount of time.
923
+ #[ cfg( feature = "backend-mmap" ) ]
924
+ fn non_atomic_access_helper < T > ( )
925
+ where
926
+ T : ByteValued
927
+ + std:: fmt:: Debug
928
+ + From < u8 >
929
+ + Into < u128 >
930
+ + std:: ops:: Not < Output = T >
931
+ + PartialEq ,
932
+ {
933
+ use std:: mem;
934
+ use std:: thread;
935
+
936
+ // A dummy type that's always going to have the same alignment as the first member,
937
+ // and then adds some bytes at the end.
938
+ #[ derive( Clone , Copy , Debug , Default , PartialEq ) ]
939
+ struct Data < T > {
940
+ val : T ,
941
+ some_bytes : [ u8 ; 7 ] ,
942
+ }
943
+
944
+ // Some sanity checks.
945
+ assert_eq ! ( mem:: align_of:: <T >( ) , mem:: align_of:: <Data <T >>( ) ) ;
946
+ assert_eq ! ( mem:: size_of:: <T >( ) , mem:: align_of:: <T >( ) ) ;
947
+
948
+ unsafe impl < T : ByteValued > ByteValued for Data < T > { }
949
+
950
+ // Start of first guest memory region.
951
+ let start = GuestAddress ( 0 ) ;
952
+ let region_len = 1 << 12 ;
953
+
954
+ // The address where we start writing/reading a Data<T> value.
955
+ let data_start = GuestAddress ( ( region_len - mem:: size_of :: < T > ( ) ) as u64 ) ;
956
+
957
+ let mem = GuestMemoryMmap :: from_ranges ( & [
958
+ ( start, region_len) ,
959
+ ( start. unchecked_add ( region_len as u64 ) , region_len) ,
960
+ ] )
961
+ . unwrap ( ) ;
962
+
963
+ // Need to clone this and move it into the new thread we create.
964
+ let mem2 = mem. clone ( ) ;
965
+ // Just some bytes.
966
+ let some_bytes = [ 1u8 , 2 , 4 , 16 , 32 , 64 , 128 ] ;
967
+
968
+ let mut data = Data {
969
+ val : T :: from ( 0u8 ) ,
970
+ some_bytes,
971
+ } ;
972
+
973
+ // Simple check that cross-region write/read is ok.
974
+ mem. write_obj ( data, data_start) . unwrap ( ) ;
975
+ let read_data = mem. read_obj :: < Data < T > > ( data_start) . unwrap ( ) ;
976
+ assert_eq ! ( read_data, data) ;
977
+
978
+ let t = thread:: spawn ( move || {
979
+ let mut count: u64 = 0 ;
980
+
981
+ loop_timed ( Duration :: from_secs ( 3 ) , || {
982
+ let data = mem2. read_obj :: < Data < T > > ( data_start) . unwrap ( ) ;
983
+
984
+ // Every time data is written to memory by the other thread, the value of
985
+ // data.val alternates between 0 and T::MAX, so the inner bytes should always
986
+ // have the same value. If they don't match, it means we read a partial value,
987
+ // so the access was not atomic.
988
+ let bytes = data. val . into ( ) . to_le_bytes ( ) ;
989
+ for i in 1 ..mem:: size_of :: < T > ( ) {
990
+ if bytes[ 0 ] != bytes[ i] {
991
+ panic ! (
992
+ "val bytes don't match {:?} after {} iterations" ,
993
+ & bytes[ ..mem:: size_of:: <T >( ) ] ,
994
+ count
995
+ ) ;
996
+ }
997
+ }
998
+ count += 1 ;
999
+ } ) ;
1000
+ } ) ;
1001
+
1002
+ // Write the object while flipping the bits of data.val over and over again.
1003
+ loop_timed ( Duration :: from_secs ( 3 ) , || {
1004
+ mem. write_obj ( data, data_start) . unwrap ( ) ;
1005
+ data. val = !data. val ;
1006
+ } ) ;
1007
+
1008
+ t. join ( ) . unwrap ( )
1009
+ }
1010
+
1011
+ #[ cfg( feature = "backend-mmap" ) ]
1012
+ #[ test]
1013
+ fn test_non_atomic_access ( ) {
1014
+ non_atomic_access_helper :: < u16 > ( )
1015
+ }
891
1016
}
0 commit comments