Skip to content

Move dataloader to Grain#256

Open
tristan-deep wants to merge 20 commits intomainfrom
feature/grain-dataloader
Open

Move dataloader to Grain#256
tristan-deep wants to merge 20 commits intomainfrom
feature/grain-dataloader

Conversation

@tristan-deep
Copy link
Collaborator

@tristan-deep tristan-deep commented Feb 9, 2026

Grain is a library for reading data for training and evaluating JAX models. It’s open source, fast and deterministic.

  • Implement grain dataloader
  • Benchmark dataloaders for common dataloading scenarios
  • Investigate chunking hdf5 @OisinNolan
  • Check if all functionality is preserved w.r.t. the old dataloader
  • Add tests for new dataloader / adjust existing tests
  • Remove old dataloader

I was able to implement most of the functionality of the existing dataloader, making it a drop in replacement for zea.backend.tensorflow.make_dataloader. Simply use zea.Dataloader instead.

I would like to move the dataloader to grain (from Tensorflow previously). There are a few advantages:

  • This will remove the heavy Tensorflow dependency in case you would like to use the dataloader with a different backend
  • Simple, works with numpy operations
  • Works for all backends
  • More modern dataloader
  • With the limited testing I did, it was faster (see benchmark todo)
  • Hopefully resolves issue we found in HDF5 dataloaders using a python generator #255
import zea

dataset_path = "hf://zeahub/camus-sample/val"
dataloader = zea.Dataloader(
    dataset_path,
    key="data/image_sc",
    batch_size=4,
    shuffle=True,
    clip_image_range=[-60, 0],
    image_range=[-60, 0],
    normalization_range=[0, 1],
    image_size=(256, 256),
    resize_type="resize",  # or "center_crop or "random_crop"
    num_threads=16,
    seed=4,
)

for batch in dataloader:
    print("Batch shape:", batch.shape)
    break  # Just show the first batch

fig, _ = zea.visualize.plot_image_grid(batch)
fig.savefig("test.png")

Summary by CodeRabbit

  • New Features

    • Added a framework-agnostic Dataloader for HDF5 datasets, exposed at the package top level.
  • Breaking Changes

    • Removed the legacy TensorFlow-specific dataloader and the old H5Generator-style API.
    • Removed the legacy I/O retry utility and file-info JSON helpers.
  • Documentation

    • Updated example notebooks and docs to use the new Dataloader API.
  • Chores

    • Added Grain >= 0.2 to project dependencies.
  • Tests

    • Updated and expanded tests to cover the new dataloader and data-source behaviors.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 9, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

Walkthrough

Replaces the TensorFlow-specific HDF5 dataloader and retry utilities with a Grain-backed Dataloader and H5DataSource; updates notebooks, tests, package exports, and docs; removes legacy TensorFlow dataloader files; adds grain >= 0.2 to dependencies.

Changes

Cohort / File(s) Summary
Notebook API Migration
docs/source/notebooks/data/zea_data_example.ipynb, docs/source/notebooks/metrics/lpips_example.ipynb, docs/source/notebooks/metrics/myocardial_quality_example.ipynb, docs/source/notebooks/models/hvae_model_example.ipynb, docs/source/notebooks/models/left_ventricle_segmentation_example.ipynb, docs/source/notebooks/models/taesd_autoencoder_example.ipynb, docs/source/notebooks/models/unet_example.ipynb
Replaced backend-specific make_dataloader imports/calls with from zea.data.dataloader import Dataloader and Dataloader(...); updated notebook text/headings to reference Dataloader.
Core Dataloader Refactor
zea/data/dataloader.py
Removed old generator-based implementation; added H5DataSource with thread-local HDF5 handle cache and inline reopen-on-I/O-error; reimplemented Dataloader as a Grain MapDataset pipeline with new API surface (dataset, to_iter_dataset(), close(), epoch-aware __iter__, __len__, improved __repr__, normalization/resize/clip/assert options).
Public API Update
zea/__init__.py
Re-exported Dataloader at package top level (from .data.dataloader import Dataloader).
TensorFlow Backend Cleanup
zea/backend/tensorflow/dataloader.py, zea/backend/tensorflow/utils/callbacks.py, zea/backend/__init__.py, zea/backend/tensorflow/__init__.py
Deleted TensorFlow-specific dataloader and placeholder callbacks file; removed make_dataloader re-export and adjusted backend module docstring/unused imports.
Utility Removals
zea/data/utils.py, zea/internal/utils.py, zea/io_lib.py
Deleted JSON encode/decode and file-info helpers, removed internal method-finding utility, and removed retry_on_io_error decorator (and its related imports).
Tests & Docs Updates
tests/data/test_dataloader.py, tests/test_io_lib.py, docs/source/notebooks/metrics.rst
Rewrote tests to exercise H5DataSource/Dataloader (new unit cases for caching, skipping, seeding, validation, normalization, summary); removed tests for retry_on_io_error; fixed RST title underline.
Datasets small change
zea/data/datasets.py
Added H5FileHandleCache.pop(file_path) to evict and close cached H5 handles (exceptions during close are swallowed).
Dependency
pyproject.toml
Added dependency: grain >= 0.2.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested labels

enhancement

Suggested reviewers

  • swpenninga
🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 79.17% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Move dataloader to Grain' directly and clearly describes the primary change: migrating from a TensorFlow-based dataloader to a Grain-based implementation.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feature/grain-dataloader
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@tristan-deep tristan-deep added efficiency Improvements made regarding code or tests efficiency data format Related to the zea data format saving and loading labels Feb 9, 2026
@tristan-deep tristan-deep added this to the v0.0.11 milestone Feb 9, 2026
@codecov
Copy link

