Skip to content

Commit db95d6d

Browse files
Fix upsampling (#1287)
* wip fixing upsample ds. * wip replace offset size combinations with bboxes. * rewrite with chunk method. * lint and typecheck. * remove assertion. * fix tests. * Update readme and fix tests. * Update webknossos/Changelog.md --------- Co-authored-by: Daniel <[email protected]>
1 parent d428571 commit db95d6d

File tree

4 files changed

+24
-40
lines changed

4 files changed

+24
-40
lines changed

webknossos/Changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@ For upgrade instructions, please check the respective _Breaking Changes_ section
1313
[Commits](https://github.com/scalableminds/webknossos-libs/compare/v2.1.0...HEAD)
1414

1515
### Breaking Changes
16+
- If `buffer_shape` is passed to `Layer.upsample()` it must be a multiple of the `shard_shape`. [#1287](https://github.com/scalableminds/webknossos-libs/pull/1287)
1617

1718
### Added
1819

1920
### Changed
2021
- Using `x-auth-header` to send tokens to datastore. [#1270](https://github.com/scalableminds/webknossos-libs/pull/1270)
2122

2223
### Fixed
24+
- Fixed an issue with upsampling views. [#1287](https://github.com/scalableminds/webknossos-libs/pull/1287)
2325

2426

2527
## [2.1.0](https://github.com/scalableminds/webknossos-libs/releases/tag/v2.1.0) - 2025-04-01

webknossos/tests/dataset/test_upsampling.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def test_upsampling(tmp_path: Path) -> None:
4242
finest_mag=Mag(1),
4343
compress=False,
4444
sampling_mode=SamplingModes.ANISOTROPIC,
45-
buffer_shape=(64, 64, 64),
4645
)
4746

4847
assert layer.get_mag("2").read().mean() == layer.get_mag("1").read().mean()
@@ -92,7 +91,7 @@ def upsample_test_helper(tmp_path: Path, use_compress: bool) -> None:
9291
0,
9392
),
9493
[0.5, 0.5, 1.0],
95-
BUFFER_SHAPE,
94+
mag1.info.shard_shape,
9695
)
9796

9897
assert np.any(source_buffer != 0)
@@ -138,7 +137,7 @@ def test_upsample_multi_channel(tmp_path: Path) -> None:
138137
upsample_cube_job(
139138
(mag2.get_view(), layer.get_mag("1").get_view(), 0),
140139
[0.5, 0.5, 0.5],
141-
BUFFER_SHAPE,
140+
mag2.info.shard_shape,
142141
)
143142

144143
channels = [

webknossos/webknossos/dataset/_downsampling_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,8 @@ class InterpolationModes(Enum):
2828
MIN = 6
2929

3030

31-
DEFAULT_BUFFER_SHAPE = Vec3Int.full(256)
32-
33-
3431
def determine_buffer_shape(array_info: ArrayInfo) -> Vec3Int:
35-
return DEFAULT_BUFFER_SHAPE.pairmin(array_info.shard_shape)
32+
return array_info.shard_shape
3633

3734

3835
def calculate_mags_to_downsample(

webknossos/webknossos/dataset/_upsampling_utils.py

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import logging
2-
import math
3-
from itertools import product
42

53
import numpy as np
64

75
from ..geometry import Vec3Int
6+
from .data_format import DataFormat
87
from .view import View
98

109

@@ -32,55 +31,42 @@ def upsample_cube_job(
3231
assert all(1 >= f for f in mag_factors), (
3332
f"mag_factors ({mag_factors}) for upsampling must be smaller than 1"
3433
)
34+
if (
35+
target_view._data_format != DataFormat.WKW
36+
and target_view._is_compressed() == False
37+
):
38+
assert buffer_shape % target_view.info.shard_shape == Vec3Int.zeros(), (
39+
f"buffer_shape ({buffer_shape}) must be divisible by shard_shape ({target_view.info.shard_shape})"
40+
)
41+
inverse_factors = [int(1 / f) for f in mag_factors]
3542

3643
try:
3744
num_channels = target_view.info.num_channels
38-
target_bbox_in_mag = target_view.bounding_box.in_mag(target_view.mag)
39-
shape = (num_channels,) + target_bbox_in_mag.size.to_tuple()
40-
shape_xyz = target_bbox_in_mag.size_xyz
41-
file_buffer = np.zeros(shape, target_view.get_dtype())
42-
43-
tiles = product(
44-
*list(
45-
[
46-
list(range(0, math.ceil(length)))
47-
for length in shape_xyz.to_np() / buffer_shape.to_np()
48-
]
49-
)
50-
)
5145

52-
for tile in tiles:
53-
target_offset = Vec3Int(tile) * buffer_shape
54-
source_offset = target_offset * source_view.mag
55-
source_size = source_view.bounding_box.size_xyz
56-
source_size = (buffer_shape * source_view.mag).pairmin(
57-
source_size - source_offset
58-
)
59-
60-
bbox = source_view.bounding_box.offset(source_offset).with_size_xyz(
61-
source_size
62-
)
46+
for chunk in target_view.bounding_box.chunk(
47+
buffer_shape * target_view.mag, buffer_shape * target_view.mag
48+
):
49+
shape = (num_channels,) + (chunk.in_mag(target_view.mag)).size.to_tuple()
50+
file_buffer = np.zeros(shape, dtype=target_view.get_dtype())
6351
cube_buffer_channels = source_view.read_xyz(
64-
absolute_bounding_box=bbox,
52+
absolute_bounding_box=chunk,
6553
)
6654

6755
for channel_index in range(num_channels):
6856
cube_buffer = cube_buffer_channels[channel_index]
6957

7058
if not np.all(cube_buffer == 0):
71-
# Upsample the buffer
72-
inverse_factors = [int(1 / f) for f in mag_factors]
7359
data_cube = upsample_cube(cube_buffer, inverse_factors)
7460

75-
buffer_bbox = target_view.bounding_box.with_topleft_xyz(
76-
target_offset * inverse_factors
77-
).with_size_xyz(data_cube.shape)
61+
buffer_bbox = chunk.with_topleft_xyz(Vec3Int.zeros()).with_size_xyz(
62+
data_cube.shape
63+
)
7864
data_cube = buffer_bbox.xyz_array_to_bbox_shape(data_cube)
7965
file_buffer[(channel_index,) + buffer_bbox.to_slices_xyz()] = (
8066
data_cube
8167
)
8268

83-
target_view.write(file_buffer, absolute_bounding_box=target_view.bounding_box)
69+
target_view.write(file_buffer, absolute_bounding_box=chunk)
8470

8571
except Exception as exc:
8672
logging.error(

0 commit comments

Comments
 (0)