Skip to content

Commit 3744f1f

Browse files
authored
Fix axes configuration when adding existing zarr arrays (#1204)
* Change default behaviour for axes when creating zarrita array. * add typehint. * Update changelog. * Add test. * Fix type annotations for added test.
1 parent 8498db0 commit 3744f1f

File tree

5 files changed

+78
-7
lines changed

5 files changed

+78
-7
lines changed

webknossos/Changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ Removed the CZI installation extra from `pip install webknossos[all]` by default
8080

8181
### Fixed
8282
- Fixed an issue with merging annotations with compressed fallback layers.
83+
- Fixed an issue where adding a Zarr array with other axes than `cxyz` leads to an error. [#1204](https://github.com/scalableminds/webknossos-libs/pull/1204)
8384

8485

8586

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from pathlib import Path
2+
3+
import numpy as np
4+
from zarrita import Array
5+
6+
import webknossos as wk
7+
8+
9+
def test_add_mag_from_zarrarray(tmp_path: Path) -> None:
10+
dataset = wk.Dataset(
11+
tmp_path / "test_add_mag_from_zarrarray", voxel_size=(10, 10, 10)
12+
)
13+
layer = dataset.add_layer(
14+
"color",
15+
wk.COLOR_CATEGORY,
16+
data_format="zarr3",
17+
bounding_box=wk.BoundingBox((0, 0, 0), (16, 16, 16)),
18+
)
19+
zarr_mag_path = tmp_path / "zarr_data" / "mag1.zarr"
20+
zarr_data = np.random.randint(0, 255, (16, 16, 16), dtype="uint8")
21+
zarr_mag = Array.create(
22+
store=zarr_mag_path, shape=(16, 16, 16), chunk_shape=(8, 8, 8), dtype="uint8"
23+
)
24+
zarr_mag[:] = zarr_data
25+
26+
layer.add_mag_from_zarrarray("1", zarr_mag_path, extend_layer_bounding_box=False)
27+
28+
assert layer.get_mag("1").read().shape == (1, 16, 16, 16)
29+
assert layer.get_mag("1").info.num_channels == 1
30+
assert layer.get_mag("1").info.dimension_names == ("c", "x", "y", "z")
31+
assert (layer.get_mag("1").read()[0] == zarr_data).all()

webknossos/webknossos/dataset/_array.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -498,15 +498,43 @@ def info(self) -> ArrayInfo:
498498
from zarrita.sharding import ShardingCodec
499499

500500
zarray = self._zarray
501+
dimension_names: tuple[str, ...]
501502
if (names := getattr(zarray.metadata, "dimension_names", None)) is None:
502-
dimension_names = ("c", "x", "y", "z")
503+
if (shape := getattr(zarray.metadata, "shape", None)) is None:
504+
raise ValueError(
505+
"Unable to determine the shape of the Zarrita Array. Neither dimension_names nor shape are present in the metadata file zarr.json."
506+
)
507+
else:
508+
if len(shape) == 2:
509+
dimension_names = ("x", "y")
510+
num_channels = 1
511+
elif len(shape) == 3:
512+
dimension_names = ("x", "y", "z")
513+
num_channels = 1
514+
elif len(shape) == 4:
515+
dimension_names = ("c", "x", "y", "z")
516+
num_channels = shape[0]
517+
else:
518+
raise ValueError(
519+
"Unusual shape for Zarrita array, please specify the dimension names in the metadata file zarr.json."
520+
)
503521
else:
504522
dimension_names = names
523+
if (shape := getattr(zarray.metadata, "shape", None)) is None:
524+
shape = VecInt.ones(dimension_names)
525+
if "c" in dimension_names:
526+
num_channels = zarray.metadata.shape[dimension_names.index("c")]
527+
else:
528+
num_channels = 1
505529
x_index, y_index, z_index = (
506530
dimension_names.index("x"),
507531
dimension_names.index("y"),
508532
dimension_names.index("z"),
509533
)
534+
if "c" not in dimension_names:
535+
shape = (num_channels,) + shape
536+
dimension_names = ("c",) + dimension_names
537+
array_shape = VecInt(shape, axes=dimension_names)
510538
if isinstance(zarray, Array):
511539
if len(zarray.codec_pipeline.codecs) == 1 and isinstance(
512540
zarray.codec_pipeline.codecs[0], ShardingCodec
@@ -516,7 +544,7 @@ def info(self) -> ArrayInfo:
516544
chunk_shape = sharding_codec.configuration.chunk_shape
517545
return ArrayInfo(
518546
data_format=DataFormat.Zarr3,
519-
num_channels=zarray.metadata.shape[0],
547+
num_channels=num_channels,
520548
voxel_type=zarray.metadata.dtype,
521549
compression_mode=self._has_compression_codecs(
522550
sharding_codec.codec_pipeline.codecs
@@ -536,12 +564,13 @@ def info(self) -> ArrayInfo:
536564
chunk_shape[z_index],
537565
)
538566
),
567+
shape=array_shape,
539568
dimension_names=dimension_names,
540569
)
541570
chunk_shape = zarray.metadata.chunk_grid.configuration.chunk_shape
542571
return ArrayInfo(
543572
data_format=DataFormat.Zarr3,
544-
num_channels=zarray.metadata.shape[0],
573+
num_channels=num_channels,
545574
voxel_type=zarray.metadata.dtype,
546575
compression_mode=self._has_compression_codecs(
547576
zarray.codec_pipeline.codecs
@@ -550,13 +579,14 @@ def info(self) -> ArrayInfo:
550579
chunk_shape[x_index], chunk_shape[y_index], chunk_shape[z_index]
551580
)
552581
or Vec3Int.full(1),
582+
shape=array_shape,
553583
chunks_per_shard=Vec3Int.full(1),
554584
dimension_names=dimension_names,
555585
)
556586
else:
557587
return ArrayInfo(
558588
data_format=DataFormat.Zarr,
559-
num_channels=zarray.metadata.shape[0],
589+
num_channels=num_channels,
560590
voxel_type=zarray.metadata.dtype,
561591
compression_mode=zarray.metadata.compressor is not None,
562592
chunk_shape=Vec3Int(
@@ -565,6 +595,7 @@ def info(self) -> ArrayInfo:
565595
zarray.metadata.chunks[z_index],
566596
)
567597
or Vec3Int.full(1),
598+
shape=array_shape,
568599
chunks_per_shard=Vec3Int.full(1),
569600
dimension_names=dimension_names,
570601
)
@@ -634,9 +665,13 @@ def create(cls, path: Path, array_info: ArrayInfo) -> "ZarritaArray":
634665
def read(self, bbox: NDBoundingBox) -> np.ndarray:
635666
shape = bbox.size.to_tuple()
636667
zarray = self._zarray
637-
slice_tuple = (slice(None),) + bbox.to_slices()
638668
with _blosc_disable_threading():
639-
data = zarray[slice_tuple]
669+
try:
670+
slice_tuple = (slice(None),) + bbox.to_slices()
671+
data = zarray[slice_tuple]
672+
except IndexError:
673+
# The data is stored without channel axis
674+
data = zarray[bbox.to_slices()]
640675

641676
shape_with_channels = (self.info.num_channels,) + shape
642677
if data.shape != shape_with_channels:

webknossos/webknossos/dataset/layer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1485,6 +1485,7 @@ def _setup_mag(self, mag: Mag, path: Optional[str] = None) -> None:
14851485
info.chunk_shape,
14861486
info.chunks_per_shard,
14871487
info.compression_mode,
1488+
info.shape,
14881489
False,
14891490
UPath(resolved_path),
14901491
)

webknossos/webknossos/dataset/mag_view.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __init__(
125125
chunk_shape: Vec3Int,
126126
chunks_per_shard: Vec3Int,
127127
compression_mode: bool,
128+
shape: Optional[VecInt] = None,
128129
create: bool = False,
129130
path: Optional[UPath] = None,
130131
) -> None:
@@ -145,7 +146,9 @@ def __init__(
145146
layer.num_channels,
146147
*VecInt.ones(layer.bounding_box.axes),
147148
axes=("c",) + layer.bounding_box.axes,
148-
),
149+
)
150+
if shape is None
151+
else shape,
149152
dimension_names=("c",) + layer.bounding_box.axes,
150153
)
151154
if create:

0 commit comments

Comments
 (0)