codecov bot commented Feb 9, 2026

Codecov Report

❌ Patch coverage is 88.14229% with 30 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
zea/data/dataloader.py 84.71% 19 Missing and 5 partials ⚠️
zea/data/datasets.py 14.28% 6 Missing ⚠️

📢 Thoughts on this report? Let us know!

@swpenninga
Copy link
Contributor

Currently working on the new grain backend.
@tristan-deep I understand the dataloader is fully np based, however I was wondering about the current _numpy_translate() and _numpy_resize().

These are expensive operations, is it not better to do the bare minimum I/O on cpu -> off-load to GPU (working on prefetching) -> do these transformations with keras ops?

@tristan-deep
Copy link
Collaborator Author

Currently working on the new grain backend. @tristan-deep I understand the dataloader is fully np based, however I was wondering about the current _numpy_translate() and _numpy_resize().

These are expensive operations, is it not better to do the bare minimum I/O on cpu -> off-load to GPU (working on prefetching) -> do these transformations with keras ops?

That would make a lot of sense actually if that is possible in an efficient way! I did notice indeed the operations you mention take a lot of time in the current implementation.

@swpenninga
Copy link
Contributor

Status update:

  • There is now a wrapper around the Dataloader that prefetches to device, that works with every backend.

  • The expensive numpy transformations are removed and done on GPU with keras (can still be numpy + cpu if that is the backend).

  • Currently I'm trying to reformat the new grain to have the exact same formatting as the old tensorflow dataloader (such that test_dataloader.py passes 100% with just an import change.

  1. return_filename, resize_axes, assert_image_range and cache were trivially added.
  2. all shuffle tests fail, since grain uses a seed that makes the shuffle reproducible (design choice).
  3. Big compatibility issue: many ultrasound datasets (the way we stored them at least) have different spatial shapes (Height, width, depth) per sample. This is incompatible with batched resizing operations on GPU. We will need to choose.
    1: resize on cpu -> move to gpu -> batch (slow).
    2: store all datasets with same spatial shapes -> move to gpu, batch, resize if necessary -> fast but loses freedom

@tristan-deep
Copy link
Collaborator Author

JAX has two tutorials, for both CPU and GPU based dataloaders using grain.. See link.

@tristan-deep tristan-deep modified the milestones: v0.0.11, v0.0.12 Mar 3, 2026
@tristan-deep tristan-deep mentioned this pull request Mar 3, 2026
@swpenninga
Copy link
Contributor

Update of the last commit:

  • I removed the GPU dataloader transformations, in hindsight this seems pointless
  • I made sure all functionality is preserved w.r.t. the make_dataloader that uses H5GeneratorTF
  • Replaced the dataloader in the old tests with the new one
  • Replaced the dataloader in all the notebooks with the new one
  • All tests pass now

I would say that the grain loader is fully functional now, the only things left are:
investigation of hdf5 chunking, benchmarking, removal of the old dataloader

@wesselvannierop
Copy link
Collaborator

@swpenninga is https://github.com/tue-bmd/zea/blob/795ce21bf4283c71ad35e5d60aa9c755298cb754/zea/data/utils.py still needed? I believe this was related to tf.data compatibility

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
docs/source/notebooks/data/zea_data_example.ipynb (1)

472-472: ⚠️ Potential issue | 🟡 Minor

Fix quote typo in resize_type options comment.

Line 472’s inline comment has mismatched quotes, which makes the option list harder to read.

🛠️ Proposed fix
-    resize_type="resize",  # or "center_crop or "random_crop"
+    resize_type="resize",  # or "center_crop" or "random_crop"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/source/notebooks/data/zea_data_example.ipynb` at line 472, The inline
comment for the parameter resize_type contains mismatched quotes; update the
comment near the resize_type usage (look for the string "resize_type" in the
notebook cell) so the options are consistently quoted, e.g.
resize_type="resize",  # or "center_crop" or "random_crop" (ensure all option
names have matching double quotes and proper commas/spacing).
🧹 Nitpick comments (4)
docs/source/notebooks/metrics/myocardial_quality_example.ipynb (1)

65-65: Prefer importing Dataloader from the public zea API in docs.

This keeps the example tied to the stable, user-facing surface instead of an internal module path.

