Skip to content

Commit 8e695cc

Browse files
Add dimension stripping to derived stats.
1 parent f40be31 commit 8e695cc

File tree

2 files changed

+136
-113
lines changed

2 files changed

+136
-113
lines changed

python/tests/test_tree_stats.py

Lines changed: 79 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1719,12 +1719,9 @@ def verify_persite_Fst(self, ts, sample_sets, indexes):
17191719
self.assertArrayAlmostEqual(sigma1, sigma2)
17201720

17211721

1722-
class FstInterfaceMixin(StatsTestCase):
1722+
class FstInterfaceMixin(object):
17231723

1724-
# Since Fst is defined using diversity and divergence, we don't seriously
1725-
# test it for correctness, and only test the interface.
1726-
1727-
def verify_interface(self):
1724+
def test_interface(self):
17281725
ts = msprime.simulate(10, mutation_rate=0.0)
17291726
sample_sets = [[0, 1, 2], [6, 7], [4]]
17301727
with self.assertRaises(ValueError):
@@ -1735,32 +1732,23 @@ def verify_interface(self):
17351732
ts.Fst(sample_sets, indexes=[(0, 1), (0, 20)])
17361733
sigma1 = ts.Fst(sample_sets, indexes=[(0, 1)], mode=self.mode)
17371734
sigma2 = ts.Fst(sample_sets, indexes=[(0, 1), (0, 2), (1, 2)], mode=self.mode)
1738-
if self.mode == "node":
1739-
self.assertArrayAlmostEqual(sigma1[:, :, 0], sigma2[:, :, 0])
1740-
else:
1741-
self.assertArrayAlmostEqual(sigma1[:, 0], sigma2[:, 0])
1735+
self.assertArrayAlmostEqual(sigma1[..., 0], sigma2[..., 0])
17421736

17431737

17441738
class TestSiteFst(TestFst, MutatedTopologyExamplesMixin, FstInterfaceMixin):
17451739
mode = "site"
17461740

1747-
def test_interface(self):
1748-
self.verify_interface()
17491741

1742+
# Since Fst is defined using diversity and divergence, we don't seriously
1743+
# test it for correctness for node and branch, and only test the interface.
17501744

1751-
class TestNodeFst(FstInterfaceMixin):
1745+
class TestNodeFst(StatsTestCase, FstInterfaceMixin):
17521746
mode = "node"
17531747

1754-
def test_interface(self):
1755-
self.verify_interface()
17561748

1757-
1758-
class TestBranchFst(FstInterfaceMixin):
1749+
class TestBranchFst(StatsTestCase, FstInterfaceMixin):
17591750
mode = "node"
17601751

1761-
def test_interface(self):
1762-
self.verify_interface()
1763-
17641752

17651753
############################################
17661754
# Y2
@@ -5105,56 +5093,67 @@ def test_one_way_stat_default_windows(self):
51055093
# We're adding on the *last* dimension, so must reshape
51065094
self.assertArrayEqual(x.reshape(ts.num_nodes, 1), y)
51075095

5108-
def test_one_way_stat_windows(self):
5109-
ts = self.get_example_ts()
5096+
def verify_one_way_stat_windows(self, ts, method):
51105097
L = ts.sequence_length
51115098
N = ts.num_nodes
51125099

5113-
windows = [0, L / 4, L / 2, L]
5100+
windows = [0, L / 4, L / 2, 0.75 * L, L]
51145101
A = ts.samples()[:6]
51155102
B = ts.samples()[6:]
51165103
for mode in ["site", "branch"]:
5117-
x = ts.diversity([A, B], windows=windows, mode=mode)
5118-
# Three windows, 2 sets.
5119-
self.assertEqual(x.shape, (3, 2))
5104+
x = method([A, B], windows=windows, mode=mode)
5105+
# Four windows, 2 sets.
5106+
self.assertEqual(x.shape, (4, 2))
51205107

5121-
x = ts.diversity([A], windows=windows, mode=mode)
5122-
# Three windows, 1 sets.
5123-
self.assertEqual(x.shape, (3, 1))
5108+
x = method([A], windows=windows, mode=mode)
5109+
# Four windows, 1 sets.
5110+
self.assertEqual(x.shape, (4, 1))
51245111

5125-
x = ts.diversity(A, windows=windows, mode=mode)
5112+
x = method(A, windows=windows, mode=mode)
51265113
# Dropping the outer list removes the last dimension
5127-
self.assertEqual(x.shape, (3, ))
5114+
self.assertEqual(x.shape, (4, ))
51285115

5129-
x = ts.diversity(windows=windows, mode=mode)
5116+
x = method(windows=windows, mode=mode)
51305117
# Default returns this for all samples
5131-
self.assertEqual(x.shape, (3, ))
5132-
y = ts.diversity(ts.samples(), windows=windows, mode=mode)
5118+
self.assertEqual(x.shape, (4, ))
5119+
y = method(ts.samples(), windows=windows, mode=mode)
51335120
self.assertArrayEqual(x, y)
51345121

