diff --git a/.github/workflows/cubed.yml b/.github/workflows/cubed.yml new file mode 100644 index 000000000..2b9cde4e8 --- /dev/null +++ b/.github/workflows/cubed.yml @@ -0,0 +1,33 @@ +name: Cubed + +on: + push: + pull_request: + # manual trigger + workflow_dispatch: + +jobs: + build: + # This workflow only runs on the origin org + # if: github.repository_owner == 'sgkit-dev' + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install deps and sgkit + run: | + python -m pip install --upgrade pip + python -m pip install -r requirements.txt -r requirements-dev.txt + python -m pip install -U git+https://github.com/cubed-dev/cubed.git -U git+https://github.com/cubed-dev/cubed-xarray.git -U git+https://github.com/pydata/xarray.git + + - name: Test with pytest + run: | + pytest -v sgkit/tests/test_aggregation.py -k "test_count_call_alleles" --use-cubed diff --git a/conftest.py b/conftest.py index bcccb6e8a..9343393f3 100644 --- a/conftest.py +++ b/conftest.py @@ -2,6 +2,30 @@ collect_ignore_glob = ["benchmarks/**", "sgkit/io/vcf/*.py", ".github/scripts/*.py"] +def pytest_addoption(parser): + parser.addoption( + "--use-cubed", action="store_true", default=False, help="run with cubed" + ) + + +def use_cubed(): + import dask + import xarray as xr + + # set xarray to use cubed by default + xr.set_options(chunk_manager="cubed") + + # ensure that dask compute raises if it is ever called + class AlwaysRaiseScheduler: + def __call__(self, dsk, keys, **kwargs): + raise RuntimeError("Dask 'compute' was called") + + dask.config.set(scheduler=AlwaysRaiseScheduler()) + + def pytest_configure(config) -> None: # type: ignore # Add "gpu" marker config.addinivalue_line("markers", "gpu:Run tests that run on GPU") + + if config.getoption("--use-cubed"): + use_cubed() diff --git a/sgkit/distarray.py b/sgkit/distarray.py new file mode 100644 index 000000000..fe3bda790 --- /dev/null +++ b/sgkit/distarray.py @@ -0,0 +1,10 @@ +from xarray.namedarray.parallelcompat import guess_chunkmanager + +# use the xarray chunk manager to determine the distributed array module to use +cm = guess_chunkmanager(None) + +if cm.array_cls.__module__.split(".")[0] == "cubed": + from cubed import * # pragma: no cover # noqa: F401, F403 +else: + # default to dask + from dask.array import * # noqa: F401, F403 diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 41adcc221..f76dbe691 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -1,11 +1,11 @@ from typing import Hashable -import dask.array as da import numpy as np import xarray as xr from typing_extensions import Literal from xarray import Dataset +import sgkit.distarray as da from sgkit import variables from sgkit.display import genotype_as_bytes from sgkit.utils import ( @@ -77,6 +77,11 @@ def count_call_alleles( variables.validate(ds, {call_genotype: variables.call_genotype_spec}) n_alleles = ds.sizes["alleles"] G = da.asarray(ds[call_genotype]) + if G.numblocks[2] > 1: + raise ValueError( + f"Variable {call_genotype} must have only a single chunk in the ploidy dimension. " + "Consider rechunking to change the size of chunks." + ) shape = (G.chunks[0], G.chunks[1], n_alleles) # use numpy array to avoid dask task dependencies between chunks N = np.empty(n_alleles, dtype=np.uint8) @@ -85,7 +90,13 @@ def count_call_alleles( variables.call_allele_count: ( ("variants", "samples", "alleles"), da.map_blocks( - count_alleles, G, N, chunks=shape, drop_axis=2, new_axis=2 + count_alleles, + G, + N, + chunks=shape, + dtype=np.uint8, + drop_axis=2, + new_axis=2, ), ) } diff --git a/sgkit/tests/test_aggregation.py b/sgkit/tests/test_aggregation.py index a5e68645e..e2719c1c1 100644 --- a/sgkit/tests/test_aggregation.py +++ b/sgkit/tests/test_aggregation.py @@ -139,8 +139,10 @@ def test_count_variant_alleles__chunked(using): calls = rs.randint(0, 1, size=(50, 10, 2)) ds = get_dataset(calls) ac1 = count_variant_alleles(ds, using=using) - # Coerce from numpy to multiple chunks in all dimensions - ds["call_genotype"] = ds["call_genotype"].chunk(chunks=(5, 5, 1)) + # Coerce from numpy to multiple chunks in all non-core dimensions + ds["call_genotype"] = ds["call_genotype"].chunk( + chunks={"variants": 5, "samples": 5} + ) ac2 = count_variant_alleles(ds, using=using) assert isinstance(ac2["variant_allele_count"].data, da.Array) xr.testing.assert_equal(ac1, ac2) @@ -265,12 +267,22 @@ def test_count_call_alleles__chunked(): calls = rs.randint(0, 1, size=(50, 10, 2)) ds = get_dataset(calls) ac1 = count_call_alleles(ds) - # Coerce from numpy to multiple chunks in all dimensions - ds["call_genotype"] = ds["call_genotype"].chunk(chunks=(5, 5, 1)) + # Coerce from numpy to multiple chunks in all non-core dimensions + ds["call_genotype"] = ds["call_genotype"].chunk( + chunks={"variants": 5, "samples": 5} + ) ac2 = count_call_alleles(ds) - assert isinstance(ac2["call_allele_count"].data, da.Array) + assert hasattr(ac2["call_allele_count"].data, "chunks") xr.testing.assert_equal(ac1, ac2) + # Multiple chunks in core dimension should fail + ds["call_genotype"] = ds["call_genotype"].chunk(chunks={"ploidy": 1}) + with pytest.raises( + ValueError, + match="Variable call_genotype must have only a single chunk in the ploidy dimension", + ): + count_call_alleles(ds) + def test_count_cohort_alleles__multi_variant_multi_sample(): ds = get_dataset( diff --git a/sgkit/tests/test_popgen.py b/sgkit/tests/test_popgen.py index 50fc9bb4c..3014b9b31 100644 --- a/sgkit/tests/test_popgen.py +++ b/sgkit/tests/test_popgen.py @@ -533,7 +533,7 @@ def test_Garud_h__raise_on_no_windows(): @pytest.mark.filterwarnings("ignore::RuntimeWarning") -@pytest.mark.parametrize("chunks", [((4,), (6,), (4,)), ((2, 2), (3, 3), (2, 2))]) +@pytest.mark.parametrize("chunks", [((4,), (6,), (4,)), ((2, 2), (3, 3), (4))]) def test_observed_heterozygosity(chunks): ds = simulate_genotype_call_dataset( n_variant=4, @@ -599,7 +599,7 @@ def test_observed_heterozygosity(chunks): @pytest.mark.filterwarnings("ignore::RuntimeWarning") -@pytest.mark.parametrize("chunks", [((4,), (6,), (4,)), ((2, 2), (3, 3), (2, 2))]) +@pytest.mark.parametrize("chunks", [((4,), (6,), (4,)), ((2, 2), (3, 3), (4,))]) @pytest.mark.parametrize( "cohorts,expectation", [