♻️ Proposed change
-from zea.data.dataloader import Dataloader
+from zea import Dataloader
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/source/notebooks/metrics/myocardial_quality_example.ipynb` at line 65,
Update the notebook import to use the public API: replace the internal import
"from zea.data.dataloader import Dataloader" with the public surface import
"from zea import Dataloader" (referencing the Dataloader symbol) so examples
rely on stable, user-facing module paths rather than internal module paths.
docs/source/notebooks/models/left_ventricle_segmentation_example.ipynb (1)

73-73: Prefer the public zea.Dataloader import in user-facing notebooks.

Using the top-level export keeps docs aligned with the stable public API and reduces coupling to internal module paths.

♻️ Proposed change
-from zea.data.dataloader import Dataloader
+from zea import Dataloader
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/source/notebooks/models/left_ventricle_segmentation_example.ipynb` at
line 73, The notebook imports Dataloader from an internal path ("from
zea.data.dataloader import Dataloader"); change this to use the public top-level
export by importing Dataloader from zea (e.g., replace that import with "from
zea import Dataloader") so docs use the stable public API and avoid coupling to
internal module paths; update any cells that reference Dataloader import to use
the new import symbol.
docs/source/notebooks/metrics/lpips_example.ipynb (1)

81-81: Prefer the public zea API import in docs.

At Line 81, importing Dataloader from zea.data.dataloader couples the notebook to an internal module path. Since Dataloader is exported at the top level, docs should use the public surface for stability and clarity.

♻️ Proposed fix
-    "from zea.data.dataloader import Dataloader\n",
+    "from zea import Dataloader\n",
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/source/notebooks/metrics/lpips_example.ipynb` at line 81, Import uses
internal path zea.data.dataloader; change to the public API by importing
Dataloader from the top-level zea package (use "from zea import Dataloader") so
the notebook relies on the stable public surface rather than an internal module
path; update the import statement that currently references Dataloader in the
notebook to the top-level export.
docs/source/notebooks/data/zea_data_example.ipynb (1)

12-12: Consider using the public zea.Dataloader API for notebook examples.

Dataloader is exported at the top level (zea.Dataloader), but the notebook currently imports from the internal module path (zea.data.dataloader.Dataloader). Using the public API provides a more stable entrypoint and better protects against future refactors. This applies to lines 12, 94, and 383.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/source/notebooks/data/zea_data_example.ipynb` at line 12, The notebook
imports the internal class zea.data.dataloader.Dataloader; replace those imports
with the public top-level export zea.Dataloader to use the stable API. Update
any import statements and references at the three locations mentioned (the cells
around lines 12, 94, and 383) so they import or reference Dataloader via
zea.Dataloader instead of zea.data.dataloader.Dataloader, ensuring all usages
(constructor calls, type hints, and doc references) point to the top-level
Dataloader symbol.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docs/source/notebooks/data/zea_data_example.ipynb`:
- Line 383: Update the explanatory sentence about dataset consistency: instead
of implying CAMUS fully solves consistency, state that
zea.data.dataloader.Dataloader works best with consistent sample shapes but some
public datasets (e.g., PICMUS) are highly variable and even CAMUS can contain
variable shapes; advise users they may still see variable-shape warnings and
should apply resizing/cropping or other preprocessing to obtain uniform shapes
before using Dataloader. Reference the class name zea.data.dataloader.Dataloader
and datasets PICMUS and CAMUS when rephrasing.

In `@zea/data/dataloader.py`:
- Line 259: The issue is that negative indices in additional_axes_iter are being
used before normalization, so the subsequent "initial" axis correction and
np.moveaxis calls operate on incorrect axes; update the Dataloader
initialization and the code paths around lines handling additional_axes_iter
(e.g., the attribute self.additional_axes_iter and the logic that adjusts
initial and calls np.moveaxis) to first normalize additional_axes_iter into
non-negative axes relative to the current ndim (e.g., map each axis a to a if
a>=0 else a+ndim) before performing the "initial" correction and any np.moveaxis
operations so negative values don't shift the wrong dimensions; apply the same
normalization in the other spot referenced (the block around the later
adjustments at ~370-373) so both code paths use normalized axes consistently.
- Line 435: The docstring stating "dataset_repetitions: Repeat dataset N times
(``None`` = infinite)" is wrong relative to the implementation; either update
the docstring to state that None yields a single-pass, or change the runtime
guard so that when dataset_repetitions is None the dataset is passed through
repeat() for infinite repetition. Locate the parameter named dataset_repetitions
in the DataLoader (or loader factory) function and either adjust the textual
description to "None = single-pass" or modify the conditional around repeat() to
call dataset.repeat() when dataset_repetitions is None (and keep the existing
behavior when an integer is provided).
- Around line 347-351: The current _get_cache() creates a per-thread
H5FileHandleCache in threading.local() so Dataloader.close() only sees the
caller thread's cache; change _get_cache() to register each created
H5FileHandleCache in a shared collection (e.g. a thread-safe weakref.WeakSet or
dict on the Dataloader instance) so that close() (and any other cleanup paths)
can iterate all registered caches and call their close() from the main thread;
implement the shared registry on the Dataloader (add e.g. self._all_caches and a
Lock), update _get_cache() to add new caches to that registry, and change
Dataloader.close() to iterate self._all_caches and close every H5FileHandleCache
to ensure handles opened by worker threads are released.
- Around line 286-288: The current truthy check on limit_n_samples incorrectly
treats 0 as "no limit"; update the conditional around self.indices slicing (the
block using limit_n_samples and self.indices in H5DataSource) to explicitly
check for None (e.g., if limit_n_samples is not None) and validate it (ensure
it's >= 0 and clamp to len(self.indices)); then slice self.indices =
self.indices[:limit_n_samples] so limit_n_samples=0 produces an empty datasource
while negative values raise/are handled appropriately.

---

Outside diff comments:
In `@docs/source/notebooks/data/zea_data_example.ipynb`:
- Line 472: The inline comment for the parameter resize_type contains mismatched
quotes; update the comment near the resize_type usage (look for the string
"resize_type" in the notebook cell) so the options are consistently quoted, e.g.
resize_type="resize",  # or "center_crop" or "random_crop" (ensure all option
names have matching double quotes and proper commas/spacing).

---

Nitpick comments:
In `@docs/source/notebooks/data/zea_data_example.ipynb`:
- Line 12: The notebook imports the internal class
zea.data.dataloader.Dataloader; replace those imports with the public top-level
export zea.Dataloader to use the stable API. Update any import statements and
references at the three locations mentioned (the cells around lines 12, 94, and
383) so they import or reference Dataloader via zea.Dataloader instead of
zea.data.dataloader.Dataloader, ensuring all usages (constructor calls, type
hints, and doc references) point to the top-level Dataloader symbol.

In `@docs/source/notebooks/metrics/lpips_example.ipynb`:
- Line 81: Import uses internal path zea.data.dataloader; change to the public
API by importing Dataloader from the top-level zea package (use "from zea import
Dataloader") so the notebook relies on the stable public surface rather than an
internal module path; update the import statement that currently references
Dataloader in the notebook to the top-level export.

In `@docs/source/notebooks/metrics/myocardial_quality_example.ipynb`:
- Line 65: Update the notebook import to use the public API: replace the
internal import "from zea.data.dataloader import Dataloader" with the public
surface import "from zea import Dataloader" (referencing the Dataloader symbol)
so examples rely on stable, user-facing module paths rather than internal module
paths.

In `@docs/source/notebooks/models/left_ventricle_segmentation_example.ipynb`:
- Line 73: The notebook imports Dataloader from an internal path ("from
zea.data.dataloader import Dataloader"); change this to use the public top-level
export by importing Dataloader from zea (e.g., replace that import with "from
zea import Dataloader") so docs use the stable public API and avoid coupling to
internal module paths; update any cells that reference Dataloader import to use
the new import symbol.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: b8faaf8e-6ddb-45ec-8899-1e00bfb39393

📥 Commits

Reviewing files that changed from the base of the PR and between 4b83401 and 9bcb82f.

⛔ Files ignored due to path filters (1)
  • poetry.lock is excluded by !**/*.lock
📒 Files selected for processing (19)
  • docs/source/notebooks/data/zea_data_example.ipynb
  • docs/source/notebooks/metrics.rst
  • docs/source/notebooks/metrics/lpips_example.ipynb
  • docs/source/notebooks/metrics/myocardial_quality_example.ipynb
  • docs/source/notebooks/models/hvae_model_example.ipynb
  • docs/source/notebooks/models/left_ventricle_segmentation_example.ipynb
  • docs/source/notebooks/models/taesd_autoencoder_example.ipynb
  • docs/source/notebooks/models/unet_example.ipynb
  • pyproject.toml
  • tests/data/test_dataloader.py
  • zea/__init__.py
  • zea/backend/__init__.py
  • zea/backend/tensorflow/__init__.py
  • zea/backend/tensorflow/dataloader.py
  • zea/backend/tensorflow/utils/callbacks.py
  • zea/data/dataloader.py
  • zea/data/utils.py
  • zea/internal/utils.py
  • zea/io_lib.py
💤 Files with no reviewable changes (6)
  • zea/backend/tensorflow/utils/callbacks.py
  • zea/backend/tensorflow/init.py
  • zea/io_lib.py
  • zea/data/utils.py
  • zea/internal/utils.py
  • zea/backend/tensorflow/dataloader.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
zea/data/dataloader.py (3)

234-245: ⚠️ Potential issue | 🟠 Major

Normalize negative axes before the frame-axis correction.

additional_axes_iter=(-2,) or initial_frame_axis=-1 still breaks the initial -= ... adjustment, so np.moveaxis() can target the wrong dimension and corrupt sample layout. Normalize both axes against the file rank once in __init__ and reuse those normalized values everywhere below.

Suggested fix
         self.file_paths = _dataset.file_paths
         self.file_shapes = _dataset.load_file_shapes(key)
         _dataset.close()
+        rank = len(self.file_shapes[0])
+        self.initial_frame_axis = map_negative_indices([self.initial_frame_axis], rank)[0]
+        self.additional_axes_iter = list(
+            map_negative_indices(self.additional_axes_iter, rank)
+        )

         # Compute per-sample index table
         self.indices = generate_h5_indices(

Also applies to: 247-259, 267-269, 345-348

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@zea/data/dataloader.py` around lines 234 - 245, Normalize any negative axis
indices early in the DataLoader initialization: compute the file rank from the
discovered file shapes (self.file_shapes) and convert self.additional_axes_iter
and self.initial_frame_axis to non-negative indices (e.g., by adding rank to
negative values) immediately after loading self.file_shapes and before
performing the frame-axis correction (the "initial -= ..." adjustment) so
np.moveaxis() targets the correct dimensions; update all usages of
additional_axes_iter and initial_frame_axis (including the logic in the blocks
referenced around lines 247-259, 267-269, and 345-348) to use these normalized
values consistently throughout the class.

280-282: ⚠️ Potential issue | 🟠 Major

close() still misses HDF5 handles opened on worker threads.

Each Grain reader thread creates its own H5FileHandleCache in threading.local(), but close() can only see the cache on the caller thread. Dataloader.close() therefore leaves worker-thread file handles alive until thread teardown or GC.

Suggested fix
         # Thread-local file handle caches (one per thread)
         self._local = threading.local()
+        self._all_caches = []
+        self._all_caches_lock = threading.Lock()
@@
     def _get_cache(self) -> H5FileHandleCache:
         """Return the file-handle cache for the current thread."""
         if not hasattr(self._local, "cache"):
-            self._local.cache = H5FileHandleCache()
+            cache = H5FileHandleCache()
+            self._local.cache = cache
+            with self._all_caches_lock:
+                self._all_caches.append(cache)
         return self._local.cache
@@
     def close(self):
-        """Close file handles for the current thread."""
-        cache = getattr(self._local, "cache", None)
-        if cache is not None:
-            cache.close()
+        """Close file handles for all threads."""
+        with self._all_caches_lock:
+            caches, self._all_caches = self._all_caches, []
+        for cache in caches:
+            cache.close()

Also applies to: 322-326, 354-358

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@zea/data/dataloader.py` around lines 280 - 282, The thread-local HDF5 caches
stored in self._local are only visible to the creating thread, so
Dataloader.close() misses handles opened by worker threads; fix by registering
each per-thread H5FileHandleCache in a shared, thread-safe registry when it's
created (e.g., add the cache to a self._h5_caches set protected by a lock inside
the code path that creates the cache in threading.local), use weakrefs if needed
to avoid preventing GC, and then update Dataloader.close() to iterate over that
shared registry and call the H5FileHandleCache.close() for every registered
cache (and clear the registry); reference symbols: self._local,
H5FileHandleCache, and close().

261-263: ⚠️ Potential issue | 🟡 Minor

Treat limit_n_samples=0 as empty instead of “no limit”.

The truthy guard skips slicing for 0, and negative values silently trim from the end. Use an explicit is not None check here.

Suggested fix
-        if limit_n_samples:
+        if limit_n_samples is not None:
+            if limit_n_samples < 0:
+                raise ValueError("limit_n_samples must be >= 0")
             log.info(f"H5DataSource: Limiting to {limit_n_samples} / {len(self.indices)} samples.")
             self.indices = self.indices[:limit_n_samples]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@zea/data/dataloader.py` around lines 261 - 263, The guard `if
limit_n_samples:` treats 0 as "no limit" and allows negative values to slice
from the end; update the check in the H5DataSource handling (the block that logs
and sets self.indices = self.indices[:limit_n_samples]) to `if limit_n_samples
is not None:` and before slicing coerce/clamp the value (e.g., n = max(0,
int(limit_n_samples))) so slicing uses a non-negative integer, then log using n
and assign `self.indices = self.indices[:n]`.
🧹 Nitpick comments (2)
tests/test_io_lib.py (1)

11-12: Remove unused constants.

MAX_RETRIES and INITIAL_DELAY are no longer used after the retry_on_io_error tests were removed.

🧹 Proposed fix
 from . import DEFAULT_TEST_SEED
 
-MAX_RETRIES = 3
-INITIAL_DELAY = 0.01
-
 
 `@pytest.fixture`
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_io_lib.py` around lines 11 - 12, Remove the now-unused constants
MAX_RETRIES and INITIAL_DELAY from tests/test_io_lib.py: delete their
declarations and ensure there are no remaining references to MAX_RETRIES or
INITIAL_DELAY elsewhere in the file; if any tests or helpers still reference
them, update those usages or remove them accordingly and run the test suite to
confirm no references remain.
tests/data/test_dataloader.py (1)

397-414: Add a negative-axis regression to this N-D test.

This only exercises positive additional_axes_iter, so accepted inputs like (-2,) or initial_frame_axis=-1 can still regress without this test failing. One negative-axis parametrization here would lock in the normalization fix.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/data/test_dataloader.py` around lines 397 - 414, The test instantiates
Dataloader but only covers positive additional_axes_iter and positive axis
params; add a negative-axis regression by adding at least one parametrization
that uses negative axes (e.g., additional_axes_iter=(-2,) or
initial_frame_axis=-1 or frame_axis=-1) to the test case so the code paths
handling negative indexing in Dataloader (constructor handling of
additional_axes_iter, initial_frame_axis, frame_axis, and the normalization
logic) are exercised and prevent regressions; update the test parametrization
for the dataset instantiation to include that negative-axis variant and keep the
same assertions so the normalization fix is locked in.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@zea/data/dataloader.py`:
- Around line 612-615: The __iter__ currently mutates the private
self._shuffle_node._seed (undocumented Grain internals); instead avoid touching
_seed and rebuild a fresh shuffled dataset per epoch by calling
MapDataset.shuffle(...) with a new seed inside __iter__ (replace usage of
self._shuffle_node mutation with creating a new shuffle node from the base
dataset using MapDataset.shuffle(seed=...)), or alternatively constrain Grain in
pyproject.toml to a pinned compatible range; update references in the __iter__
implementation to use the new shuffle node (and remove direct accesses to _seed)
and ensure the RNG provides the seed for MapDataset.shuffle().

---

Duplicate comments:
In `@zea/data/dataloader.py`:
- Around line 234-245: Normalize any negative axis indices early in the
DataLoader initialization: compute the file rank from the discovered file shapes
(self.file_shapes) and convert self.additional_axes_iter and
self.initial_frame_axis to non-negative indices (e.g., by adding rank to
negative values) immediately after loading self.file_shapes and before
performing the frame-axis correction (the "initial -= ..." adjustment) so
np.moveaxis() targets the correct dimensions; update all usages of
additional_axes_iter and initial_frame_axis (including the logic in the blocks
referenced around lines 247-259, 267-269, and 345-348) to use these normalized
values consistently throughout the class.
- Around line 280-282: The thread-local HDF5 caches stored in self._local are
only visible to the creating thread, so Dataloader.close() misses handles opened
by worker threads; fix by registering each per-thread H5FileHandleCache in a
shared, thread-safe registry when it's created (e.g., add the cache to a
self._h5_caches set protected by a lock inside the code path that creates the
cache in threading.local), use weakrefs if needed to avoid preventing GC, and
then update Dataloader.close() to iterate over that shared registry and call the
H5FileHandleCache.close() for every registered cache (and clear the registry);
reference symbols: self._local, H5FileHandleCache, and close().
- Around line 261-263: The guard `if limit_n_samples:` treats 0 as "no limit"
and allows negative values to slice from the end; update the check in the
H5DataSource handling (the block that logs and sets self.indices =
self.indices[:limit_n_samples]) to `if limit_n_samples is not None:` and before
slicing coerce/clamp the value (e.g., n = max(0, int(limit_n_samples))) so
slicing uses a non-negative integer, then log using n and assign `self.indices =
self.indices[:n]`.

---

Nitpick comments:
In `@tests/data/test_dataloader.py`:
- Around line 397-414: The test instantiates Dataloader but only covers positive
additional_axes_iter and positive axis params; add a negative-axis regression by
adding at least one parametrization that uses negative axes (e.g.,
additional_axes_iter=(-2,) or initial_frame_axis=-1 or frame_axis=-1) to the
test case so the code paths handling negative indexing in Dataloader
(constructor handling of additional_axes_iter, initial_frame_axis, frame_axis,
and the normalization logic) are exercised and prevent regressions; update the
test parametrization for the dataset instantiation to include that negative-axis
variant and keep the same assertions so the normalization fix is locked in.

In `@tests/test_io_lib.py`:
- Around line 11-12: Remove the now-unused constants MAX_RETRIES and
INITIAL_DELAY from tests/test_io_lib.py: delete their declarations and ensure
there are no remaining references to MAX_RETRIES or INITIAL_DELAY elsewhere in
the file; if any tests or helpers still reference them, update those usages or
remove them accordingly and run the test suite to confirm no references remain.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 5c52433d-f8ce-4650-8bec-a7789f72c699

📥 Commits

Reviewing files that changed from the base of the PR and between 9bcb82f and 3bbdaed.

📒 Files selected for processing (4)
  • tests/data/test_dataloader.py
  • tests/test_io_lib.py
  • zea/data/__init__.py
  • zea/data/dataloader.py
💤 Files with no reviewable changes (1)
  • zea/data/init.py

@swpenninga
Copy link
Contributor

swpenninga commented Mar 20, 2026

@swpenninga is https://github.com/tue-bmd/zea/blob/795ce21bf4283c71ad35e5d60aa9c755298cb754/zea/data/utils.py still needed? I believe this was related to tf.data compatibility

@wesselvannierop
I see that json_dumps() is still being used in the dataloader getitem, and json_loads is used in the dataloader test once, I dont know if we still need this, i thought we didnt use json

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/data/test_dataloader.py (1)

423-435: ⚠️ Potential issue | 🟡 Minor

Comment contradicts the code: return_centers=True vs "must not return centers".

Line 423 states the augmentation "must not return centers," but return_centers=True is set on line 430. Either update the comment to reflect the actual behavior, or set return_centers=False if centers should not be returned.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/data/test_dataloader.py` around lines 423 - 435, The comment and code
disagree: the test's augmentation (variable augmentation using
RandomCircleInclusion) is documented as "must not return centers" but the
constructor sets return_centers=True; either change the parameter on
RandomCircleInclusion to return_centers=False to match the comment, or edit the
comment to state that centers are returned; update the
RandomCircleInclusion(...) call or the preceding comment accordingly so the
comment and the return_centers argument are consistent.
🧹 Nitpick comments (4)
tests/data/test_dataloader.py (4)

485-487: Prefer keras.ops.convert_to_numpy() for multi-backend compatibility.

Same issue as in test_random_circle_inclusion_augmentation. Based on learnings, use keras.ops.convert_to_numpy(images) for Torch backend compatibility.

♻️ Suggested fix
-    images_np = np.array(images)
+    images_np = keras.ops.convert_to_numpy(images)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/data/test_dataloader.py` around lines 485 - 487, The test currently
converts a batch via numpy conversion using np.array(images), which breaks
multi-backend compatibility for PyTorch; replace that conversion by calling
keras.ops.convert_to_numpy(images) so the test uses the framework-agnostic
helper (locate the lines where images = next(iter(dataset)) and images_np =
np.array(images) and change the latter to use
keras.ops.convert_to_numpy(images)). Ensure you import or reference keras.ops if
not already available in the test file.

495-509: Test name mentions "warning" but doesn't verify warning emission.

The test is named test_skipped_files_warning but only asserts len(source) == 0. If a warning is expected when files are skipped, use pytest.warns() to verify:

with pytest.warns(UserWarning, match="skipped"):
    source = H5DataSource(...)

If no warning is expected, consider renaming to test_skipped_files_insufficient_frames.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/data/test_dataloader.py` around lines 495 - 509, The test
test_skipped_files_warning currently only asserts len(source) == 0 but its name
implies a warning should be emitted; wrap the H5DataSource instantiation in a
pytest.warns(UserWarning, match="skipped") context to assert the expected
warning is emitted (referencing H5DataSource in the test), or if no warning
should be produced, rename the test to test_skipped_files_insufficient_frames to
reflect behavior; update the assertion accordingly and keep the check for
len(source) == 0.

450-452: Prefer keras.ops.convert_to_numpy() for multi-backend compatibility.

Using np.array(images) may fail with the Torch backend. Based on learnings: "In tests, when comparing Keras tensors with NumPy operations, convert tensors to NumPy arrays first using keras.ops.convert_to_numpy() to ensure multi-backend compatibility."

♻️ Suggested fix
-    images_np = np.array(images)
+    images_np = keras.ops.convert_to_numpy(images)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/data/test_dataloader.py` around lines 450 - 452, Replace the direct
NumPy conversion of the batch with a backend-safe conversion: instead of using
np.array(images) in the test (where images is produced by iter(dataset)), call
keras.ops.convert_to_numpy(images) and assign that to images_np so the test
works with both TensorFlow and Torch backends; update the reference to images_np
wherever used afterwards.

382-417: Consider adding shape assertions for n-dimensional dataset test.

The test only verifies that iteration doesn't raise an exception. Adding assertions on the output shape would strengthen the test coverage:

batch = next(iter(dataset))
assert batch.shape[:1] == (batch_size,), "Batch dimension mismatch"
assert batch.shape[-3:-1] == image_size, "Spatial dimensions mismatch"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/data/test_dataloader.py` around lines 382 - 417, The test
test_ndim_hdf5_dataset currently only ensures iteration doesn't raise; modify it
to capture the returned batch from next(iter(dataset)) and assert expected
shapes using batch_size and image_size: check the leading batch dimension equals
batch_size and verify the spatial dimensions (using image_size) appear at the
expected positions in the tensor returned by Dataloader; reference the
Dataloader invocation and variables batch_size and image_size to locate where to
add the assertions.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@tests/data/test_dataloader.py`:
- Around line 423-435: The comment and code disagree: the test's augmentation
(variable augmentation using RandomCircleInclusion) is documented as "must not
return centers" but the constructor sets return_centers=True; either change the
parameter on RandomCircleInclusion to return_centers=False to match the comment,
or edit the comment to state that centers are returned; update the
RandomCircleInclusion(...) call or the preceding comment accordingly so the
comment and the return_centers argument are consistent.

---

Nitpick comments:
In `@tests/data/test_dataloader.py`:
- Around line 485-487: The test currently converts a batch via numpy conversion
using np.array(images), which breaks multi-backend compatibility for PyTorch;
replace that conversion by calling keras.ops.convert_to_numpy(images) so the
test uses the framework-agnostic helper (locate the lines where images =
next(iter(dataset)) and images_np = np.array(images) and change the latter to
use keras.ops.convert_to_numpy(images)). Ensure you import or reference
keras.ops if not already available in the test file.
- Around line 495-509: The test test_skipped_files_warning currently only
asserts len(source) == 0 but its name implies a warning should be emitted; wrap
the H5DataSource instantiation in a pytest.warns(UserWarning, match="skipped")
context to assert the expected warning is emitted (referencing H5DataSource in
the test), or if no warning should be produced, rename the test to
test_skipped_files_insufficient_frames to reflect behavior; update the assertion
accordingly and keep the check for len(source) == 0.
- Around line 450-452: Replace the direct NumPy conversion of the batch with a
backend-safe conversion: instead of using np.array(images) in the test (where
images is produced by iter(dataset)), call keras.ops.convert_to_numpy(images)
and assign that to images_np so the test works with both TensorFlow and Torch
backends; update the reference to images_np wherever used afterwards.
- Around line 382-417: The test test_ndim_hdf5_dataset currently only ensures
iteration doesn't raise; modify it to capture the returned batch from
next(iter(dataset)) and assert expected shapes using batch_size and image_size:
check the leading batch dimension equals batch_size and verify the spatial
dimensions (using image_size) appear at the expected positions in the tensor
returned by Dataloader; reference the Dataloader invocation and variables
batch_size and image_size to locate where to add the assertions.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: a25ab36d-fafa-457e-8b2e-4a3f30a48a7d

📥 Commits

Reviewing files that changed from the base of the PR and between 1ae5174 and 7e301c8.

📒 Files selected for processing (1)
  • tests/data/test_dataloader.py

swpenninga
swpenninga previously approved these changes Mar 20, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
zea/data/dataloader.py (1)

96-97: ⚠️ Potential issue | 🟡 Minor

limit_n_frames=0 currently disables limiting

Line 96 treats 0 as falsy and resets to np.inf, which silently returns full data instead of zero frames.

Proposed fix
-    if not limit_n_frames:
-        limit_n_frames = np.inf
+    if limit_n_frames is None:
+        limit_n_frames = np.inf
+    else:
+        assert limit_n_frames >= 0, "`limit_n_frames` must be >= 0"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@zea/data/dataloader.py` around lines 96 - 97, The current falsy check "if not
limit_n_frames:" treats 0 as "no limit" and returns all frames; change the logic
to only treat an absent parameter as unlimited by replacing that check with an
explicit None check (e.g., "if limit_n_frames is None: limit_n_frames = np.inf")
in the dataloader code path where limit_n_frames is used so that
limit_n_frames==0 correctly means zero frames; update any function signature
default to None if necessary.
♻️ Duplicate comments (2)
zea/data/dataloader.py (2)

342-345: ⚠️ Potential issue | 🟠 Major

Normalize additional_axes_iter before adjusting initial

Line 344 compares raw axes values; negative axes (-1, etc.) are treated as < initial_frame_axis and shift initial incorrectly, causing wrong np.moveaxis behavior.

Proposed fix
         self.file_paths = _dataset.file_paths
         self.file_shapes = _dataset.load_file_shapes(key)
         _dataset.close()
+        rank = len(self.file_shapes[0])
+        self.initial_frame_axis = map_negative_indices([self.initial_frame_axis], rank)[0]
+        self.additional_axes_iter = list(
+            map_negative_indices(self.additional_axes_iter, rank)
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@zea/data/dataloader.py` around lines 342 - 345, The adjustment to
self.initial_frame_axis uses raw values from self.additional_axes_iter which
fails for negative axes; normalize axes to non-negative equivalents before
comparison (e.g., for each ax in self.additional_axes_iter compute ax_norm = ax
% images.ndim), ensure additional_axes_iter is materialized (list) if it can be
a generator, then compute initial -= sum(ax_norm < self.initial_frame_axis for
ax_norm in normalized_axes) and finally call np.moveaxis(images, initial,
self.frame_axis); update the block that references self.initial_frame_axis,
self.additional_axes_iter, images, and self.frame_axis accordingly.

335-339: ⚠️ Potential issue | 🟠 Major

Use the original index key for cache invalidation/reopen

On Line 335-339, retry uses file.filename as cache key. If that differs from the original file_name key used in get_file, pop() can miss and leave stale entries.

Proposed fix
-        image = self._load(file, key, indices)
+        image = self._load(file_name, file, key, indices)
@@
-    def _load(self, file: File, key: str, indices):
+    def _load(self, file_name: str, file: File, key: str, indices):
@@
-            fname = file.filename
             file_handle_cache = self._get_file_handle_cache()
-            file_handle_cache.pop(fname)
-            file = file_handle_cache.get_file(fname)
+            file_handle_cache.pop(file_name)
+            file = file_handle_cache.get_file(file_name)
             images = file.load_data(key, indices)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@zea/data/dataloader.py` around lines 335 - 339, The retry path currently
invalidates the file handle cache using file.filename (fname) which may differ
from the original cache key used by get_file; update the retry logic in the
block around _get_file_handle_cache(), get_file(...) and load_data(...) so it
uses the original cache key (the variable named file_name or the same key passed
into get_file) when calling file_handle_cache.pop(...), ensuring the cache entry
is removed and reopened correctly; keep using file.load_data(key, indices) after
re-acquiring the file handle.
🧹 Nitpick comments (1)
tests/data/test_dataloader.py (1)

508-523: test_skipped_files_warning does not verify the warning path

Line 508 names this as a warning test, but it only checks len(source) == 0. Consider asserting the warning/log content too, so regressions in warning behavior are caught.

Suggested test adjustment
-def test_skipped_files_warning(tmp_path):
+def test_skipped_files_warning(tmp_path, caplog):
@@
-    source = H5DataSource(
+    with caplog.at_level("WARNING"):
+        source = H5DataSource(
         file_paths=tmp_path,
         key="data",
         n_frames=5,
         frame_index_stride=1,
         validate=False,
-    )
+        )
     assert len(source) == 0
+    assert any("Skipping" in rec.message for rec in caplog.records)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/data/test_dataloader.py` around lines 508 - 523,
test_skipped_files_warning currently only asserts len(source) == 0 but doesn't
assert that the expected warning was emitted; update the test
(test_skipped_files_warning) to capture and assert the warning/log output when
constructing H5DataSource (e.g., use pytest.warns with the expected Warning
subclass or use pytest's caplog to capture a logger message) and assert the
message contains the expected text (like "skipped" or "too few frames") so the
warning path for H5DataSource is verified.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@zea/data/dataloader.py`:
- Around line 96-97: The current falsy check "if not limit_n_frames:" treats 0
as "no limit" and returns all frames; change the logic to only treat an absent
parameter as unlimited by replacing that check with an explicit None check
(e.g., "if limit_n_frames is None: limit_n_frames = np.inf") in the dataloader
code path where limit_n_frames is used so that limit_n_frames==0 correctly means
zero frames; update any function signature default to None if necessary.

---

Duplicate comments:
In `@zea/data/dataloader.py`:
- Around line 342-345: The adjustment to self.initial_frame_axis uses raw values
from self.additional_axes_iter which fails for negative axes; normalize axes to
non-negative equivalents before comparison (e.g., for each ax in
self.additional_axes_iter compute ax_norm = ax % images.ndim), ensure
additional_axes_iter is materialized (list) if it can be a generator, then
compute initial -= sum(ax_norm < self.initial_frame_axis for ax_norm in
normalized_axes) and finally call np.moveaxis(images, initial, self.frame_axis);
update the block that references self.initial_frame_axis,
self.additional_axes_iter, images, and self.frame_axis accordingly.
- Around line 335-339: The retry path currently invalidates the file handle
cache using file.filename (fname) which may differ from the original cache key
used by get_file; update the retry logic in the block around
_get_file_handle_cache(), get_file(...) and load_data(...) so it uses the
original cache key (the variable named file_name or the same key passed into
get_file) when calling file_handle_cache.pop(...), ensuring the cache entry is
removed and reopened correctly; keep using file.load_data(key, indices) after
re-acquiring the file handle.

---

Nitpick comments:
In `@tests/data/test_dataloader.py`:
- Around line 508-523: test_skipped_files_warning currently only asserts
len(source) == 0 but doesn't assert that the expected warning was emitted;
update the test (test_skipped_files_warning) to capture and assert the
warning/log output when constructing H5DataSource (e.g., use pytest.warns with
the expected Warning subclass or use pytest's caplog to capture a logger
message) and assert the message contains the expected text (like "skipped" or
"too few frames") so the warning path for H5DataSource is verified.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: ace08441-3be8-4a74-8a54-c8035f59f566

📥 Commits

Reviewing files that changed from the base of the PR and between 7e301c8 and 813c68a.

📒 Files selected for processing (4)
  • tests/data/test_dataloader.py
  • zea/data/dataloader.py
  • zea/data/datasets.py
  • zea/data/utils.py
💤 Files with no reviewable changes (1)
  • zea/data/utils.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

data format Related to the zea data format saving and loading efficiency Improvements made regarding code or tests efficiency

Projects

None yet

Development

Successfully merging this pull request may close these issues.

HDF5 dataloaders using a python generator

3 participants