Skip to content

Commit 0a3ee04

Browse files
normanrzjstriebel
andauthored
use remotedataset in learned_segmenter example (#726)
* use remotedataset in learned_segmenter example * various fixes & simplifications * add downsampling Co-authored-by: Jonathan Striebel <[email protected]> Co-authored-by: Jonathan Striebel <[email protected]>
1 parent 97e0d8b commit 0a3ee04

File tree

6 files changed

+1017
-95466
lines changed

6 files changed

+1017
-95466
lines changed

webknossos/examples/learned_segmenter.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from functools import partial
3+
from tempfile import TemporaryDirectory
34
from time import gmtime, strftime
45

56
import numpy as np
@@ -18,23 +19,21 @@ def main() -> None:
1819
"https://webknossos.org/annotations/616457c2010000870032ced4"
1920
)
2021

21-
# Step 1: Download the dataset and our training data annotation from webKnossos to our local computer
22+
# Step 1: Download the training data annotation from webKnossos to our local computer
2223
training_data_bbox = annotation.user_bounding_boxes[0] # type: ignore[index]
2324
time_str = strftime("%Y-%m-%d_%H-%M-%S", gmtime())
2425
new_dataset_name = annotation.dataset_name + f"_segmented_{time_str}"
25-
dataset = wk.Dataset.download(
26-
annotation.dataset_name,
27-
organization_id="scalable_minds",
28-
path=new_dataset_name,
29-
webknossos_url="https://webknossos.org",
26+
dataset = annotation.get_remote_base_dataset(
27+
webknossos_url="https://webknossos.org"
3028
)
31-
dataset.name = new_dataset_name
3229

33-
volume_annotation = annotation.export_volume_layer_to_dataset(dataset)
34-
volume_annotation.bounding_box = training_data_bbox
30+
with annotation.temporary_volume_layer_copy() as volume_annotation_layer:
31+
mag = volume_annotation_layer.get_finest_mag().mag
32+
volume_annotation_data = volume_annotation_layer.mags[mag].read(
33+
absolute_bounding_box=training_data_bbox
34+
)
3535

36-
mag = wk.Mag(1)
37-
mag_view = dataset.layers["color"].mags[mag]
36+
color_mag_view = dataset.layers["color"].mags[mag]
3837

3938
# Step 2: Initialize a machine learning model to segment our dataset
4039
features_func = partial(
@@ -45,40 +44,45 @@ def main() -> None:
4544
# Step 3: Manipulate our data to fit the ML model and start training on
4645
# data from our annotated training data bounding box
4746
print("Starting training…")
48-
img_data_train = mag_view.read(
47+
img_data_train = color_mag_view.read(
4948
absolute_bounding_box=training_data_bbox
5049
) # wk data has dimensions (Channels, X, Y, Z)
5150
# move channels to last dimension, remove z dimension to match skimage's shape
5251
X_train = np.moveaxis(np.squeeze(img_data_train), 0, -1)
53-
Y_train = np.squeeze(volume_annotation.mags[mag].read())
52+
Y_train = np.squeeze(volume_annotation_data)
5453

5554
segmenter.fit(X_train, Y_train)
5655

5756
# Step 4: Use our trained model and predict a class for each pixel in the dataset
5857
# to get a full segmentation of the data
5958
print("Starting prediction…")
60-
X_predict = np.moveaxis(np.squeeze(mag_view.read()), 0, -1)
59+
X_predict = np.moveaxis(np.squeeze(color_mag_view.read()), 0, -1)
6160
Y_predicted = segmenter.predict(X_predict)
6261
segmentation = Y_predicted[:, :, None] # adds z dimension
6362
assert segmentation.max() < 256
6463
segmentation = segmentation.astype("uint8")
6564

6665
# Step 5: Bundle everying a webKnossos layer and upload to wK for viewing and further work
67-
segmentation_layer = dataset.add_layer(
68-
"segmentation",
69-
wk.SEGMENTATION_CATEGORY,
70-
segmentation.dtype,
71-
compressed=True,
72-
largest_segment_id=int(segmentation.max()),
73-
)
74-
segmentation_layer.bounding_box = dataset.layers["color"].bounding_box
75-
segmentation_layer.add_mag(mag, compress=True).write(segmentation)
76-
77-
remote_ds = dataset.upload(
78-
layers_to_link=[annotation.get_remote_base_dataset().get_layer("color")]
79-
if "PYTEST_CURRENT_TEST" not in os.environ
80-
else None
81-
)
66+
with TemporaryDirectory() as tempdir:
67+
new_dataset = wk.Dataset(
68+
tempdir, voxel_size=dataset.voxel_size, name=new_dataset_name
69+
)
70+
segmentation_layer = new_dataset.add_layer(
71+
"segmentation",
72+
wk.SEGMENTATION_CATEGORY,
73+
segmentation.dtype,
74+
compressed=True,
75+
largest_segment_id=int(segmentation.max()),
76+
)
77+
segmentation_layer.bounding_box = dataset.layers["color"].bounding_box
78+
segmentation_layer.add_mag(mag, compress=True).write(segmentation)
79+
segmentation_layer.downsample(sampling_mode="constant_z")
80+
81+
remote_ds = new_dataset.upload(
82+
layers_to_link=[dataset.layers["color"]]
83+
if "PYTEST_CURRENT_TEST" not in os.environ
84+
else None
85+
)
8286

8387
url = remote_ds.url
8488
print(f"Successfully uploaded {url}")

webknossos/poetry.lock

Lines changed: 43 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

webknossos/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ fastremap = "^1.12.2"
8888
pandas = "^1.3.4"
8989
pooch = "^1.5.2"
9090
s3fs = "^2022.2.0"
91+
scikit-learn = "^1.0.1"
9192
tabulate = "^0.8.9"
9293

9394
[tool.black]
@@ -123,6 +124,7 @@ disable = ["logging-format-interpolation","logging-fstring-interpolation","broad
123124

124125
[tool.pytest.ini_options]
125126
markers = ["with_vcr: Runs with VCR recording and optionally blocked network"]
127+
testpaths = ["tests"]
126128

127129
[build-system]
128130
requires = ["poetry>=1.1"]

0 commit comments

Comments
 (0)