Skip to content

Commit c552a84

Browse files
committed
Use sgkit.distarray for sample_stats
1 parent 4eea44c commit c552a84

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

.github/workflows/cubed.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ jobs:
3030
3131
- name: Test with pytest
3232
run: |
33-
pytest -v sgkit/tests/test_aggregation.py -k 'test_count_call_alleles or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed
33+
pytest -v sgkit/tests/test_aggregation.py -k 'test_count_call_alleles or test_sample_stats or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed

sgkit/stats/aggregation.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -803,22 +803,23 @@ def sample_stats(
803803
mixed_ploidy = ds[call_genotype].attrs.get("mixed_ploidy", False)
804804
if mixed_ploidy:
805805
raise ValueError("Mixed-ploidy dataset")
806-
G = da.asarray(ds[call_genotype].data)
806+
GT = da.asarray(ds[call_genotype].transpose("samples", "variants", "ploidy").data)
807807
H = xr.DataArray(
808808
da.map_blocks(
809-
count_hom,
810-
G.transpose(1, 0, 2),
809+
lambda *args: count_hom(*args)[:, np.newaxis, :],
810+
GT,
811811
np.zeros(3, np.uint64),
812-
drop_axis=(1, 2),
813-
new_axis=1,
812+
drop_axis=2,
813+
new_axis=2,
814814
dtype=np.int64,
815-
chunks=(G.chunks[1], 3),
815+
chunks=(GT.chunks[0], 1, 3),
816816
),
817-
dims=["samples", "categories"],
817+
dims=["samples", "variants", "categories"],
818818
)
819-
n_variant, _, _ = G.shape
819+
H = H.sum(axis=1)
820+
_, n_variant, _ = GT.shape
820821
n_called = H.sum(axis=-1)
821-
call_rate = n_called / n_variant
822+
call_rate = n_called.astype(float) / float(n_variant)
822823
n_hom_ref = H[:, 0]
823824
n_hom_alt = H[:, 1]
824825
n_het = H[:, 2]

sgkit/tests/test_aggregation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,17 @@ def test_sample_stats__raise_on_mixed_ploidy():
857857
sample_stats(ds)
858858

859859

860+
@pytest.mark.parametrize("chunks", [(-1, -1, -1), (100, -1, -1), (100, 10, -1)])
861+
def test_sample_stats__chunks(chunks):
862+
ds = simulate_genotype_call_dataset(
863+
n_variant=1000, n_sample=30, missing_pct=0.01, seed=0
864+
)
865+
expect = sample_stats(ds, merge=False).compute()
866+
ds["call_genotype"] = ds["call_genotype"].chunk(chunks)
867+
actual = sample_stats(ds, merge=False).compute()
868+
assert actual.equals(expect)
869+
870+
860871
def test_infer_call_ploidy():
861872
ds = get_dataset(
862873
[

0 commit comments

Comments
 (0)