Skip to content

Commit afa6e4c

Browse files
authored
Merge pull request #220 from tskit-dev/fst
Fst implementation (and doc fixes)
2 parents 5664a68 + 3d54501 commit afa6e4c

File tree

3 files changed

+162
-43
lines changed

3 files changed

+162
-43
lines changed

docs/stats.rst

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ in units of crossovers and mutations per base pair, respectively).
140140
.. _sec_general_stats_sample_sets:
141141

142142
***********************
143-
Sample sets and indices
143+
Sample sets and indexes
144144
***********************
145145

146146
Many standard population genetics statistics
@@ -164,40 +164,35 @@ and want to compute **all** fourty-five pairwise divergences?
164164
You could call ``divergence`` fourty-five times, but this would be tedious
165165
and also inefficient, because the allele frequencies for one population
166166
gets used in computing many of those values.
167-
So, statistics that take a ``sample_sets`` argument also take an ``indices`` argument,
167+
So, statistics that take a ``sample_sets`` argument also take an ``indexes`` argument,
168168
which for a statistic that operates on ``k`` sample sets will be a list of ``k``-tuples.
169-
If ``indices`` is a length ``n`` list of ``k``-tuples,
169+
If ``indexes`` is a length ``n`` list of ``k``-tuples,
170170
then the output will have ``n`` columns,
171-
and if ``indices[j]`` is a tuple ``(i0, ..., ik)``,
171+
and if ``indexes[j]`` is a tuple ``(i0, ..., ik)``,
172172
then the ``j``-th column will contain values of the statistic computed on
173173
``(sample_sets[i0], sample_sets[i1], ..., sample_sets[ik])``.
174174

175-
To recap: ``indices`` must be a list of tuples, each of length ``k``,
175+
To recap: ``indexes`` must be a list of tuples, each of length ``k``,
176176
of integers between ``0`` and ``len(sample_sets) - 1``.
177177
The appropriate value of ``k`` depends on the statistic.
178178

179179
Here are some additional special cases:
180180

181-
``indices = None``
181+
``indexes = None``
182182
If the statistic takes ``k`` inputs for ``k > 1``,
183183
and there are exactly ``k`` lists in ``sample_sets``,
184184
then this will compute just one statistic, and is equivalent to passing
185-
``indices = [(0, 1, ..., k-1)]``.
185+
``indexes = [(0, 1, ..., k-1)]``.
186186
If there are not exactly ``k`` sample sets, this will throw an error.
187187

