@@ -864,9 +864,14 @@ impl<T: GuestMemory> Bytes<GuestAddress> for T {
864
864
mod tests {
865
865
use super :: * ;
866
866
#[ cfg( feature = "backend-mmap" ) ]
867
+ use crate :: bytes:: ByteValued ;
868
+ #[ cfg( feature = "backend-mmap" ) ]
867
869
use crate :: { GuestAddress , GuestMemoryMmap } ;
868
870
#[ cfg( feature = "backend-mmap" ) ]
869
871
use std:: io:: Cursor ;
872
+ #[ cfg( feature = "backend-mmap" ) ]
873
+ use std:: time:: { Duration , Instant } ;
874
+
870
875
use vmm_sys_util:: tempfile:: TempFile ;
871
876
872
877
#[ cfg( feature = "backend-mmap" ) ]
@@ -905,4 +910,124 @@ mod tests {
905
910
. unwrap( )
906
911
) ;
907
912
}
913
+
914
+ // Runs the provided closure in a loop, until at least `duration` time units have elapsed.
915
+ #[ cfg( feature = "backend-mmap" ) ]
916
+ fn loop_timed < F > ( duration : Duration , mut f : F )
917
+ where
918
+ F : FnMut ( ) -> ( ) ,
919
+ {
920
+ // We check the time every `CHECK_PERIOD` iterations.
921
+ const CHECK_PERIOD : u64 = 1_000_000 ;
922
+ let start_time = Instant :: now ( ) ;
923
+
924
+ loop {
925
+ for _ in 0 ..CHECK_PERIOD {
926
+ f ( ) ;
927
+ }
928
+ if start_time. elapsed ( ) >= duration {
929
+ break ;
930
+ }
931
+ }
932
+ }
933
+
934
+ // Helper method for the following test. It spawns a writer and a reader thread, which
935
+ // simultaneously try to access an object that is placed at the junction of two memory regions.
936
+ // The part of the object that's continuously accessed is a member of type T. The writer
937
+ // flips all the bits of the member with every write, while the reader checks that every byte
938
+ // has the same value (and thus it did not do a non-atomic access). The test succeeds if
939
+ // no mismatch is detected after performing accesses for a pre-determined amount of time.
940
+ #[ cfg( feature = "backend-mmap" ) ]
941
+ fn non_atomic_access_helper < T > ( )
942
+ where
943
+ T : ByteValued
944
+ + std:: fmt:: Debug
945
+ + From < u8 >
946
+ + Into < u128 >
947
+ + std:: ops:: Not < Output = T >
948
+ + PartialEq ,
949
+ {
950
+ use std:: mem;
951
+ use std:: thread;
952
+
953
+ // A dummy type that's always going to have the same alignment as the first member,
954
+ // and then adds some bytes at the end.
955
+ #[ derive( Clone , Copy , Debug , Default , PartialEq ) ]
956
+ struct Data < T > {
957
+ val : T ,
958
+ some_bytes : [ u8 ; 7 ] ,
959
+ }
960
+
961
+ // Some sanity checks.
962
+ assert_eq ! ( mem:: align_of:: <T >( ) , mem:: align_of:: <Data <T >>( ) ) ;
963
+ assert_eq ! ( mem:: size_of:: <T >( ) , mem:: align_of:: <T >( ) ) ;
964
+
965
+ unsafe impl < T : ByteValued > ByteValued for Data < T > { }
966
+
967
+ // Start of first guest memory region.
968
+ let start = GuestAddress ( 0 ) ;
969
+ let region_len = 1 << 12 ;
970
+
971
+ // The address where we start writing/reading a Data<T> value.
972
+ let data_start = GuestAddress ( ( region_len - mem:: size_of :: < T > ( ) ) as u64 ) ;
973
+
974
+ let mem = GuestMemoryMmap :: from_ranges ( & [
975
+ ( start, region_len) ,
976
+ ( start. unchecked_add ( region_len as u64 ) , region_len) ,
977
+ ] )
978
+ . unwrap ( ) ;
979
+
980
+ // Need to clone this and move it into the new thread we create.
981
+ let mem2 = mem. clone ( ) ;
982
+ // Just some bytes.
983
+ let some_bytes = [ 1u8 , 2 , 4 , 16 , 32 , 64 , 128 ] ;
984
+
985
+ let mut data = Data {
986
+ val : T :: from ( 0u8 ) ,
987
+ some_bytes,
988
+ } ;
989
+
990
+ // Simple check that cross-region write/read is ok.
991
+ mem. write_obj ( data, data_start) . unwrap ( ) ;
992
+ let read_data = mem. read_obj :: < Data < T > > ( data_start) . unwrap ( ) ;
993
+ assert_eq ! ( read_data, data) ;
994
+
995
+ let t = thread:: spawn ( move || {
996
+ let mut count: u64 = 0 ;
997
+
998
+ loop_timed ( Duration :: from_secs ( 3 ) , || {
999
+ let data = mem2. read_obj :: < Data < T > > ( data_start) . unwrap ( ) ;
1000
+
1001
+ // Every time data is written to memory by the other thread, the value of
1002
+ // data.val alternates between 0 and T::MAX, so the inner bytes should always
1003
+ // have the same value. If they don't match, it means we read a partial value,
1004
+ // so the access was not atomic.
1005
+ let bytes = data. val . into ( ) . to_le_bytes ( ) ;
1006
+ for i in 1 ..mem:: size_of :: < T > ( ) {
1007
+ if bytes[ 0 ] != bytes[ i] {
1008
+ panic ! (
1009
+ "val bytes don't match {:?} after {} iterations" ,
1010
+ & bytes[ ..mem:: size_of:: <T >( ) ] ,
1011
+ count
1012
+ ) ;
1013
+ }
1014
+ }
1015
+ count += 1 ;
1016
+ } ) ;
1017
+ } ) ;
1018
+
1019
+ // Write the object while flipping the bits of data.val over and over again.
1020
+ loop_timed ( Duration :: from_secs ( 3 ) , || {
1021
+ mem. write_obj ( data, data_start) . unwrap ( ) ;
1022
+ data. val = !data. val ;
1023
+ } ) ;
1024
+
1025
+ t. join ( ) . unwrap ( )
1026
+ }
1027
+
1028
+ #[ cfg( feature = "backend-mmap" ) ]
1029
+ #[ test]
1030
+ fn test_non_atomic_access ( ) {
1031
+ non_atomic_access_helper :: < u16 > ( )
1032
+ }
908
1033
}
0 commit comments