Skip to content

Commit a890951

Browse files
committed
wip fix chunks
1 parent 5a20f41 commit a890951

File tree

3 files changed

+59
-15
lines changed

3 files changed

+59
-15
lines changed

benchmarks/benchmark_image.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from spatialdata import SpatialData
1212
from xarray import DataArray
1313

14-
from spatialdata_io import image
14+
from spatialdata_io import image # type: ignore[attr-defined]
1515

1616
# =============================================================================
1717
# CONFIGURATION - Edit these paths to match your setup
@@ -73,13 +73,29 @@ def _convert_image(
7373
# sanity check
7474
if scale_factors is None:
7575
assert isinstance(sdata["image"], DataArray)
76+
if chunks is not None:
77+
assert (
78+
sdata["image"].chunksizes["x"][0] == chunks[0]
79+
or sdata["image"].chunksizes["x"][0] == sdata["image"].shape[2]
80+
)
81+
assert (
82+
sdata["image"].chunksizes["y"][0] == chunks[1]
83+
or sdata["image"].chunksizes["y"][0] == sdata["image"].shape[1]
84+
)
7685
else:
77-
assert len(sdata["image"].keys()) == len(scale_factors)
86+
assert len(sdata["image"].keys()) == len(scale_factors) + 1
87+
if chunks is not None:
88+
assert (
89+
sdata["image"]["scale0"]["image"].chunksizes["x"][0] == chunks[0]
90+
or sdata["image"]["scale0"]["image"].chunksizes["x"][0]
91+
== sdata["image"]["scale0"]["image"].shape[2]
92+
)
93+
assert (
94+
sdata["image"]["scale0"]["image"].chunksizes["y"][0] == chunks[1]
95+
or sdata["image"]["scale0"]["image"].chunksizes["y"][0]
96+
== sdata["image"]["scale0"]["image"].shape[1]
97+
)
7898

79-
if chunks is not None:
80-
# TODO: bug here!
81-
assert sdata["image"].chunksizes["x"] == chunks[0]
82-
assert sdata["image"].chunksizes["y"] == chunks[1]
8399
return sdata
84100

85101
def time_io(self, scale_factors: list[int] | None, use_tiff_memmap: bool, chunks: tuple[int, int]) -> None:
@@ -96,5 +112,27 @@ def peakmem_io(self, scale_factors: list[int] | None, use_tiff_memmap: bool, chu
96112
if __name__ == "__main__":
97113
# Run a single test case for quick verification
98114
bench = IOBenchmarkImage()
99-
bench.setup(None, True, (1000, 1000))
100-
bench.time_io(None, True, (1000, 1000))
115+
116+
# bench.setup()
117+
# bench.time_io(None, True, (5000, 5000))
118+
119+
# bench.setup()
120+
# bench.time_io(None, True, (1000, 1000))
121+
122+
# bench.setup()
123+
# bench.time_io(None, False, (5000, 5000))
124+
125+
# bench.setup()
126+
# bench.time_io(None, False, (1000, 1000))
127+
128+
# bench.setup()
129+
# bench.time_io([2, 2, 2], True, (5000, 5000))
130+
131+
# bench.setup()
132+
# bench.time_io([2, 2, 2], True, (1000, 1000))
133+
134+
bench.setup()
135+
bench.time_io([2, 2, 2], False, (5000, 5000))
136+
137+
# bench.setup()
138+
# bench.time_io([2, 2, 2], False, (1000, 1000))

benchmarks/benchmark_xenium.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,6 @@ def peakmem_io(self) -> None:
103103

104104

105105
if __name__ == "__main__":
106-
IOBenchmarkXenium().time_io()
106+
benchmark = IOBenchmarkXenium()
107+
benchmark.setup()
108+
benchmark.time_io()

src/spatialdata_io/readers/generic.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,15 @@ def _reader_func(slide: np.memmap, y0: int, x0: int, height: int, width: int) ->
131131
return _read_chunks(_reader_func, slide, coords=chunk_coords, n_channel=n_channel, dtype=slide.dtype)
132132

133133

134-
def _dask_image_imread(input: Path, data_axes: Sequence[str]) -> da.Array:
134+
def _dask_image_imread(input: Path, data_axes: Sequence[str], chunks: tuple[int, int] | None = None) -> da.Array:
135+
if set(data_axes) != {"c", "y", "x"}:
136+
raise NotImplementedError(f"Only 'c', 'y', 'x' axes are supported, got {data_axes}")
135137
image = imread(input)
136-
if len(image.shape) == len(data_axes) + 1 and image.shape[0] == 1:
137-
image = np.squeeze(image, axis=0)
138-
return image
138+
if image.ndim != len(data_axes):
139+
raise ValueError(f"Expected image with {len(data_axes)} dimensions, got {image.ndim}")
140+
image = image.transpose(*[data_axes.index(ax) for ax in ["c", "y", "x"]])
141+
chunks = (1,) + chunks
142+
return image.rechunk(chunks)
139143

140144

141145
def image(
@@ -187,11 +191,11 @@ def image(
187191
use_tiff_memmap = False
188192

189193
if input.suffix in [".tiff", ".tif"] and not use_tiff_memmap or input.suffix in [".png", ".jpg", ".jpeg"]:
190-
im = _dask_image_imread(input=input, data_axes=data_axes)
194+
im = _dask_image_imread(input=input, data_axes=data_axes, chunks=chunks)
191195

192196
if im is None:
193197
raise NotImplementedError(f"File format {input.suffix} not implemented")
194198

195199
return Image2DModel.parse(
196-
im, dims=data_axes, transformations={coordinate_system: Identity()}, scale_factors=scale_factors
200+
im, dims=data_axes, transformations={coordinate_system: Identity()}, scale_factors=scale_factors, chunks=chunks
197201
)

0 commit comments

Comments
 (0)