Skip to content

Commit 813c68a

Browse files
Remove custom json for dataloader
1 parent 663c9c1 commit 813c68a

File tree

4 files changed

+47
-100
lines changed

4 files changed

+47
-100
lines changed

tests/data/test_dataloader.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from zea.data.datasets import Dataset
1515
from zea.data.file import File
1616
from zea.data.layers import Resizer
17-
from zea.data.utils import json_loads
1817
from zea.tools.hf import HFPath
1918

2019
from .. import DEFAULT_TEST_SEED
@@ -239,6 +238,7 @@ def test_h5_dataset_return_filename(
239238
validate = directory != "dummy_hdf5"
240239
directory = request.getfixturevalue(directory)
241240

241+
N_AXIS = 3 # n_frames, height, width
242242
dataset = Dataloader(
243243
directory,
244244
key=key,
@@ -259,18 +259,31 @@ def test_h5_dataset_return_filename(
259259

260260
_, file_dict = batch
261261

262-
assert len(file_dict) == batch_size, (
263-
"The file_dict should contain the same number of elements as the batch size"
264-
)
262+
# Check keys
263+
keys = ["filename", "fullpath", "indices"]
264+
for key in keys:
265+
assert key in file_dict, f"The file_dict should contain the key '{key}'"
265266

266-
file_dict = file_dict[0] # get the first file_dict of the batch
267-
file_dict = json_loads(file_dict)
267+
# Check batch size and types
268+
keys = ["filename", "fullpath"]
269+
for key in keys:
270+
assert len(file_dict[key]) == batch_size, (
271+
f"The file_dict['{key}'] should contain the same number of elements as the batch size"
272+
)
273+
for path in file_dict[key]:
274+
assert isinstance(path, str), f"Each path in file_dict['{key}'] should be a string"
268275

269-
filename = file_dict["filename"]
270-
assert isinstance(filename, str), "The filename should be a string"
271-
fullpath = file_dict["fullpath"]
272-
assert isinstance(fullpath, str), "The fullpath should be a string"
273-
assert "indices" in file_dict, "The file_dict should contain indices"
276+
# indices nests one deeper, because it has one element per axis (n_frames, height, width)
277+
indices = file_dict["indices"]
278+
assert len(indices) == N_AXIS, (
279+
f"The file_dict['indices'] should contain {N_AXIS} elements in this test"
280+
)
281+
282+
for idx in indices:
283+
assert len(idx) == batch_size, (
284+
"Each axis in file_dict['indices'] should contain the same number of elements "
285+
"as the batch size"
286+
)
274287

275288

276289
@pytest.mark.parametrize(

zea/data/dataloader.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from zea.data.datasets import Dataset, H5FileHandleCache, count_samples_per_directory
3535
from zea.data.file import File
3636
from zea.data.layers import Resizer
37-
from zea.data.utils import json_dumps
3837
from zea.utils import map_negative_indices
3938

4039
DEFAULT_NORMALIZATION_RANGE = (0, 1)
@@ -84,12 +83,12 @@ def generate_h5_indices(
8483
(
8584
"/folder/path_to_file.hdf5",
8685
"data/image",
87-
(range(0, 1), slice(None, 256, None), slice(None, 256, None)),
86+
(slice(0, 1, 1), slice(None, 256, None), slice(None, 256, None)),
8887
),
8988
(
9089
"/folder/path_to_file.hdf5",
9190
"data/image",
92-
(range(1, 2), slice(None, 256, None), slice(None, 256, None)),
91+
(slice(1, 2, 1), slice(None, 256, None), slice(None, 256, None)),
9392
),
9493
...,
9594
]
@@ -136,7 +135,7 @@ def axis_indices_files():
136135
# Optionally limit frames to load from each file
137136
n_frames_in_file = min(n_frames_in_file, limit_n_frames)
138137
indices = [
139-
list(range(i, i + block_size, frame_index_stride))
138+
slice(i, i + block_size, frame_index_stride)
140139
for i in range(0, n_frames_in_file - block_size + 1, block_step_size)
141140
]
142141
yield [indices]
@@ -293,18 +292,16 @@ def __getitem__(self, index: int):
293292
return self._data_cache[index]
294293

295294
file_name, key, indices = self.indices[index]
296-
file_cache = self._get_cache()
297-
file = file_cache.get_file(file_name)
295+
file_handle_cache = self._get_file_handle_cache()
296+
file = file_handle_cache.get_file(file_name)
298297
image = self._load(file, key, indices)
299298

300299
if self.return_filename:
301-
file_data = json_dumps(
302-
{
303-
"fullpath": file.filename,
304-
"filename": Path(file_name).stem,
305-
"indices": indices,
306-
}
307-
)
300+
file_data = {
301+
"fullpath": file.filename,
302+
"filename": Path(file_name).stem,
303+
"indices": indices,
304+
}
308305
result = (image, file_data)
309306
else:
310307
result = image
@@ -321,7 +318,7 @@ def __repr__(self) -> str:
321318

322319
# -- internals -------------------------------------------------------------
323320

324-
def _get_cache(self) -> H5FileHandleCache:
321+
def _get_file_handle_cache(self) -> H5FileHandleCache:
325322
"""Return the file-handle cache for the current thread."""
326323
if not hasattr(self._local, "cache"):
327324
self._local.cache = H5FileHandleCache()
@@ -335,14 +332,10 @@ def _load(self, file: File, key: str, indices):
335332
images = file.load_data(key, indices)
336333
except (OSError, IOError):
337334
# Invalidate cache entry and retry once
338-
cache = self._get_cache()
339335
fname = file.filename
340-
cache._file_handle_cache.pop(fname, None)
341-
try:
342-
file.close()
343-
except Exception:
344-
pass
345-
file = cache.get_file(fname)
336+
file_handle_cache = self._get_file_handle_cache()
337+
file_handle_cache.pop(fname)
338+
file = file_handle_cache.get_file(fname)
346339
images = file.load_data(key, indices)
347340

348341
if self.insert_frame_axis:

zea/data/datasets.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,15 @@ def get_file(self, file_path) -> File:
104104

105105
return self._file_handle_cache[file_path]
106106

107+
def pop(self, file_path):
108+
"""Pop a file from the cache and close it."""
109+
file = self._file_handle_cache.pop(file_path, None)
110+
if file is not None:
111+
try:
112+
file.close()
113+
except Exception:
114+
pass # swallow exceptions during close
115+
107116
def close(self):
108117
"""Close all cached file handles."""
109118
cache: OrderedDict = getattr(self, "_file_handle_cache", None)

zea/data/utils.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

0 commit comments

Comments
 (0)