Skip to content

Commit dd9ddb4

Browse files
petrelharpjeromekelleher
authored andcommitted
stats tests fixups
1 parent 47ec75f commit dd9ddb4

File tree

2 files changed

+25
-142
lines changed

2 files changed

+25
-142
lines changed

python/tests/test_tree_stats.py

Lines changed: 18 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -796,13 +796,16 @@ def example_sample_sets(ts, min_size=1):
796796
number of sample sets returned in each example must be at least min_size
797797
"""
798798
samples = ts.samples()
799+
np.random.shuffle(samples)
799800
splits = np.array_split(samples, min_size)
800801
yield splits
801802
yield [[s] for s in samples]
802803
if min_size == 1:
803804
yield [samples[:1]]
804-
if ts.num_samples <= 2 and min_size >= 2:
805+
if ts.num_samples > 2 and min_size <= 2:
805806
yield [samples[:2], samples[2:]]
807+
if ts.num_samples > 7 and min_size <= 4:
808+
yield [samples[:2], samples[2:4], samples[4:6], samples[6:]]
806809

807810

808811
def example_sample_set_index_pairs(sample_sets):
@@ -1163,7 +1166,7 @@ def site_segregating_sites(ts, sample_sets, windows=None, span_normalise=True):
11631166
haps = ts.genotype_matrix(impute_missing_data=True)
11641167
site_positions = [x.position for x in ts.sites()]
11651168
for i, X in enumerate(sample_sets):
1166-
X_index = np.where(np.in1d(X, samples))[0]
1169+
X_index = np.where(np.in1d(samples, X))[0]
11671170
for k in range(ts.num_sites):
11681171
if (site_positions[k] >= begin) and (site_positions[k] < end):
11691172
num_alleles = len(set(haps[k, X_index]))
@@ -1303,7 +1306,7 @@ def site_tajimas_d(ts, sample_sets, windows=None):
13031306
nn = n[i]
13041307
S = 0
13051308
T = 0
1306-
X_index = np.where(np.in1d(X, samples))[0]
1309+
X_index = np.where(np.in1d(samples, X))[0]
13071310
for k in range(ts.num_sites):
13081311
if (site_positions[k] >= begin) and (site_positions[k] < end):
13091312
hX = haps[k, X_index]
@@ -1497,7 +1500,8 @@ def verify_sample_sets(self, ts, sample_sets, windows):
14971500
denom = n * (n - 1) * (n - 2)
14981501

14991502
def f(x):
1500-
return x * (n - x) * (n - x - 1) / denom
1503+
with np.errstate(invalid='ignore', divide='ignore'):
1504+
return x * (n - x) * (n - x - 1) / denom
15011505

15021506
self.verify_definition(ts, sample_sets, windows, f, ts.Y1, Y1)
15031507

@@ -1875,8 +1879,6 @@ def node_Y2(ts, sample_sets, indexes, windows=None, span_normalise=True):
18751879

18761880

18771881
def Y2(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise=True):
1878-
windows = ts.parse_windows(windows)
1879-
18801882
windows = ts.parse_windows(windows)
18811883
if indexes is None:
18821884
indexes = [(0, 1)]
@@ -3330,11 +3332,11 @@ def test_bad_mode(self):
33303332
def test_bad_window_strings(self):
33313333
ts = self.get_tree_sequence()
33323334
with self.assertRaises(ValueError):
3333-
ts.diversity([list(ts.samples())], mode="site", windows="abc")
3335+
ts.diversity([ts.samples()], mode="site", windows="abc")
33343336
with self.assertRaises(ValueError):
3335-
ts.diversity([list(ts.samples())], mode="site", windows="")
3337+
ts.diversity([ts.samples()], mode="site", windows="")
33363338
with self.assertRaises(ValueError):
3337-
ts.diversity([list(ts.samples())], mode="tree", windows="abc")
3339+
ts.diversity([ts.samples()], mode="tree", windows="abc")
33383340

33393341
def test_bad_summary_function(self):
33403342
ts = self.get_tree_sequence()
@@ -3441,8 +3443,7 @@ class TestGeneralSiteStats(StatsTestCase):
34413443
def compare_general_stat(self, ts, W, f, windows=None, polarised=False):
34423444
# Determine output_dim of the function
34433445
M = len(f(W[0]))
3444-
py_ssc = PythonSiteStatCalculator(ts)
3445-
sigma1 = py_ssc.naive_general_stat(W, f, windows, polarised=polarised)
3446+
sigma1 = naive_site_general_stat(ts, W, f, windows, polarised=polarised)
34463447
sigma2 = ts.general_stat(W, f, M, windows, polarised=polarised, mode="site")
34473448
sigma3 = site_general_stat(ts, W, f, windows, polarised=polarised)
34483449
self.assertEqual(sigma1.shape, sigma2.shape)
@@ -4391,7 +4392,6 @@ class BranchSampleSetStatsTestCase(SampleSetStatTestCase):
43914392
def setUp(self):
43924393
self.rng = random.Random(self.random_seed)
43934394
self.stat_type = "branch"
4394-
self.py_stat_class = PythonBranchStatCalculator
43954395

43964396
def get_ts(self):
43974397
for N in [12, 15, 20]:
@@ -4595,7 +4595,7 @@ def f(x):
45954595
branch_true_mean_diversity)
45964596
self.assertAlmostEqual(ts.divergence([A[0], A[1]], [(0, 1)], mode=mode),
45974597
branch_true_mean_diversity)
4598-
self.assertAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0][0],
4598+
self.assertAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0],
45994599
branch_true_mean_diversity)
46004600

46014601
# Y-statistic for (0/12)
@@ -4610,7 +4610,7 @@ def f(x):
46104610
py_bsc_Y = Y3(ts, [[0], [1], [2]], [(0, 1, 2)], windows=[0.0, 1.0], mode=mode)
46114611
self.assertArrayAlmostEqual(bts_Y, branch_true_Y)
46124612
self.assertArrayAlmostEqual(py_bsc_Y, branch_true_Y)
4613-
self.assertArrayAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0][0],
4613+
self.assertArrayAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0],
46144614
branch_true_Y)
46154615

46164616
mode = "site"
@@ -4619,7 +4619,7 @@ def f(x):
46194619
py_ssc_Y = Y3(ts, [[0], [1], [2]], [(0, 1, 2)], windows=[0.0, 1.0], mode=mode)
46204620
self.assertArrayAlmostEqual(sts_Y, site_true_Y)
46214621
self.assertArrayAlmostEqual(py_ssc_Y, site_true_Y)
4622-
self.assertArrayAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0][0],
4622+
self.assertArrayAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0],
46234623
site_true_Y)
46244624

46254625
A = [[0, 1, 2]]
@@ -5002,7 +5002,7 @@ def f(x):
50025002
branch_true_diversity_02]):
50035003

50045004
self.assertAlmostEqual(diversity(ts, A, mode=mode)[0][0], truth)
5005-
self.assertAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0][0], truth)
5005+
self.assertAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0], truth)
50065006
self.assertAlmostEqual(ts.diversity(A, mode="branch")[0], truth)
50075007

50085008
# Y-statistic for (0/12)
@@ -5017,7 +5017,7 @@ def f(x):
50175017
branch_true_Y)
50185018
self.assertArrayAlmostEqual(ts.Y3([[0], [1], [2]], [(0, 1, 2)], mode=mode),
50195019
branch_true_Y)
5020-
self.assertArrayAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0][0],
5020+
self.assertArrayAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0],
50215021
branch_true_Y)
50225022

50235023
# sites:
@@ -5026,7 +5026,7 @@ def f(x):
50265026
py_ssc_Y = Y3(ts, [[0], [1], [2]], [(0, 1, 2)], windows=[0.0, 1.0])
50275027
self.assertAlmostEqual(site_tsc_Y, site_true_Y)
50285028
self.assertAlmostEqual(py_ssc_Y, site_true_Y)
5029-
self.assertAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0][0],
5029+
self.assertAlmostEqual(ts.sample_count_stat(A, f, 1, mode=mode)[0],
50305030
site_true_Y)
50315031

50325032

@@ -5294,122 +5294,3 @@ def test_Y3_windows(self):
52945294
def test_f3_windows(self):
52955295
ts = self.get_example_ts()
52965296
self.verify_three_way_stat_windows(ts, ts.f3)
5297-
5298-
############################################
5299-
# Old code where stats are defined within type
5300-
# specific calculattors. These definititions have been
5301-
# move to stat-specific regions above
5302-
# The only thing left to port is the SFS code.
5303-
############################################
5304-
5305-
5306-
class PythonBranchStatCalculator(object):
5307-
"""
5308-
Python implementations of various ("tree") branch-length statistics -
5309-
inefficient but more clear what they are doing.
5310-
"""
5311-
5312-
def __init__(self, tree_sequence):
5313-
self.tree_sequence = tree_sequence
5314-
5315-
def site_frequency_spectrum(self, sample_set, windows=None):
5316-
if windows is None:
5317-
windows = [0.0, self.tree_sequence.sequence_length]
5318-
n_out = len(sample_set)
5319-
out = np.zeros((n_out, len(windows) - 1))
5320-
for j in range(len(windows) - 1):
5321-
begin = windows[j]
5322-
end = windows[j + 1]
5323-
S = [0.0 for j in range(n_out)]
5324-
for t in self.tree_sequence.trees(tracked_samples=sample_set,
5325-
sample_counts=True):
5326-
root = t.root
5327-
tr_len = min(end, t.interval[1]) - max(begin, t.interval[0])
5328-
if tr_len > 0:
5329-
for node in t.nodes():
5330-
if node != root:
5331-
x = t.num_tracked_samples(node)
5332-
if x > 0:
5333-
S[x - 1] += t.branch_length(node) * tr_len
5334-
for j in range(n_out):
5335-
S[j] /= (end-begin)
5336-
out[j] = S
5337-
return(out)
5338-
5339-
5340-
class PythonSiteStatCalculator(object):
5341-
"""
5342-
Python implementations of various single-site statistics -
5343-
inefficient but more clear what they are doing.
5344-
"""
5345-
5346-
def __init__(self, tree_sequence):
5347-
self.tree_sequence = tree_sequence
5348-
5349-
def sample_count_stats(self, sample_sets, f, windows=None, polarised=False):
5350-
'''
5351-
Here sample_sets is a list of lists of samples, and f is a function
5352-
whose argument is a list of integers of the same length as sample_sets
5353-
that returns a list of numbers; there will be one output for each element.
5354-
For each value, each allele in a tree is weighted by f(x), where
5355-
x[i] is the number of samples in sample_sets[i] that inherit that allele.
5356-
This finds the sum of this value for all alleles at all polymorphic sites,
5357-
and across the tree sequence ts, weighted by genomic length.
5358-
5359-
This version is inefficient as it works directly with haplotypes.
5360-
'''
5361-
if windows is None:
5362-
windows = [0.0, self.tree_sequence.sequence_length]
5363-
for U in sample_sets:
5364-
if max([U.count(x) for x in set(U)]) > 1:
5365-
raise ValueError("elements of sample_sets",
5366-
"cannot contain repeated elements.")
5367-
haps = list(self.tree_sequence.haplotypes())
5368-
n_out = len(f([0 for a in sample_sets]))
5369-
out = np.zeros((n_out, len(windows) - 1))
5370-
for j in range(len(windows) - 1):
5371-
begin = windows[j]
5372-
end = windows[j + 1]
5373-
site_positions = [x.position for x in self.tree_sequence.sites()]
5374-
S = [0.0 for j in range(n_out)]
5375-
for k in range(self.tree_sequence.num_sites):
5376-
if (site_positions[k] >= begin) and (site_positions[k] < end):
5377-
all_g = [haps[j][k] for j in range(self.tree_sequence.num_samples)]
5378-
g = [[haps[j][k] for j in u] for u in sample_sets]
5379-
for a in set(all_g):
5380-
x = [h.count(a) for h in g]
5381-
w = f(x)
5382-
for j in range(n_out):
5383-
S[j] += w[j]
5384-
for j in range(n_out):
5385-
S[j] /= (end - begin)
5386-
out[j] = np.array([S])
5387-
return out
5388-
5389-
def naive_general_stat(self, W, f, windows=None, polarised=False):
5390-
return naive_site_general_stat(
5391-
self.tree_sequence, W, f, windows=windows, polarised=polarised)
5392-
5393-
def site_frequency_spectrum(self, sample_set, windows=None):
5394-
if windows is None:
5395-
windows = [0.0, self.tree_sequence.sequence_length]
5396-
haps = list(self.tree_sequence.haplotypes())
5397-
site_positions = [x.position for x in self.tree_sequence.sites()]
5398-
n_out = len(sample_set)
5399-
out = np.zeros((n_out, len(windows) - 1))
5400-
for j in range(len(windows) - 1):
5401-
begin = windows[j]
5402-
end = windows[j + 1]
5403-
S = [0.0 for j in range(n_out)]
5404-
for k in range(self.tree_sequence.num_sites):
5405-
if (site_positions[k] >= begin) and (site_positions[k] < end):
5406-
all_g = [haps[j][k] for j in range(self.tree_sequence.num_samples)]
5407-
g = [haps[j][k] for j in sample_set]
5408-
for a in set(all_g):
5409-
x = g.count(a)
5410-
if x > 0:
5411-
S[x - 1] += 1.0
5412-
for j in range(n_out):
5413-
S[j] /= (end - begin)
5414-
out[j] = S
5415-
return out

python/tskit/trees.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3327,7 +3327,6 @@ def general_stat(self, W, f, output_dim, windows=None, polarised=False, mode=Non
33273327
"""
33283328
if mode is None:
33293329
mode = "site"
3330-
windows = self.parse_windows(windows)
33313330
if strict:
33323331
total_weights = np.sum(W, axis=0)
33333332
for x in [total_weights, total_weights * 0.0]:
@@ -3337,13 +3336,14 @@ def general_stat(self, W, f, output_dim, windows=None, polarised=False, mode=Non
33373336
if not np.allclose(fx, np.zeros((output_dim, ))):
33383337
raise ValueError("Summary function does not return zero for both"
33393338
"zero weight and total weight.")
3340-
return self.ll_tree_sequence.general_stat(
3341-
W, f, output_dim, windows, polarised=polarised,
3339+
return self.__run_windowed_stat(
3340+
windows, self.ll_tree_sequence.general_stat,
3341+
W, f, output_dim, polarised=polarised,
33423342
span_normalise=span_normalise, mode=mode)
33433343

33443344
def sample_count_stat(
33453345
self, sample_sets, f, output_dim, windows=None, polarised=False, mode=None,
3346-
span_normalise=True):
3346+
span_normalise=True, strict=True):
33473347
"""
33483348
Compute a windowed statistic from sample counts and a summary function.
33493349
This is a wrapper around :meth:`.general_stat` for the common case in
@@ -3403,6 +3403,7 @@ def f(x):
34033403
(defaults to "site").
34043404
:param bool span_normalise: Whether to divide the result by the span of the
34053405
window (defaults to True).
3406+
:param bool strict: Whether to check that f(0) and f(total weight) are zero.
34063407
:return: A ndarray with shape equal to (num windows, num statistics).
34073408
"""
34083409
# helper function for common case where weights are indicators of sample sets
@@ -3418,7 +3419,8 @@ def f(x):
34183419

34193420
W = np.array([[float(u in A) for A in sample_sets] for u in self.samples()])
34203421
return self.general_stat(W, f, output_dim, windows=windows, polarised=polarised,
3421-
mode=mode, span_normalise=span_normalise)
3422+
mode=mode, span_normalise=span_normalise,
3423+
strict=strict)
34223424

34233425
def parse_windows(self, windows):
34243426
# Note: need to make sure windows is a string or we try to compare the

0 commit comments

Comments
 (0)