Skip to content

Commit 8258883

Browse files
timothymillarmergify[bot]
authored andcommitted
Add 'matching' method to identity_by_state #1227
1 parent 02b9ce5 commit 8258883

File tree

3 files changed

+297
-71
lines changed

3 files changed

+297
-71
lines changed

sgkit/stats/ibs.py

Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Hashable
22

33
import dask.array as da
4+
from typing_extensions import Literal
45
from xarray import Dataset
56

67
from sgkit import variables
@@ -12,10 +13,53 @@
1213
)
1314

1415

16+
def _ibs_of_frequencies(af, skipna=True):
17+
af = da.array(af)
18+
if skipna:
19+
af0 = da.where(da.isnan(af), 0.0, af)
20+
num = sum(m.T @ m for m in af0.transpose(2, 0, 1))
21+
called = da.nansum(af, axis=-1)
22+
denom = called.T @ called
23+
else:
24+
num = sum(m.T @ m for m in af.transpose(2, 0, 1))
25+
denom = len(af)
26+
return num / denom
27+
28+
29+
def _ibs_of_genotypes(gt):
30+
from .ibs_numba_fns import allele_matching_block, allele_matching_diag
31+
32+
gt_blocks = da.array(gt).blocks
33+
v_chunks, s_chunks, p_chunks = gt_blocks.shape
34+
if p_chunks != 1:
35+
raise ValueError(
36+
"The 'matching' method does not support chunking in the ploidy dimension"
37+
)
38+
ibs = [[None for _ in range(s_chunks)] for _ in range(s_chunks)]
39+
for x in range(s_chunks):
40+
nums, denoms = zip(
41+
*[allele_matching_diag(gt_blocks[v, x]) for v in range(gt_blocks.shape[0])]
42+
)
43+
ibs[x][x] = sum(nums) / sum(denoms)
44+
for y in range(x):
45+
nums, denoms = zip(
46+
*[
47+
allele_matching_block(gt_blocks[v, x], gt_blocks[v, y])
48+
for v in range(v_chunks)
49+
]
50+
)
51+
mtx = sum(nums) / sum(denoms)
52+
ibs[x][y] = mtx
53+
ibs[y][x] = mtx.T
54+
return da.vstack([da.hstack(row) for row in ibs])
55+
56+
1557
def identity_by_state(
1658
ds: Dataset,
1759
*,
60+
call_genotype: Hashable = variables.call_genotype,
1861
call_allele_frequency: Hashable = variables.call_allele_frequency,
62+
method: Literal["frequencies", "matching"] = "frequencies",
1963
skipna: bool = True,
2064
merge: bool = True,
2165
) -> Dataset:
@@ -31,11 +75,24 @@ def identity_by_state(
3175
----------
3276
ds
3377
Dataset containing call genotype alleles.
78+
call_genotype
79+
Input variable name holding call_genotype as defined by
80+
:data:`sgkit.variables.call_allele_frequency_spec`.
81+
This variable is only required for the "matching" method.
3482
call_allele_frequency
3583
Input variable name holding call_allele_frequency as defined by
3684
:data:`sgkit.variables.call_allele_frequency_spec`.
85+
This variable is only required for the "frequencies" method.
3786
If the variable is not present in ``ds``, it will be computed
3887
using :func:`call_allele_frequencies`.
88+
method
89+
The method used for IBS estimation. Defaults to "frequencies"
90+
which calculates IBS probabilities by matrix multiplication
91+
of call allele frequencies which is more efficient when the
92+
alleles dimension is small.
93+
The "matching" method calculates IBS probabilities directly
94+
from the call genotypes and is more efficient when the alleles
95+
dimension is large.
3996
skipna
4097
If True (the default), missing (nan) allele frequencies will be
4198
skipped.
@@ -66,29 +123,29 @@ def identity_by_state(
66123
[0.5, 1. , 0.5],
67124
[0.5, 0.5, 0.5]])
68125
"""
69-
ds = define_variable_if_absent(
70-
ds,
71-
variables.call_allele_frequency,
72-
call_allele_frequency,
73-
call_allele_frequencies,
74-
)
75-
variables.validate(
76-
ds, {call_allele_frequency: variables.call_allele_frequency_spec}
77-
)
78-
af = da.asarray(ds[call_allele_frequency])
79-
if skipna:
80-
af0 = da.where(da.isnan(af), 0.0, af)
81-
num = sum(m.T @ m for m in af0.transpose(2, 0, 1))
82-
called = da.nansum(af, axis=-1)
83-
denom = called.T @ called
126+
if method == "frequencies":
127+
ds = define_variable_if_absent(
128+
ds,
129+
variables.call_allele_frequency,
130+
call_allele_frequency,
131+
call_allele_frequencies,
132+
)
133+
variables.validate(
134+
ds, {call_allele_frequency: variables.call_allele_frequency_spec}
135+
)
136+
af = ds[call_allele_frequency]
137+
ibs = _ibs_of_frequencies(af, skipna=skipna)
138+
elif method == "matching":
139+
variables.validate(ds, {call_genotype: variables.call_genotype_spec})
140+
gt = ds[call_genotype]
141+
ibs = _ibs_of_genotypes(gt)
84142
else:
85-
num = sum(m.T @ m for m in af.transpose(2, 0, 1))
86-
denom = len(af)
143+
raise ValueError(f"Unknown method '{method}'.")
87144
new_ds = create_dataset(
88145
{
89146
variables.stat_identity_by_state: (
90147
("samples_0", "samples_1"),
91-
num / denom,
148+
ibs,
92149
)
93150
}
94151
)

sgkit/stats/ibs_numba_fns.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from sgkit.accelerate import numba_guvectorize
2+
from sgkit.typing import ArrayLike
3+
4+
5+
@numba_guvectorize( # type: ignore
6+
[
7+
"void(int8[:,:,:], float64[:,:], float64[:,:])",
8+
"void(int16[:,:,:], float64[:,:], float64[:,:])",
9+
"void(int32[:,:,:], float64[:,:], float64[:,:])",
10+
"void(int64[:,:,:], float64[:,:], float64[:,:])",
11+
],
12+
"(v,s,k)->(s,s),(s,s)",
13+
)
14+
def allele_matching_diag(
15+
gt: ArrayLike,
16+
numerator: ArrayLike,
17+
denominator: ArrayLike,
18+
) -> None: # pragma: no cover
19+
n_variant, n_sample, ploidy = gt.shape
20+
numerator[:] = 0.0
21+
denominator[:] = 0.0
22+
for v in range(n_variant):
23+
for s0 in range(n_sample):
24+
for s1 in range(s0 + 1):
25+
# local IBS prob to ensure even weighting of loci
26+
local_num = 0
27+
local_denom = 0
28+
for i in range(ploidy):
29+
a0 = gt[v, s0, i]
30+
if a0 >= 0:
31+
for j in range(ploidy):
32+
a1 = gt[v, s1, j]
33+
if a1 >= 0:
34+
local_denom += 1
35+
if a0 == a1:
36+
local_num += 1
37+
if local_denom > 0:
38+
p_ibs = local_num / local_denom
39+
numerator[s0, s1] += p_ibs
40+
numerator[s1, s0] += p_ibs
41+
denominator[s0, s1] += 1.0
42+
denominator[s1, s0] += 1.0
43+
# undo double addition to diagonal
44+
if local_denom > 0:
45+
numerator[s0, s0] -= p_ibs
46+
denominator[s0, s0] -= 1.0
47+
48+
49+
@numba_guvectorize( # type: ignore
50+
[
51+
"void(int8[:,:,:], int8[:,:,:], float64[:,:], float64[:,:])",
52+
"void(int16[:,:,:], int16[:,:,:], float64[:,:], float64[:,:])",
53+
"void(int32[:,:,:], int32[:,:,:], float64[:,:], float64[:,:])",
54+
"void(int64[:,:,:], int64[:,:,:], float64[:,:], float64[:,:])",
55+
],
56+
"(v,s0,k),(v,s1,k)->(s0,s1),(s0,s1)",
57+
)
58+
def allele_matching_block(
59+
gt0: ArrayLike,
60+
gt1: ArrayLike,
61+
numerator: ArrayLike,
62+
denominator: ArrayLike,
63+
) -> None: # pragma: no cover
64+
n_variant, n_sample0, ploidy = gt0.shape
65+
_, n_sample1, _ = gt1.shape
66+
numerator[:] = 0.0
67+
denominator[:] = 0.0
68+
for v in range(n_variant):
69+
for s0 in range(n_sample0):
70+
for s1 in range(n_sample1):
71+
# local IBS prob to ensure even weighting of loci
72+
local_num = 0
73+
local_denom = 0
74+
for i in range(ploidy):
75+
a0 = gt0[v, s0, i]
76+
if a0 >= 0:
77+
for j in range(ploidy):
78+
a1 = gt1[v, s1, j]
79+
if a1 >= 0:
80+
local_denom += 1
81+
if a0 == a1:
82+
local_num += 1
83+
if local_denom > 0:
84+
p_ibs = local_num / local_denom
85+
numerator[s0, s1] += p_ibs
86+
denominator[s0, s1] += 1.0

0 commit comments

Comments
 (0)