@@ -894,7 +894,7 @@ def test_get_confidences(self):
894894 return_gross = True ,
895895 rng = rng
896896 )
897- _ , synthetic_gross_ss = synth .generate (fg_seeds_ss , bg_seeds_ss [0 ])
897+ _ , synthetic_gross_ss = synth .generate (fg_seeds_ss , bg_seeds_ss [0 ], verbose = False )
898898 synthetic_gross_ss .drop_sources (bg_seeds_ss .sources .columns .levels [2 ])
899899 synthetic_gross_ss .sources = synthetic_gross_ss .sources [fg_seeds_ss .sources .columns ]
900900 synthetic_gross_ss .prediction_probas = pd .DataFrame (
@@ -907,7 +907,8 @@ def test_get_confidences(self):
907907 bg_cps = synth .bg_cps
908908 )
909909
910- _ , synthetic_mixed_gross_ss = synth .generate (mixed_fg_seeds_ss , bg_seeds_ss [0 ])
910+ _ , synthetic_mixed_gross_ss = synth .generate (mixed_fg_seeds_ss , bg_seeds_ss [0 ],
911+ verbose = False )
911912 synthetic_mixed_gross_ss .drop_sources (bg_seeds_ss .sources .columns .levels [2 ])
912913 synthetic_mixed_gross_ss .sources = synthetic_mixed_gross_ss .sources [
913914 fg_seeds_ss .sources .columns
@@ -975,6 +976,60 @@ def test_get_confidences(self):
975976 bg_cps = None
976977 )
977978
979+ def test_get_energy_roi_masks (self ):
980+ from riid .data .sampleset import _get_energy_roi_masks
981+ ROIS1 = [
982+ (0 , 2 ),
983+ (4 , 8 ),
984+ ]
985+ ROIS2 = [
986+ (0 , 0.75 ),
987+ (2.0 , 2.5 ),
988+ (3.0 , 5.0 ),
989+ ]
990+ ENERGIES = np .array ([
991+ [0 , 1 , 2 , 3 , 4 , 5 , 6 ],
992+ [0 , 0.5 , 1.0 , 1.5 , 2.0 , 2.5 , 3.0 ]
993+ ])
994+ EXPECTED_MASKS1 = np .array ([
995+ [True , True , False , False , True , True , True ],
996+ [True , True , True , True , False , False , False ],
997+ ])
998+ EXPECTED_MASKS2 = np .array ([
999+ [True , False , True , True , True , False , False ],
1000+ [True , True , False , False , True , False , True ],
1001+ ])
1002+ masks1 = _get_energy_roi_masks (ROIS1 , ENERGIES )
1003+ masks2 = _get_energy_roi_masks (ROIS2 , ENERGIES )
1004+
1005+ self .assertTrue (np .array_equal (masks1 , EXPECTED_MASKS1 ))
1006+ self .assertTrue (np .array_equal (masks2 , EXPECTED_MASKS2 ))
1007+
1008+ def test_as_regions (self ):
1009+ from riid .data .sampleset import _get_energy_roi_masks
1010+ ROIS = [
1011+ (0 , 100 ),
1012+ (500 , 550 ),
1013+ (2400 , 2600 ),
1014+ ]
1015+ ss1 = get_dummy_seeds (n_channels = 1000 )
1016+ ss2 = get_dummy_seeds (n_channels = 1000 ).as_ecal (0 , 2500 , 0 , 0 , 0 )
1017+ ss3 = get_dummy_seeds (n_channels = 500 ).as_ecal (20 , 2000 , 0 , 0 , 0 )
1018+
1019+ with self .assertRaises (ValueError ):
1020+ ss1 .as_regions ([])
1021+
1022+ ss_mixed_ecal = SampleSet ()
1023+ ss_mixed_ecal .concat ([ss1 , ss2 ])
1024+ with self .assertRaises (ValueError ):
1025+ ss_mixed_ecal .as_regions (ROIS )
1026+
1027+ for ss in [ss1 , ss2 , ss3 ]:
1028+ channel_energies = ss .get_channel_energies (0 )
1029+ ss_channels_expected = _get_energy_roi_masks (ROIS , channel_energies ).sum ()
1030+ rois = ss .as_regions (ROIS )
1031+ self .assertEqual (rois .n_channels , ss_channels_expected )
1032+
9781033 def _assert_row_labels (self , level , actual , expected ):
9791034 for i , (a , e ) in enumerate (zip (actual , expected )):
9801035 self .assertEqual (
0 commit comments