@@ -50,6 +50,9 @@ class Likelihoods:
50
50
A class to store and process likelihoods. Likelihoods for edges are stored as a
51
51
flattened lower triangular matrix of all the possible delta t's. This class also
52
52
provides methods for accessing this lower triangular matrix, multiplying it, etc.
53
+
54
+ If ``standardize`` is true, routines will operate to standardize the likelihoods
55
+ such that their maximum is one (in linear space) or zero (in log space)
53
56
"""
54
57
55
58
probability_space = base .LIN
@@ -65,7 +68,7 @@ def __init__(
65
68
* ,
66
69
eps = 0 ,
67
70
fixed_node_set = None ,
68
- normalize = True ,
71
+ standardize = True ,
69
72
progress = False ,
70
73
):
71
74
self .ts = ts
@@ -75,7 +78,7 @@ def __init__(
75
78
)
76
79
self .mut_rate = mutation_rate
77
80
self .rec_rate = recombination_rate
78
- self .normalize = normalize
81
+ self .standardize = standardize
79
82
self .grid_size = len (timepoints )
80
83
self .tri_size = self .grid_size * (self .grid_size + 1 ) / 2
81
84
self .ll_mut = {}
@@ -145,25 +148,25 @@ def get_mut_edges(ts):
145
148
return mut_edges
146
149
147
150
@staticmethod
148
- def _lik (muts , span , dt , mutation_rate , normalize = True ):
151
+ def _lik (muts , span , dt , mutation_rate , standardize = True ):
149
152
"""
150
153
The likelihood of an edge given a number of mutations, as set of time deltas (dt)
151
154
and a span. This is a static function to allow parallelization
152
155
"""
153
156
ll = scipy .stats .poisson .pmf (muts , dt * mutation_rate * span )
154
- if normalize :
157
+ if standardize :
155
158
return ll / np .max (ll )
156
159
else :
157
160
return ll
158
161
159
162
@staticmethod
160
- def _lik_wrapper (muts_span , dt , mutation_rate , normalize = True ):
163
+ def _lik_wrapper (muts_span , dt , mutation_rate , standardize = True ):
161
164
"""
162
165
A wrapper to allow this _lik to be called by pool.imap_unordered, returning the
163
166
mutation and span values
164
167
"""
165
168
return muts_span , Likelihoods ._lik (
166
- muts_span [0 ], muts_span [1 ], dt , mutation_rate , normalize = normalize
169
+ muts_span [0 ], muts_span [1 ], dt , mutation_rate , standardize = standardize
167
170
)
168
171
169
172
def precalculate_mutation_likelihoods (self , num_threads = None , unique_method = 0 ):
@@ -206,7 +209,7 @@ def precalculate_mutation_likelihoods(self, num_threads=None, unique_method=0):
206
209
self ._lik_wrapper ,
207
210
dt = self .timediff_lower_tri ,
208
211
mutation_rate = self .mut_rate ,
209
- normalize = self .normalize ,
212
+ standardize = self .standardize ,
210
213
)
211
214
if num_threads == 1 :
212
215
# Useful for testing
@@ -240,7 +243,7 @@ def precalculate_mutation_likelihoods(self, num_threads=None, unique_method=0):
240
243
span ,
241
244
dt = self .timediff_lower_tri ,
242
245
mutation_rate = self .mut_rate ,
243
- normalize = self .normalize ,
246
+ standardize = self .standardize ,
244
247
)
245
248
246
249
def get_mut_lik_fixed_node (self , edge ):
@@ -266,7 +269,7 @@ def get_mut_lik_fixed_node(self, edge):
266
269
edge .span ,
267
270
self .timediff ,
268
271
self .mut_rate ,
269
- normalize = self .normalize ,
272
+ standardize = self .standardize ,
270
273
)
271
274
272
275
def get_mut_lik_lower_tri (self , edge ):
@@ -423,24 +426,24 @@ def logsumexp(X):
423
426
return np .log (r ) + alpha
424
427
425
428
@staticmethod
426
- def _lik (muts , span , dt , mutation_rate , normalize = True ):
429
+ def _lik (muts , span , dt , mutation_rate , standardize = True ):
427
430
"""
428
431
The likelihood of an edge given a number of mutations, as set of time deltas (dt)
429
432
and a span. This is a static function to allow parallelization
430
433
"""
431
434
ll = scipy .stats .poisson .logpmf (muts , dt * mutation_rate * span )
432
- if normalize :
435
+ if standardize :
433
436
return ll - np .max (ll )
434
437
else :
435
438
return ll
436
439
437
440
@staticmethod
438
- def _lik_wrapper (muts_span , dt , mutation_rate , normalize = True ):
441
+ def _lik_wrapper (muts_span , dt , mutation_rate , standardize = True ):
439
442
"""
440
443
Needs redefining to refer to the LogLikelihoods class
441
444
"""
442
445
return muts_span , LogLikelihoods ._lik (
443
- muts_span [0 ], muts_span [1 ], dt , mutation_rate , normalize = normalize
446
+ muts_span [0 ], muts_span [1 ], dt , mutation_rate , standardize = standardize
444
447
)
445
448
446
449
def rowsum_lower_tri (self , input_array ):
@@ -626,7 +629,7 @@ def edges_by_child_then_parent_desc(self):
626
629
627
630
# === MAIN ALGORITHMS ===
628
631
629
- def inside_pass (self , * , normalize = True , cache_inside = False , progress = None ):
632
+ def inside_pass (self , * , standardize = True , cache_inside = False , progress = None ):
630
633
"""
631
634
Use dynamic programming to find approximate posterior to sample from
632
635
"""
@@ -639,7 +642,7 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
639
642
g_i = np .full (
640
643
(self .ts .num_edges , self .lik .grid_size ), self .lik .identity_constant
641
644
)
642
- norm = np .full (self .ts .num_nodes , np .nan )
645
+ denominator = np .full (self .ts .num_nodes , np .nan )
643
646
# Iterate through the nodes via groupby on parent node
644
647
for parent , edges in tqdm (
645
648
self .edges_by_parent_asc (),
@@ -680,18 +683,22 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
680
683
val = self .lik .combine (val , edge_lik )
681
684
if cache_inside :
682
685
g_i [edge .id ] = edge_lik
683
- norm [parent ] = np .max (val ) if normalize else 1
684
- inside [parent ] = self .lik .reduce (val , norm [parent ])
686
+ denominator [parent ] = (
687
+ np .max (val ) if standardize else self .lik .identity_constant
688
+ )
689
+ inside [parent ] = self .lik .reduce (val , denominator [parent ])
685
690
if cache_inside :
686
- self .g_i = self .lik .reduce (g_i , norm [self .ts .tables .edges .child , None ])
691
+ self .g_i = self .lik .reduce (
692
+ g_i , denominator [self .ts .tables .edges .child , None ]
693
+ )
687
694
# Keep the results in this object
688
695
self .inside = inside
689
- self .norm = norm
696
+ self .denominator = denominator
690
697
691
698
def outside_pass (
692
699
self ,
693
700
* ,
694
- normalize = False ,
701
+ standardize = False ,
695
702
ignore_oldest_root = False ,
696
703
progress = None ,
697
704
):
@@ -700,8 +707,8 @@ def outside_pass(
700
707
posterior values. These are *not* probabilities, as they do not sum to one:
701
708
to convert to probabilities, call posterior.to_probabilities()
702
709
703
- Normalising *during* the outside process may be necessary if there is overflow,
704
- but means that we cannot check the total functional value at each node
710
+ Standardizing *during* the outside process may be necessary if there is
711
+ overflow, but means that we cannot check the total functional value at each node
705
712
706
713
Ignoring the oldest root may also be necessary when the oldest root node
707
714
causes numerical stability issues.
@@ -750,7 +757,7 @@ def outside_pass(
750
757
spanfrac , self .lik .make_lower_tri (self .inside [edge .child ])
751
758
)
752
759
edge_lik = self .lik .get_inside (daughter_val , edge )
753
- cur_g_i = self .lik .reduce (edge_lik , self .norm [child ])
760
+ cur_g_i = self .lik .reduce (edge_lik , self .denominator [child ])
754
761
inside_div_gi = self .lik .reduce (
755
762
self .inside [edge .parent ], cur_g_i , div_0_null = True
756
763
)
@@ -760,15 +767,15 @@ def outside_pass(
760
767
self .lik .combine (outside [edge .parent ], inside_div_gi )
761
768
),
762
769
)
763
- if normalize :
770
+ if standardize :
764
771
parent_val = self .lik .reduce (parent_val , np .max (parent_val ))
765
772
edge_lik = self .lik .get_outside (parent_val , edge )
766
773
val = self .lik .combine (val , edge_lik )
767
774
768
775
# vv[0] = 0 # Seems a hack: internal nodes should be allowed at time 0
769
- assert self .norm [edge .child ] > self .lik .null_constant
770
- outside [child ] = self .lik .reduce (val , self .norm [child ])
771
- if normalize :
776
+ assert self .denominator [edge .child ] > self .lik .null_constant
777
+ outside [child ] = self .lik .reduce (val , self .denominator [child ])
778
+ if standardize :
772
779
outside [child ] = self .lik .reduce (val , np .max (val ))
773
780
self .outside = outside
774
781
posterior = outside .clone_with_new_data (
@@ -1054,7 +1061,7 @@ def get_dates(
1054
1061
eps = 1e-6 ,
1055
1062
num_threads = None ,
1056
1063
method = "inside_outside" ,
1057
- outside_normalize = True ,
1064
+ outside_standardize = True ,
1058
1065
ignore_oldest_root = False ,
1059
1066
progress = False ,
1060
1067
cache_inside = False ,
@@ -1134,10 +1141,10 @@ def get_dates(
1134
1141
posterior = None
1135
1142
if method == "inside_outside" :
1136
1143
posterior = dynamic_prog .outside_pass (
1137
- normalize = outside_normalize , ignore_oldest_root = ignore_oldest_root
1144
+ standardize = outside_standardize , ignore_oldest_root = ignore_oldest_root
1138
1145
)
1139
1146
# Turn the posterior into probabilities
1140
- posterior .normalize () # Just to make sure there are no floating point issues
1147
+ posterior .standardize () # Just to make sure there are no floating point issues
1141
1148
posterior .force_probability_space (base .LIN )
1142
1149
posterior .to_probabilities ()
1143
1150
tree_sequence , mn_post , _ = posterior_mean_var (
0 commit comments