Skip to content

Commit 49ccb52

Browse files
authored
Add list_bounding_boxes() for zarr arrays (#1238)
* add list_bounding_boxes() for zarr arrays * rm test_perf * changelog * types * yield
1 parent ac4e005 commit 49ccb52

File tree

4 files changed

+90
-27
lines changed

4 files changed

+90
-27
lines changed

webknossos/Changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ For upgrade instructions, please check the respective _Breaking Changes_ section
1515
### Breaking Changes
1616

1717
### Added
18+
- Added `list_bounding_boxes()` for Zarr-based datasets. [#1238](https://github.com/scalableminds/webknossos-libs/pull/1238)
1819

1920
### Changed
2021

webknossos/tests/dataset/test_dataset.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from webknossos.utils import (
3636
copytree,
3737
get_executor_for_args,
38+
is_remote_path,
3839
named_partial,
3940
rmtree,
4041
snake_to_camel_case,
@@ -1989,7 +1990,7 @@ def test_bounding_box_on_disk(data_format: DataFormat, output_path: Path) -> Non
19891990
for offset in write_positions:
19901991
mag.write(absolute_offset=offset * mag.mag.to_vec3_int(), data=write_data)
19911992

1992-
if data_format in (DataFormat.Zarr, DataFormat.Zarr3):
1993+
if is_remote_path(output_path):
19931994
with pytest.warns(UserWarning, match=".*can be slow.*"):
19941995
bounding_boxes_on_disk = list(mag.get_bounding_boxes_on_disk())
19951996

@@ -2051,22 +2052,13 @@ def test_compression(data_format: DataFormat, output_path: Path) -> None:
20512052
compressed_dataset_path = (
20522053
REMOTE_TESTOUTPUT_DIR / f"simple_{data_format}_dataset_compressed"
20532054
)
2054-
if data_format in (DataFormat.Zarr, DataFormat.Zarr3):
2055-
with pytest.warns(UserWarning, match=".*can be slow.*"):
2056-
mag1.compress(
2057-
target_path=compressed_dataset_path,
2058-
)
2059-
else:
2055+
with pytest.warns(UserWarning, match=".*can be slow.*"):
20602056
mag1.compress(
20612057
target_path=compressed_dataset_path,
20622058
)
20632059
mag1 = Dataset.open(compressed_dataset_path).get_layer("color").get_mag(1)
20642060
else:
2065-
if data_format in (DataFormat.Zarr, DataFormat.Zarr3):
2066-
with pytest.warns(UserWarning, match=".*can be slow.*"):
2067-
mag1.compress()
2068-
else:
2069-
mag1.compress()
2061+
mag1.compress()
20702062

20712063
assert mag1._is_compressed()
20722064
assert mag1.info.data_format == data_format

webknossos/webknossos/dataset/_array.py

Lines changed: 79 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import re
22
import warnings
33
from abc import ABC, abstractmethod
4+
from collections.abc import Iterable
45
from dataclasses import dataclass
56
from functools import lru_cache
67
from os.path import relpath
@@ -11,6 +12,7 @@
1112
Dict,
1213
Iterator,
1314
List,
15+
Literal,
1416
Optional,
1517
Tuple,
1618
Type,
@@ -223,7 +225,7 @@ def _list_files(self) -> Iterator[Path]:
223225
for filename in self._wkw_dataset.list_files()
224226
)
225227

226-
def list_bounding_boxes(self) -> Iterator[NDBoundingBox]:
228+
def list_bounding_boxes(self) -> Iterator[BoundingBox]:
227229
def _extract_num(s: str) -> int:
228230
match = re.search("[0-9]+", s)
229231
assert match is not None
@@ -592,9 +594,62 @@ def write(self, bbox: NDBoundingBox, data: np.ndarray) -> None:
592594
)
593595
array[requested_domain].write(data).result()
594596

595-
def list_bounding_boxes(self) -> Iterator[BoundingBox]:
597+
def _chunk_key_encoding(self) -> tuple[Literal["default", "v2"], Literal["/", "."]]:
596598
raise NotImplementedError
597599

600+
def _list_bounding_boxes(
601+
self, kvstore: Any, shard_shape: Vec3Int, shape: VecInt
602+
) -> Iterator[BoundingBox]:
603+
_type, separator = self._chunk_key_encoding()
604+
605+
def _try_parse_ints(vec: Iterable[Any]) -> Optional[list[int]]:
606+
output = []
607+
for value in vec:
608+
try:
609+
output.append(int(value))
610+
except ValueError: # noqa: PERF203
611+
return None
612+
return output
613+
614+
keys = kvstore.list().result()
615+
for key in keys:
616+
key_parts = key.decode("utf-8").split(separator)
617+
if _type == "default":
618+
if key_parts[0] != "c":
619+
continue
620+
key_parts = key_parts[1:]
621+
if len(key_parts) != self._array.ndim:
622+
continue
623+
chunk_coords_list = _try_parse_ints(key_parts)
624+
if chunk_coords_list is None:
625+
continue
626+
627+
if shape.axes[0] == "c":
628+
chunk_coords = Vec3Int(chunk_coords_list[1:])
629+
else:
630+
chunk_coords = Vec3Int(chunk_coords_list)
631+
632+
yield BoundingBox(chunk_coords * shard_shape, shard_shape)
633+
634+
def list_bounding_boxes(self) -> Iterator[BoundingBox]:
635+
kvstore = self._array.kvstore
636+
637+
if kvstore.spec().to_json()["driver"] == "s3":
638+
raise NotImplementedError(
639+
"list_bounding_boxes() is not supported for s3 arrays."
640+
)
641+
642+
_, _, shard_shape, _, shape = self._get_array_dimensions(self._array)
643+
644+
if shape.axes != ("c", "x", "y", "z") and shape.axes != ("x", "y", "z"):
645+
raise NotImplementedError(
646+
"list_bounding_boxes() is not supported for non 3-D arrays."
647+
)
648+
649+
# This needs to be a separate function because we need the NotImplementedError
650+
# to be raised immediately and not part of the iterator.
651+
return self._list_bounding_boxes(kvstore, shard_shape, shape)
652+
598653
def close(self) -> None:
599654
if self._cached_array is not None:
600655
self._cached_array = None
@@ -724,12 +779,10 @@ def create(cls, path: Path, array_info: ArrayInfo) -> "Zarr3Array":
724779
"configuration": {"endian": "little"},
725780
},
726781
{
727-
"name": "blosc",
782+
"name": "zstd",
728783
"configuration": {
729-
"cname": "zstd",
730-
"clevel": 5,
731-
"shuffle": "shuffle",
732-
"typesize": array_info.voxel_type.itemsize,
784+
"level": 5,
785+
"checksum": True,
733786
},
734787
},
735788
]
@@ -761,6 +814,17 @@ def create(cls, path: Path, array_info: ArrayInfo) -> "Zarr3Array":
761814
).result()
762815
return cls(path, _array)
763816

