|
| 1 | +import ast |
| 2 | +import random |
| 3 | +from typing import Callable, List, Literal, Optional |
| 4 | + |
| 5 | +import zarr |
| 6 | +import numpy as np |
| 7 | +import pandas as pd |
| 8 | +import torch |
| 9 | +from iohub import open_ome_zarr |
| 10 | + |
| 11 | +from torch.utils.data import Dataset |
| 12 | +from monai.transforms import ( |
| 13 | + CenterSpatialCropd, |
| 14 | + Compose, |
| 15 | + SpatialPadd, |
| 16 | + ToTensord, |
| 17 | +) |
| 18 | + |
| 19 | + |
| 20 | +class BaseDataset(Dataset): |
| 21 | + """ |
| 22 | + Base PyTorch Dataset for loading and preprocessing microscopy image patches with associated labels. |
| 23 | +
|
| 24 | + This dataset handles loading image patches from OME-Zarr stores, applying spatial transformations, |
| 25 | + and preparing data for training deep learning models on microscopy data. It supports single or |
| 26 | + multiple channel selection, cell masking, and flexible patch sizing. |
| 27 | +
|
| 28 | + Attributes: |
| 29 | + stores (dict): Dictionary mapping store keys to OME-Zarr store objects. |
| 30 | + labels_df (pd.DataFrame): DataFrame containing crop information and labels for each sample. |
| 31 | + initial_yx_patch_size (tuple): Initial spatial size (height, width) for cropping patches. |
| 32 | + final_yx_patch_size (tuple): Final spatial size after padding/cropping transformations. |
| 33 | + out_channels (List[str] | Literal["random"] | Literal["all"]): Channel selection strategy. |
| 34 | + label_int_lut (dict): Lookup table mapping gene names to integer labels. |
| 35 | + mask_cell (bool): Whether to apply cell segmentation mask to image data. |
| 36 | + use_original_crop_size (bool): Whether to use original crop size without padding/cropping. |
| 37 | + transform (Compose): MONAI composition of transformations to apply to data. |
| 38 | + """ |
| 39 | + |
| 40 | + def __init__( |
| 41 | + self, |
| 42 | + stores: dict, |
| 43 | + labels_df: pd.DataFrame, |
| 44 | + initial_yx_patch_size: tuple = (128, 128), |
| 45 | + final_yx_patch_size: tuple = (128, 128), |
| 46 | + out_channels: List[str] | Literal["random"] | Literal["all"] = "random", |
| 47 | + label_int_lut: Optional[dict] = None, # string --> int |
| 48 | + mask_cell: bool = True, |
| 49 | + use_original_crop_size: bool = False, |
| 50 | + ): |
| 51 | + """ |
| 52 | + Initialize the BaseDataset. |
| 53 | +
|
| 54 | + Args: |
| 55 | + stores (dict): Dictionary mapping store keys to opened OME-Zarr store objects. |
| 56 | + labels_df (pd.DataFrame): DataFrame with columns including 'gene_name', 'bbox', |
| 57 | + 'store_key', 'well', 'segmentation_id', and 'total_index'. |
| 58 | + initial_yx_patch_size (tuple, optional): Initial (height, width) size for extracting |
| 59 | + patches before transformations. Defaults to (128, 128). |
| 60 | + final_yx_patch_size (tuple, optional): Final (height, width) size after padding and |
| 61 | + center cropping. Defaults to (128, 128). |
| 62 | + out_channels (List[str] | Literal["random"] | Literal["all"], optional): Strategy for |
| 63 | + channel selection. "random" selects one random channel per sample, "all" uses all |
| 64 | + available channels, or provide a list of specific channel names. Defaults to "random". |
| 65 | + label_int_lut (Optional[dict], optional): Lookup table mapping gene names (str) to |
| 66 | + integer labels (int). If None, automatically generated from unique gene names in |
| 67 | + labels_df. Defaults to None. |
| 68 | + mask_cell (bool, optional): If True, multiply image data by segmentation mask to isolate |
| 69 | + individual cells. Defaults to True. |
| 70 | + use_original_crop_size (bool, optional): If True, skip padding/cropping transformations |
| 71 | + and use original bounding box size. Defaults to False. |
| 72 | + """ |
| 73 | + self.stores = stores |
| 74 | + self.labels_df = labels_df |
| 75 | + self.initial_yx_patch_size = initial_yx_patch_size |
| 76 | + self.final_yx_patch_size = final_yx_patch_size |
| 77 | + self.out_channels = out_channels |
| 78 | + self.mask_cell = mask_cell |
| 79 | + self.use_original_crop_size = use_original_crop_size |
| 80 | + if label_int_lut is None: |
| 81 | + gene_labels = sorted(self.labels_df["gene_name"].unique()) |
| 82 | + self.label_int_lut = {gene: i for i, gene in enumerate(gene_labels)} |
| 83 | + else: |
| 84 | + self.label_int_lut = label_int_lut |
| 85 | + |
| 86 | + if self.use_original_crop_size: |
| 87 | + self.transform = Compose( |
| 88 | + [ |
| 89 | + ToTensord(keys=["data", "mask"]), |
| 90 | + ] |
| 91 | + ) |
| 92 | + else: |
| 93 | + self.transform = Compose( |
| 94 | + [ |
| 95 | + SpatialPadd( |
| 96 | + keys=["data", "mask"], |
| 97 | + spatial_size=self.initial_yx_patch_size, |
| 98 | + ), |
| 99 | + CenterSpatialCropd( |
| 100 | + keys=["data", "mask"], roi_size=(self.final_yx_patch_size) |
| 101 | + ), |
| 102 | + ToTensord( |
| 103 | + keys=["data", "mask"], |
| 104 | + ), |
| 105 | + ] |
| 106 | + ) |
| 107 | + |
| 108 | + return |
| 109 | + |
| 110 | + def _get_bbox(self, ci, final_shape): |
| 111 | + """ |
| 112 | + Extract and optionally expand bounding box to match target shape. |
| 113 | +
|
| 114 | + Parses the bounding box from crop info and pads it equally on all sides if it's |
| 115 | + smaller than the target shape. Padding is distributed symmetrically to keep the |
| 116 | + original crop centered. |
| 117 | +
|
| 118 | + Args: |
| 119 | + ci: Row from labels_df containing crop information with a 'bbox' field. |
| 120 | + final_shape (tuple): Target (height, width) dimensions for the bounding box. |
| 121 | +
|
| 122 | + Returns: |
| 123 | + tuple: Bounding box as (ymin, xmin, ymax, xmax). If use_original_crop_size is True, |
| 124 | + returns the original bbox. Otherwise, returns padded bbox matching final_shape. |
| 125 | +
|
| 126 | + Note: |
| 127 | + The bbox format is (ymin, xmin, ymax, xmax) representing top-left and bottom-right |
| 128 | + corners in (y, x) coordinates. |
| 129 | + """ |
| 130 | + bbox = ast.literal_eval(ci.bbox) |
| 131 | + |
| 132 | + if not self.use_original_crop_size: |
| 133 | + |
| 134 | + if len(final_shape) > 2: |
| 135 | + final_shape = final_shape[-2:] |
| 136 | + |
| 137 | + ymin, xmin, ymax, xmax = bbox |
| 138 | + target_height, target_width = final_shape |
| 139 | + |
| 140 | + # Calculate current bbox dimensions |
| 141 | + current_height = ymax - ymin |
| 142 | + current_width = xmax - xmin |
| 143 | + |
| 144 | + # Calculate padding needed |
| 145 | + height_padding = max(0, target_height - current_height) |
| 146 | + width_padding = max(0, target_width - current_width) |
| 147 | + |
| 148 | + # Distribute padding equally on both sides |
| 149 | + pad_top = height_padding / 2 |
| 150 | + pad_bottom = height_padding / 2 |
| 151 | + pad_left = width_padding / 2 |
| 152 | + pad_right = width_padding / 2 |
| 153 | + |
| 154 | + # Apply padding |
| 155 | + new_ymin = int(ymin - pad_top) |
| 156 | + new_ymax = int(ymax + pad_bottom) |
| 157 | + new_xmin = int(xmin - pad_left) |
| 158 | + new_xmax = int(xmax + pad_right) |
| 159 | + bbox = (new_ymin, new_xmin, new_ymax, new_xmax) |
| 160 | + |
| 161 | + return bbox |
| 162 | + |
| 163 | + def _get_channels(self, ci): |
| 164 | + """ |
| 165 | + Determine which channels to load based on the configured strategy. |
| 166 | +
|
| 167 | + Retrieves available channel names from the OME-Zarr metadata and selects channels |
| 168 | + according to the out_channels strategy (random, all, or specific channels). |
| 169 | +
|
| 170 | + Args: |
| 171 | + ci: Row from labels_df containing 'store_key' and 'well' fields to identify |
| 172 | + the data source. |
| 173 | +
|
| 174 | + Returns: |
| 175 | + tuple: A tuple containing: |
| 176 | + - channel_names (list): List of channel name strings to load. |
| 177 | + - channel_index (list): List of integer indices corresponding to the channels |
| 178 | + in the OME-Zarr store. |
| 179 | +
|
| 180 | + Note: |
| 181 | + If out_channels is "random", selects one random channel per call. |
| 182 | + If out_channels is "all", returns all available channels. |
| 183 | + Otherwise, uses the specific channel names provided in out_channels. |
| 184 | + """ |
| 185 | + |
| 186 | + attrs = self.stores[ci.store_key][ci.well].attrs.asdict() |
| 187 | + all_channel_names = [a["label"] for a in attrs["ome"]["omero"]["channels"]] |
| 188 | + |
| 189 | + if self.out_channels == "random": |
| 190 | + channel_names = [random.choice(all_channel_names)] |
| 191 | + if self.out_channels == "all": |
| 192 | + channel_names = all_channel_names |
| 193 | + else: |
| 194 | + channel_names = self.out_channels |
| 195 | + channel_index = [all_channel_names.index(c) for c in channel_names] |
| 196 | + |
| 197 | + return channel_names, channel_index |
| 198 | + |
| 199 | + def __len__(self): |
| 200 | + return len(self.labels_df) |
| 201 | + |
| 202 | + def add_labels_to_batch(self, ci): |
| 203 | + """ |
| 204 | + Extract label information from crop info for the current sample. |
| 205 | +
|
| 206 | + Converts gene name to integer label using the lookup table and retrieves |
| 207 | + additional metadata for tracking. |
| 208 | +
|
| 209 | + Args: |
| 210 | + ci: Row from labels_df containing 'gene_name' and 'total_index' fields. |
| 211 | +
|
| 212 | + Returns: |
| 213 | + tuple: A tuple containing: |
| 214 | + - gene_label (int): Integer label for the gene name. |
| 215 | + - total_index: Unique identifier for this sample. |
| 216 | + - crop_info (dict): Dictionary representation of all crop metadata. |
| 217 | + """ |
| 218 | + gene_label = self.label_int_lut[ci.gene_name] |
| 219 | + total_index = ci.total_index |
| 220 | + |
| 221 | + return gene_label, total_index, ci.to_dict() |
| 222 | + |
| 223 | + def add_mask_to_batch(self, ci, bbox): |
| 224 | + """ |
| 225 | + Load and extract binary segmentation mask for a specific cell. |
| 226 | +
|
| 227 | + Retrieves the segmentation mask from the OME-Zarr store, crops it to the bounding box, |
| 228 | + and creates a binary mask for the specific cell identified by segmentation_id. |
| 229 | +
|
| 230 | + Args: |
| 231 | + ci: Row from labels_df containing 'store_key', 'well', and 'segmentation_id' fields. |
| 232 | + bbox (tuple): Bounding box as (ymin, xmin, ymax, xmax) defining the region to extract. |
| 233 | +
|
| 234 | + Returns: |
| 235 | + np.ndarray: Binary mask of shape (1, height, width) where True indicates pixels |
| 236 | + belonging to the target cell (segmentation_id) and False elsewhere. |
| 237 | + """ |
| 238 | + mask_fov = self.stores[ci.store_key][ci.well]["labels"]["seg"]["0"] |
| 239 | + mask = np.asarray( |
| 240 | + mask_fov[0:1, :, 0:1, slice(bbox[0], bbox[2]), slice(bbox[1], bbox[3])] |
| 241 | + ).copy() |
| 242 | + mask = np.squeeze(mask) |
| 243 | + mask = np.expand_dims(mask, axis=0) |
| 244 | + sc_mask = mask == ci.segmentation_id |
| 245 | + |
| 246 | + return sc_mask |
| 247 | + |
| 248 | + def add_data_to_batch(self, ci, bbox, channel_index): |
| 249 | + """ |
| 250 | + Load and extract image data for specified channels and bounding box. |
| 251 | +
|
| 252 | + Retrieves raw microscopy image data from the OME-Zarr store, crops it to the |
| 253 | + bounding box region, and extracts the specified channels. |
| 254 | +
|
| 255 | + Args: |
| 256 | + ci: Row from labels_df containing 'store_key' and 'well' fields to identify |
| 257 | + the data source. |
| 258 | + bbox (tuple): Bounding box as (ymin, xmin, ymax, xmax) defining the region to extract. |
| 259 | + channel_index (list): List of integer indices specifying which channels to load. |
| 260 | +
|
| 261 | + Returns: |
| 262 | + np.ndarray: Image data as float32 array with shape (C, height, width) where C is |
| 263 | + the number of channels. Single channel images are expanded to (1, height, width). |
| 264 | + """ |
| 265 | + fov = self.stores[ci.store_key][ci.well]["0"] |
| 266 | + data = np.asarray( |
| 267 | + fov[ |
| 268 | + 0:1, |
| 269 | + channel_index, |
| 270 | + 0:1, |
| 271 | + slice(bbox[0], bbox[2]), |
| 272 | + slice(bbox[1], bbox[3]), |
| 273 | + ] |
| 274 | + ).copy() |
| 275 | + data = np.squeeze(data) |
| 276 | + if len(data.shape) == 2: |
| 277 | + data = np.expand_dims(data, axis=0) |
| 278 | + |
| 279 | + return data.astype(np.float32) |
| 280 | + |
| 281 | + def __getitem__(self, index): |
| 282 | + """ |
| 283 | + Load and preprocess a single sample from the dataset. |
| 284 | +
|
| 285 | + This is the main data loading method called by PyTorch DataLoader. It orchestrates |
| 286 | + loading the image patch, segmentation mask, and labels, applies transformations, |
| 287 | + and returns a dictionary containing all sample data. |
| 288 | +
|
| 289 | + Args: |
| 290 | + index (int): Index of the sample to retrieve from labels_df. |
| 291 | +
|
| 292 | + Returns: |
| 293 | + dict: Dictionary containing: |
| 294 | + - 'data' (torch.Tensor): Image data of shape (C, H, W) or (1, C, H, W) if |
| 295 | + final_yx_patch_size has 3 dimensions. Optionally masked by cell segmentation. |
| 296 | + - 'mask' (torch.Tensor): Binary segmentation mask of shape (1, H, W) or |
| 297 | + (1, 1, H, W), with same dimensionality as data. |
| 298 | + - 'marker_label' (list): List of channel names loaded for this sample. |
| 299 | + - 'gene_label' (int): Integer label for the gene associated with this cell. |
| 300 | + - 'total_index' (int): Unique identifier for this sample. |
| 301 | + - 'crop_info' (dict): Complete metadata for this crop from labels_df. |
| 302 | +
|
| 303 | + Note: |
| 304 | + If mask_cell is True, the returned data will be element-wise multiplied by the mask. |
| 305 | + Transformations (padding, cropping, tensor conversion) are applied based on the |
| 306 | + transform pipeline configured during initialization. |
| 307 | + """ |
| 308 | + batch = {} |
| 309 | + ci = self.labels_df.iloc[index] # crop info |
| 310 | + bbox = self._get_bbox(ci, self.initial_yx_patch_size) |
| 311 | + c_names, c_index = self._get_channels(ci) |
| 312 | + batch["marker_label"] = c_names |
| 313 | + |
| 314 | + gene_label, total_index, crop_info = self.add_labels_to_batch(ci) |
| 315 | + batch["gene_label"] = gene_label |
| 316 | + batch["total_index"] = int(total_index) |
| 317 | + batch["crop_info"] = crop_info |
| 318 | + |
| 319 | + batch["data"] = self.add_data_to_batch(ci, bbox, c_index) |
| 320 | + batch["mask"] = self.add_mask_to_batch(ci, bbox) |
| 321 | + |
| 322 | + if self.mask_cell: |
| 323 | + batch["data"] = batch["data"] * batch["mask"] |
| 324 | + |
| 325 | + if len(self.final_yx_patch_size) == 3: |
| 326 | + batch["data"] = np.expand_dims(batch["data"], axis=0) |
| 327 | + batch["mask"] = np.expand_dims(batch["mask"], axis=0) |
| 328 | + |
| 329 | + if self.transform is not None: |
| 330 | + batch = self.transform(batch) |
| 331 | + |
| 332 | + return batch |
0 commit comments