Skip to content

Commit 4e5de92

Browse files
authored
Merge pull request #244 from hyanwong/fix-posteriors
Change "normalize" to "standardize"
2 parents d56c63d + d0ffb86 commit 4e5de92

File tree

5 files changed

+75
-61
lines changed

5 files changed

+75
-61
lines changed

CHANGELOG.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88
individuals, populations, or sites, aiming to change the tree sequence tables as
99
little as possible.
1010

11+
- Not strictly breaking, as not in the published API, but the "normalize" flag
12+
in ``get_dates`` and the internal ``normalize`` terminology is changed to
13+
``standardize`` to better reflect the fact that the maximum (not sum) is one.
14+
1115
**Bugfixes**
1216

1317
- The returned posteriors when ``return_posteriors=True`` now return actual
14-
probabilities (scaled so that they sum to one) rather than normalised
15-
probabilites whose maximum value is one.
18+
probabilities (scaled so that they sum to one) rather than standardized
19+
"probabilites" whose maximum value is one.
1620

1721
--------------------
1822
[0.1.5] - 2022-06-07

tests/test_functions.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -570,17 +570,17 @@ def test_one_tree_n2_intervals(self):
570570

571571

572572
class TestLikelihoodClass:
573-
def poisson(self, param, x, normalize=True):
573+
def poisson(self, param, x, standardize=True):
574574
ll = np.exp(-param) * param**x / scipy.special.factorial(x)
575-
if normalize:
575+
if standardize:
576576
return ll / np.max(ll)
577577
else:
578578
return ll
579579

580-
def log_poisson(self, param, x, normalize=True):
580+
def log_poisson(self, param, x, standardize=True):
581581
with np.errstate(divide="ignore"):
582582
ll = np.log(np.exp(-param) * param**x / scipy.special.factorial(x))
583-
if normalize:
583+
if standardize:
584584
return ll - np.max(ll)
585585
else:
586586
return ll
@@ -669,8 +669,8 @@ def test_precalc_lik_upper_multithread(self):
669669
(Likelihoods, self.poisson),
670670
(LogLikelihoods, self.log_poisson),
671671
]:
672-
for normalize in (True, False):
673-
lik = L(ts, grid, mut_rate, eps, normalize=normalize)
672+
for standardize in (True, False):
673+
lik = L(ts, grid, mut_rate, eps, standardize=standardize)
674674
dt = grid
675675
for num_threads in (None, 1, 2):
676676
n_internal_edges = 0
@@ -691,7 +691,7 @@ def test_precalc_lik_upper_multithread(self):
691691
expected_lik_dt = pois(
692692
dt * (mut_rate * span),
693693
num_muts,
694-
normalize=normalize,
694+
standardize=standardize,
695695
)
696696
upper_tri = lik.get_mut_lik_upper_tri(edge)
697697

@@ -946,7 +946,7 @@ def test_nonmatching_prior_vs_lik_fixednodes(self):
946946

947947

