Skip to content

Commit e36db7a

Browse files
Merge pull request #330 from petrelharp/fix_testing
Fix testing
2 parents 47ec75f + 5e90c6e commit e36db7a

File tree

3 files changed

+183
-190
lines changed

3 files changed

+183
-190
lines changed

docs/stats.rst

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,14 @@ Windowing
7070
*********
7171

7272
Each statistic has an argument, ``windows``,
73-
which defines a collection of contiguous windows along the genome.
74-
If ``windows`` is a list of ``n+1`` increasing numbers between 0 and the ``sequence_length``,
75-
then the statistic will be computed separately in each of the ``n`` windows,
73+
which defines a collection of contiguous windows spanning the genome.
74+
``windows`` should be a list of ``n+1`` increasing numbers beginning with 0
75+
and ending with the ``sequence_length``.
76+
The statistic will be computed separately in each of the ``n`` windows,
7677
and the ``k``-th row of the output will report the values of the statistic
7778
in the ``k``-th window, i.e., from (and including) ``windows[k]`` to (but not including) ``windows[k+1]``.
7879

79-
All windowed statistics by default return **averages** within each of the windows,
80+
Most windowed statistics by default return **averages** within each of the windows,
8081
so the values are comparable between windows, even of different lengths.
8182
(However, shorter windows may be noisier.)
8283
Suppose for instance that you compute some statistic with ``windows = [a, b, c]``
@@ -108,6 +109,13 @@ There are some shortcuts to other useful options:
108109
since the windows are all different sizes you probably want to also pass
109110
``span_normalise=False`` (see below).
110111

112+
113+
.. _sec_general_stats_span_normalise:
114+
115+
+++++++++++++
116+
Normalisation
117+
+++++++++++++
118+
111119
Furthermore, there is an option, ``span_normalise`` (default ``True``),
112120
that if ``False`` returns the **sum** of the relevant statistic across each window rather than the average.
113121
The statistic that is returned by default is an average because we divide by
@@ -206,6 +214,7 @@ Here are some additional special cases:
206214
were that allowed.)
207215

208216

217+
209218
.. _sec_general_stats_output_format:
210219

211220
*************
@@ -343,6 +352,18 @@ regression with other covariates (as in GWAS).
343352
- :meth:`.TreeSequence.trait_covariance`
344353
- :meth:`.TreeSequence.trait_correlation`
345354

355+
------------------
356+
Derived statistics
357+
------------------
358+
359+
The other statistics above all have the property that `mode="branch"` and
360+
`mode="site"` are "dual" in the sense that they are equal, on average, under
361+
a high neutral mutation rate. The following statistics do not have this
362+
property (since both are ratios of statistics that do have this property).
363+
364+
- :meth:`.TreeSequence.Fst`
365+
- :meth:`.TreeSequence.TajimasD`
366+
346367
---------------
347368
General methods
348369
---------------
@@ -355,15 +376,35 @@ using these methods directly, so they should be preferred.
355376
- :meth:`.TreeSequence.general_stat`
356377
- :meth:`.TreeSequence.sample_count_stat`
357378

358-
------------------
359-
Derived statistics
360-
------------------
361379

362-
The other statistics above all have the property that `mode="branch"` and
363-
`mode="site"` are "dual" in the sense that they are equal, on average, under
364-
a high neutral mutation rate. The following statistics do not have this
365-
property (since both are ratios of statistics that do have this property).
380+
.. _sec_general_stats_advanced:
366381

367-
- :meth:`.TreeSequence.Fst`
368-
- :meth:`.TreeSequence.TajimasD`
382+
****************
383+
Advanced methods
384+
****************
385+
386+
The methods :meth:`.TreeSequence.general_stat` and :meth:`.TreeSequence.sample_count_stat`
387+
provide access to the general-purpose algorithm for computing statistics.
388+
Here is a bit more discussion of how to use these.
389+
390+
.. _sec_general_stats_polarisation:
391+
392+
++++++++++++
393+
Polarisation
394+
++++++++++++
395+
396+
Many statistics calculated from genome sequence treat all alleles on equal footing,
397+
as one must without knowledge of the ancestral state and sequence of mutations that produced the data.
398+
Separating out the *ancestral* allele (e.g., as inferred using an outgroup)
399+
is known as *polarisiation*.
400+
For instance, in the allele frequency spectrum, a site with alleles at 20% and 80% frequency
401+
is no different than another whose alleles are at 80% and 20%,
402+
unless we know in each case which allele is ancestral,
403+
and so while the unpolarised allele frequency spectrum gives the distribution of frequencies of *all* alleles,
404+
the *polarised* allele frequency spectrum gives the distribution of frequencies of only *derived* alleles.
369405

