Skip to content

Commit f6a0ed7

Browse files
committed
popgen progress
1 parent dcfc110 commit f6a0ed7

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

sgkit/stats/aggregation.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,8 @@ def count_cohort_alleles(
273273
ds, variables.call_allele_count, call_allele_count, count_call_alleles
274274
)
275275
variables.validate(ds, {call_allele_count: variables.call_allele_count_spec})
276-
# ensure cohorts is a numpy array to minimize dask task
277-
# dependencies between chunks in other dimensions
278-
AC, SC = da.asarray(ds[call_allele_count]), ds[sample_cohort].values
279-
n_cohorts = SC.max() + 1 # 0-based indexing
276+
AC, SC = da.asarray(ds[call_allele_count]), da.asarray(ds[sample_cohort])
277+
n_cohorts = ds[sample_cohort].values.max() + 1 # 0-based indexing
280278
AC = cohort_sum(AC, SC, n_cohorts, axis=1)
281279
new_ds = create_dataset(
282280
{variables.cohort_allele_count: (("variants", "cohorts", "alleles"), AC)}

sgkit/stats/cohort_numba_fns.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,19 @@ def cohort_reduction(gufunc: Callable) -> Callable:
4343

4444
@wraps(gufunc)
4545
def func(x: ArrayLike, cohort: ArrayLike, n: int, axis: int = -1) -> ArrayLike:
46-
x = da.swapaxes(da.asarray(x), axis, -1)
46+
x = da.moveaxis(da.asarray(x), [axis, -1], [-1, axis])
4747
replaced = len(x.shape) - 1
4848
chunks = x.chunks[0:-1] + (n,)
4949
out = da.map_blocks(
5050
gufunc,
5151
x,
5252
cohort,
53-
np.empty(n, np.int8),
53+
da.empty(n, dtype=np.int8),
5454
chunks=chunks,
5555
drop_axis=replaced,
5656
new_axis=replaced,
5757
)
58-
return da.swapaxes(out, axis, -1)
58+
return da.moveaxis(out, [axis, -1], [-1, axis])
5959

6060
return func
6161

0 commit comments

Comments
 (0)