@@ -1719,12 +1719,9 @@ def verify_persite_Fst(self, ts, sample_sets, indexes):
17191719 self .assertArrayAlmostEqual (sigma1 , sigma2 )
17201720
17211721
1722- class FstInterfaceMixin (StatsTestCase ):
1722+ class FstInterfaceMixin (object ):
17231723
1724- # Since Fst is defined using diversity and divergence, we don't seriously
1725- # test it for correctness, and only test the interface.
1726-
1727- def verify_interface (self ):
1724+ def test_interface (self ):
17281725 ts = msprime .simulate (10 , mutation_rate = 0.0 )
17291726 sample_sets = [[0 , 1 , 2 ], [6 , 7 ], [4 ]]
17301727 with self .assertRaises (ValueError ):
@@ -1735,32 +1732,23 @@ def verify_interface(self):
17351732 ts .Fst (sample_sets , indexes = [(0 , 1 ), (0 , 20 )])
17361733 sigma1 = ts .Fst (sample_sets , indexes = [(0 , 1 )], mode = self .mode )
17371734 sigma2 = ts .Fst (sample_sets , indexes = [(0 , 1 ), (0 , 2 ), (1 , 2 )], mode = self .mode )
1738- if self .mode == "node" :
1739- self .assertArrayAlmostEqual (sigma1 [:, :, 0 ], sigma2 [:, :, 0 ])
1740- else :
1741- self .assertArrayAlmostEqual (sigma1 [:, 0 ], sigma2 [:, 0 ])
1735+ self .assertArrayAlmostEqual (sigma1 [..., 0 ], sigma2 [..., 0 ])
17421736
17431737
17441738class TestSiteFst (TestFst , MutatedTopologyExamplesMixin , FstInterfaceMixin ):
17451739 mode = "site"
17461740
1747- def test_interface (self ):
1748- self .verify_interface ()
17491741
1742+ # Since Fst is defined using diversity and divergence, we don't seriously
1743+ # test it for correctness for node and branch, and only test the interface.
17501744
1751- class TestNodeFst (FstInterfaceMixin ):
1745+ class TestNodeFst (StatsTestCase , FstInterfaceMixin ):
17521746 mode = "node"
17531747
1754- def test_interface (self ):
1755- self .verify_interface ()
17561748
1757-
1758- class TestBranchFst (FstInterfaceMixin ):
1749+ class TestBranchFst (StatsTestCase , FstInterfaceMixin ):
17591750 mode = "node"
17601751
1761- def test_interface (self ):
1762- self .verify_interface ()
1763-
17641752
17651753############################################
17661754# Y2
@@ -5105,56 +5093,67 @@ def test_one_way_stat_default_windows(self):
51055093 # We're adding on the *last* dimension, so must reshape
51065094 self .assertArrayEqual (x .reshape (ts .num_nodes , 1 ), y )
51075095
5108- def test_one_way_stat_windows (self ):
5109- ts = self .get_example_ts ()
5096+ def verify_one_way_stat_windows (self , ts , method ):
51105097 L = ts .sequence_length
51115098 N = ts .num_nodes
51125099
5113- windows = [0 , L / 4 , L / 2 , L ]
5100+ windows = [0 , L / 4 , L / 2 , 0.75 * L , L ]
51145101 A = ts .samples ()[:6 ]
51155102 B = ts .samples ()[6 :]
51165103 for mode in ["site" , "branch" ]:
5117- x = ts . diversity ([A , B ], windows = windows , mode = mode )
5118- # Three windows, 2 sets.
5119- self .assertEqual (x .shape , (3 , 2 ))
5104+ x = method ([A , B ], windows = windows , mode = mode )
5105+ # Four windows, 2 sets.
5106+ self .assertEqual (x .shape , (4 , 2 ))
51205107
5121- x = ts . diversity ([A ], windows = windows , mode = mode )
5122- # Three windows, 1 sets.
5123- self .assertEqual (x .shape , (3 , 1 ))
5108+ x = method ([A ], windows = windows , mode = mode )
5109+ # Four windows, 1 sets.
5110+ self .assertEqual (x .shape , (4 , 1 ))
51245111
5125- x = ts . diversity (A , windows = windows , mode = mode )
5112+ x = method (A , windows = windows , mode = mode )
51265113 # Dropping the outer list removes the last dimension
5127- self .assertEqual (x .shape , (3 , ))
5114+ self .assertEqual (x .shape , (4 , ))
51285115
5129- x = ts . diversity (windows = windows , mode = mode )
5116+ x = method (windows = windows , mode = mode )
51305117 # Default returns this for all samples
5131- self .assertEqual (x .shape , (3 , ))
5132- y = ts . diversity (ts .samples (), windows = windows , mode = mode )
5118+ self .assertEqual (x .shape , (4 , ))
5119+ y = method (ts .samples (), windows = windows , mode = mode )
51335120 self .assertArrayEqual (x , y )
51345121
51355122 mode = "node"
5136- x = ts . diversity ([A , B ], windows = windows , mode = mode )
5137- # Three windows, N nodes and 2 sets.
5138- self .assertEqual (x .shape , (3 , N , 2 ))
5123+ x = method ([A , B ], windows = windows , mode = mode )
5124+ # Four windows, N nodes and 2 sets.
5125+ self .assertEqual (x .shape , (4 , N , 2 ))
51395126
5140- x = ts . diversity ([A ], windows = windows , mode = mode )
5141- # Three windows, N nodes and 1 set.
5142- self .assertEqual (x .shape , (3 , N , 1 ))
5127+ x = method ([A ], windows = windows , mode = mode )
5128+ # Four windows, N nodes and 1 set.
5129+ self .assertEqual (x .shape , (4 , N , 1 ))
51435130
5144- x = ts . diversity (A , windows = windows , mode = mode )
5131+ x = method (A , windows = windows , mode = mode )
51455132 # Drop the outer list, so we lose the last dimension
5146- self .assertEqual (x .shape , (3 , N ))
5133+ self .assertEqual (x .shape , (4 , N ))
51475134
5148- x = ts . diversity (windows = windows , mode = mode )
5135+ x = method (windows = windows , mode = mode )
51495136 # The default sample sets also drops the last dimension
5150- self .assertEqual (x .shape , (3 , N ))
5137+ self .assertEqual (x .shape , (4 , N ))
51515138
51525139 self .assertEqual (ts .num_trees , 1 )
51535140 # In this example, we know that the trees are all the same so check this
51545141 # for sanity.
51555142 self .assertArrayEqual (x [0 ], x [1 ])
51565143 self .assertArrayEqual (x [0 ], x [2 ])
51575144
5145+ def test_diversity_windows (self ):
5146+ ts = self .get_example_ts ()
5147+ self .verify_one_way_stat_windows (ts , ts .diversity )
5148+
5149+ def test_Tajimas_D_windows (self ):
5150+ ts = self .get_example_ts ()
5151+ self .verify_one_way_stat_windows (ts , ts .Tajimas_D )
5152+
5153+ def test_segregating_sites_windows (self ):
5154+ ts = self .get_example_ts ()
5155+ self .verify_one_way_stat_windows (ts , ts .segregating_sites )
5156+
51585157 def test_two_way_stat_default_windows (self ):
51595158 ts = self .get_example_ts ()
51605159 # Use divergence as the example one-way stat.
@@ -5179,47 +5178,45 @@ def test_two_way_stat_default_windows(self):
51795178 # We're adding on the *last* dimension, so must reshape
51805179 self .assertArrayEqual (x .reshape (ts .num_nodes , 1 ), y )
51815180
5182- def test_two_way_stat_windows (self ):
5183-
5184- ts = self .get_example_ts ()
5181+ def verify_two_way_stat_windows (self , ts , method ):
51855182 L = ts .sequence_length
51865183 N = ts .num_nodes
51875184
51885185 windows = [0 , L / 4 , L / 2 , L ]
51895186 A = ts .samples ()[:7 ]
51905187 B = ts .samples ()[7 :]
51915188 for mode in ["site" , "branch" ]:
5192- x = ts . divergence (
5189+ x = method (
51935190 [A , B , A ], indexes = [[0 , 1 ], [0 , 2 ]], windows = windows , mode = mode )
51945191 # Three windows, 2 pairs
51955192 self .assertEqual (x .shape , (3 , 2 ))
51965193
5197- x = ts . divergence ([A , B ], indexes = [[0 , 1 ]], windows = windows , mode = mode )
5194+ x = method ([A , B ], indexes = [[0 , 1 ]], windows = windows , mode = mode )
51985195 # Three windows, 1 pair
51995196 self .assertEqual (x .shape , (3 , 1 ))
52005197
5201- x = ts . divergence ([A , B ], indexes = [0 , 1 ], windows = windows , mode = mode )
5198+ x = method ([A , B ], indexes = [0 , 1 ], windows = windows , mode = mode )
52025199 # Dropping the outer list removes the last dimension
52035200 self .assertEqual (x .shape , (3 , ))
52045201
5205- y = ts . divergence ([A , B ], windows = windows , mode = mode )
5202+ y = method ([A , B ], windows = windows , mode = mode )
52065203 self .assertEqual (y .shape , (3 , ))
52075204 self .assertArrayEqual (x , y )
52085205
52095206 mode = "node"
5210- x = ts . divergence ([A , B ], indexes = [[0 , 1 ], [0 , 1 ]], windows = windows , mode = mode )
5207+ x = method ([A , B ], indexes = [[0 , 1 ], [0 , 1 ]], windows = windows , mode = mode )
52115208 # Three windows, N nodes and 2 pairs
52125209 self .assertEqual (x .shape , (3 , N , 2 ))
52135210
5214- x = ts . divergence ([A , B ], indexes = [[0 , 1 ]], windows = windows , mode = mode )
5211+ x = method ([A , B ], indexes = [[0 , 1 ]], windows = windows , mode = mode )
52155212 # Three windows, N nodes and 1 pairs
52165213 self .assertEqual (x .shape , (3 , N , 1 ))
52175214
5218- x = ts . divergence ([A , B ], indexes = [0 , 1 ], windows = windows , mode = mode )
5215+ x = method ([A , B ], indexes = [0 , 1 ], windows = windows , mode = mode )
52195216 # Drop the outer list, so we lose the last dimension
52205217 self .assertEqual (x .shape , (3 , N ))
52215218
5222- x = ts . divergence ([A , B ], windows = windows , mode = mode )
5219+ x = method ([A , B ], windows = windows , mode = mode )
52235220 # The default sample sets also drops the last dimension
52245221 self .assertEqual (x .shape , (3 , N ))
52255222
@@ -5229,8 +5226,19 @@ def test_two_way_stat_windows(self):
52295226 self .assertArrayEqual (x [0 ], x [1 ])
52305227 self .assertArrayEqual (x [0 ], x [2 ])
52315228
5232- def test_three_way_stat_windows (self ):
5229+ def test_divergence_windows (self ):
52335230 ts = self .get_example_ts ()
5231+ self .verify_two_way_stat_windows (ts , ts .divergence )
5232+
5233+ def test_Fst_windows (self ):
5234+ ts = self .get_example_ts ()
5235+ self .verify_two_way_stat_windows (ts , ts .Fst )
5236+
5237+ def test_f2_windows (self ):
5238+ ts = self .get_example_ts ()
5239+ self .verify_two_way_stat_windows (ts , ts .f2 )
5240+
5241+ def verify_three_way_stat_windows (self , ts , method ):
52345242 L = ts .sequence_length
52355243 N = ts .num_nodes
52365244
@@ -5239,37 +5247,37 @@ def test_three_way_stat_windows(self):
52395247 B = ts .samples ()[2 : 4 ]
52405248 C = ts .samples ()[4 :]
52415249 for mode in ["site" , "branch" ]:
5242- x = ts . Y3 (
5250+ x = method (
52435251 [A , B , C ], indexes = [[0 , 1 , 2 ], [0 , 2 , 1 ]], windows = windows , mode = mode )
52445252 # Three windows, 2 triple
52455253 self .assertEqual (x .shape , (3 , 2 ))
52465254
5247- x = ts . Y3 ([A , B , C ], indexes = [[0 , 1 , 2 ]], windows = windows , mode = mode )
5255+ x = method ([A , B , C ], indexes = [[0 , 1 , 2 ]], windows = windows , mode = mode )
52485256 # Three windows, 1 triple
52495257 self .assertEqual (x .shape , (3 , 1 ))
52505258
5251- x = ts . Y3 ([A , B , C ], indexes = [0 , 1 , 2 ], windows = windows , mode = mode )
5259+ x = method ([A , B , C ], indexes = [0 , 1 , 2 ], windows = windows , mode = mode )
52525260 # Dropping the outer list removes the last dimension
52535261 self .assertEqual (x .shape , (3 , ))
52545262
5255- y = ts . Y3 ([A , B , C ], windows = windows , mode = mode )
5263+ y = method ([A , B , C ], windows = windows , mode = mode )
52565264 self .assertEqual (y .shape , (3 , ))
52575265 self .assertArrayEqual (x , y )
52585266
52595267 mode = "node"
5260- x = ts . Y3 ([A , B , C ], indexes = [[0 , 1 , 2 ], [0 , 2 , 1 ]], windows = windows , mode = mode )
5268+ x = method ([A , B , C ], indexes = [[0 , 1 , 2 ], [0 , 2 , 1 ]], windows = windows , mode = mode )
52615269 # Three windows, N nodes and 2 triples
52625270 self .assertEqual (x .shape , (3 , N , 2 ))
52635271
5264- x = ts . Y3 ([A , B , C ], indexes = [[0 , 1 , 2 ]], windows = windows , mode = mode )
5272+ x = method ([A , B , C ], indexes = [[0 , 1 , 2 ]], windows = windows , mode = mode )
52655273 # Three windows, N nodes and 1 triples
52665274 self .assertEqual (x .shape , (3 , N , 1 ))
52675275
5268- x = ts . Y3 ([A , B , C ], indexes = [0 , 1 , 2 ], windows = windows , mode = mode )
5276+ x = method ([A , B , C ], indexes = [0 , 1 , 2 ], windows = windows , mode = mode )
52695277 # Drop the outer list, so we lose the last dimension
52705278 self .assertEqual (x .shape , (3 , N ))
52715279
5272- x = ts . Y3 ([A , B , C ], windows = windows , mode = mode )
5280+ x = method ([A , B , C ], windows = windows , mode = mode )
52735281 # The default sample sets also drops the last dimension
52745282 self .assertEqual (x .shape , (3 , N ))
52755283
@@ -5279,6 +5287,13 @@ def test_three_way_stat_windows(self):
52795287 self .assertArrayEqual (x [0 ], x [1 ])
52805288 self .assertArrayEqual (x [0 ], x [2 ])
52815289
5290+ def test_Y3_windows (self ):
5291+ ts = self .get_example_ts ()
5292+ self .verify_three_way_stat_windows (ts , ts .Y3 )
5293+
5294+ def test_f3_windows (self ):
5295+ ts = self .get_example_ts ()
5296+ self .verify_three_way_stat_windows (ts , ts .f3 )
52825297
52835298############################################
52845299# Old code where stats are defined within type
@@ -5287,6 +5302,7 @@ def test_three_way_stat_windows(self):
52875302# The only thing left to port is the SFS code.
52885303############################################
52895304
5305+
52905306class PythonBranchStatCalculator (object ):
52915307 """
52925308 Python implementations of various ("tree") branch-length statistics -
0 commit comments