|
| 1 | +import contextlib |
1 | 2 | import math |
2 | 3 | import tempfile |
3 | 4 | from pathlib import Path |
4 | 5 |
|
5 | 6 | import numpy as np |
6 | 7 | import pytest |
| 8 | +from dask import config |
7 | 9 | from geopandas.testing import geom_almost_equals |
8 | 10 | from xarray import DataArray, DataTree |
9 | 11 |
|
10 | 12 | from spatialdata import transform |
11 | 13 | from spatialdata._core.data_extent import are_extents_equal, get_extent |
12 | 14 | from spatialdata._core.spatialdata import SpatialData |
13 | | -from spatialdata._utils import unpad_raster |
| 15 | +from spatialdata._utils import disable_dask_tune_optimization, unpad_raster |
14 | 16 | from spatialdata.models import Image2DModel, PointsModel, ShapesModel, get_axes_names |
15 | 17 | from spatialdata.transformations.operations import ( |
16 | 18 | align_elements_using_landmarks, |
@@ -586,6 +588,44 @@ def test_transform_elements_and_entire_spatial_data_object(full_sdata: SpatialDa |
586 | 588 | _ = full_sdata.transform_to_coordinate_system("my_space", maintain_positioning=maintain_positioning) |
587 | 589 |
|
588 | 590 |
|
| 591 | +def test_transform_points_with_multiple_partitions(full_sdata: SpatialData, tmp_path: str): |
| 592 | + tmpdir = Path(tmp_path) / "tmp.zarr" |
| 593 | + points_memory = full_sdata["points_0"].compute() |
| 594 | + full_sdata["points_0"] = PointsModel.parse( |
| 595 | + full_sdata["points_0"].repartition(npartitions=4), |
| 596 | + transformations={"global": get_transformation(full_sdata["points_0"])}, |
| 597 | + ) |
| 598 | + assert points_memory.equals(full_sdata["points_0"].compute()) |
| 599 | + |
| 600 | + full_sdata.write(tmpdir) |
| 601 | + |
| 602 | + full_sdata = SpatialData.read(tmpdir) |
| 603 | + |
| 604 | + # This just needs to run without error |
| 605 | + data = transform(full_sdata["points_0"], to_coordinate_system="global") |
| 606 | + |
| 607 | + # test that data still can be computed |
| 608 | + data.compute() |
| 609 | + |
| 610 | + |
| 611 | +@pytest.mark.parametrize( |
| 612 | + "tune,partition", |
| 613 | + [ |
| 614 | + (True, None), |
| 615 | + (False, 4), |
| 616 | + ], |
| 617 | +) |
| 618 | +def test_dask_tune_contextmanager(full_sdata: SpatialData, partition: int | None, tune: bool): |
| 619 | + if partition: |
| 620 | + full_sdata["points_0"] = PointsModel.parse( |
| 621 | + full_sdata["points_0"].repartition(npartitions=4), |
| 622 | + transformations={"global": get_transformation(full_sdata["points_0"])}, |
| 623 | + ) |
| 624 | + |
| 625 | + with disable_dask_tune_optimization() if full_sdata["points_0"].npartitions > 1 else contextlib.nullcontext(): |
| 626 | + assert config.config["optimization"]["tune"]["active"] is tune |
| 627 | + |
| 628 | + |
589 | 629 | @pytest.mark.parametrize("maintain_positioning", [True, False]) |
590 | 630 | def test_transform_elements_and_entire_spatial_data_object_multi_hop( |
591 | 631 | full_sdata: SpatialData, maintain_positioning: bool |
|
0 commit comments