188-
``k=1`` does not allow ``indices``:
188+
``k=1`` does not allow ``indexes``:
189189
Statistics that operate on one sample set at a time (i.e., ``k=1``)
190190
do **not** take the ``indexes`` argument,
191191
and instead just return the value of the statistic separately for each of ``sample_sets``
192192
in the order they are given.
193-
(This would be equivalent to passing ``indices = [[0], [1], ..., [len(sample_sets)]]``,
193+
(This would be equivalent to passing ``indexes = [[0], [1], ..., [len(sample_sets)]]``,
194194
were that allowed.)
195195

196-
``stat_type = "node"`` does not allow ``indices``:
197-
Since node statistics output one value per node (unlike the other types, which output
198-
something summed across all nodes), it is an error to specify ``indices`` when computing
199-
a node statistic (consequently, you need to have exactly ``k`` sample sets).
200-
201196

202197
.. _sec_general_stats_output:
203198

@@ -216,7 +211,6 @@ from ``windows[i]`` to ``windows[i + 1]`` (including the left but not the right
216211
The output is a two-dimensional array,
217212
with columns corresponding to the different statistics computed: ``out[i, j]`` is the ``j``-th statistic
218213
in the ``i``-th window.
219-
If the statistic takes an ``indices`` argument, then ``out[i, j]`` has the statistic computed with ``indices[j]``.
220214

221215
``mode="node"``
222216
The output is a three-dimensional array,

python/tests/test_tree_stats.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def path_length(tr, x, y):
125125

126126
@contextlib.contextmanager
127127
def suppress_division_by_zero_warning():
128-
with np.errstate(invalid='ignore'):
128+
with np.errstate(invalid='ignore', divide='ignore'):
129129
yield
130130

131131

@@ -830,8 +830,6 @@ def wrapped_summary_func(x):
830830
self.assertEqual(sigma1.shape, sigma4.shape)
831831
self.assertArrayAlmostEqual(sigma1, sigma2)
832832
self.assertArrayAlmostEqual(sigma1, sigma3)
833-
# print("computed", sigma1)
834-
# print("definition", sigma4)
835833
self.assertArrayAlmostEqual(sigma1, sigma4)
836834

837835

@@ -1301,6 +1299,68 @@ class TestSiteDivergence(TestDivergence, MutatedTopologyExamplesMixin):
13011299
mode = "site"
13021300

13031301

1302+
############################################
1303+
# Fst
1304+
############################################
1305+
1306+
def single_site_Fst(ts, sample_sets, indexes):
1307+
"""
1308+
Compute single-site Fst, which between two groups with frequencies p and q is
1309+
1 - 2 * (p (1-p) + q(1-q)) / ( p(1-p) + q(1-q) + p(1-q) + q(1-p) )
1310+
or in the multiallelic case, replacing p(1-p) with the sum over alleles of p(1-p),
1311+
and adjusted for sampling without replacement.
1312+
"""
1313+
# TODO: what to do in this case?
1314+
if ts.num_sites == 0:
1315+
out = np.array([np.repeat(np.nan, len(indexes))])
1316+
return out
1317+
out = np.zeros((ts.num_sites, len(indexes)))
1318+
samples = ts.samples()
1319+
for j, v in enumerate(ts.variants()):
1320+
for i, (ix, iy) in enumerate(indexes):
1321+
g = v.genotypes
1322+
X = sample_sets[ix]
1323+
Y = sample_sets[iy]
1324+
gX = [a for k, a in zip(samples, g) if k in X]
1325+
gY = [a for k, a in zip(samples, g) if k in Y]
1326+
nX = len(X)
1327+
nY = len(Y)
1328+
dX = dY = dXY = 0
1329+
for a in set(g):
1330+
fX = np.sum(gX == a)
1331+
fY = np.sum(gY == a)
1332+
with suppress_division_by_zero_warning():
1333+
dX += fX * (nX - fX) / (nX * (nX - 1))
1334+
dY += fY * (nY - fY) / (nY * (nY - 1))
1335+
dXY += (fX * (nY - fY) + (nX - fX) * fY) / (2 * nX * nY)
1336+
with suppress_division_by_zero_warning():
1337+
out[j][i] = 1 - 2 * (dX + dY) / (dX + dY + 2 * dXY)
1338+
return out
1339+
1340+
1341+
class TestFst(StatsTestCase, TwoWaySampleSetStatsMixin):
1342+
1343+
# Derived classes define this to get a specific stats mode.
1344+
mode = None
1345+
1346+
def verify(self, ts):
1347+
# only check per-site
1348+
for sample_sets in example_sample_sets(ts, min_size=2):
1349+
for indexes in example_sample_set_index_pairs(sample_sets):
1350+
self.verify_persite_Fst(ts, sample_sets, indexes)
1351+
1352+
def verify_persite_Fst(self, ts, sample_sets, indexes):
1353+
sigma1 = ts.Fst(sample_sets, indexes=indexes, windows="sites",
1354+
mode=self.mode, span_normalise=False)
1355+
sigma2 = single_site_Fst(ts, sample_sets, indexes)
1356+
self.assertEqual(sigma1.shape, sigma2.shape)
1357+
self.assertArrayAlmostEqual(sigma1, sigma2)
1358+
1359+
1360+
class TestSiteFst(TestFst, MutatedTopologyExamplesMixin):
1361+
mode = "site"
1362+
1363+
13041364
############################################
13051365
# Y2
13061366
############################################
@@ -2289,6 +2349,11 @@ class TestGeneralStatInterface(StatsTestCase):
22892349
Tests for the basic interface for general_stats.
22902350
"""
22912351

2352+
def get_tree_sequence(self):
2353+
ts = msprime.simulate(10, recombination_rate=2,
2354+
mutation_rate=2, random_seed=1)
2355+
return ts
2356+
22922357
def test_default_mode(self):
22932358
ts = msprime.simulate(10, recombination_rate=1, random_seed=2)
22942359
W = np.ones((ts.num_samples, 2))
@@ -2303,6 +2368,15 @@ def test_bad_mode(self):
23032368
with self.assertRaises(ValueError):
23042369
ts.general_stat(W, lambda x: x, mode=bad_mode)
23052370

2371+
def test_bad_window_strings(self):
2372+
ts = self.get_tree_sequence()
2373+
with self.assertRaises(ValueError):
2374+
ts.diversity([list(ts.samples())], mode="site", windows="abc")
2375+
with self.assertRaises(ValueError):
2376+
ts.diversity([list(ts.samples())], mode="site", windows="")
2377+
with self.assertRaises(ValueError):
2378+
ts.diversity([list(ts.samples())], mode="tree", windows="abc")
2379+
23062380

23072381
class TestGeneralBranchStats(StatsTestCase):
23082382
"""

0 commit comments

Comments
 (0)