Skip to content

Commit 82a7300

Browse files
author
Alexander Hillsley
committed
bugfix [data_loader]: fix normalization in the dataloader
Previous: normalized by the mean of all pixels, including both channels, this resulted in a tiny range of values for the phase channel and a very unusual distribtuion for the fluorescent channel Fix: For the phase channel do nothing, values are already 0-centered and have a std of roughly 0.25. For the fluorescent channel take the log o the image to squash the long tail, then zero-mean and standardize
1 parent 1a81328 commit 82a7300

File tree

2 files changed

+113
-62
lines changed

2 files changed

+113
-62
lines changed

src/ops_model/data/data_loader.py

Lines changed: 20 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -211,46 +211,20 @@ def __init__(
211211
)
212212
return
213213

214-
def _normalize_data(self, ci, channel_names, data, masks):
215-
216-
# Temporary Fix
217-
# normalize per crop and squash all values between -1 and 1
218-
data_shift = data - np.mean(data)
219-
lo, hi = np.percentile(data_shift, [1, 99.5])
220-
scale = max(abs(lo), abs(hi)) # symmetric mapping
221-
data_norm = np.clip(data_shift, -scale, scale) / scale
222-
if self.cell_masks:
223-
data_norm = data_norm * masks
224-
225-
# fov_attrs = self.stores[ci.store_key][
226-
# ci.tile_pheno
227-
# ].zattrs.asdict() # can create dict for all tiles at beginning
228-
229-
# # TODO: need a real measure of dataset background
230-
# bg = [np.percentile(data, 1)]
231-
232-
# iqrs = [
233-
# fov_attrs["normalization"][i]["fov_statistics"]["iqr"]
234-
# for i in channel_names
235-
# ]
236-
# means = [
237-
# fov_attrs["normalization"][i]["fov_statistics"]["mean"]
238-
# for i in channel_names
239-
# ]
240-
241-
# data_bg_sub = np.clip(data - np.expand_dims(bg, (1, 2)), a_min=0, a_max=None)
242-
243-
# if self.cell_masks:
244-
# data_bg_sub = data_bg_sub * masks
245-
246-
# data_iqr = (data_bg_sub - np.expand_dims(means, (1, 2))) / (
247-
# np.expand_dims(iqrs, (1, 2)) + 1e-6
248-
# )
249-
250-
# # TODO: Need to fix to work with multiple channels
251-
# lo, hi = np.percentile(data_iqr, [1, 99.5])
252-
# scale = max(abs(lo), abs(hi)) # symmetric mapping
253-
# data_norm = np.clip(data_iqr, -scale, scale) / scale
214+
def _normalize_data(self, channel_names, data):
215+
img_list = []
216+
for ch in channel_names:
217+
print(ch)
218+
if ch == "Phase2D":
219+
img_list.append(data[0])
220+
else:
221+
# apply log normalization
222+
img = data[channel_names.index(ch)]
223+
log_img = np.log1p(img)
224+
img_norm = (log_img - log_img.mean()) / log_img.std()
225+
img_list.append(img_norm)
226+
227+
data_norm = np.stack(img_list, axis=0)
254228

255229
return data_norm
256230

@@ -282,9 +256,7 @@ def __getitem__(self, index):
282256
gene_label = self.label_int_lut[ci.gene_name]
283257
total_index = ci.total_index
284258

285-
channel_names, channel_index = self._get_channels(
286-
ci, well
287-
) # probably doesn't have to be done per dataset
259+
channel_names, channel_index = self._get_channels(ci, well)
288260

289261
data = np.asarray(
290262
fov[0, channel_index, 0, slice(bbox[0], bbox[2]), slice(bbox[1], bbox[3])]
@@ -295,7 +267,10 @@ def __getitem__(self, index):
295267
).copy()
296268
sc_mask = mask == ci.segmentation_id
297269

298-
data_norm = self._normalize_data(ci, channel_names, data, sc_mask)
270+
data_norm = self._normalize_data(channel_names, data)
271+
272+
if self.cell_masks:
273+
data_norm = data_norm * sc_mask
299274

