Skip to content

Commit 1383834

Browse files
authored
Fix transform graph (#1077)
Summary: fix the issue of partitions of a dask dataframe collapsing due to dask graph optimization introduced with dask expressions. This fix relies on a fix in dask introduced in version 2025.12.0, thus lower bound dependency for dask is increased. * test: transforming points with multiple partitions * fix: transform points data with multiple partitions. * test: add test for dask_tune contextmanager * docs: add note explaining the workaround This is a note explaining the workaround in case people run into the partition collaps problem due to dask graph optimization. * chore: push lower bound dask version to 2025.11.0 Reason for this is that this dask v2025.2.0 does not allow for disabling graph optimization, but neither keeps partition size consistent. Turning optimization off was introduced in dask 2025.12.0. * docs: add dask issue * test: test compute can be performed
1 parent bc5b2ca commit 1383834

File tree

7 files changed

+86
-6
lines changed

7 files changed

+86
-6
lines changed

.github/workflows/test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
fail-fast: false
2020
matrix:
2121
include:
22-
- {os: windows-latest, python: "3.11", dask-version: "2025.2.0", name: "Dask 2025.2.0"}
22+
- {os: windows-latest, python: "3.11", dask-version: "2025.12.0", name: "Dask 2025.12.0"}
2323
- {os: windows-latest, python: "3.13", dask-version: "latest", name: "Dask latest"}
2424
- {os: ubuntu-latest, python: "3.11", dask-version: "latest", name: "Dask latest"}
2525
- {os: ubuntu-latest, python: "3.13", dask-version: "latest", name: "Dask latest"}

docs/index.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,23 @@ SpatialData is a data framework that comprises a FAIR storage format and a colle
1414

1515
Please see our publication {cite}`marconatoSpatialDataOpenUniversal2024` for citation and to learn more.
1616

17+
:::{note}
18+
With dask >= 2025.2.0, users can get an error as described in [#1077](https://github.com/scverse/spatialdata/issues/1064). While we tried implementing fixes in SpatialData, it can be that
19+
users perform operations on the `Points` data themselves and get this error. In order to prevent it, users can use a context manager we created.
20+
21+
```python
22+
from spatialdata import disable_dask_tune_optimization
23+
import contextlib
24+
...
25+
26+
with disable_dask_tune_optimization() if data.npartitions > 1 else contextlib.nullcontext():
27+
<your operation on points dask dataframe>
28+
```
29+
30+
This will disable dask graph optimization if the dataframe has more than 1 partition and otherwise keep it enabled. This solves
31+
the problem discussed in this [dask issue](https://github.com/dask/dask/issues/12193). We are looking into an upstream fix.
32+
:::
33+
1734
[//]: # "numfocus-fiscal-sponsor-attribution"
1835

1936
spatialdata is part of the scverse® project ([website](https://scverse.org), [governance](https://scverse.org/about/roles)) and is fiscally sponsored by [NumFOCUS](https://numfocus.org/).

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dependencies = [
2626
"annsel>=0.1.2",
2727
"click",
2828
"dask-image",
29-
"dask>=2025.2.0,<2026.1.2",
29+
"dask>=2025.12.0,<2026.1.2",
3030
"distributed<2026.1.2",
3131
"datashader",
3232
"fsspec[s3,http]",

src/spatialdata/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"transformations",
1010
"datasets",
1111
"dataloader",
12+
"disable_dask_tune_optimization",
1213
"concatenate",
1314
"rasterize",
1415
"rasterize_bins",
@@ -72,5 +73,5 @@
7273
from spatialdata._io._utils import get_dask_backing_files
7374
from spatialdata._io.format import SpatialDataFormatType
7475
from spatialdata._io.io_zarr import read_zarr
75-
from spatialdata._utils import get_pyramid_levels, unpad_raster
76+
from spatialdata._utils import disable_dask_tune_optimization, get_pyramid_levels, unpad_raster
7677
from spatialdata.config import settings

src/spatialdata/_core/operations/transform.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import itertools
45
import warnings
56
from functools import singledispatch
@@ -17,6 +18,7 @@
1718

1819
from spatialdata._core.spatialdata import SpatialData
1920
from spatialdata._types import ArrayLike
21+
from spatialdata._utils import disable_dask_tune_optimization
2022
from spatialdata.models import SpatialElement, get_axes_names, get_model
2123
from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM, get_channel_names
2224
from spatialdata.transformations._utils import _get_scale, compute_coordinates, scale_radii
@@ -439,8 +441,15 @@ def _(
439441
)
440442
axes = get_axes_names(data)
441443
arrays = []
442-
for ax in axes:
443-
arrays.append(data[ax].to_dask_array(lengths=True).reshape(-1, 1))
444+
445+
# Workaround to prevent partition collaps and missing dependency problem for now.
446+
with disable_dask_tune_optimization() if data.npartitions > 1 else contextlib.nullcontext():
447+
for ax in axes:
448+
# TODO We have to pass on the lengths explicitly as automatic determination with dask graph optimization
449+
# leads to collaps of the partitions. However this causes a missing dependency problem, which for now is
450+
# prevented by setting the optimization to False when performing this operation.
451+
arrays.append(data[ax].to_dask_array(lengths=[len(part) for part in data.partitions]).reshape(-1, 1))
452+
444453
xdata = DataArray(da.concatenate(arrays, axis=1), coords={"points": range(len(data)), "dim": list(axes)})
445454
xtransformed = transformation._transform_coordinates(xdata)
446455
transformed = data.drop(columns=list(axes)).copy()

src/spatialdata/_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
import re
33
import warnings
44
from collections.abc import Callable, Generator
5+
from contextlib import contextmanager
56
from itertools import islice
67
from typing import Any, TypeVar
78

89
import numpy as np
910
import pandas as pd
1011
from anndata import AnnData
1112
from dask import array as da
13+
from dask import config
1214
from dask.array import Array as DaskArray
1315
from xarray import DataArray, Dataset, DataTree
1416

@@ -20,6 +22,17 @@
2022
RT = TypeVar("RT")
2123

2224

25+
@contextmanager
26+
def disable_dask_tune_optimization() -> Generator[None, None, None]:
27+
"""Prevent dask graph optimization when performing operations on dask dataframes with npartition > 1."""
28+
old_setting = config.config["optimization"]["tune"]["active"]
29+
config.set({"optimization.tune.active": False})
30+
try:
31+
yield
32+
finally:
33+
config.set({"optimization.tune.active": old_setting})
34+
35+
2336
def _parse_list_into_array(array: list[Number] | ArrayLike) -> ArrayLike:
2437
if isinstance(array, list):
2538
array = np.array(array)

tests/core/operations/test_transform.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1+
import contextlib
12
import math
23
import tempfile
34
from pathlib import Path
45

56
import numpy as np
67
import pytest
8+
from dask import config
79
from geopandas.testing import geom_almost_equals
810
from xarray import DataArray, DataTree
911

1012
from spatialdata import transform
1113
from spatialdata._core.data_extent import are_extents_equal, get_extent
1214
from spatialdata._core.spatialdata import SpatialData
13-
from spatialdata._utils import unpad_raster
15+
from spatialdata._utils import disable_dask_tune_optimization, unpad_raster
1416
from spatialdata.models import Image2DModel, PointsModel, ShapesModel, get_axes_names
1517
from spatialdata.transformations.operations import (
1618
align_elements_using_landmarks,
@@ -586,6 +588,44 @@ def test_transform_elements_and_entire_spatial_data_object(full_sdata: SpatialDa
586588
_ = full_sdata.transform_to_coordinate_system("my_space", maintain_positioning=maintain_positioning)
587589

588590

591+
def test_transform_points_with_multiple_partitions(full_sdata: SpatialData, tmp_path: str):
592+
tmpdir = Path(tmp_path) / "tmp.zarr"
593+
points_memory = full_sdata["points_0"].compute()
594+
full_sdata["points_0"] = PointsModel.parse(
595+
full_sdata["points_0"].repartition(npartitions=4),
596+
transformations={"global": get_transformation(full_sdata["points_0"])},
597+
)
598+
assert points_memory.equals(full_sdata["points_0"].compute())
599+
600+
full_sdata.write(tmpdir)
601+
602+
full_sdata = SpatialData.read(tmpdir)
603+
604+
# This just needs to run without error
605+
data = transform(full_sdata["points_0"], to_coordinate_system="global")
606+
607+
# test that data still can be computed
608+
data.compute()
609+
610+
611+
@pytest.mark.parametrize(
612+
"tune,partition",
613+
[
614+
(True, None),
615+
(False, 4),
616+
],
617+
)
618+
def test_dask_tune_contextmanager(full_sdata: SpatialData, partition: int | None, tune: bool):
619+
if partition:
620+
full_sdata["points_0"] = PointsModel.parse(
621+
full_sdata["points_0"].repartition(npartitions=4),
622+
transformations={"global": get_transformation(full_sdata["points_0"])},
623+
)
624+
625+
with disable_dask_tune_optimization() if full_sdata["points_0"].npartitions > 1 else contextlib.nullcontext():
626+
assert config.config["optimization"]["tune"]["active"] is tune
627+
628+
589629
@pytest.mark.parametrize("maintain_positioning", [True, False])
590630
def test_transform_elements_and_entire_spatial_data_object_multi_hop(
591631
full_sdata: SpatialData, maintain_positioning: bool

0 commit comments

Comments
 (0)