Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
WalkthroughReplaces the TensorFlow-specific HDF5 dataloader and retry utilities with a Grain-backed Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
|
Currently working on the new grain backend. These are expensive operations, is it not better to do the bare minimum I/O on cpu -> off-load to GPU (working on prefetching) -> do these transformations with keras ops? |
That would make a lot of sense actually if that is possible in an efficient way! I did notice indeed the operations you mention take a lot of time in the current implementation. |
|
Status update:
|
|
JAX has two tutorials, for both CPU and GPU based dataloaders using grain.. See link. |
|
Update of the last commit:
I would say that the grain loader is fully functional now, the only things left are: |
|
@swpenninga is https://github.com/tue-bmd/zea/blob/795ce21bf4283c71ad35e5d60aa9c755298cb754/zea/data/utils.py still needed? I believe this was related to tf.data compatibility |
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
docs/source/notebooks/data/zea_data_example.ipynb (1)
472-472:⚠️ Potential issue | 🟡 MinorFix quote typo in
resize_typeoptions comment.Line 472’s inline comment has mismatched quotes, which makes the option list harder to read.
🛠️ Proposed fix
- resize_type="resize", # or "center_crop or "random_crop" + resize_type="resize", # or "center_crop" or "random_crop"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/source/notebooks/data/zea_data_example.ipynb` at line 472, The inline comment for the parameter resize_type contains mismatched quotes; update the comment near the resize_type usage (look for the string "resize_type" in the notebook cell) so the options are consistently quoted, e.g. resize_type="resize", # or "center_crop" or "random_crop" (ensure all option names have matching double quotes and proper commas/spacing).
🧹 Nitpick comments (4)
docs/source/notebooks/metrics/myocardial_quality_example.ipynb (1)
65-65: Prefer importingDataloaderfrom the publiczeaAPI in docs.This keeps the example tied to the stable, user-facing surface instead of an internal module path.
♻️ Proposed change
-from zea.data.dataloader import Dataloader +from zea import Dataloader🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/source/notebooks/metrics/myocardial_quality_example.ipynb` at line 65, Update the notebook import to use the public API: replace the internal import "from zea.data.dataloader import Dataloader" with the public surface import "from zea import Dataloader" (referencing the Dataloader symbol) so examples rely on stable, user-facing module paths rather than internal module paths.docs/source/notebooks/models/left_ventricle_segmentation_example.ipynb (1)
73-73: Prefer the publiczea.Dataloaderimport in user-facing notebooks.Using the top-level export keeps docs aligned with the stable public API and reduces coupling to internal module paths.
♻️ Proposed change
-from zea.data.dataloader import Dataloader +from zea import Dataloader🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/source/notebooks/models/left_ventricle_segmentation_example.ipynb` at line 73, The notebook imports Dataloader from an internal path ("from zea.data.dataloader import Dataloader"); change this to use the public top-level export by importing Dataloader from zea (e.g., replace that import with "from zea import Dataloader") so docs use the stable public API and avoid coupling to internal module paths; update any cells that reference Dataloader import to use the new import symbol.docs/source/notebooks/metrics/lpips_example.ipynb (1)
81-81: Prefer the publiczeaAPI import in docs.At Line 81, importing
Dataloaderfromzea.data.dataloadercouples the notebook to an internal module path. SinceDataloaderis exported at the top level, docs should use the public surface for stability and clarity.♻️ Proposed fix
- "from zea.data.dataloader import Dataloader\n", + "from zea import Dataloader\n",🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/source/notebooks/metrics/lpips_example.ipynb` at line 81, Import uses internal path zea.data.dataloader; change to the public API by importing Dataloader from the top-level zea package (use "from zea import Dataloader") so the notebook relies on the stable public surface rather than an internal module path; update the import statement that currently references Dataloader in the notebook to the top-level export.docs/source/notebooks/data/zea_data_example.ipynb (1)
12-12: Consider using the publiczea.DataloaderAPI for notebook examples.
Dataloaderis exported at the top level (zea.Dataloader), but the notebook currently imports from the internal module path (zea.data.dataloader.Dataloader). Using the public API provides a more stable entrypoint and better protects against future refactors. This applies to lines 12, 94, and 383.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/source/notebooks/data/zea_data_example.ipynb` at line 12, The notebook imports the internal class zea.data.dataloader.Dataloader; replace those imports with the public top-level export zea.Dataloader to use the stable API. Update any import statements and references at the three locations mentioned (the cells around lines 12, 94, and 383) so they import or reference Dataloader via zea.Dataloader instead of zea.data.dataloader.Dataloader, ensuring all usages (constructor calls, type hints, and doc references) point to the top-level Dataloader symbol.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/source/notebooks/data/zea_data_example.ipynb`:
- Line 383: Update the explanatory sentence about dataset consistency: instead
of implying CAMUS fully solves consistency, state that
zea.data.dataloader.Dataloader works best with consistent sample shapes but some
public datasets (e.g., PICMUS) are highly variable and even CAMUS can contain
variable shapes; advise users they may still see variable-shape warnings and
should apply resizing/cropping or other preprocessing to obtain uniform shapes
before using Dataloader. Reference the class name zea.data.dataloader.Dataloader
and datasets PICMUS and CAMUS when rephrasing.
In `@zea/data/dataloader.py`:
- Line 259: The issue is that negative indices in additional_axes_iter are being
used before normalization, so the subsequent "initial" axis correction and
np.moveaxis calls operate on incorrect axes; update the Dataloader
initialization and the code paths around lines handling additional_axes_iter
(e.g., the attribute self.additional_axes_iter and the logic that adjusts
initial and calls np.moveaxis) to first normalize additional_axes_iter into
non-negative axes relative to the current ndim (e.g., map each axis a to a if
a>=0 else a+ndim) before performing the "initial" correction and any np.moveaxis
operations so negative values don't shift the wrong dimensions; apply the same
normalization in the other spot referenced (the block around the later
adjustments at ~370-373) so both code paths use normalized axes consistently.
- Line 435: The docstring stating "dataset_repetitions: Repeat dataset N times
(``None`` = infinite)" is wrong relative to the implementation; either update
the docstring to state that None yields a single-pass, or change the runtime
guard so that when dataset_repetitions is None the dataset is passed through
repeat() for infinite repetition. Locate the parameter named dataset_repetitions
in the DataLoader (or loader factory) function and either adjust the textual
description to "None = single-pass" or modify the conditional around repeat() to
call dataset.repeat() when dataset_repetitions is None (and keep the existing
behavior when an integer is provided).
- Around line 347-351: The current _get_cache() creates a per-thread
H5FileHandleCache in threading.local() so Dataloader.close() only sees the
caller thread's cache; change _get_cache() to register each created
H5FileHandleCache in a shared collection (e.g. a thread-safe weakref.WeakSet or
dict on the Dataloader instance) so that close() (and any other cleanup paths)
can iterate all registered caches and call their close() from the main thread;
implement the shared registry on the Dataloader (add e.g. self._all_caches and a
Lock), update _get_cache() to add new caches to that registry, and change
Dataloader.close() to iterate self._all_caches and close every H5FileHandleCache
to ensure handles opened by worker threads are released.
- Around line 286-288: The current truthy check on limit_n_samples incorrectly
treats 0 as "no limit"; update the conditional around self.indices slicing (the
block using limit_n_samples and self.indices in H5DataSource) to explicitly
check for None (e.g., if limit_n_samples is not None) and validate it (ensure
it's >= 0 and clamp to len(self.indices)); then slice self.indices =
self.indices[:limit_n_samples] so limit_n_samples=0 produces an empty datasource
while negative values raise/are handled appropriately.
---
Outside diff comments:
In `@docs/source/notebooks/data/zea_data_example.ipynb`:
- Line 472: The inline comment for the parameter resize_type contains mismatched
quotes; update the comment near the resize_type usage (look for the string
"resize_type" in the notebook cell) so the options are consistently quoted, e.g.
resize_type="resize", # or "center_crop" or "random_crop" (ensure all option
names have matching double quotes and proper commas/spacing).
---
Nitpick comments:
In `@docs/source/notebooks/data/zea_data_example.ipynb`:
- Line 12: The notebook imports the internal class
zea.data.dataloader.Dataloader; replace those imports with the public top-level
export zea.Dataloader to use the stable API. Update any import statements and
references at the three locations mentioned (the cells around lines 12, 94, and
383) so they import or reference Dataloader via zea.Dataloader instead of
zea.data.dataloader.Dataloader, ensuring all usages (constructor calls, type
hints, and doc references) point to the top-level Dataloader symbol.
In `@docs/source/notebooks/metrics/lpips_example.ipynb`:
- Line 81: Import uses internal path zea.data.dataloader; change to the public
API by importing Dataloader from the top-level zea package (use "from zea import
Dataloader") so the notebook relies on the stable public surface rather than an
internal module path; update the import statement that currently references
Dataloader in the notebook to the top-level export.
In `@docs/source/notebooks/metrics/myocardial_quality_example.ipynb`:
- Line 65: Update the notebook import to use the public API: replace the
internal import "from zea.data.dataloader import Dataloader" with the public
surface import "from zea import Dataloader" (referencing the Dataloader symbol)
so examples rely on stable, user-facing module paths rather than internal module
paths.
In `@docs/source/notebooks/models/left_ventricle_segmentation_example.ipynb`:
- Line 73: The notebook imports Dataloader from an internal path ("from
zea.data.dataloader import Dataloader"); change this to use the public top-level
export by importing Dataloader from zea (e.g., replace that import with "from
zea import Dataloader") so docs use the stable public API and avoid coupling to
internal module paths; update any cells that reference Dataloader import to use
the new import symbol.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: b8faaf8e-6ddb-45ec-8899-1e00bfb39393
⛔ Files ignored due to path filters (1)
poetry.lockis excluded by!**/*.lock
📒 Files selected for processing (19)
docs/source/notebooks/data/zea_data_example.ipynbdocs/source/notebooks/metrics.rstdocs/source/notebooks/metrics/lpips_example.ipynbdocs/source/notebooks/metrics/myocardial_quality_example.ipynbdocs/source/notebooks/models/hvae_model_example.ipynbdocs/source/notebooks/models/left_ventricle_segmentation_example.ipynbdocs/source/notebooks/models/taesd_autoencoder_example.ipynbdocs/source/notebooks/models/unet_example.ipynbpyproject.tomltests/data/test_dataloader.pyzea/__init__.pyzea/backend/__init__.pyzea/backend/tensorflow/__init__.pyzea/backend/tensorflow/dataloader.pyzea/backend/tensorflow/utils/callbacks.pyzea/data/dataloader.pyzea/data/utils.pyzea/internal/utils.pyzea/io_lib.py
💤 Files with no reviewable changes (6)
- zea/backend/tensorflow/utils/callbacks.py
- zea/backend/tensorflow/init.py
- zea/io_lib.py
- zea/data/utils.py
- zea/internal/utils.py
- zea/backend/tensorflow/dataloader.py
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (3)
zea/data/dataloader.py (3)
234-245:⚠️ Potential issue | 🟠 MajorNormalize negative axes before the frame-axis correction.
additional_axes_iter=(-2,)orinitial_frame_axis=-1still breaks theinitial -= ...adjustment, sonp.moveaxis()can target the wrong dimension and corrupt sample layout. Normalize both axes against the file rank once in__init__and reuse those normalized values everywhere below.Suggested fix
self.file_paths = _dataset.file_paths self.file_shapes = _dataset.load_file_shapes(key) _dataset.close() + rank = len(self.file_shapes[0]) + self.initial_frame_axis = map_negative_indices([self.initial_frame_axis], rank)[0] + self.additional_axes_iter = list( + map_negative_indices(self.additional_axes_iter, rank) + ) # Compute per-sample index table self.indices = generate_h5_indices(Also applies to: 247-259, 267-269, 345-348
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@zea/data/dataloader.py` around lines 234 - 245, Normalize any negative axis indices early in the DataLoader initialization: compute the file rank from the discovered file shapes (self.file_shapes) and convert self.additional_axes_iter and self.initial_frame_axis to non-negative indices (e.g., by adding rank to negative values) immediately after loading self.file_shapes and before performing the frame-axis correction (the "initial -= ..." adjustment) so np.moveaxis() targets the correct dimensions; update all usages of additional_axes_iter and initial_frame_axis (including the logic in the blocks referenced around lines 247-259, 267-269, and 345-348) to use these normalized values consistently throughout the class.
280-282:⚠️ Potential issue | 🟠 Major
close()still misses HDF5 handles opened on worker threads.Each Grain reader thread creates its own
H5FileHandleCacheinthreading.local(), butclose()can only see the cache on the caller thread.Dataloader.close()therefore leaves worker-thread file handles alive until thread teardown or GC.Suggested fix
# Thread-local file handle caches (one per thread) self._local = threading.local() + self._all_caches = [] + self._all_caches_lock = threading.Lock() @@ def _get_cache(self) -> H5FileHandleCache: """Return the file-handle cache for the current thread.""" if not hasattr(self._local, "cache"): - self._local.cache = H5FileHandleCache() + cache = H5FileHandleCache() + self._local.cache = cache + with self._all_caches_lock: + self._all_caches.append(cache) return self._local.cache @@ def close(self): - """Close file handles for the current thread.""" - cache = getattr(self._local, "cache", None) - if cache is not None: - cache.close() + """Close file handles for all threads.""" + with self._all_caches_lock: + caches, self._all_caches = self._all_caches, [] + for cache in caches: + cache.close()Also applies to: 322-326, 354-358
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@zea/data/dataloader.py` around lines 280 - 282, The thread-local HDF5 caches stored in self._local are only visible to the creating thread, so Dataloader.close() misses handles opened by worker threads; fix by registering each per-thread H5FileHandleCache in a shared, thread-safe registry when it's created (e.g., add the cache to a self._h5_caches set protected by a lock inside the code path that creates the cache in threading.local), use weakrefs if needed to avoid preventing GC, and then update Dataloader.close() to iterate over that shared registry and call the H5FileHandleCache.close() for every registered cache (and clear the registry); reference symbols: self._local, H5FileHandleCache, and close().
261-263:⚠️ Potential issue | 🟡 MinorTreat
limit_n_samples=0as empty instead of “no limit”.The truthy guard skips slicing for
0, and negative values silently trim from the end. Use an explicitis not Nonecheck here.Suggested fix
- if limit_n_samples: + if limit_n_samples is not None: + if limit_n_samples < 0: + raise ValueError("limit_n_samples must be >= 0") log.info(f"H5DataSource: Limiting to {limit_n_samples} / {len(self.indices)} samples.") self.indices = self.indices[:limit_n_samples]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@zea/data/dataloader.py` around lines 261 - 263, The guard `if limit_n_samples:` treats 0 as "no limit" and allows negative values to slice from the end; update the check in the H5DataSource handling (the block that logs and sets self.indices = self.indices[:limit_n_samples]) to `if limit_n_samples is not None:` and before slicing coerce/clamp the value (e.g., n = max(0, int(limit_n_samples))) so slicing uses a non-negative integer, then log using n and assign `self.indices = self.indices[:n]`.
🧹 Nitpick comments (2)
tests/test_io_lib.py (1)
11-12: Remove unused constants.
MAX_RETRIESandINITIAL_DELAYare no longer used after theretry_on_io_errortests were removed.🧹 Proposed fix
from . import DEFAULT_TEST_SEED -MAX_RETRIES = 3 -INITIAL_DELAY = 0.01 - `@pytest.fixture`🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_io_lib.py` around lines 11 - 12, Remove the now-unused constants MAX_RETRIES and INITIAL_DELAY from tests/test_io_lib.py: delete their declarations and ensure there are no remaining references to MAX_RETRIES or INITIAL_DELAY elsewhere in the file; if any tests or helpers still reference them, update those usages or remove them accordingly and run the test suite to confirm no references remain.tests/data/test_dataloader.py (1)
397-414: Add a negative-axis regression to this N-D test.This only exercises positive
additional_axes_iter, so accepted inputs like(-2,)orinitial_frame_axis=-1can still regress without this test failing. One negative-axis parametrization here would lock in the normalization fix.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/data/test_dataloader.py` around lines 397 - 414, The test instantiates Dataloader but only covers positive additional_axes_iter and positive axis params; add a negative-axis regression by adding at least one parametrization that uses negative axes (e.g., additional_axes_iter=(-2,) or initial_frame_axis=-1 or frame_axis=-1) to the test case so the code paths handling negative indexing in Dataloader (constructor handling of additional_axes_iter, initial_frame_axis, frame_axis, and the normalization logic) are exercised and prevent regressions; update the test parametrization for the dataset instantiation to include that negative-axis variant and keep the same assertions so the normalization fix is locked in.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@zea/data/dataloader.py`:
- Around line 612-615: The __iter__ currently mutates the private
self._shuffle_node._seed (undocumented Grain internals); instead avoid touching
_seed and rebuild a fresh shuffled dataset per epoch by calling
MapDataset.shuffle(...) with a new seed inside __iter__ (replace usage of
self._shuffle_node mutation with creating a new shuffle node from the base
dataset using MapDataset.shuffle(seed=...)), or alternatively constrain Grain in
pyproject.toml to a pinned compatible range; update references in the __iter__
implementation to use the new shuffle node (and remove direct accesses to _seed)
and ensure the RNG provides the seed for MapDataset.shuffle().
---
Duplicate comments:
In `@zea/data/dataloader.py`:
- Around line 234-245: Normalize any negative axis indices early in the
DataLoader initialization: compute the file rank from the discovered file shapes
(self.file_shapes) and convert self.additional_axes_iter and
self.initial_frame_axis to non-negative indices (e.g., by adding rank to
negative values) immediately after loading self.file_shapes and before
performing the frame-axis correction (the "initial -= ..." adjustment) so
np.moveaxis() targets the correct dimensions; update all usages of
additional_axes_iter and initial_frame_axis (including the logic in the blocks
referenced around lines 247-259, 267-269, and 345-348) to use these normalized
values consistently throughout the class.
- Around line 280-282: The thread-local HDF5 caches stored in self._local are
only visible to the creating thread, so Dataloader.close() misses handles opened
by worker threads; fix by registering each per-thread H5FileHandleCache in a
shared, thread-safe registry when it's created (e.g., add the cache to a
self._h5_caches set protected by a lock inside the code path that creates the
cache in threading.local), use weakrefs if needed to avoid preventing GC, and
then update Dataloader.close() to iterate over that shared registry and call the
H5FileHandleCache.close() for every registered cache (and clear the registry);
reference symbols: self._local, H5FileHandleCache, and close().
- Around line 261-263: The guard `if limit_n_samples:` treats 0 as "no limit"
and allows negative values to slice from the end; update the check in the
H5DataSource handling (the block that logs and sets self.indices =
self.indices[:limit_n_samples]) to `if limit_n_samples is not None:` and before
slicing coerce/clamp the value (e.g., n = max(0, int(limit_n_samples))) so
slicing uses a non-negative integer, then log using n and assign `self.indices =
self.indices[:n]`.
---
Nitpick comments:
In `@tests/data/test_dataloader.py`:
- Around line 397-414: The test instantiates Dataloader but only covers positive
additional_axes_iter and positive axis params; add a negative-axis regression by
adding at least one parametrization that uses negative axes (e.g.,
additional_axes_iter=(-2,) or initial_frame_axis=-1 or frame_axis=-1) to the
test case so the code paths handling negative indexing in Dataloader
(constructor handling of additional_axes_iter, initial_frame_axis, frame_axis,
and the normalization logic) are exercised and prevent regressions; update the
test parametrization for the dataset instantiation to include that negative-axis
variant and keep the same assertions so the normalization fix is locked in.
In `@tests/test_io_lib.py`:
- Around line 11-12: Remove the now-unused constants MAX_RETRIES and
INITIAL_DELAY from tests/test_io_lib.py: delete their declarations and ensure
there are no remaining references to MAX_RETRIES or INITIAL_DELAY elsewhere in
the file; if any tests or helpers still reference them, update those usages or
remove them accordingly and run the test suite to confirm no references remain.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 5c52433d-f8ce-4650-8bec-a7789f72c699
📒 Files selected for processing (4)
tests/data/test_dataloader.pytests/test_io_lib.pyzea/data/__init__.pyzea/data/dataloader.py
💤 Files with no reviewable changes (1)
- zea/data/init.py
@wesselvannierop |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/data/test_dataloader.py (1)
423-435:⚠️ Potential issue | 🟡 MinorComment contradicts the code:
return_centers=Truevs "must not return centers".Line 423 states the augmentation "must not return centers," but
return_centers=Trueis set on line 430. Either update the comment to reflect the actual behavior, or setreturn_centers=Falseif centers should not be returned.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/data/test_dataloader.py` around lines 423 - 435, The comment and code disagree: the test's augmentation (variable augmentation using RandomCircleInclusion) is documented as "must not return centers" but the constructor sets return_centers=True; either change the parameter on RandomCircleInclusion to return_centers=False to match the comment, or edit the comment to state that centers are returned; update the RandomCircleInclusion(...) call or the preceding comment accordingly so the comment and the return_centers argument are consistent.
🧹 Nitpick comments (4)
tests/data/test_dataloader.py (4)
485-487: Preferkeras.ops.convert_to_numpy()for multi-backend compatibility.Same issue as in
test_random_circle_inclusion_augmentation. Based on learnings, usekeras.ops.convert_to_numpy(images)for Torch backend compatibility.♻️ Suggested fix
- images_np = np.array(images) + images_np = keras.ops.convert_to_numpy(images)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/data/test_dataloader.py` around lines 485 - 487, The test currently converts a batch via numpy conversion using np.array(images), which breaks multi-backend compatibility for PyTorch; replace that conversion by calling keras.ops.convert_to_numpy(images) so the test uses the framework-agnostic helper (locate the lines where images = next(iter(dataset)) and images_np = np.array(images) and change the latter to use keras.ops.convert_to_numpy(images)). Ensure you import or reference keras.ops if not already available in the test file.
495-509: Test name mentions "warning" but doesn't verify warning emission.The test is named
test_skipped_files_warningbut only assertslen(source) == 0. If a warning is expected when files are skipped, usepytest.warns()to verify:with pytest.warns(UserWarning, match="skipped"): source = H5DataSource(...)If no warning is expected, consider renaming to
test_skipped_files_insufficient_frames.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/data/test_dataloader.py` around lines 495 - 509, The test test_skipped_files_warning currently only asserts len(source) == 0 but its name implies a warning should be emitted; wrap the H5DataSource instantiation in a pytest.warns(UserWarning, match="skipped") context to assert the expected warning is emitted (referencing H5DataSource in the test), or if no warning should be produced, rename the test to test_skipped_files_insufficient_frames to reflect behavior; update the assertion accordingly and keep the check for len(source) == 0.
450-452: Preferkeras.ops.convert_to_numpy()for multi-backend compatibility.Using
np.array(images)may fail with the Torch backend. Based on learnings: "In tests, when comparing Keras tensors with NumPy operations, convert tensors to NumPy arrays first usingkeras.ops.convert_to_numpy()to ensure multi-backend compatibility."♻️ Suggested fix
- images_np = np.array(images) + images_np = keras.ops.convert_to_numpy(images)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/data/test_dataloader.py` around lines 450 - 452, Replace the direct NumPy conversion of the batch with a backend-safe conversion: instead of using np.array(images) in the test (where images is produced by iter(dataset)), call keras.ops.convert_to_numpy(images) and assign that to images_np so the test works with both TensorFlow and Torch backends; update the reference to images_np wherever used afterwards.
382-417: Consider adding shape assertions for n-dimensional dataset test.The test only verifies that iteration doesn't raise an exception. Adding assertions on the output shape would strengthen the test coverage:
batch = next(iter(dataset)) assert batch.shape[:1] == (batch_size,), "Batch dimension mismatch" assert batch.shape[-3:-1] == image_size, "Spatial dimensions mismatch"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/data/test_dataloader.py` around lines 382 - 417, The test test_ndim_hdf5_dataset currently only ensures iteration doesn't raise; modify it to capture the returned batch from next(iter(dataset)) and assert expected shapes using batch_size and image_size: check the leading batch dimension equals batch_size and verify the spatial dimensions (using image_size) appear at the expected positions in the tensor returned by Dataloader; reference the Dataloader invocation and variables batch_size and image_size to locate where to add the assertions.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@tests/data/test_dataloader.py`:
- Around line 423-435: The comment and code disagree: the test's augmentation
(variable augmentation using RandomCircleInclusion) is documented as "must not
return centers" but the constructor sets return_centers=True; either change the
parameter on RandomCircleInclusion to return_centers=False to match the comment,
or edit the comment to state that centers are returned; update the
RandomCircleInclusion(...) call or the preceding comment accordingly so the
comment and the return_centers argument are consistent.
---
Nitpick comments:
In `@tests/data/test_dataloader.py`:
- Around line 485-487: The test currently converts a batch via numpy conversion
using np.array(images), which breaks multi-backend compatibility for PyTorch;
replace that conversion by calling keras.ops.convert_to_numpy(images) so the
test uses the framework-agnostic helper (locate the lines where images =
next(iter(dataset)) and images_np = np.array(images) and change the latter to
use keras.ops.convert_to_numpy(images)). Ensure you import or reference
keras.ops if not already available in the test file.
- Around line 495-509: The test test_skipped_files_warning currently only
asserts len(source) == 0 but its name implies a warning should be emitted; wrap
the H5DataSource instantiation in a pytest.warns(UserWarning, match="skipped")
context to assert the expected warning is emitted (referencing H5DataSource in
the test), or if no warning should be produced, rename the test to
test_skipped_files_insufficient_frames to reflect behavior; update the assertion
accordingly and keep the check for len(source) == 0.
- Around line 450-452: Replace the direct NumPy conversion of the batch with a
backend-safe conversion: instead of using np.array(images) in the test (where
images is produced by iter(dataset)), call keras.ops.convert_to_numpy(images)
and assign that to images_np so the test works with both TensorFlow and Torch
backends; update the reference to images_np wherever used afterwards.
- Around line 382-417: The test test_ndim_hdf5_dataset currently only ensures
iteration doesn't raise; modify it to capture the returned batch from
next(iter(dataset)) and assert expected shapes using batch_size and image_size:
check the leading batch dimension equals batch_size and verify the spatial
dimensions (using image_size) appear at the expected positions in the tensor
returned by Dataloader; reference the Dataloader invocation and variables
batch_size and image_size to locate where to add the assertions.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: a25ab36d-fafa-457e-8b2e-4a3f30a48a7d
📒 Files selected for processing (1)
tests/data/test_dataloader.py
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
zea/data/dataloader.py (1)
96-97:⚠️ Potential issue | 🟡 Minor
limit_n_frames=0currently disables limitingLine 96 treats
0as falsy and resets tonp.inf, which silently returns full data instead of zero frames.Proposed fix
- if not limit_n_frames: - limit_n_frames = np.inf + if limit_n_frames is None: + limit_n_frames = np.inf + else: + assert limit_n_frames >= 0, "`limit_n_frames` must be >= 0"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@zea/data/dataloader.py` around lines 96 - 97, The current falsy check "if not limit_n_frames:" treats 0 as "no limit" and returns all frames; change the logic to only treat an absent parameter as unlimited by replacing that check with an explicit None check (e.g., "if limit_n_frames is None: limit_n_frames = np.inf") in the dataloader code path where limit_n_frames is used so that limit_n_frames==0 correctly means zero frames; update any function signature default to None if necessary.
♻️ Duplicate comments (2)
zea/data/dataloader.py (2)
342-345:⚠️ Potential issue | 🟠 MajorNormalize
additional_axes_iterbefore adjustinginitialLine 344 compares raw axes values; negative axes (
-1, etc.) are treated as< initial_frame_axisand shiftinitialincorrectly, causing wrongnp.moveaxisbehavior.Proposed fix
self.file_paths = _dataset.file_paths self.file_shapes = _dataset.load_file_shapes(key) _dataset.close() + rank = len(self.file_shapes[0]) + self.initial_frame_axis = map_negative_indices([self.initial_frame_axis], rank)[0] + self.additional_axes_iter = list( + map_negative_indices(self.additional_axes_iter, rank) + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@zea/data/dataloader.py` around lines 342 - 345, The adjustment to self.initial_frame_axis uses raw values from self.additional_axes_iter which fails for negative axes; normalize axes to non-negative equivalents before comparison (e.g., for each ax in self.additional_axes_iter compute ax_norm = ax % images.ndim), ensure additional_axes_iter is materialized (list) if it can be a generator, then compute initial -= sum(ax_norm < self.initial_frame_axis for ax_norm in normalized_axes) and finally call np.moveaxis(images, initial, self.frame_axis); update the block that references self.initial_frame_axis, self.additional_axes_iter, images, and self.frame_axis accordingly.
335-339:⚠️ Potential issue | 🟠 MajorUse the original index key for cache invalidation/reopen
On Line 335-339, retry uses
file.filenameas cache key. If that differs from the originalfile_namekey used inget_file,pop()can miss and leave stale entries.Proposed fix
- image = self._load(file, key, indices) + image = self._load(file_name, file, key, indices) @@ - def _load(self, file: File, key: str, indices): + def _load(self, file_name: str, file: File, key: str, indices): @@ - fname = file.filename file_handle_cache = self._get_file_handle_cache() - file_handle_cache.pop(fname) - file = file_handle_cache.get_file(fname) + file_handle_cache.pop(file_name) + file = file_handle_cache.get_file(file_name) images = file.load_data(key, indices)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@zea/data/dataloader.py` around lines 335 - 339, The retry path currently invalidates the file handle cache using file.filename (fname) which may differ from the original cache key used by get_file; update the retry logic in the block around _get_file_handle_cache(), get_file(...) and load_data(...) so it uses the original cache key (the variable named file_name or the same key passed into get_file) when calling file_handle_cache.pop(...), ensuring the cache entry is removed and reopened correctly; keep using file.load_data(key, indices) after re-acquiring the file handle.
🧹 Nitpick comments (1)
tests/data/test_dataloader.py (1)
508-523:test_skipped_files_warningdoes not verify the warning pathLine 508 names this as a warning test, but it only checks
len(source) == 0. Consider asserting the warning/log content too, so regressions in warning behavior are caught.Suggested test adjustment
-def test_skipped_files_warning(tmp_path): +def test_skipped_files_warning(tmp_path, caplog): @@ - source = H5DataSource( + with caplog.at_level("WARNING"): + source = H5DataSource( file_paths=tmp_path, key="data", n_frames=5, frame_index_stride=1, validate=False, - ) + ) assert len(source) == 0 + assert any("Skipping" in rec.message for rec in caplog.records)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/data/test_dataloader.py` around lines 508 - 523, test_skipped_files_warning currently only asserts len(source) == 0 but doesn't assert that the expected warning was emitted; update the test (test_skipped_files_warning) to capture and assert the warning/log output when constructing H5DataSource (e.g., use pytest.warns with the expected Warning subclass or use pytest's caplog to capture a logger message) and assert the message contains the expected text (like "skipped" or "too few frames") so the warning path for H5DataSource is verified.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@zea/data/dataloader.py`:
- Around line 96-97: The current falsy check "if not limit_n_frames:" treats 0
as "no limit" and returns all frames; change the logic to only treat an absent
parameter as unlimited by replacing that check with an explicit None check
(e.g., "if limit_n_frames is None: limit_n_frames = np.inf") in the dataloader
code path where limit_n_frames is used so that limit_n_frames==0 correctly means
zero frames; update any function signature default to None if necessary.
---
Duplicate comments:
In `@zea/data/dataloader.py`:
- Around line 342-345: The adjustment to self.initial_frame_axis uses raw values
from self.additional_axes_iter which fails for negative axes; normalize axes to
non-negative equivalents before comparison (e.g., for each ax in
self.additional_axes_iter compute ax_norm = ax % images.ndim), ensure
additional_axes_iter is materialized (list) if it can be a generator, then
compute initial -= sum(ax_norm < self.initial_frame_axis for ax_norm in
normalized_axes) and finally call np.moveaxis(images, initial, self.frame_axis);
update the block that references self.initial_frame_axis,
self.additional_axes_iter, images, and self.frame_axis accordingly.
- Around line 335-339: The retry path currently invalidates the file handle
cache using file.filename (fname) which may differ from the original cache key
used by get_file; update the retry logic in the block around
_get_file_handle_cache(), get_file(...) and load_data(...) so it uses the
original cache key (the variable named file_name or the same key passed into
get_file) when calling file_handle_cache.pop(...), ensuring the cache entry is
removed and reopened correctly; keep using file.load_data(key, indices) after
re-acquiring the file handle.
---
Nitpick comments:
In `@tests/data/test_dataloader.py`:
- Around line 508-523: test_skipped_files_warning currently only asserts
len(source) == 0 but doesn't assert that the expected warning was emitted;
update the test (test_skipped_files_warning) to capture and assert the
warning/log output when constructing H5DataSource (e.g., use pytest.warns with
the expected Warning subclass or use pytest's caplog to capture a logger
message) and assert the message contains the expected text (like "skipped" or
"too few frames") so the warning path for H5DataSource is verified.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: ace08441-3be8-4a74-8a54-c8035f59f566
📒 Files selected for processing (4)
tests/data/test_dataloader.pyzea/data/dataloader.pyzea/data/datasets.pyzea/data/utils.py
💤 Files with no reviewable changes (1)
- zea/data/utils.py
Grain is a library for reading data for training and evaluating JAX models. It’s open source, fast and deterministic.
I was able to implement most of the functionality of the existing dataloader, making it a drop in replacement for
zea.backend.tensorflow.make_dataloader. Simply usezea.Dataloaderinstead.I would like to move the dataloader to grain (from Tensorflow previously). There are a few advantages:
Summary by CodeRabbit
New Features
Breaking Changes
Documentation
Chores
Tests