948948
class TestInsideAlgorithm:
949-
def run_inside_algorithm(self, ts, prior_distr, normalize=True, **kwargs):
949+
def run_inside_algorithm(self, ts, prior_distr, standardize=True, **kwargs):
950950
Ne = 0.5
951951
priors = tsdate.build_prior_grid(
952952
ts,
@@ -961,7 +961,7 @@ def run_inside_algorithm(self, ts, prior_distr, normalize=True, **kwargs):
961961
lls = Likelihoods(ts, priors.timepoints, mut_rate, eps=eps)
962962
lls.precalculate_mutation_likelihoods()
963963
algo = InOutAlgorithms(priors, lls)
964-
algo.inside_pass(normalize=normalize)
964+
algo.inside_pass(standardize=standardize)
965965
return algo, priors
966966

967967
def test_one_tree_n2(self):
@@ -989,7 +989,7 @@ def test_polytomy_tree(self):
989989

990990
def test_two_tree_ts(self):
991991
ts = utility_functions.two_tree_ts()
992-
algo, priors = self.run_inside_algorithm(ts, "gamma", normalize=False)
992+
algo, priors = self.run_inside_algorithm(ts, "gamma", standardize=False)
993993
mut_rate = 0.5
994994
# priors[3][1] * Ll_(0->3)(1.2 - 0 + eps) ** 2
995995
node3_t1 = (
@@ -1098,7 +1098,7 @@ def test_dangling_fails(self):
10981098

10991099
class TestOutsideAlgorithm:
11001100
def run_outside_algorithm(
1101-
self, ts, prior_distr="lognorm", normalize=False, ignore_oldest_root=False
1101+
self, ts, prior_distr="lognorm", standardize=False, ignore_oldest_root=False
11021102
):
11031103
span_data = SpansBySamples(ts)
11041104
Ne = 0.5
@@ -1113,7 +1113,9 @@ def run_outside_algorithm(
11131113
lls.precalculate_mutation_likelihoods()
11141114
algo = InOutAlgorithms(prior_vals, lls)
11151115
algo.inside_pass()
1116-
algo.outside_pass(normalize=normalize, ignore_oldest_root=ignore_oldest_root)
1116+
algo.outside_pass(
1117+
standardize=standardize, ignore_oldest_root=ignore_oldest_root
1118+
)
11171119
return algo
11181120

11191121
def test_one_tree_n2(self):
@@ -1157,17 +1159,17 @@ def test_outside_before_inside_fails(self):
11571159
with pytest.raises(RuntimeError):
11581160
algo.outside_pass()
11591161

1160-
def test_normalize_outside(self):
1162+
def test_standardize_outside(self):
11611163
ts = msprime.simulate(
11621164
50, Ne=10000, mutation_rate=1e-8, recombination_rate=1e-8, random_seed=12
11631165
)
1164-
normalize = self.run_outside_algorithm(ts, normalize=True)
1165-
no_normalize = self.run_outside_algorithm(ts, normalize=False)
1166+
standardize = self.run_outside_algorithm(ts, standardize=True)
1167+
no_standardize = self.run_outside_algorithm(ts, standardize=False)
11661168
assert np.allclose(
1167-
normalize.outside.grid_data[:],
1169+
standardize.outside.grid_data[:],
11681170
(
1169-
no_normalize.outside.grid_data[:]
1170-
/ np.max(no_normalize.outside.grid_data[:], axis=1)[:, np.newaxis]
1171+
no_standardize.outside.grid_data[:]
1172+
/ np.max(no_standardize.outside.grid_data[:], axis=1)[:, np.newaxis]
11711173
),
11721174
)
11731175

@@ -1213,7 +1215,7 @@ def find_posterior(self, ts, prior_distr):
12131215
lls.precalculate_mutation_likelihoods()
12141216
algo = InOutAlgorithms(prior_vals, lls)
12151217
algo.inside_pass()
1216-
posterior = algo.outside_pass(normalize=False)
1218+
posterior = algo.outside_pass(standardize=False)
12171219
assert np.array_equal(
12181220
np.sum(algo.inside.grid_data * algo.outside.grid_data, axis=1),
12191221
np.sum(algo.inside.grid_data * algo.outside.grid_data, axis=1),
@@ -1278,11 +1280,11 @@ def test_gil_tree(self):
12781280
prior_vals.grid_data[1] = [0, 0.05, 0.1, 0.2, 0.45, 0.1, 0.1]
12791281
mut_rate = 1
12801282
eps = 0.01
1281-
lls = Likelihoods(ts, grid, mut_rate, eps=eps, normalize=False)
1283+
lls = Likelihoods(ts, grid, mut_rate, eps=eps, standardize=False)
12821284
lls.precalculate_mutation_likelihoods()
12831285
algo = InOutAlgorithms(prior_vals, lls)
1284-
algo.inside_pass(normalize=False, cache_inside=cache_inside)
1285-
algo.outside_pass(normalize=False)
1286+
algo.inside_pass(standardize=False, cache_inside=cache_inside)
1287+
algo.outside_pass(standardize=False)
12861288
assert np.allclose(
12871289
np.sum(algo.inside.grid_data * algo.outside.grid_data, axis=1),
12881290
[7.44449e-05, 7.44449e-05],

tsdate/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# MIT License
22
#
3-
# Copyright (c) 2020 University of Oxford
3+
# Copyright (c) 2021-23 Tskit Developers
4+
# Copyright (c) 2020-21 University of Oxford
45
#
56
# Permission is hereby granted, free of charge, to any person obtaining a copy
67
# of this software and associated documentation files (the "Software"), to deal
@@ -127,9 +128,9 @@ def force_probability_space(self, probability_space):
127128
else:
128129
logging.warning("Cannot force", *descr)
129130

130-
def normalize(self):
131+
def standardize(self):
131132
"""
132-
normalize grid data so the max is one (in linear space) or zero
133+
Standardize grid data so the max for each row is one (in linear space) or zero
133134
(in logarithmic space)
134135
135136
TODO - is it clear why we omit the first element of the

tsdate/core.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ class Likelihoods:
5050
A class to store and process likelihoods. Likelihoods for edges are stored as a
5151
flattened lower triangular matrix of all the possible delta t's. This class also
5252
provides methods for accessing this lower triangular matrix, multiplying it, etc.
53+
54+
If ``standardize`` is true, routines will operate to standardize the likelihoods
55+
such that their maximum is one (in linear space) or zero (in log space)
5356
"""
5457

5558
probability_space = base.LIN
@@ -65,7 +68,7 @@ def __init__(
6568
*,
6669
eps=0,
6770
fixed_node_set=None,
68-
normalize=True,
71+
standardize=True,
6972
progress=False,
7073
):
7174
self.ts = ts
@@ -75,7 +78,7 @@ def __init__(
7578
)
7679
self.mut_rate = mutation_rate
7780
self.rec_rate = recombination_rate
78-
self.normalize = normalize
81+
self.standardize = standardize
7982
self.grid_size = len(timepoints)
8083
self.tri_size = self.grid_size * (self.grid_size + 1) / 2
8184
self.ll_mut = {}
@@ -145,25 +148,25 @@ def get_mut_edges(ts):
145148
return mut_edges
146149

147150
@staticmethod
148-
def _lik(muts, span, dt, mutation_rate, normalize=True):
151+
def _lik(muts, span, dt, mutation_rate, standardize=True):
149152
"""
150153
The likelihood of an edge given a number of mutations, as set of time deltas (dt)
151154
and a span. This is a static function to allow parallelization
152155
"""
153156
ll = scipy.stats.poisson.pmf(muts, dt * mutation_rate * span)
154-
if normalize:
157+
if standardize:
155158
return ll / np.max(ll)
156159
else:
157160
return ll
158161

159162
@staticmethod
160-
def _lik_wrapper(muts_span, dt, mutation_rate, normalize=True):
163+
def _lik_wrapper(muts_span, dt, mutation_rate, standardize=True):
161164
"""
162165
A wrapper to allow this _lik to be called by pool.imap_unordered, returning the
163166
mutation and span values
164167
"""
165168
return muts_span, Likelihoods._lik(
166-
muts_span[0], muts_span[1], dt, mutation_rate, normalize=normalize
169+
muts_span[0], muts_span[1], dt, mutation_rate, standardize=standardize
167170
)
168171

169172
def precalculate_mutation_likelihoods(self, num_threads=None, unique_method=0):
@@ -206,7 +209,7 @@ def precalculate_mutation_likelihoods(self, num_threads=None, unique_method=0):
206209
self._lik_wrapper,
207210
dt=self.timediff_lower_tri,
208211
mutation_rate=self.mut_rate,
209-
normalize=self.normalize,
212+
standardize=self.standardize,
210213
)
211214
if num_threads == 1:
212215
# Useful for testing
@@ -240,7 +243,7 @@ def precalculate_mutation_likelihoods(self, num_threads=None, unique_method=0):
240243
span,
241244
dt=self.timediff_lower_tri,
242245
mutation_rate=self.mut_rate,
243-
normalize=self.normalize,
246+
standardize=self.standardize,
244247
)
245248

246249
def get_mut_lik_fixed_node(self, edge):
@@ -266,7 +269,7 @@ def get_mut_lik_fixed_node(self, edge):
266269
edge.span,
267270
self.timediff,
268271
self.mut_rate,
269-
normalize=self.normalize,
272+
standardize=self.standardize,
270273
)
271274

272275
def get_mut_lik_lower_tri(self, edge):
@@ -423,24 +426,24 @@ def logsumexp(X):
423426
return np.log(r) + alpha
424427

425428
@staticmethod
426-
def _lik(muts, span, dt, mutation_rate, normalize=True):
429+
def _lik(muts, span, dt, mutation_rate, standardize=True):
427430
"""
428431
The likelihood of an edge given a number of mutations, as set of time deltas (dt)
429432
and a span. This is a static function to allow parallelization
430433
"""
431434
ll = scipy.stats.poisson.logpmf(muts, dt * mutation_rate * span)
432-
if normalize:
435+
if standardize:
433436
return ll - np.max(ll)
434437
else:
435438
return ll
436439

437440
@staticmethod
438-
def _lik_wrapper(muts_span, dt, mutation_rate, normalize=True):
441+
def _lik_wrapper(muts_span, dt, mutation_rate, standardize=True):
439442
"""
440443
Needs redefining to refer to the LogLikelihoods class
441444
"""
442445
return muts_span, LogLikelihoods._lik(
443-
muts_span[0], muts_span[1], dt, mutation_rate, normalize=normalize
446+
muts_span[0], muts_span[1], dt, mutation_rate, standardize=standardize
444447
)
445448

446449
def rowsum_lower_tri(self, input_array):
@@ -626,7 +629,7 @@ def edges_by_child_then_parent_desc(self):
626629

627630
# === MAIN ALGORITHMS ===
628631

629-
def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
632+
def inside_pass(self, *, standardize=True, cache_inside=False, progress=None):
630633
"""
631634
Use dynamic programming to find approximate posterior to sample from
632635
"""
@@ -639,7 +642,7 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
639642
g_i = np.full(
640643
(self.ts.num_edges, self.lik.grid_size), self.lik.identity_constant
641644
)
642-
norm = np.full(self.ts.num_nodes, np.nan)
645+
denominator = np.full(self.ts.num_nodes, np.nan)
643646
# Iterate through the nodes via groupby on parent node
644647
for parent, edges in tqdm(
645648
self.edges_by_parent_asc(),
@@ -680,18 +683,22 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
680683
val = self.lik.combine(val, edge_lik)
681684
if cache_inside:
682685
g_i[edge.id] = edge_lik
683-
norm[parent] = np.max(val) if normalize else 1
684-
inside[parent] = self.lik.reduce(val, norm[parent])
686+
denominator[parent] = (
687+
np.max(val) if standardize else self.lik.identity_constant
688+
)
689+
inside[parent] = self.lik.reduce(val, denominator[parent])
685690
if cache_inside:
686-
self.g_i = self.lik.reduce(g_i, norm[self.ts.tables.edges.child, None])
691+
self.g_i = self.lik.reduce(
692+
g_i, denominator[self.ts.tables.edges.child, None]
693+
)
687694
# Keep the results in this object
688695
self.inside = inside
689-
self.norm = norm
696+
self.denominator = denominator
690697

691698
def outside_pass(
692699
self,
693700
*,
694-
normalize=False,
701+
standardize=False,
695702
ignore_oldest_root=False,
696703
progress=None,
697704
):
@@ -700,8 +707,8 @@ def outside_pass(
700707
posterior values. These are *not* probabilities, as they do not sum to one:
701708
to convert to probabilities, call posterior.to_probabilities()
702709
703-
Normalising *during* the outside process may be necessary if there is overflow,
704-
but means that we cannot check the total functional value at each node
710+
Standardizing *during* the outside process may be necessary if there is
711+
overflow, but means that we cannot check the total functional value at each node
705712
706713
Ignoring the oldest root may also be necessary when the oldest root node
707714
causes numerical stability issues.
@@ -750,7 +757,7 @@ def outside_pass(
750757
spanfrac, self.lik.make_lower_tri(self.inside[edge.child])
751758
)
752759
edge_lik = self.lik.get_inside(daughter_val, edge)
753-
cur_g_i = self.lik.reduce(edge_lik, self.norm[child])
760+
cur_g_i = self.lik.reduce(edge_lik, self.denominator[child])
754761
inside_div_gi = self.lik.reduce(
755762
self.inside[edge.parent], cur_g_i, div_0_null=True
756763
)
@@ -760,15 +767,15 @@ def outside_pass(
760767
self.lik.combine(outside[edge.parent], inside_div_gi)
761768
),
762769
)
763-
if normalize:
770+
if standardize:
764771
parent_val = self.lik.reduce(parent_val, np.max(parent_val))
765772
edge_lik = self.lik.get_outside(parent_val, edge)
766773
val = self.lik.combine(val, edge_lik)
767774

768775
# vv[0] = 0 # Seems a hack: internal nodes should be allowed at time 0
769-
assert self.norm[edge.child] > self.lik.null_constant
770-
outside[child] = self.lik.reduce(val, self.norm[child])
771-
if normalize:
776+
assert self.denominator[edge.child] > self.lik.null_constant
777+
outside[child] = self.lik.reduce(val, self.denominator[child])
778+
if standardize:
772779
outside[child] = self.lik.reduce(val, np.max(val))
773780
self.outside = outside
774781
posterior = outside.clone_with_new_data(
@@ -1054,7 +1061,7 @@ def get_dates(
10541061
eps=1e-6,
10551062
num_threads=None,
10561063
method="inside_outside",
1057-
outside_normalize=True,
1064+
outside_standardize=True,
10581065
ignore_oldest_root=False,
10591066
progress=False,
10601067
cache_inside=False,
@@ -1134,10 +1141,10 @@ def get_dates(
11341141
posterior = None
11351142
if method == "inside_outside":
11361143
posterior = dynamic_prog.outside_pass(
1137-
normalize=outside_normalize, ignore_oldest_root=ignore_oldest_root
1144+
standardize=outside_standardize, ignore_oldest_root=ignore_oldest_root
11381145
)
11391146
# Turn the posterior into probabilities
1140-
posterior.normalize() # Just to make sure there are no floating point issues
1147+
posterior.standardize() # Just to make sure there are no floating point issues
11411148
posterior.force_probability_space(base.LIN)
11421149
posterior.to_probabilities()
11431150
tree_sequence, mn_post, _ = posterior_mean_var(

0 commit comments

Comments
 (0)