44
55
66@pytest .fixture (scope = "module" )
7- def data_manager ():
7+ def feature_data_manager ():
88 """Create data manager for testing (reused across all tests in module)."""
99 experiment_dict = {"ops0033_20250429" : ["A/1/0" , "A/2/0" , "A/3/0" ]}
1010 dm = data_loader .OpsDataManager (
@@ -20,13 +20,36 @@ def data_manager():
2020
2121
2222@pytest .fixture (scope = "module" )
23- def batch (data_manager ):
23+ def basic_data_manager ():
24+ """Create data manager for testing (reused across all tests in module)."""
25+ experiment_dict = {"ops0033_20250429" : ["A/1/0" , "A/2/0" , "A/3/0" ]}
26+ dm = data_loader .OpsDataManager (
27+ experiments = experiment_dict ,
28+ batch_size = 2 ,
29+ data_split = (1 , 0 , 0 ),
30+ out_channels = ["Phase2D" , "mCherry" ],
31+ initial_yx_patch_size = (256 , 256 ),
32+ verbose = False ,
33+ )
34+ dm .construct_dataloaders (num_workers = 1 , dataset_type = "basic" )
35+ return dm
36+
37+
38+ @pytest .fixture (scope = "module" )
39+ def feature_batch (feature_data_manager ):
2440 """Get a single batch for testing (reused across all tests in module)."""
25- train_loader = data_manager .train_loader
41+ train_loader = feature_data_manager .train_loader
2642 return next (iter (train_loader ))
2743
2844
29- def test_batch_keys_cellprofiler (batch ):
45+ @pytest .fixture (scope = "module" )
46+ def basic_batch (basic_data_manager ):
47+ """Get a single batch for testing (reused across all tests in module)."""
48+ train_loader = basic_data_manager .train_loader
49+ return next (iter (train_loader ))
50+
51+
52+ def test_batch_keys_cellprofiler (feature_batch ):
3053 expected_keys = [
3154 "data" ,
3255 "cell_mask" ,
@@ -51,31 +74,84 @@ def test_batch_keys_cellprofiler(batch):
5174 "crop_info" : list ,
5275 }
5376
54- batch_keys = list (batch .keys ())
77+ batch_keys = list (feature_batch .keys ())
5578 for k , v in expected_keys .items ():
5679 assert k in batch_keys
5780
58- assert isinstance (batch [k ], v )
81+ assert isinstance (feature_batch [k ], v )
82+ return
83+
84+
85+ # Test that the data returned is normalized
86+ def test_data_normalization (basic_batch ):
87+ data = basic_batch ["data" ]
88+ # compute mean over all but batch and channel dimensions
89+ mean = torch .mean (data , dim = (0 , 2 , 3 ))
90+
91+ # assert that mean is approximately 0
92+ assert torch .allclose (mean , torch .zeros_like (mean ), atol = 1e-1 )
93+
94+ return
95+
96+
97+ # test that requesting different out channels works
98+ def test_out_channels (basic_data_manager , basic_batch ):
99+
100+ shape = basic_batch ["data" ].shape
101+ assert shape [1 ] == 2 # 2 out channels requested
102+
103+ basic_data_manager .out_channels = ["Phase2D" ]
104+ basic_data_manager .construct_dataloaders (num_workers = 1 , dataset_type = "basic" )
105+ batch = next (iter (basic_data_manager .train_loader ))
106+ shape_1 = batch ["data" ].shape
107+ assert shape_1 [1 ] == 1 # 1 out channel requested
108+
109+ basic_data_manager .out_channels = ["mCherry" ]
110+ basic_data_manager .construct_dataloaders (num_workers = 1 , dataset_type = "basic" )
111+ batch = next (iter (basic_data_manager .train_loader ))
112+ shape_2 = batch ["data" ].shape
113+ assert shape_2 [1 ] == 1 # 1 out channel requested
59114
60115 return
61116
62117
63- # def test_batch_keys_basic(batch):
118+ # Test that turning masking on/off works
119+ def test_cell_masking (feature_data_manager , feature_batch ):
120+
121+ data = feature_batch ["data" ]
122+ cell_mask = feature_batch ["cell_mask" ]
123+ # assert that where cell_mask is 0, data is also 0
124+ masked_data = data * (cell_mask == 0 )
125+ assert torch .sum (masked_data ) == 0
64126
65- # return
127+ feature_data_manager .train_loader .dataset .use_cell_mask = False
128+ batch = next (iter (feature_data_manager .train_loader ))
129+ data_no_mask = batch ["data" ]
130+ # assert that data_no_mask is not equal to data everywhere
131+ assert not torch .equal (data , data_no_mask )
66132
133+ return
67134
68- # def test_data_loader_consistancy(data_manager):
69- # dm, batch = data_manager
70135
71- # new_data_manager, _ = create_data_manager()
136+ # Test that different patch sizes work
137+ def test_patch_size (basic_data_manager , basic_batch ):
72138
73- # batch_labels = batch["gene_label"].detach().cpu().numpy()
74- # total_indxs = batch["total_index"].detach().cpu().numpy()
139+ shape = basic_batch ["data" ].shape
140+ assert shape [2 ] == 128 # initial patch size
141+ assert shape [3 ] == 128
75142
76- # gene_names = dm.labels_df.iloc[total_indxs].gene_name.to_list()
77- # mapped_labels = np.asarray([new_data_manager.label_int_lut[a] for a in gene_names])
143+ basic_data_manager .final_yx_patch_size = (256 , 256 )
144+ basic_data_manager .construct_dataloaders (num_workers = 1 , dataset_type = "basic" )
145+ batch = next (iter (basic_data_manager .train_loader ))
146+ shape_1 = batch ["data" ].shape
147+ assert shape_1 [2 ] == 256 # changed patch size
148+ assert shape_1 [3 ] == 256
78149
79- # assert np.all(batch_labels == mapped_labels)
150+ basic_data_manager .final_yx_patch_size = (64 , 64 )
151+ basic_data_manager .construct_dataloaders (num_workers = 1 , dataset_type = "basic" )
152+ batch = next (iter (basic_data_manager .train_loader ))
153+ shape_2 = batch ["data" ].shape
154+ assert shape_2 [2 ] == 64 # changed patch size again
155+ assert shape_2 [3 ] == 64
80156
81- # return
157+ return
0 commit comments