@@ -796,13 +796,16 @@ def example_sample_sets(ts, min_size=1):
796796 number of sample sets returned in each example must be at least min_size
797797 """
798798 samples = ts .samples ()
799+ np .random .shuffle (samples )
799800 splits = np .array_split (samples , min_size )
800801 yield splits
801802 yield [[s ] for s in samples ]
802803 if min_size == 1 :
803804 yield [samples [:1 ]]
804- if ts .num_samples <= 2 and min_size > = 2 :
805+ if ts .num_samples > 2 and min_size < = 2 :
805806 yield [samples [:2 ], samples [2 :]]
807+ if ts .num_samples > 7 and min_size <= 4 :
808+ yield [samples [:2 ], samples [2 :4 ], samples [4 :6 ], samples [6 :]]
806809
807810
808811def example_sample_set_index_pairs (sample_sets ):
@@ -1163,7 +1166,7 @@ def site_segregating_sites(ts, sample_sets, windows=None, span_normalise=True):
11631166 haps = ts .genotype_matrix (impute_missing_data = True )
11641167 site_positions = [x .position for x in ts .sites ()]
11651168 for i , X in enumerate (sample_sets ):
1166- X_index = np .where (np .in1d (X , samples ))[0 ]
1169+ X_index = np .where (np .in1d (samples , X ))[0 ]
11671170 for k in range (ts .num_sites ):
11681171 if (site_positions [k ] >= begin ) and (site_positions [k ] < end ):
11691172 num_alleles = len (set (haps [k , X_index ]))
@@ -1303,7 +1306,7 @@ def site_tajimas_d(ts, sample_sets, windows=None):
13031306 nn = n [i ]
13041307 S = 0
13051308 T = 0
1306- X_index = np .where (np .in1d (X , samples ))[0 ]
1309+ X_index = np .where (np .in1d (samples , X ))[0 ]
13071310 for k in range (ts .num_sites ):
13081311 if (site_positions [k ] >= begin ) and (site_positions [k ] < end ):
13091312 hX = haps [k , X_index ]
@@ -1497,7 +1500,8 @@ def verify_sample_sets(self, ts, sample_sets, windows):
14971500 denom = n * (n - 1 ) * (n - 2 )
14981501
14991502 def f (x ):
1500- return x * (n - x ) * (n - x - 1 ) / denom
1503+ with np .errstate (invalid = 'ignore' , divide = 'ignore' ):
1504+ return x * (n - x ) * (n - x - 1 ) / denom
15011505
15021506 self .verify_definition (ts , sample_sets , windows , f , ts .Y1 , Y1 )
15031507
@@ -1875,8 +1879,6 @@ def node_Y2(ts, sample_sets, indexes, windows=None, span_normalise=True):
18751879
18761880
18771881def Y2 (ts , sample_sets , indexes = None , windows = None , mode = "site" , span_normalise = True ):
1878- windows = ts .parse_windows (windows )
1879-
18801882 windows = ts .parse_windows (windows )
18811883 if indexes is None :
18821884 indexes = [(0 , 1 )]
@@ -3330,11 +3332,11 @@ def test_bad_mode(self):
33303332 def test_bad_window_strings (self ):
33313333 ts = self .get_tree_sequence ()
33323334 with self .assertRaises (ValueError ):
3333- ts .diversity ([list ( ts .samples () )], mode = "site" , windows = "abc" )
3335+ ts .diversity ([ts .samples ()], mode = "site" , windows = "abc" )
33343336 with self .assertRaises (ValueError ):
3335- ts .diversity ([list ( ts .samples () )], mode = "site" , windows = "" )
3337+ ts .diversity ([ts .samples ()], mode = "site" , windows = "" )
33363338 with self .assertRaises (ValueError ):
3337- ts .diversity ([list ( ts .samples () )], mode = "tree" , windows = "abc" )
3339+ ts .diversity ([ts .samples ()], mode = "tree" , windows = "abc" )
33383340
33393341 def test_bad_summary_function (self ):
33403342 ts = self .get_tree_sequence ()
@@ -3441,8 +3443,7 @@ class TestGeneralSiteStats(StatsTestCase):
34413443 def compare_general_stat (self , ts , W , f , windows = None , polarised = False ):
34423444 # Determine output_dim of the function
34433445 M = len (f (W [0 ]))
3444- py_ssc = PythonSiteStatCalculator (ts )
3445- sigma1 = py_ssc .naive_general_stat (W , f , windows , polarised = polarised )
3446+ sigma1 = naive_site_general_stat (ts , W , f , windows , polarised = polarised )
34463447 sigma2 = ts .general_stat (W , f , M , windows , polarised = polarised , mode = "site" )
34473448 sigma3 = site_general_stat (ts , W , f , windows , polarised = polarised )
34483449 self .assertEqual (sigma1 .shape , sigma2 .shape )
@@ -4391,7 +4392,6 @@ class BranchSampleSetStatsTestCase(SampleSetStatTestCase):
43914392 def setUp (self ):
43924393 self .rng = random .Random (self .random_seed )
43934394 self .stat_type = "branch"
4394- self .py_stat_class = PythonBranchStatCalculator
43954395
43964396 def get_ts (self ):
43974397 for N in [12 , 15 , 20 ]:
@@ -4595,7 +4595,7 @@ def f(x):
45954595 branch_true_mean_diversity )
45964596 self .assertAlmostEqual (ts .divergence ([A [0 ], A [1 ]], [(0 , 1 )], mode = mode ),
45974597 branch_true_mean_diversity )
4598- self .assertAlmostEqual (ts .sample_count_stat (A , f , 1 , mode = mode )[0 ][ 0 ] ,
4598+ self .assertAlmostEqual (ts .sample_count_stat (A , f , 1 , mode = mode )[0 ],
45994599 branch_true_mean_diversity )
46004600
46014601 # Y-statistic for (0/12)
@@ -4610,7 +4610,7 @@ def f(x):
46104610 py_bsc_Y = Y3 (ts , [[0 ], [1 ], [2 ]], [(0 , 1 , 2 )], windows = [0.0 , 1.0 ], mode = mode )
46114611 self .assertArrayAlmostEqual (bts_Y , branch_true_Y )
46124612 self .assertArrayAlmostEqual (py_bsc_Y , branch_true_Y )
4613- self .assertArrayAlmostEqual (ts .sample_count_stat (A , f , 1 , mode = mode )[0 ][ 0 ] ,
4613+ self .assertArrayAlmostEqual (ts .sample_count_stat (A , f , 1 , mode = mode )[0 ],
46144614 branch_true_Y )
46154615
46164616 mode = "site"
@@ -4619,7 +4619,7 @@ def f(x):
46194619 py_ssc_Y = Y3 (ts , [[0 ], [1 ], [2 ]], [(0 , 1 , 2 )], windows = [0.0 , 1.0 ], mode = mode )
46204620 self .assertArrayAlmostEqual (sts_Y , site_true_Y )
46214621 self .assertArrayAlmostEqual (py_ssc_Y , site_true_Y )
4622- self .assertArrayAlmostEqual (ts .sample_count_stat (A , f , 1 , mode = mode )[0 ][ 0 ] ,
4622+ self .assertArrayAlmostEqual (ts .sample_count_stat (A , f , 1 , mode = mode )[0 ],
46234623 site_true_Y )
46244624
46254625 A = [[0 , 1 , 2 ]]
@@ -5002,7 +5002,7 @@ def f(x):
50025002 branch_true_diversity_02 ]):
50035003
50045004 self .assertAlmostEqual (diversity (ts , A , mode = mode )[0 ][0 ], truth )
5005- self .assertAlmostEqual (ts .sample_count_stat (A , f , 1 , mode = mode )[0 ][ 0 ] , truth )
5005+ self .assertAlmostEqual (ts .sample_count_stat (A , f , 1 , mode = mode )[0 ], truth )
50065006 self .assertAlmostEqual (ts .diversity (A , mode = "branch" )[0 ], truth )
50075007
50085008 # Y-statistic for (0/12)
@@ -5017,7 +5017,7 @@ def f(x):
50175017 branch_true_Y )
50185018 self .assertArrayAlmostEqual (ts .Y3 ([[0 ], [1 ], [2 ]], [(0 , 1 , 2 )], mode = mode ),
50195019 branch_true_Y )
5020- self .assertArrayAlmostEqual (ts .sample_count_stat (A , f , 1 , mode = mode )[0 ][ 0 ] ,
5020+ self .assertArrayAlmostEqual (ts .sample_count_stat (A , f , 1 , mode = mode )[0 ],
50215021 branch_true_Y )
50225022
50235023 # sites:
@@ -5026,7 +5026,7 @@ def f(x):
50265026 py_ssc_Y = Y3 (ts , [[0 ], [1 ], [2 ]], [(0 , 1 , 2 )], windows = [0.0 , 1.0 ])
50275027 self .assertAlmostEqual (site_tsc_Y , site_true_Y )
50285028 self .assertAlmostEqual (py_ssc_Y , site_true_Y )
5029- self .assertAlmostEqual (ts .sample_count_stat (A , f , 1 , mode = mode )[0 ][ 0 ] ,
5029+ self .assertAlmostEqual (ts .sample_count_stat (A , f , 1 , mode = mode )[0 ],
50305030 site_true_Y )
50315031
50325032
@@ -5294,122 +5294,3 @@ def test_Y3_windows(self):
52945294 def test_f3_windows (self ):
52955295 ts = self .get_example_ts ()
52965296 self .verify_three_way_stat_windows (ts , ts .f3 )
5297-
5298- ############################################
5299- # Old code where stats are defined within type
5300- # specific calculattors. These definititions have been
5301- # move to stat-specific regions above
5302- # The only thing left to port is the SFS code.
5303- ############################################
5304-
5305-
5306- class PythonBranchStatCalculator (object ):
5307- """
5308- Python implementations of various ("tree") branch-length statistics -
5309- inefficient but more clear what they are doing.
5310- """
5311-
5312- def __init__ (self , tree_sequence ):
5313- self .tree_sequence = tree_sequence
5314-
5315- def site_frequency_spectrum (self , sample_set , windows = None ):
5316- if windows is None :
5317- windows = [0.0 , self .tree_sequence .sequence_length ]
5318- n_out = len (sample_set )
5319- out = np .zeros ((n_out , len (windows ) - 1 ))
5320- for j in range (len (windows ) - 1 ):
5321- begin = windows [j ]
5322- end = windows [j + 1 ]
5323- S = [0.0 for j in range (n_out )]
5324- for t in self .tree_sequence .trees (tracked_samples = sample_set ,
5325- sample_counts = True ):
5326- root = t .root
5327- tr_len = min (end , t .interval [1 ]) - max (begin , t .interval [0 ])
5328- if tr_len > 0 :
5329- for node in t .nodes ():
5330- if node != root :
5331- x = t .num_tracked_samples (node )
5332- if x > 0 :
5333- S [x - 1 ] += t .branch_length (node ) * tr_len
5334- for j in range (n_out ):
5335- S [j ] /= (end - begin )
5336- out [j ] = S
5337- return (out )
5338-
5339-
5340- class PythonSiteStatCalculator (object ):
5341- """
5342- Python implementations of various single-site statistics -
5343- inefficient but more clear what they are doing.
5344- """
5345-
5346- def __init__ (self , tree_sequence ):
5347- self .tree_sequence = tree_sequence
5348-
5349- def sample_count_stats (self , sample_sets , f , windows = None , polarised = False ):
5350- '''
5351- Here sample_sets is a list of lists of samples, and f is a function
5352- whose argument is a list of integers of the same length as sample_sets
5353- that returns a list of numbers; there will be one output for each element.
5354- For each value, each allele in a tree is weighted by f(x), where
5355- x[i] is the number of samples in sample_sets[i] that inherit that allele.
5356- This finds the sum of this value for all alleles at all polymorphic sites,
5357- and across the tree sequence ts, weighted by genomic length.
5358-
5359- This version is inefficient as it works directly with haplotypes.
5360- '''
5361- if windows is None :
5362- windows = [0.0 , self .tree_sequence .sequence_length ]
5363- for U in sample_sets :
5364- if max ([U .count (x ) for x in set (U )]) > 1 :
5365- raise ValueError ("elements of sample_sets" ,
5366- "cannot contain repeated elements." )
5367- haps = list (self .tree_sequence .haplotypes ())
5368- n_out = len (f ([0 for a in sample_sets ]))
5369- out = np .zeros ((n_out , len (windows ) - 1 ))
5370- for j in range (len (windows ) - 1 ):
5371- begin = windows [j ]
5372- end = windows [j + 1 ]
5373- site_positions = [x .position for x in self .tree_sequence .sites ()]
5374- S = [0.0 for j in range (n_out )]
5375- for k in range (self .tree_sequence .num_sites ):
5376- if (site_positions [k ] >= begin ) and (site_positions [k ] < end ):
5377- all_g = [haps [j ][k ] for j in range (self .tree_sequence .num_samples )]
5378- g = [[haps [j ][k ] for j in u ] for u in sample_sets ]
5379- for a in set (all_g ):
5380- x = [h .count (a ) for h in g ]
5381- w = f (x )
5382- for j in range (n_out ):
5383- S [j ] += w [j ]
5384- for j in range (n_out ):
5385- S [j ] /= (end - begin )
5386- out [j ] = np .array ([S ])
5387- return out
5388-
5389- def naive_general_stat (self , W , f , windows = None , polarised = False ):
5390- return naive_site_general_stat (
5391- self .tree_sequence , W , f , windows = windows , polarised = polarised )
5392-
5393- def site_frequency_spectrum (self , sample_set , windows = None ):
5394- if windows is None :
5395- windows = [0.0 , self .tree_sequence .sequence_length ]
5396- haps = list (self .tree_sequence .haplotypes ())
5397- site_positions = [x .position for x in self .tree_sequence .sites ()]
5398- n_out = len (sample_set )
5399- out = np .zeros ((n_out , len (windows ) - 1 ))
5400- for j in range (len (windows ) - 1 ):
5401- begin = windows [j ]
5402- end = windows [j + 1 ]
5403- S = [0.0 for j in range (n_out )]
5404- for k in range (self .tree_sequence .num_sites ):
5405- if (site_positions [k ] >= begin ) and (site_positions [k ] < end ):
5406- all_g = [haps [j ][k ] for j in range (self .tree_sequence .num_samples )]
5407- g = [haps [j ][k ] for j in sample_set ]
5408- for a in set (all_g ):
5409- x = g .count (a )
5410- if x > 0 :
5411- S [x - 1 ] += 1.0
5412- for j in range (n_out ):
5413- S [j ] /= (end - begin )
5414- out [j ] = S
5415- return out
0 commit comments