Skip to content

Commit 53b9438

Browse files
melonorapre-commit-ci[bot]LucaMarconato
authored
unpinning dask (#1006)
* add attrs accesor * change deprecated Index access * add accessor to init * remove query planning * additional changes to accessor * divisions is not settable anymore * add fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix rasterize points * copy partitioned attrs * fix mypy * fix last mypy error * Apply suggestion from @melonora * Apply suggestion from @melonora * Apply suggestion from @melonora * deduplicate * .attrs is now always an accessor, never a dict * simplify wrapper logic: * revert after loc/iloc indexer * clean-up, simplify accessor logic * remove asserts * remove asserts * simplify accessor logic by reducing number of classes * rename wrap_with_attrs * remove comment * wrapping methods for dd.Series * add dask tests for accessor * fix index.compute() attrs missing * change fix .attrs on index * wrap dd.Series.loc * remove old code, add comments * move accesor code * change git workflow * some fixes * remove old test code * test dask among os * fix * fix * fix * revert changes * fix * adjust * adjust dask pin * adjust dask pin * fix dask backing files and windows permissions * fix dask mixed graph problem * temporary fix indexing * fix rasterize * adjust github workflow * move 3.13 to include * make more concise * Apply suggestion from @melonora * fix str representation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: LucaMarconato <[email protected]> Co-authored-by: Luca Marconato <[email protected]>
1 parent c641609 commit 53b9438

File tree

22 files changed

+465
-79
lines changed

22 files changed

+465
-79
lines changed

.github/workflows/test.yaml

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ on:
44
push:
55
branches: [main]
66
tags:
7-
- "v*" # Push events to matching v*, i.e. v1.0, v20.15.10
7+
- "v*"
88
pull_request:
99
branches: "*"
1010

@@ -13,26 +13,24 @@ jobs:
1313
runs-on: ${{ matrix.os }}
1414
defaults:
1515
run:
16-
shell: bash -e {0} # -e to fail on error
16+
shell: bash -e {0}
1717

1818
strategy:
1919
fail-fast: false
2020
matrix:
21-
python: ["3.11", "3.13"]
22-
os: [ubuntu-latest]
2321
include:
24-
- os: macos-latest
25-
python: "3.11"
26-
- os: macos-latest
27-
python: "3.12"
28-
pip-flags: "--pre"
29-
name: "Python 3.12 (pre-release)"
30-
- os: windows-latest
31-
python: "3.11"
32-
22+
- {os: windows-latest, python: "3.11", dask-version: "2025.2.0", name: "Dask 2025.2.0"}
23+
- {os: windows-latest, python: "3.11", dask-version: "latest", name: "Dask latest"}
24+
- {os: ubuntu-latest, python: "3.11", dask-version: "2025.2.0", name: "Dask 2025.2.0"}
25+
- {os: ubuntu-latest, python: "3.11", dask-version: "latest", name: "Dask latest"}
26+
- {os: ubuntu-latest, python: "3.13", dask-version: "latest", name: "Dask latest"}
27+
- {os: macos-latest, python: "3.11", dask-version: "2025.2.0", name: "Dask 2025.2.0"}
28+
- {os: macos-latest, python: "3.11", dask-version: "latest", name: "Dask latest"}
29+
- {os: macos-latest, python: "3.12", pip-flags: "--pre", name: "Python 3.12 (pre-release)"}
3330
env:
3431
OS: ${{ matrix.os }}
3532
PYTHON: ${{ matrix.python }}
33+
DASK_VERSION: ${{ matrix.dask-version }}
3634

3735
steps:
3836
- uses: actions/checkout@v2
@@ -42,7 +40,15 @@ jobs:
4240
version: "latest"
4341
python-version: ${{ matrix.python }}
4442
- name: Install dependencies
45-
run: "uv sync --extra test"
43+
run: |
44+
uv sync --extra test
45+
if [[ -n "${DASK_VERSION}" ]]; then
46+
if [[ "${DASK_VERSION}" == "latest" ]]; then
47+
uv pip install --upgrade dask
48+
else
49+
uv pip install dask==${DASK_VERSION}
50+
fi
51+
fi
4652
- name: Test
4753
env:
4854
MPLBACKEND: agg

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ repos:
1212
rev: v3.5.3
1313
hooks:
1414
- id: prettier
15+
exclude: ^.github/workflows/test.yaml
1516
- repo: https://github.com/pre-commit/mirrors-mypy
1617
rev: v1.15.0
1718
hooks:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ dependencies = [
2525
"anndata>=0.9.1",
2626
"click",
2727
"dask-image",
28-
"dask>=2024.10.0,<=2024.11.2",
28+
"dask>=2025.2.0",
2929
"datashader",
3030
"fsspec[s3,http]",
3131
"geopandas>=0.14",

src/spatialdata/__init__.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,7 @@
1-
import dask
2-
3-
dask.config.set({"dataframe.query-planning": False})
4-
import dask.dataframe as dd
5-
6-
# Setting `dataframe.query-planning` to False is effective only if run before `dask.dataframe` is initialized. In
7-
# the case in which the user had initilized `dask.dataframe` before, we would have DASK_EXPER_ENABLED set to `True`.
8-
# Here we check that this does not happen.
9-
if hasattr(dd, "DASK_EXPR_ENABLED") and dd.DASK_EXPR_ENABLED:
10-
raise RuntimeError(
11-
"Unsupported backend: dask-expr has been detected as the backend of dask.dataframe. Please "
12-
"use:\nimport dask\ndask.config.set({'dataframe.query-planning': False})\nbefore importing "
13-
"dask.dataframe to disable dask-expr. The support is being worked on, for more information please see"
14-
"https://github.com/scverse/spatialdata/pull/570"
15-
)
161
from importlib.metadata import version
172

3+
import spatialdata.models._accessor # noqa: F401
4+
185
__version__ = version("spatialdata")
196

207
__all__ = [

src/spatialdata/_core/_deepcopy.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,12 @@ def _(gdf: GeoDataFrame) -> GeoDataFrame:
9494
@deepcopy.register(DaskDataFrame)
9595
def _(df: DaskDataFrame) -> DaskDataFrame:
9696
# bug: the parser may change the order of the columns
97-
new_ddf = PointsModel.parse(df.compute().copy(deep=True))
97+
compute_df = df.compute().copy(deep=True)
98+
new_ddf = PointsModel.parse(compute_df)
9899
# the problem is not .copy(deep=True), but the parser, which discards some metadata https://github.com/scverse/spatialdata/issues/503#issuecomment-2015275322
99-
new_ddf.attrs = _deepcopy(df.attrs)
100+
# We need to use the compute_df here as with deepcopy, df._attrs does not exist anymore.
101+
# print(type(new_ddf.attrs))
102+
new_ddf.attrs.update(_deepcopy(compute_df.attrs))
100103
return new_ddf
101104

102105

src/spatialdata/_core/operations/rasterize.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,20 +653,28 @@ def rasterize_shapes_points(
653653

654654
table_name = table_name if table_name is not None else "table"
655655

656+
index = False
656657
if value_key is not None:
657658
kwargs = {"sdata": sdata, "element_name": element_name} if element_name is not None else {"element": data}
658659
data[VALUES_COLUMN] = get_values(value_key, table_name=table_name, **kwargs).iloc[:, 0] # type: ignore[arg-type, union-attr]
659660
elif isinstance(data, GeoDataFrame) or isinstance(data, DaskDataFrame) and return_regions_as_labels is True:
660661
value_key = VALUES_COLUMN
661662
data[VALUES_COLUMN] = data.index.astype("category")
663+
index = True
662664
else:
663665
value_key = VALUES_COLUMN
664666
data[VALUES_COLUMN] = 1
665667

666668
label_index_to_category = None
667669
if VALUES_COLUMN in data and data[VALUES_COLUMN].dtype == "category":
668670
if isinstance(data, DaskDataFrame):
669-
data[VALUES_COLUMN] = data[VALUES_COLUMN].cat.as_known()
671+
# We have to do this because as_known() does not preserve the order anymore in latest dask versions
672+
# TODO discuss whether we can always expect the index from before to be monotonically increasing, because
673+
# then we don't have to check order.
674+
if index:
675+
data[VALUES_COLUMN] = data[VALUES_COLUMN].cat.set_categories(data.index, ordered=True)
676+
else:
677+
data[VALUES_COLUMN] = data[VALUES_COLUMN].cat.as_known()
670678
label_index_to_category = dict(enumerate(data[VALUES_COLUMN].cat.categories, start=1))
671679

672680
if return_single_channel is None:

src/spatialdata/_core/operations/transform.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import itertools
44
import warnings
55
from functools import singledispatch
6-
from typing import TYPE_CHECKING, Any
6+
from typing import TYPE_CHECKING, Any, cast
77

88
import dask.array as da
99
import dask_image.ndinterp
1010
import numpy as np
11+
import pandas as pd
1112
from dask.array.core import Array as DaskArray
1213
from dask.dataframe import DataFrame as DaskDataFrame
1314
from geopandas import GeoDataFrame
@@ -432,18 +433,20 @@ def _(
432433
xtransformed = transformation._transform_coordinates(xdata)
433434
transformed = data.drop(columns=list(axes)).copy()
434435
# dummy transformation that will be replaced by _adjust_transformation()
435-
transformed.attrs[TRANSFORM_KEY] = {DEFAULT_COORDINATE_SYSTEM: Identity()}
436-
# TODO: the following line, used in place of the line before, leads to an incorrect aggregation result. Look into
437-
# this! Reported here: ...
438-
# transformed.attrs = {TRANSFORM_KEY: {DEFAULT_COORDINATE_SYSTEM: Identity()}}
439-
assert isinstance(transformed, DaskDataFrame)
436+
default_cs = {DEFAULT_COORDINATE_SYSTEM: Identity()}
437+
transformed.attrs[TRANSFORM_KEY] = default_cs
438+
440439
for ax in axes:
441440
indices = xtransformed["dim"] == ax
442441
new_ax = xtransformed[:, indices]
443-
transformed[ax] = new_ax.data.flatten()
442+
# TODO: discuss with dask team
443+
# This is not nice, but otherwise there is a problem with the joint graph of new_ax and transformed, causing
444+
# a getattr missing dependency of dependent from_dask_array.
445+
new_col = pd.Series(new_ax.data.flatten().compute(), index=transformed.index)
446+
transformed[ax] = new_col
447+
448+
old_transformations = cast(dict[str, Any], get_transformation(data, get_all=True))
444449

445-
old_transformations = get_transformation(data, get_all=True)
446-
assert isinstance(old_transformations, dict)
447450
_set_transformation_for_transformed_elements(
448451
transformed,
449452
old_transformations,

src/spatialdata/_core/query/spatial_query.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -672,14 +672,24 @@ def _(
672672
max_coordinate=max_coordinate_intrinsic,
673673
)
674674

675-
# assert that the number of bounding boxes is correct
676-
assert len(in_intrinsic_bounding_box) == len(min_coordinate)
675+
if not (len_df := len(in_intrinsic_bounding_box)) == (len_bb := len(min_coordinate)):
676+
raise ValueError(f"Number of dataframes `{len_df}` is not equal to the number of bounding boxes `{len_bb}`.")
677677
points_in_intrinsic_bounding_box: list[DaskDataFrame | None] = []
678+
points_pd = points.compute()
679+
attrs = points.attrs.copy()
678680
for mask in in_intrinsic_bounding_box:
679681
if mask.sum() == 0:
680682
points_in_intrinsic_bounding_box.append(None)
681683
else:
682-
points_in_intrinsic_bounding_box.append(points.loc[mask])
684+
# TODO there is a problem when mixing dask dataframe graph with dask array graph. Need to compute for now.
685+
# we can't compute either mask or points as when we calculate either one of them
686+
# test_query_points_multiple_partitions will fail as the mask will be used to index each partition.
687+
# However, if we compute and then create the dask array again we get the mixed dask graph problem.
688+
mask_np = mask.compute()
689+
filtered_pd = points_pd[mask_np]
690+
points_filtered = dd.from_pandas(filtered_pd, npartitions=points.npartitions)
691+
points_filtered.attrs.update(attrs)
692+
points_in_intrinsic_bounding_box.append(points_filtered)
683693
if len(points_in_intrinsic_bounding_box) == 0:
684694
return None
685695

src/spatialdata/_core/spatialdata.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
import zarr
1414
from anndata import AnnData
1515
from dask.dataframe import DataFrame as DaskDataFrame
16-
from dask.dataframe import read_parquet
17-
from dask.delayed import Delayed
16+
from dask.dataframe import Scalar, read_parquet
1817
from geopandas import GeoDataFrame
1918
from shapely import MultiPolygon, Polygon
2019
from xarray import DataArray, DataTree
@@ -1985,9 +1984,7 @@ def h(s: str) -> str:
19851984
else:
19861985
shape_str = (
19871986
"("
1988-
+ ", ".join(
1989-
[(str(dim) if not isinstance(dim, Delayed) else "<Delayed>") for dim in v.shape]
1990-
)
1987+
+ ", ".join([(str(dim) if not isinstance(dim, Scalar) else "<Delayed>") for dim in v.shape])
19911988
+ ")"
19921989
)
19931990
descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} with shape: {shape_str} {dim_string}"

src/spatialdata/_io/_utils.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import zarr
1616
from anndata import AnnData
17+
from dask._task_spec import Task
1718
from dask.array import Array as DaskArray
1819
from dask.dataframe import DataFrame as DaskDataFrame
1920
from geopandas import GeoDataFrame
@@ -301,6 +302,19 @@ def _get_backing_files(element: DaskArray | DaskDataFrame) -> list[str]:
301302
return files
302303

303304

305+
def _find_piece_dict(obj: dict[str, tuple[str | None]] | Task) -> dict[str, tuple[str | None | None]] | None:
306+
"""Recursively search for dict containing the key 'piece' in Dask task specs containing the parquet file path."""
307+
if isinstance(obj, dict):
308+
if "piece" in obj:
309+
return obj
310+
elif hasattr(obj, "args"): # Handles dask._task_spec.* objects like Task and List
311+
for v in obj.args:
312+
result = _find_piece_dict(v)
313+
if result is not None:
314+
return result
315+
return None
316+
317+
304318
def _search_for_backing_files_recursively(subgraph: Any, files: list[str]) -> None:
305319
# see the types allowed for the dask graph here: https://docs.dask.org/en/stable/spec.html
306320

@@ -327,25 +341,31 @@ def _search_for_backing_files_recursively(subgraph: Any, files: list[str]) -> No
327341
path = getattr(v.store, "path", None) if getattr(v.store, "path", None) else v.store.root
328342
files.append(str(UPath(path).resolve()))
329343
elif name.startswith("read-parquet") or name.startswith("read_parquet"):
330-
if hasattr(v, "creation_info"):
331-
# https://github.com/dask/dask/blob/ff2488aec44d641696e0b7aa41ed9e995c710705/dask/dataframe/io/parquet/core.py#L625
332-
t = v.creation_info["args"]
333-
if not isinstance(t, tuple) or len(t) != 1:
334-
raise ValueError(
335-
f"Unable to parse the parquet file from the dask subgraph {subgraph}. Please "
336-
f"report this bug."
337-
)
338-
parquet_file = t[0]
339-
files.append(str(UPath(parquet_file).resolve()))
340-
elif isinstance(v, tuple) and len(v) > 1 and isinstance(v[1], dict) and "piece" in v[1]:
344+
# Here v is a read_parquet task with arguments and the only value is a dictionary.
345+
if "piece" in v.args[0]:
341346
# https://github.com/dask/dask/blob/ff2488aec44d641696e0b7aa41ed9e995c710705/dask/dataframe/io/parquet/core.py#L870
342-
parquet_file, check0, check1 = v[1]["piece"]
347+
parquet_file, check0, check1 = v.args[0]["piece"]
343348
if not parquet_file.endswith(".parquet") or check0 is not None or check1 is not None:
344349
raise ValueError(
345350
f"Unable to parse the parquet file from the dask subgraph {subgraph}. Please "
346351
f"report this bug."
347352
)
348353
files.append(os.path.realpath(parquet_file))
354+
else:
355+
# This occurs when for example points and images are mixed, the main task still starts with
356+
# read_parquet, but the execution happens through a subgraph which we iterate over to get the
357+
# actual read_parquet task.
358+
for task in v.args[0].values():
359+
# Recursively go through tasks, this is required because differences between dask versions.
360+
piece_dict = _find_piece_dict(task)
361+
if isinstance(piece_dict, dict) and "piece" in piece_dict:
362+
parquet_file, check0, check1 = piece_dict["piece"] # type: ignore[misc]
363+
if not parquet_file.endswith(".parquet") or check0 is not None or check1 is not None:
364+
raise ValueError(
365+
f"Unable to parse the parquet file from the dask subgraph {subgraph}. Please "
366+
f"report this bug."
367+
)
368+
files.append(os.path.realpath(parquet_file))
349369

350370

351371
def _backed_elements_contained_in_path(path: Path, object: SpatialData | SpatialElement | AnnData) -> list[bool]:

0 commit comments

Comments
 (0)