Skip to content

Commit 2d6f90e

Browse files
author
Alexander Hillsley
committed
Merge branch 'dataloader'
2 parents 128c5f8 + 97079f2 commit 2d6f90e

File tree

2 files changed

+467
-0
lines changed

2 files changed

+467
-0
lines changed

src/ops_model/data/base_dataset.py

Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
import ast
2+
import random
3+
from typing import Callable, List, Literal, Optional
4+
5+
import zarr
6+
import numpy as np
7+
import pandas as pd
8+
import torch
9+
from iohub import open_ome_zarr
10+
11+
from torch.utils.data import Dataset
12+
from monai.transforms import (
13+
CenterSpatialCropd,
14+
Compose,
15+
SpatialPadd,
16+
ToTensord,
17+
)
18+
19+
20+
class BaseDataset(Dataset):
21+
"""
22+
Base PyTorch Dataset for loading and preprocessing microscopy image patches with associated labels.
23+
24+
This dataset handles loading image patches from OME-Zarr stores, applying spatial transformations,
25+
and preparing data for training deep learning models on microscopy data. It supports single or
26+
multiple channel selection, cell masking, and flexible patch sizing.
27+
28+
Attributes:
29+
stores (dict): Dictionary mapping store keys to OME-Zarr store objects.
30+
labels_df (pd.DataFrame): DataFrame containing crop information and labels for each sample.
31+
initial_yx_patch_size (tuple): Initial spatial size (height, width) for cropping patches.
32+
final_yx_patch_size (tuple): Final spatial size after padding/cropping transformations.
33+
out_channels (List[str] | Literal["random"] | Literal["all"]): Channel selection strategy.
34+
label_int_lut (dict): Lookup table mapping gene names to integer labels.
35+
mask_cell (bool): Whether to apply cell segmentation mask to image data.
36+
use_original_crop_size (bool): Whether to use original crop size without padding/cropping.
37+
transform (Compose): MONAI composition of transformations to apply to data.
38+
"""
39+
40+
def __init__(
41+
self,
42+
stores: dict,
43+
labels_df: pd.DataFrame,
44+
initial_yx_patch_size: tuple = (128, 128),
45+
final_yx_patch_size: tuple = (128, 128),
46+
out_channels: List[str] | Literal["random"] | Literal["all"] = "random",
47+
label_int_lut: Optional[dict] = None, # string --> int
48+
mask_cell: bool = True,
49+
use_original_crop_size: bool = False,
50+
):
51+
"""
52+
Initialize the BaseDataset.
53+
54+
Args:
55+
stores (dict): Dictionary mapping store keys to opened OME-Zarr store objects.
56+
labels_df (pd.DataFrame): DataFrame with columns including 'gene_name', 'bbox',
57+
'store_key', 'well', 'segmentation_id', and 'total_index'.
58+
initial_yx_patch_size (tuple, optional): Initial (height, width) size for extracting
59+
patches before transformations. Defaults to (128, 128).
60+
final_yx_patch_size (tuple, optional): Final (height, width) size after padding and
61+
center cropping. Defaults to (128, 128).
62+
out_channels (List[str] | Literal["random"] | Literal["all"], optional): Strategy for
63+
channel selection. "random" selects one random channel per sample, "all" uses all
64+
available channels, or provide a list of specific channel names. Defaults to "random".
65+
label_int_lut (Optional[dict], optional): Lookup table mapping gene names (str) to
66+
integer labels (int). If None, automatically generated from unique gene names in
67+
labels_df. Defaults to None.
68+
mask_cell (bool, optional): If True, multiply image data by segmentation mask to isolate
69+
individual cells. Defaults to True.
70+
use_original_crop_size (bool, optional): If True, skip padding/cropping transformations
71+
and use original bounding box size. Defaults to False.
72+
"""
73+
self.stores = stores
74+
self.labels_df = labels_df
75+
self.initial_yx_patch_size = initial_yx_patch_size
76+
self.final_yx_patch_size = final_yx_patch_size
77+
self.out_channels = out_channels
78+
self.mask_cell = mask_cell
79+
self.use_original_crop_size = use_original_crop_size
80+
if label_int_lut is None:
81+
gene_labels = sorted(self.labels_df["gene_name"].unique())
82+
self.label_int_lut = {gene: i for i, gene in enumerate(gene_labels)}
83+
else:
84+
self.label_int_lut = label_int_lut
85+
86+
if self.use_original_crop_size:
87+
self.transform = Compose(
88+
[
89+
ToTensord(keys=["data", "mask"]),
90+
]
91+
)
92+
else:
93+
self.transform = Compose(
94+
[
95+
SpatialPadd(
96+
keys=["data", "mask"],
97+
spatial_size=self.initial_yx_patch_size,
98+
),
99+
CenterSpatialCropd(
100+
keys=["data", "mask"], roi_size=(self.final_yx_patch_size)
101+
),
102+
ToTensord(
103+
keys=["data", "mask"],
104+
),
105+
]
106+
)
107+
108+
return
109+
110+
def _get_bbox(self, ci, final_shape):
111+
"""
112+
Extract and optionally expand bounding box to match target shape.
113+
114+
Parses the bounding box from crop info and pads it equally on all sides if it's
115+
smaller than the target shape. Padding is distributed symmetrically to keep the
116+
original crop centered.
117+
118+
Args:
119+
ci: Row from labels_df containing crop information with a 'bbox' field.
120+
final_shape (tuple): Target (height, width) dimensions for the bounding box.
121+
122+
Returns:
123+
tuple: Bounding box as (ymin, xmin, ymax, xmax). If use_original_crop_size is True,
124+
returns the original bbox. Otherwise, returns padded bbox matching final_shape.
125+
126+
Note:
127+
The bbox format is (ymin, xmin, ymax, xmax) representing top-left and bottom-right
128+
corners in (y, x) coordinates.
129+
"""
130+
bbox = ast.literal_eval(ci.bbox)
131+
132+
if not self.use_original_crop_size:
133+
134+
if len(final_shape) > 2:
135+
final_shape = final_shape[-2:]
136+
137+
ymin, xmin, ymax, xmax = bbox
138+
target_height, target_width = final_shape
139+
140+
# Calculate current bbox dimensions
141+
current_height = ymax - ymin
142+
current_width = xmax - xmin
143+
144+
# Calculate padding needed
145+
height_padding = max(0, target_height - current_height)
146+
width_padding = max(0, target_width - current_width)
147+
148+
# Distribute padding equally on both sides
149+
pad_top = height_padding / 2
150+
pad_bottom = height_padding / 2
151+
pad_left = width_padding / 2
152+
pad_right = width_padding / 2
153+
154+
# Apply padding
155+
new_ymin = int(ymin - pad_top)
156+
new_ymax = int(ymax + pad_bottom)
157+
new_xmin = int(xmin - pad_left)
158+
new_xmax = int(xmax + pad_right)
159+
bbox = (new_ymin, new_xmin, new_ymax, new_xmax)
160+
161+
return bbox
162+
163+
def _get_channels(self, ci):
164+
"""
165+
Determine which channels to load based on the configured strategy.
166+
167+
Retrieves available channel names from the OME-Zarr metadata and selects channels
168+
according to the out_channels strategy (random, all, or specific channels).
169+
170+
Args:
171+
ci: Row from labels_df containing 'store_key' and 'well' fields to identify
172+
the data source.
173+
174+
Returns:
175+
tuple: A tuple containing:
176+
- channel_names (list): List of channel name strings to load.
177+
- channel_index (list): List of integer indices corresponding to the channels
178+
in the OME-Zarr store.
179+
180+
Note:
181+
If out_channels is "random", selects one random channel per call.
182+
If out_channels is "all", returns all available channels.
183+
Otherwise, uses the specific channel names provided in out_channels.
184+
"""
185+
186+
attrs = self.stores[ci.store_key][ci.well].attrs.asdict()
187+
all_channel_names = [a["label"] for a in attrs["ome"]["omero"]["channels"]]
188+
189+
if self.out_channels == "random":
190+
channel_names = [random.choice(all_channel_names)]
191+
if self.out_channels == "all":
192+
channel_names = all_channel_names
193+
else:
194+
channel_names = self.out_channels
195+
channel_index = [all_channel_names.index(c) for c in channel_names]
196+
197+
return channel_names, channel_index
198+
199+
def __len__(self):
200+
return len(self.labels_df)
201+
202+
def add_labels_to_batch(self, ci):
203+
"""
204+
Extract label information from crop info for the current sample.
205+
206+
Converts gene name to integer label using the lookup table and retrieves
207+
additional metadata for tracking.
208+
209+
Args:
210+
ci: Row from labels_df containing 'gene_name' and 'total_index' fields.
211+
212+
Returns:
213+
tuple: A tuple containing:
214+
- gene_label (int): Integer label for the gene name.
215+
- total_index: Unique identifier for this sample.
216+
- crop_info (dict): Dictionary representation of all crop metadata.
217+
"""
218+
gene_label = self.label_int_lut[ci.gene_name]
219+
total_index = ci.total_index
220+
221+
return gene_label, total_index, ci.to_dict()
222+
223+
def add_mask_to_batch(self, ci, bbox):
224+
"""
225+
Load and extract binary segmentation mask for a specific cell.
226+
227+
Retrieves the segmentation mask from the OME-Zarr store, crops it to the bounding box,
228+
and creates a binary mask for the specific cell identified by segmentation_id.
229+
230+
Args:
231+
ci: Row from labels_df containing 'store_key', 'well', and 'segmentation_id' fields.
232+
bbox (tuple): Bounding box as (ymin, xmin, ymax, xmax) defining the region to extract.
233+
234+
Returns:
235+
np.ndarray: Binary mask of shape (1, height, width) where True indicates pixels
236+
belonging to the target cell (segmentation_id) and False elsewhere.
237+
"""
238+
mask_fov = self.stores[ci.store_key][ci.well]["labels"]["seg"]["0"]
239+
mask = np.asarray(
240+
mask_fov[0:1, :, 0:1, slice(bbox[0], bbox[2]), slice(bbox[1], bbox[3])]
241+
).copy()
242+
mask = np.squeeze(mask)
243+
mask = np.expand_dims(mask, axis=0)
244+
sc_mask = mask == ci.segmentation_id
245+
246+
return sc_mask
247+
248+
def add_data_to_batch(self, ci, bbox, channel_index):
249+
"""
250+
Load and extract image data for specified channels and bounding box.
251+
252+
Retrieves raw microscopy image data from the OME-Zarr store, crops it to the
253+
bounding box region, and extracts the specified channels.
254+
255+
Args:
256+
ci: Row from labels_df containing 'store_key' and 'well' fields to identify
257+
the data source.
258+
bbox (tuple): Bounding box as (ymin, xmin, ymax, xmax) defining the region to extract.
259+
channel_index (list): List of integer indices specifying which channels to load.
260+
261+
Returns:
262+
np.ndarray: Image data as float32 array with shape (C, height, width) where C is
263+
the number of channels. Single channel images are expanded to (1, height, width).
264+
"""
265+
fov = self.stores[ci.store_key][ci.well]["0"]
266+
data = np.asarray(
267+
fov[
268+
0:1,
269+
channel_index,
270+
0:1,
271+
slice(bbox[0], bbox[2]),
272+
slice(bbox[1], bbox[3]),
273+
]
274+
).copy()
275+
data = np.squeeze(data)
276+
if len(data.shape) == 2:
277+
data = np.expand_dims(data, axis=0)
278+
279+
return data.astype(np.float32)
280+
281+
def __getitem__(self, index):
282+
"""
283+
Load and preprocess a single sample from the dataset.
284+
285+
This is the main data loading method called by PyTorch DataLoader. It orchestrates
286+
loading the image patch, segmentation mask, and labels, applies transformations,
287+
and returns a dictionary containing all sample data.
288+
289+
Args:
290+
index (int): Index of the sample to retrieve from labels_df.
291+
292+
Returns:
293+
dict: Dictionary containing:
294+
- 'data' (torch.Tensor): Image data of shape (C, H, W) or (1, C, H, W) if
295+
final_yx_patch_size has 3 dimensions. Optionally masked by cell segmentation.
296+
- 'mask' (torch.Tensor): Binary segmentation mask of shape (1, H, W) or
297+
(1, 1, H, W), with same dimensionality as data.
298+
- 'marker_label' (list): List of channel names loaded for this sample.
299+
- 'gene_label' (int): Integer label for the gene associated with this cell.
300+
- 'total_index' (int): Unique identifier for this sample.
301+
- 'crop_info' (dict): Complete metadata for this crop from labels_df.
302+
303+
Note:
304+
If mask_cell is True, the returned data will be element-wise multiplied by the mask.
305+
Transformations (padding, cropping, tensor conversion) are applied based on the
306+
transform pipeline configured during initialization.
307+
"""
308+
batch = {}
309+
ci = self.labels_df.iloc[index] # crop info
310+
bbox = self._get_bbox(ci, self.initial_yx_patch_size)
311+
c_names, c_index = self._get_channels(ci)
312+
batch["marker_label"] = c_names
313+
314+
gene_label, total_index, crop_info = self.add_labels_to_batch(ci)
315+
batch["gene_label"] = gene_label
316+
batch["total_index"] = int(total_index)
317+
batch["crop_info"] = crop_info
318+
319+
batch["data"] = self.add_data_to_batch(ci, bbox, c_index)
320+
batch["mask"] = self.add_mask_to_batch(ci, bbox)
321+
322+
if self.mask_cell:
323+
batch["data"] = batch["data"] * batch["mask"]
324+
325+
if len(self.final_yx_patch_size) == 3:
326+
batch["data"] = np.expand_dims(batch["data"], axis=0)
327+
batch["mask"] = np.expand_dims(batch["mask"], axis=0)
328+
329+
if self.transform is not None:
330+
batch = self.transform(batch)
331+
332+
return batch

0 commit comments

Comments
 (0)