@@ -796,9 +796,14 @@ impl<T: GuestMemory> Bytes<GuestAddress> for T {
796
796
mod tests {
797
797
use super :: * ;
798
798
#[ cfg( feature = "backend-mmap" ) ]
799
+ use crate :: bytes:: ByteValued ;
800
+ #[ cfg( feature = "backend-mmap" ) ]
799
801
use crate :: { GuestAddress , GuestMemoryMmap } ;
800
802
#[ cfg( feature = "backend-mmap" ) ]
801
803
use std:: io:: Cursor ;
804
+ #[ cfg( feature = "backend-mmap" ) ]
805
+ use std:: time:: { Duration , Instant } ;
806
+
802
807
use vmm_sys_util:: tempfile:: TempFile ;
803
808
804
809
#[ cfg( feature = "backend-mmap" ) ]
@@ -837,4 +842,124 @@ mod tests {
837
842
. unwrap( )
838
843
) ;
839
844
}
845
+
846
+ // Runs the provided closure in a loop, until at least `duration` time units have elapsed.
847
+ #[ cfg( feature = "backend-mmap" ) ]
848
+ fn loop_timed < F > ( duration : Duration , mut f : F )
849
+ where
850
+ F : FnMut ( ) -> ( ) ,
851
+ {
852
+ // We check the time every `CHECK_PERIOD` iterations.
853
+ const CHECK_PERIOD : u64 = 1_000_000 ;
854
+ let start_time = Instant :: now ( ) ;
855
+
856
+ loop {
857
+ for _ in 0 ..CHECK_PERIOD {
858
+ f ( ) ;
859
+ }
860
+ if start_time. elapsed ( ) >= duration {
861
+ break ;
862
+ }
863
+ }
864
+ }
865
+
866
+ // Helper method for the following test. It spawns a writer and a reader thread, which
867
+ // simultaneously try to access an object that is placed at the junction of two memory regions.
868
+ // The part of the object that's continuously accessed is a member of type T. The writer
869
+ // flips all the bits of the member with every write, while the reader checks that every byte
870
+ // has the same value (and thus it did not do a non-atomic access). The test succeeds if
871
+ // no mismatch is detected after performing accesses for a pre-determined amount of time.
872
+ #[ cfg( feature = "backend-mmap" ) ]
873
+ fn non_atomic_access_helper < T > ( )
874
+ where
875
+ T : ByteValued
876
+ + std:: fmt:: Debug
877
+ + From < u8 >
878
+ + Into < u128 >
879
+ + std:: ops:: Not < Output = T >
880
+ + PartialEq ,
881
+ {
882
+ use std:: mem;
883
+ use std:: thread;
884
+
885
+ // A dummy type that's always going to have the same alignment as the first member,
886
+ // and then adds some bytes at the end.
887
+ #[ derive( Clone , Copy , Debug , Default , PartialEq ) ]
888
+ struct Data < T > {
889
+ val : T ,
890
+ some_bytes : [ u8 ; 7 ] ,
891
+ }
892
+
893
+ // Some sanity checks.
894
+ assert_eq ! ( mem:: align_of:: <T >( ) , mem:: align_of:: <Data <T >>( ) ) ;
895
+ assert_eq ! ( mem:: size_of:: <T >( ) , mem:: align_of:: <T >( ) ) ;
896
+
897
+ unsafe impl < T : ByteValued > ByteValued for Data < T > { }
898
+
899
+ // Start of first guest memory region.
900
+ let start = GuestAddress ( 0 ) ;
901
+ let region_len = 1 << 12 ;
902
+
903
+ // The address where we start writing/reading a Data<T> value.
904
+ let data_start = GuestAddress ( ( region_len - mem:: size_of :: < T > ( ) ) as u64 ) ;
905
+
906
+ let mem = GuestMemoryMmap :: from_ranges ( & [
907
+ ( start, region_len) ,
908
+ ( start. unchecked_add ( region_len as u64 ) , region_len) ,
909
+ ] )
910
+ . unwrap ( ) ;
911
+
912
+ // Need to clone this and move it into the new thread we create.
913
+ let mem2 = mem. clone ( ) ;
914
+ // Just some bytes.
915
+ let some_bytes = [ 1u8 , 2 , 4 , 16 , 32 , 64 , 128 ] ;
916
+
917
+ let mut data = Data {
918
+ val : T :: from ( 0u8 ) ,
919
+ some_bytes,
920
+ } ;
921
+
922
+ // Simple check that cross-region write/read is ok.
923
+ mem. write_obj ( data, data_start) . unwrap ( ) ;
924
+ let read_data = mem. read_obj :: < Data < T > > ( data_start) . unwrap ( ) ;
925
+ assert_eq ! ( read_data, data) ;
926
+
927
+ let t = thread:: spawn ( move || {
928
+ let mut count: u64 = 0 ;
929
+
930
+ loop_timed ( Duration :: from_secs ( 3 ) , || {
931
+ let data = mem2. read_obj :: < Data < T > > ( data_start) . unwrap ( ) ;
932
+
933
+ // Every time data is written to memory by the other thread, the value of
934
+ // data.val alternates between 0 and T::MAX, so the inner bytes should always
935
+ // have the same value. If they don't match, it means we read a partial value,
936
+ // so the access was not atomic.
937
+ let bytes = data. val . into ( ) . to_le_bytes ( ) ;
938
+ for i in 1 ..mem:: size_of :: < T > ( ) {
939
+ if bytes[ 0 ] != bytes[ i] {
940
+ panic ! (
941
+ "val bytes don't match {:?} after {} iterations" ,
942
+ & bytes[ ..mem:: size_of:: <T >( ) ] ,
943
+ count
944
+ ) ;
945
+ }
946
+ }
947
+ count += 1 ;
948
+ } ) ;
949
+ } ) ;
950
+
951
+ // Write the object while flipping the bits of data.val over and over again.
952
+ loop_timed ( Duration :: from_secs ( 3 ) , || {
953
+ mem. write_obj ( data, data_start) . unwrap ( ) ;
954
+ data. val = !data. val ;
955
+ } ) ;
956
+
957
+ t. join ( ) . unwrap ( )
958
+ }
959
+
960
+ #[ cfg( feature = "backend-mmap" ) ]
961
+ #[ test]
962
+ fn test_non_atomic_access ( ) {
963
+ non_atomic_access_helper :: < u16 > ( )
964
+ }
840
965
}
0 commit comments