51355122
mode = "node"
5136-
x = ts.diversity([A, B], windows=windows, mode=mode)
5137-
# Three windows, N nodes and 2 sets.
5138-
self.assertEqual(x.shape, (3, N, 2))
5123+
x = method([A, B], windows=windows, mode=mode)
5124+
# Four windows, N nodes and 2 sets.
5125+
self.assertEqual(x.shape, (4, N, 2))
51395126

5140-
x = ts.diversity([A], windows=windows, mode=mode)
5141-
# Three windows, N nodes and 1 set.
5142-
self.assertEqual(x.shape, (3, N, 1))
5127+
x = method([A], windows=windows, mode=mode)
5128+
# Four windows, N nodes and 1 set.
5129+
self.assertEqual(x.shape, (4, N, 1))
51435130

5144-
x = ts.diversity(A, windows=windows, mode=mode)
5131+
x = method(A, windows=windows, mode=mode)
51455132
# Drop the outer list, so we lose the last dimension
5146-
self.assertEqual(x.shape, (3, N))
5133+
self.assertEqual(x.shape, (4, N))
51475134

5148-
x = ts.diversity(windows=windows, mode=mode)
5135+
x = method(windows=windows, mode=mode)
51495136
# The default sample sets also drops the last dimension
5150-
self.assertEqual(x.shape, (3, N))
5137+
self.assertEqual(x.shape, (4, N))
51515138

51525139
self.assertEqual(ts.num_trees, 1)
51535140
# In this example, we know that the trees are all the same so check this
51545141
# for sanity.
51555142
self.assertArrayEqual(x[0], x[1])
51565143
self.assertArrayEqual(x[0], x[2])
51575144

5145+
def test_diversity_windows(self):
5146+
ts = self.get_example_ts()
5147+
self.verify_one_way_stat_windows(ts, ts.diversity)
5148+
5149+
def test_Tajimas_D_windows(self):
5150+
ts = self.get_example_ts()
5151+
self.verify_one_way_stat_windows(ts, ts.Tajimas_D)
5152+
5153+
def test_segregating_sites_windows(self):
5154+
ts = self.get_example_ts()
5155+
self.verify_one_way_stat_windows(ts, ts.segregating_sites)
5156+
51585157
def test_two_way_stat_default_windows(self):
51595158
ts = self.get_example_ts()
51605159
# Use divergence as the example one-way stat.
@@ -5179,47 +5178,45 @@ def test_two_way_stat_default_windows(self):
51795178
# We're adding on the *last* dimension, so must reshape
51805179
self.assertArrayEqual(x.reshape(ts.num_nodes, 1), y)
51815180

5182-
def test_two_way_stat_windows(self):
5183-
5184-
ts = self.get_example_ts()
5181+
def verify_two_way_stat_windows(self, ts, method):
51855182
L = ts.sequence_length
51865183
N = ts.num_nodes
51875184