406+
This concept is extended to more general statistics as follows.
407+
For site statistics, summary functions are applied to the total weight or number of samples
408+
associated with each allele; but if polarised, then the ancestral allele is left out of this sum.
409+
For branch or node statistics, summary functions are applied to the total weight or number of samples
410+
below, and above each branch or node; if polarised, then only the weight below is used.

python/tests/test_tree_stats.py

Lines changed: 18 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -796,13 +796,16 @@ def example_sample_sets(ts, min_size=1):
796796
number of sample sets returned in each example must be at least min_size
797797
"""
798798
samples = ts.samples()
799+
np.random.shuffle(samples)
799800
splits = np.array_split(samples, min_size)
800801
yield splits
801802
yield [[s] for s in samples]
802803
if min_size == 1:
803804
yield [samples[:1]]
804-
if ts.num_samples <= 2 and min_size >= 2:
805+
if ts.num_samples > 2 and min_size <= 2:
805806
yield [samples[:2], samples[2:]]
807+
if ts.num_samples > 7 and min_size <= 4:
808+
yield [samples[:2], samples[2:4], samples[4:6], samples[6:]]
806809

807810

808811
def example_sample_set_index_pairs(sample_sets):
@@ -1163,7 +1166,7 @@ def site_segregating_sites(ts, sample_sets, windows=None, span_normalise=True):
11631166
haps = ts.genotype_matrix(impute_missing_data=True)
11641167
site_positions = [x.position for x in ts.sites()]
11651168
for i, X in enumerate(sample_sets):
1166-
X_index = np.where(np.in1d(X, samples))[0]
1169+
X_index = np.where(np.in1d(samples, X))[0]
11671170
for k in range(ts.num_sites):
11681171
if (site_positions[k] >= begin) and (site_positions[k] < end):
11691172
num_alleles = len(set(haps[k, X_index]))
@@ -1303,7 +1306,7 @@ def site_tajimas_d(ts, sample_sets, windows=None):
13031306
nn = n[i]
13041307
S = 0
13051308
T = 0
1306-
X_index = np.where(np.in1d(X, samples))[0]
1309+
X_index = np.where(np.in1d(samples, X))[0]
13071310
for k in range(ts.num_sites):
13081311
if (site_positions[k] >= begin) and (site_positions[k] < end):
13091312
hX = haps[k, X_index]
@@ -1497,7 +1500,8 @@ def verify_sample_sets(self, ts, sample_sets, windows):
14971500
denom = n * (n - 1) * (n - 2)
14981501

14991502
def f(x):
1500-
return x * (n - x) * (n - x - 1) / denom
1503+
with np.errstate(invalid='ignore', divide='ignore'):
1504+
return x * (n - x) * (n - x - 1) / denom
15011505

15021506
self.verify_definition(ts, sample_sets, windows, f, ts.Y1, Y1)
15031507

@@ -1875,8 +1879,6 @@ def node_Y2(ts, sample_sets, indexes, windows=None, span_normalise=True):
18751879

18761880

18771881
def Y2(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise=True):
1878-
windows = ts.parse_windows(windows)
1879-
18801882
windows = ts.parse_windows(windows)
18811883
if indexes is None:
18821884
indexes = [(0, 1)]
@@ -3330,11 +3332,11 @@ def test_bad_mode(self):
33303332
def test_bad_window_strings(self):
33313333
ts = self.get_tree_sequence()
33323334
with self.assertRaises(ValueError):
3333-
ts.diversity([list(ts.samples())], mode="site", windows="abc")
3335+
ts.diversity([ts.samples()], mode="site", windows="abc")
33343336
with self.assertRaises(ValueError):
3335-
ts.diversity([list(ts.samples())], mode="site", windows="")
3337+
ts.diversity([ts.samples()], mode="site", windows="")
33363338
with self.assertRaises(ValueError):
3337-
ts.diversity([list(ts.samples())], mode="tree", windows="abc")
3339+
ts.diversity([ts.samples()], mode="tree", windows="abc")
33383340

33393341
def test_bad_summary_function(self):
33403342
ts = self.get_tree_sequence()
@@ -3441,8 +3443,7 @@ class TestGeneralSiteStats(StatsTestCase):
34413443
def compare_general_stat(self, ts, W, f, windows=None, polarised=False):
34423444
# Determine output_dim of the function
34433445
M = len(f(W[0]))
3444-
py_ssc = PythonSiteStatCalculator(ts)
3445-
sigma1 = py_ssc.naive_general_stat(W, f, windows, polarised=polarised)
3446+
sigma1 = naive_site_general_stat(ts, W, f, windows, polarised=polarised)
34463447
sigma2 = ts.general_stat(W, f, M, windows, polarised=polarised, mode="site")
34473448
sigma3 = site_general_stat(ts, W, f, windows, polarised=polarised)
34483449
self.assertEqual(sigma1.shape, sigma2.shape)
@@ -4391,7 +4392,6 @@ class BranchSampleSetStatsTestCase(SampleSetStatTestCase):
43914392
def setUp(self):
43924393
self.rng = random.Random(self.random_seed)
43934394
self.stat_type = "branch"
4394-
self.py_stat_class = PythonBranchStatCalculator
43954395

43964396
def get_ts(self):
43974397
for N in [12, 15, 20]:
@@ -4595,7 +4595,7 @@ def f(x):
45954595
branch_true_mean_diversity)
45964596
self.assertAlmostEqual(ts.divergence([A[0], A[1]], [(0, 1)], mode=mode),
45974597
branch_true_mean_diversity)
4598-
self.assertAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0][0],
4598+
self.assertAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0],
45994599
branch_true_mean_diversity)
46004600

46014601
# Y-statistic for (0/12)
@@ -4610,7 +4610,7 @@ def f(x):
46104610
py_bsc_Y = Y3(ts, [[0], [1], [2]], [(0, 1, 2)], windows=[0.0, 1.0], mode=mode)
46114611
self.assertArrayAlmostEqual(bts_Y, branch_true_Y)
46124612
self.assertArrayAlmostEqual(py_bsc_Y, branch_true_Y)
4613-
self.assertArrayAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0][0],
4613+
self.assertArrayAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0],
46144614
branch_true_Y)
46154615

46164616
mode = "site"
@@ -4619,7 +4619,7 @@ def f(x):
46194619
py_ssc_Y = Y3(ts, [[0], [1], [2]], [(0, 1, 2)], windows=[0.0, 1.0], mode=mode)
46204620
self.assertArrayAlmostEqual(sts_Y, site_true_Y)
46214621
self.assertArrayAlmostEqual(py_ssc_Y, site_true_Y)
4622-
self.assertArrayAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0][0],
4622+
self.assertArrayAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0],
46234623
site_true_Y)
46244624

46254625
A = [[0, 1, 2]]
@@ -5002,7 +5002,7 @@ def f(x):
50025002
branch_true_diversity_02]):
50035003

50045004
self.assertAlmostEqual(diversity(ts, A, mode=mode)[0][0], truth)
5005-
self.assertAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0][0], truth)
5005+
self.assertAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0], truth)
50065006
self.assertAlmostEqual(ts.diversity(A, mode="branch")[0], truth)
50075007

50085008
# Y-statistic for (0/12)
@@ -5017,7 +5017,7 @@ def f(x):
50175017
branch_true_Y)
50185018
self.assertArrayAlmostEqual(ts.Y3([[0], [1], [2]], [(0, 1, 2)], mode=mode),
50195019
branch_true_Y)
5020-
self.assertArrayAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0][0],
5020+
self.assertArrayAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0],
50215021
branch_true_Y)
50225022

50235023
# sites:
@@ -5026,7 +5026,7 @@ def f(x):
50265026
py_ssc_Y = Y3(ts, [[0], [1], [2]], [(0, 1, 2)], windows=[0.0, 1.0])
50275027
self.assertAlmostEqual(site_tsc_Y, site_true_Y)
50285028
self.assertAlmostEqual(py_ssc_Y, site_true_Y)
5029-
self.assertAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0][0],
5029+
self.assertAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0],
50305030
site_true_Y)
50315031

50325032

@@ -5294,122 +5294,3 @@ def test_Y3_windows(self):
52945294
def test_f3_windows(self):
52955295
ts = self.get_example_ts()
52965296
self.verify_three_way_stat_windows(ts, ts.f3)
5297-
5298-
############################################
5299-
# Old code where stats are defined within type
5300-
# specific calculattors. These definititions have been
5301-
# move to stat-specific regions above
5302-
# The only thing left to port is the SFS code.
5303-
############################################
5304-
5305-
5306-
class PythonBranchStatCalculator(object):
5307-
"""
5308-
Python implementations of various ("tree") branch-length statistics -
5309-
inefficient but more clear what they are doing.
5310-
"""
5311-
5312-
def __init__(self, tree_sequence):
5313-
self.tree_sequence = tree_sequence
5314-
5315-
def site_frequency_spectrum(self, sample_set, windows=None):
5316-
if windows is None:
5317-
windows = [0.0, self.tree_sequence.sequence_length]
5318-
n_out = len(sample_set)
5319-
out = np.zeros((n_out, len(windows) - 1))
5320-
for j in range(len(windows) - 1):
5321-
begin = windows[j]
5322-
end = windows[j + 1]
5323-
S = [0.0 for j in range(n_out)]
5324-
for t in self.tree_sequence.trees(tracked_samples=sample_set,
5325-
sample_counts=True):
5326-
root = t.root
5327-
tr_len = min(end, t.interval[1]) - max(begin, t.interval[0])
5328-
if tr_len > 0:
5329-
for node in t.nodes():
5330-
if node != root:
5331-
x = t.num_tracked_samples(node)
5332-
if x > 0:
5333-
S[x - 1] += t.branch_length(node) * tr_len
5334-
for j in range(n_out):
5335-
S[j] /= (end-begin)
5336-
out[j] = S
5337-
return(out)
5338-
5339-
5340-
class PythonSiteStatCalculator(object):
5341-
"""
5342-
Python implementations of various single-site statistics -
5343-
inefficient but more clear what they are doing.
5344-
"""
5345-
5346-
def __init__(self, tree_sequence):
5347-
self.tree_sequence = tree_sequence
5348-
5349-
def sample_count_stats(self, sample_sets, f, windows=None, polarised=False):
5350-
'''
5351-
Here sample_sets is a list of lists of samples, and f is a function
5352-
whose argument is a list of integers of the same length as sample_sets
5353-
that returns a list of numbers; there will be one output for each element.
5354-
For each value, each allele in a tree is weighted by f(x), where
5355-
x[i] is the number of samples in sample_sets[i] that inherit that allele.
5356-
This finds the sum of this value for all alleles at all polymorphic sites,
5357-
and across the tree sequence ts, weighted by genomic length.
5358-
5359-
This version is inefficient as it works directly with haplotypes.
5360-
'''
5361-
if windows is None:
5362-
windows = [0.0, self.tree_sequence.sequence_length]
5363-
for U in sample_sets:
5364-
if max([U.count(x) for x in set(U)]) > 1:
5365-
raise ValueError("elements of sample_sets",
5366-
"cannot contain repeated elements.")
5367-
haps = list(self.tree_sequence.haplotypes())
5368-
n_out = len(f([0 for a in sample_sets]))
5369-
out = np.zeros((n_out, len(windows) - 1))
5370-
for j in range(len(windows) - 1):
5371-
begin = windows[j]
5372-
end = windows[j + 1]
5373-
site_positions = [x.position for x in self.tree_sequence.sites()]
5374-
S = [0.0 for j in range(n_out)]
5375-
for k in range(self.tree_sequence.num_sites):
5376-
if (site_positions[k] >= begin) and (site_positions[k] < end):
5377-
all_g = [haps[j][k] for j in range(self.tree_sequence.num_samples)]
5378-
g = [[haps[j][k] for j in u] for u in sample_sets]
5379-
for a in set(all_g):
5380-
x = [h.count(a) for h in g]
5381-
w = f(x)
5382-
for j in range(n_out):
5383-
S[j] += w[j]
5384-
for j in range(n_out):
5385-
S[j] /= (end - begin)
5386-
out[j] = np.array([S])
5387-
return out
5388-
5389-
def naive_general_stat(self, W, f, windows=None, polarised=False):
5390-
return naive_site_general_stat(
5391-
self.tree_sequence, W, f, windows=windows, polarised=polarised)
5392-
5393-
def site_frequency_spectrum(self, sample_set, windows=None):
5394-
if windows is None:
5395-
windows = [0.0, self.tree_sequence.sequence_length]
5396-
haps = list(self.tree_sequence.haplotypes())
5397-
site_positions = [x.position for x in self.tree_sequence.sites()]
5398-
n_out = len(sample_set)
5399-
out = np.zeros((n_out, len(windows) - 1))
5400-
for j in range(len(windows) - 1):
5401-
begin = windows[j]
5402-
end = windows[j + 1]
5403-
S = [0.0 for j in range(n_out)]
5404-
for k in range(self.tree_sequence.num_sites):
5405-
if (site_positions[k] >= begin) and (site_positions[k] < end):
5406-
all_g = [haps[j][k] for j in range(self.tree_sequence.num_samples)]
5407-
g = [haps[j][k] for j in sample_set]
5408-
for a in set(all_g):
5409-
x = g.count(a)
5410-
if x > 0:
5411-
S[x - 1] += 1.0
5412-
for j in range(n_out):
5413-
S[j] /= (end - begin)
5414-
out[j] = S
5415-
return out

0 commit comments

Comments
 (0)