300275
batch = {
301276
"data": data_norm.astype(np.float32),
@@ -455,7 +430,7 @@ def construct_dataloaders(
455430
self,
456431
num_workers: int = 1,
457432
shuffle: bool = True,
458-
dataset_type: Literal["basic", "triplet"] = "basic",
433+
dataset_type: Literal["basic", "triplet", "cell_profile"] = "basic",
459434
triplet_kwargs: dict = None,
460435
basic_kwargs: dict = None,
461436
cp_kwargs: dict = None,

tests/test_dataloader.py

Lines changed: 93 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
@pytest.fixture(scope="module")
7-
def data_manager():
7+
def feature_data_manager():
88
"""Create data manager for testing (reused across all tests in module)."""
99
experiment_dict = {"ops0033_20250429": ["A/1/0", "A/2/0", "A/3/0"]}
1010
dm = data_loader.OpsDataManager(
@@ -20,13 +20,36 @@ def data_manager():
2020

2121

2222
@pytest.fixture(scope="module")
23-
def batch(data_manager):
23+
def basic_data_manager():
24+
"""Create data manager for testing (reused across all tests in module)."""
25+
experiment_dict = {"ops0033_20250429": ["A/1/0", "A/2/0", "A/3/0"]}
26+
dm = data_loader.OpsDataManager(
27+
experiments=experiment_dict,
28+
batch_size=2,
29+
data_split=(1, 0, 0),
30+
out_channels=["Phase2D", "mCherry"],
31+
initial_yx_patch_size=(256, 256),
32+
verbose=False,
33+
)
34+
dm.construct_dataloaders(num_workers=1, dataset_type="basic")
35+
return dm
36+
37+
38+
@pytest.fixture(scope="module")
39+
def feature_batch(feature_data_manager):
2440
"""Get a single batch for testing (reused across all tests in module)."""
25-
train_loader = data_manager.train_loader
41+
train_loader = feature_data_manager.train_loader
2642
return next(iter(train_loader))
2743

2844

29-
def test_batch_keys_cellprofiler(batch):
45+
@pytest.fixture(scope="module")
46+
def basic_batch(basic_data_manager):
47+
"""Get a single batch for testing (reused across all tests in module)."""
48+
train_loader = basic_data_manager.train_loader
49+
return next(iter(train_loader))
50+
51+
52+
def test_batch_keys_cellprofiler(feature_batch):
3053
expected_keys = [
3154
"data",
3255
"cell_mask",
@@ -51,31 +74,84 @@ def test_batch_keys_cellprofiler(batch):
5174
"crop_info": list,
5275
}
5376

54-
batch_keys = list(batch.keys())
77+
batch_keys = list(feature_batch.keys())
5578
for k, v in expected_keys.items():
5679
assert k in batch_keys
5780

58-
assert isinstance(batch[k], v)
81+
assert isinstance(feature_batch[k], v)
82+
return
83+
84+
85+
# Test that the data returned is normalized
86+
def test_data_normalization(basic_batch):
87+
data = basic_batch["data"]
88+
# compute mean over all but batch and channel dimensions
89+
mean = torch.mean(data, dim=(0, 2, 3))
90+
91+
# assert that mean is approximately 0
92+
assert torch.allclose(mean, torch.zeros_like(mean), atol=1e-1)
93+
94+
return
95+
96+
97+
# test that requesting different out channels works
98+
def test_out_channels(basic_data_manager, basic_batch):
99+
100+
shape = basic_batch["data"].shape
101+
assert shape[1] == 2 # 2 out channels requested
102+
103+
basic_data_manager.out_channels = ["Phase2D"]
104+
basic_data_manager.construct_dataloaders(num_workers=1, dataset_type="basic")
105+
batch = next(iter(basic_data_manager.train_loader))
106+
shape_1 = batch["data"].shape
107+
assert shape_1[1] == 1 # 1 out channel requested
108+
109+
basic_data_manager.out_channels = ["mCherry"]
110+
basic_data_manager.construct_dataloaders(num_workers=1, dataset_type="basic")
111+
batch = next(iter(basic_data_manager.train_loader))
112+
shape_2 = batch["data"].shape
113+
assert shape_2[1] == 1 # 1 out channel requested
59114

60115
return
61116

62117

63-
# def test_batch_keys_basic(batch):
118+
# Test that turning masking on/off works
119+
def test_cell_masking(feature_data_manager, feature_batch):
120+
121+
data = feature_batch["data"]
122+
cell_mask = feature_batch["cell_mask"]
123+
# assert that where cell_mask is 0, data is also 0
124+
masked_data = data * (cell_mask == 0)
125+
assert torch.sum(masked_data) == 0
64126

65-
# return
127+
feature_data_manager.train_loader.dataset.use_cell_mask = False
128+
batch = next(iter(feature_data_manager.train_loader))
129+
data_no_mask = batch["data"]
130+
# assert that data_no_mask is not equal to data everywhere
131+
assert not torch.equal(data, data_no_mask)
66132

133+
return
67134

68-
# def test_data_loader_consistancy(data_manager):
69-
# dm, batch = data_manager
70135

71-
# new_data_manager, _ = create_data_manager()
136+
# Test that different patch sizes work
137+
def test_patch_size(basic_data_manager, basic_batch):
72138

73-
# batch_labels = batch["gene_label"].detach().cpu().numpy()
74-
# total_indxs = batch["total_index"].detach().cpu().numpy()
139+
shape = basic_batch["data"].shape
140+
assert shape[2] == 128 # initial patch size
141+
assert shape[3] == 128
75142

76-
# gene_names = dm.labels_df.iloc[total_indxs].gene_name.to_list()
77-
# mapped_labels = np.asarray([new_data_manager.label_int_lut[a] for a in gene_names])
143+
basic_data_manager.final_yx_patch_size = (256, 256)
144+
basic_data_manager.construct_dataloaders(num_workers=1, dataset_type="basic")
145+
batch = next(iter(basic_data_manager.train_loader))
146+
shape_1 = batch["data"].shape
147+
assert shape_1[2] == 256 # changed patch size
148+
assert shape_1[3] == 256
78149

79-
# assert np.all(batch_labels == mapped_labels)
150+
basic_data_manager.final_yx_patch_size = (64, 64)
151+
basic_data_manager.construct_dataloaders(num_workers=1, dataset_type="basic")
152+
batch = next(iter(basic_data_manager.train_loader))
153+
shape_2 = batch["data"].shape
154+
assert shape_2[2] == 64 # changed patch size again
155+
assert shape_2[3] == 64
80156

81-
# return
157+
return

0 commit comments

Comments
 (0)