817+
def _chunk_key_encoding(self) -> tuple[Literal["default", "v2"], Literal["/", "."]]:
818+
metadata = self._array.spec().to_json()["metadata"]
819+
chunk_key_encoding = metadata["chunk_key_encoding"]
820+
_type = chunk_key_encoding["name"]
821+
separator = chunk_key_encoding.get("configuration", {}).get(
822+
"separator", "/" if _type == "default" else "."
823+
)
824+
assert _type in ["default", "v2"]
825+
assert separator in ["/", "."]
826+
return _type, separator
827+
764828

765829
class Zarr2Array(TensorStoreArray):
766830
data_format = DataFormat.Zarr
@@ -812,10 +876,8 @@ def create(cls, path: Path, array_info: ArrayInfo) -> "Zarr2Array":
812876
"order": "F",
813877
"compressor": (
814878
{
815-
"id": "blosc",
816-
"cname": "zstd",
817-
"clevel": 5,
818-
"shuffle": 1,
879+
"id": "zstd",
880+
"level": 5,
819881
}
820882
if array_info.compression_mode
821883
else None
@@ -827,3 +889,9 @@ def create(cls, path: Path, array_info: ArrayInfo) -> "Zarr2Array":
827889
}
828890
).result()
829891
return cls(path, _array)
892+
893+
def _chunk_key_encoding(self) -> tuple[Literal["default", "v2"], Literal["/", "."]]:
894+
metadata = self._array.spec().to_json()["metadata"]
895+
separator = metadata.get("dimension_separator", ".")
896+
assert separator in ["/", "."]
897+
return "v2", separator

webknossos/webknossos/dataset/mag_view.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
from cluster_tools import Executor
1111
from upath import UPath
1212

13-
from webknossos.dataset.data_format import DataFormat
14-
1513
from ..geometry import Mag, NDBoundingBox, Vec3Int, Vec3IntLike, VecInt
1614
from ..utils import (
1715
LazyPath,
@@ -679,9 +677,13 @@ def compress(
679677
)
680678
)
681679
with get_executor_for_args(args, executor) as executor:
682-
if self.layer.data_format == DataFormat.WKW:
680+
try:
681+
bbox_iterator = self._array.list_bounding_boxes()
682+
except NotImplementedError:
683+
bbox_iterator = None
684+
if bbox_iterator is not None:
683685
job_args = []
684-
for i, bbox in enumerate(self._array.list_bounding_boxes()):
686+
for i, bbox in enumerate(bbox_iterator):
685687
bbox = bbox.from_mag_to_mag1(self._mag).intersected_with(
686688
self.layer.bounding_box, dont_assert=True
687689
)

0 commit comments

Comments
 (0)