51885185
windows = [0, L / 4, L / 2, L]
51895186
A = ts.samples()[:7]
51905187
B = ts.samples()[7:]
51915188
for mode in ["site", "branch"]:
5192-
x = ts.divergence(
5189+
x = method(
51935190
[A, B, A], indexes=[[0, 1], [0, 2]], windows=windows, mode=mode)
51945191
# Three windows, 2 pairs
51955192
self.assertEqual(x.shape, (3, 2))
51965193

5197-
x = ts.divergence([A, B], indexes=[[0, 1]], windows=windows, mode=mode)
5194+
x = method([A, B], indexes=[[0, 1]], windows=windows, mode=mode)
51985195
# Three windows, 1 pair
51995196
self.assertEqual(x.shape, (3, 1))
52005197

5201-
x = ts.divergence([A, B], indexes=[0, 1], windows=windows, mode=mode)
5198+
x = method([A, B], indexes=[0, 1], windows=windows, mode=mode)
52025199
# Dropping the outer list removes the last dimension
52035200
self.assertEqual(x.shape, (3, ))
52045201

5205-
y = ts.divergence([A, B], windows=windows, mode=mode)
5202+
y = method([A, B], windows=windows, mode=mode)
52065203
self.assertEqual(y.shape, (3, ))
52075204
self.assertArrayEqual(x, y)
52085205

52095206
mode = "node"
5210-
x = ts.divergence([A, B], indexes=[[0, 1], [0, 1]], windows=windows, mode=mode)
5207+
x = method([A, B], indexes=[[0, 1], [0, 1]], windows=windows, mode=mode)
52115208
# Three windows, N nodes and 2 pairs
52125209
self.assertEqual(x.shape, (3, N, 2))
52135210

5214-
x = ts.divergence([A, B], indexes=[[0, 1]], windows=windows, mode=mode)
5211+
x = method([A, B], indexes=[[0, 1]], windows=windows, mode=mode)
52155212
# Three windows, N nodes and 1 pairs
52165213
self.assertEqual(x.shape, (3, N, 1))
52175214

5218-
x = ts.divergence([A, B], indexes=[0, 1], windows=windows, mode=mode)
5215+
x = method([A, B], indexes=[0, 1], windows=windows, mode=mode)
52195216
# Drop the outer list, so we lose the last dimension
52205217
self.assertEqual(x.shape, (3, N))
52215218

5222-
x = ts.divergence([A, B], windows=windows, mode=mode)
5219+
x = method([A, B], windows=windows, mode=mode)
52235220
# The default sample sets also drops the last dimension
52245221
self.assertEqual(x.shape, (3, N))
52255222

@@ -5229,8 +5226,19 @@ def test_two_way_stat_windows(self):
52295226
self.assertArrayEqual(x[0], x[1])
52305227
self.assertArrayEqual(x[0], x[2])
52315228

5232-
def test_three_way_stat_windows(self):
5229+
def test_divergence_windows(self):
52335230
ts = self.get_example_ts()
5231+
self.verify_two_way_stat_windows(ts, ts.divergence)
5232+
5233+
def test_Fst_windows(self):
5234+
ts = self.get_example_ts()
5235+
self.verify_two_way_stat_windows(ts, ts.Fst)
5236+
5237+
def test_f2_windows(self):
5238+
ts = self.get_example_ts()
5239+
self.verify_two_way_stat_windows(ts, ts.f2)
5240+
5241+
def verify_three_way_stat_windows(self, ts, method):
52345242
L = ts.sequence_length
52355243
N = ts.num_nodes
52365244

@@ -5239,37 +5247,37 @@ def test_three_way_stat_windows(self):
52395247
B = ts.samples()[2: 4]
52405248
C = ts.samples()[4:]
52415249
for mode in ["site", "branch"]:
5242-
x = ts.Y3(
5250+
x = method(
52435251
[A, B, C], indexes=[[0, 1, 2], [0, 2, 1]], windows=windows, mode=mode)
52445252
# Three windows, 2 triple
52455253
self.assertEqual(x.shape, (3, 2))
52465254

5247-
x = ts.Y3([A, B, C], indexes=[[0, 1, 2]], windows=windows, mode=mode)
5255+
x = method([A, B, C], indexes=[[0, 1, 2]], windows=windows, mode=mode)
52485256
# Three windows, 1 triple
52495257
self.assertEqual(x.shape, (3, 1))
52505258

5251-
x = ts.Y3([A, B, C], indexes=[0, 1, 2], windows=windows, mode=mode)
5259+
x = method([A, B, C], indexes=[0, 1, 2], windows=windows, mode=mode)
52525260
# Dropping the outer list removes the last dimension
52535261
self.assertEqual(x.shape, (3, ))
52545262

5255-
y = ts.Y3([A, B, C], windows=windows, mode=mode)
5263+
y = method([A, B, C], windows=windows, mode=mode)
52565264
self.assertEqual(y.shape, (3, ))
52575265
self.assertArrayEqual(x, y)
52585266

52595267
mode = "node"
5260-
x = ts.Y3([A, B, C], indexes=[[0, 1, 2], [0, 2, 1]], windows=windows, mode=mode)
5268+
x = method([A, B, C], indexes=[[0, 1, 2], [0, 2, 1]], windows=windows, mode=mode)
52615269
# Three windows, N nodes and 2 triples
52625270
self.assertEqual(x.shape, (3, N, 2))
52635271

5264-
x = ts.Y3([A, B, C], indexes=[[0, 1, 2]], windows=windows, mode=mode)
5272+
x = method([A, B, C], indexes=[[0, 1, 2]], windows=windows, mode=mode)
52655273
# Three windows, N nodes and 1 triples
52665274
self.assertEqual(x.shape, (3, N, 1))
52675275

5268-
x = ts.Y3([A, B, C], indexes=[0, 1, 2], windows=windows, mode=mode)
5276+
x = method([A, B, C], indexes=[0, 1, 2], windows=windows, mode=mode)
52695277
# Drop the outer list, so we lose the last dimension
52705278
self.assertEqual(x.shape, (3, N))
52715279

5272-
x = ts.Y3([A, B, C], windows=windows, mode=mode)
5280+
x = method([A, B, C], windows=windows, mode=mode)
52735281
# The default sample sets also drops the last dimension
52745282
self.assertEqual(x.shape, (3, N))
52755283

@@ -5279,6 +5287,13 @@ def test_three_way_stat_windows(self):
52795287
self.assertArrayEqual(x[0], x[1])
52805288
self.assertArrayEqual(x[0], x[2])
52815289

5290+
def test_Y3_windows(self):
5291+
ts = self.get_example_ts()
5292+
self.verify_three_way_stat_windows(ts, ts.Y3)
5293+
5294+
def test_f3_windows(self):
5295+
ts = self.get_example_ts()
5296+
self.verify_three_way_stat_windows(ts, ts.f3)
52825297

52835298
############################################
52845299
# Old code where stats are defined within type
@@ -5287,6 +5302,7 @@ def test_three_way_stat_windows(self):
52875302
# The only thing left to port is the SFS code.
52885303
############################################
52895304

5305+
52905306
class PythonBranchStatCalculator(object):
52915307
"""
52925308
Python implementations of various ("tree") branch-length statistics -

0 commit comments

Comments
 (0)