Skip to content

Commit d706941

Browse files
committed
Refactor grid building into MixturePrior class
1 parent 9c7232d commit d706941

File tree

1 file changed

+101
-78
lines changed

1 file changed

+101
-78
lines changed

tsdate/prior.py

Lines changed: 101 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,104 @@ def fill_priors(
10061006
return prior_times
10071007

10081008

1009+
class MixturePrior:
1010+
"""
1011+
Maps ConditionalCoalescentPrior onto nodes in a tree sequence and creates time-discretized priors
1012+
"""
1013+
1014+
def __init__(self, tree_sequence, approximate_priors=False, approx_prior_size=None, prior_distribution="lognorm", allow_unary=False, progress=False):
1015+
1016+
if approximate_priors:
1017+
if not approx_prior_size:
1018+
approx_prior_size = 1000
1019+
else:
1020+
if approx_prior_size is not None:
1021+
raise ValueError(
1022+
"Can't set approx_prior_size if approximate_prior is False"
1023+
)
1024+
1025+
contmpr_ts, node_map = util.reduce_to_contemporaneous(tree_sequence)
1026+
if contmpr_ts.num_nodes != tree_sequence.num_nodes:
1027+
raise ValueError(
1028+
"Passed tree sequence is not simplified and/or contains "
1029+
"noncontemporaneous samples"
1030+
)
1031+
span_data = SpansBySamples(contmpr_ts, progress=progress, allow_unary=allow_unary)
1032+
1033+
base_priors = ConditionalCoalescentTimes(
1034+
approx_prior_size, prior_distribution, progress=progress
1035+
)
1036+
1037+
base_priors.add(contmpr_ts.num_samples, approximate_priors)
1038+
for total_fixed in span_data.total_fixed_at_0_counts:
1039+
# For missing data: trees vary in total fixed node count => have different priors
1040+
if total_fixed > 0:
1041+
base_priors.add(total_fixed, approximate_priors)
1042+
prior_params_contmpr = base_priors.get_mixture_prior_params(span_data)
1043+
1044+
# Map the nodes in the prior params back to the node ids in the original ts
1045+
self.prior_params = prior_params_contmpr[node_map, :]
1046+
self.base_priors = base_priors
1047+
self.tree_sequence = tree_sequence
1048+
self.prior_distribution = prior_distribution
1049+
1050+
1051+
def make_discretized_prior(self, population_size, timepoints=20, progress=False):
1052+
"""
1053+
Calculate prior grid for a set of timepoints and a population size history
1054+
"""
1055+
1056+
if isinstance(population_size, np.ndarray):
1057+
if population_size.ndim != 2:
1058+
raise ValueError("Array 'population_size' must be two-dimensional")
1059+
if population_size.shape[1] != 2:
1060+
raise ValueError(
1061+
"Population size array must have two columns that contain \
1062+
epoch start times and population sizes, respectively"
1063+
)
1064+
if np.any(population_size[:, 0] < 0.0):
1065+
raise ValueError("Epoch start times must be nonnegative")
1066+
if np.any(population_size[:, 1] <= 0.0):
1067+
raise ValueError("Population sizes must be positive")
1068+
if population_size[0, 0] != 0:
1069+
raise ValueError("The first epoch must start at time 0")
1070+
if not np.all(np.diff(population_size[:, 0]) > 0):
1071+
raise ValueError("Epoch start times must be unique and increasing")
1072+
else:
1073+
if population_size <= 0:
1074+
raise ValueError("Parameter 'population_size' must be greater than 0")
1075+
population_size = np.array([[0, population_size]], dtype=float)
1076+
1077+
if isinstance(timepoints, int):
1078+
if timepoints < 2:
1079+
raise ValueError("You must have at least 2 time points")
1080+
timepoints = create_timepoints(self.base_priors, timepoints + 1)
1081+
elif isinstance(timepoints, np.ndarray):
1082+
try:
1083+
timepoints = np.sort(timepoints.astype(base.FLOAT_DTYPE, casting="safe"))
1084+
except TypeError:
1085+
raise TypeError("Timepoints array cannot be converted to float dtype")
1086+
if len(timepoints) < 2:
1087+
raise ValueError("You must have at least 2 time points")
1088+
elif np.any(timepoints < 0):
1089+
raise ValueError("Timepoints cannot be negative")
1090+
elif np.any(np.unique(timepoints, return_counts=True)[1] > 1):
1091+
raise ValueError("Timepoints cannot have duplicate values")
1092+
else:
1093+
raise ValueError("time_slices must be an integer or a numpy array of floats")
1094+
1095+
# Set all fixed nodes (i.e. samples) to have 0 variance
1096+
priors = fill_priors(
1097+
self.prior_params,
1098+
timepoints,
1099+
self.tree_sequence,
1100+
population_size,
1101+
prior_distr=self.prior_distribution,
1102+
progress=progress,
1103+
)
1104+
return priors
1105+
1106+
10091107
def build_grid(
10101108
tree_sequence,
10111109
population_size,
@@ -1014,7 +1112,6 @@ def build_grid(
10141112
approximate_priors=False,
10151113
approx_prior_size=None,
10161114
prior_distribution="lognorm",
1017-
eps=1e-6,
10181115
# Parameters below undocumented
10191116
progress=False,
10201117
allow_unary=False,
@@ -1044,87 +1141,13 @@ def build_grid(
10441141
better fit, but slightly slower to calculate) or "gamma" for the gamma
10451142
distribution (slightly faster, but a poorer fit for recent nodes). Default:
10461143
"lognorm"
1047-
:param float eps: Specify minimum distance separating points in the time grid. Also
1048-
specifies the error factor in time difference calculations. Default: 1e-6
10491144
:return: A prior object to pass to tsdate.date() containing prior values for
10501145
inference and a discretised time grid
10511146
:rtype: base.NodeGridValues Object
10521147
"""
1053-
if isinstance(population_size, np.ndarray):
1054-
if population_size.ndim != 2:
1055-
raise ValueError("Array 'population_size' must be two-dimensional")
1056-
if population_size.shape[1] != 2:
1057-
raise ValueError(
1058-
"Population size array must have two columns that contain \
1059-
epoch start times and population sizes, respectively"
1060-
)
1061-
if np.any(population_size[:, 0] < 0.0):
1062-
raise ValueError("Epoch start times must be nonnegative")
1063-
if np.any(population_size[:, 1] <= 0.0):
1064-
raise ValueError("Population sizes must be positive")
1065-
if population_size[0, 0] != 0:
1066-
raise ValueError("The first epoch must start at time 0")
1067-
if not np.all(np.diff(population_size[:, 0]) > 0):
1068-
raise ValueError("Epoch start times must be unique and increasing")
1069-
else:
1070-
if population_size <= 0:
1071-
raise ValueError("Parameter 'population_size' must be greater than 0")
1072-
population_size = np.array([[0, population_size]], dtype=float)
1073-
if approximate_priors:
1074-
if not approx_prior_size:
1075-
approx_prior_size = 1000
1076-
else:
1077-
if approx_prior_size is not None:
1078-
raise ValueError(
1079-
"Can't set approx_prior_size if approximate_prior is False"
1080-
)
10811148

1082-
contmpr_ts, node_map = util.reduce_to_contemporaneous(tree_sequence)
1083-
if contmpr_ts.num_nodes != tree_sequence.num_nodes:
1084-
raise ValueError(
1085-
"Passed tree sequence is not simplified and/or contains "
1086-
"noncontemporaneous samples"
1087-
)
1088-
span_data = SpansBySamples(contmpr_ts, progress=progress, allow_unary=allow_unary)
1089-
1090-
base_priors = ConditionalCoalescentTimes(
1091-
approx_prior_size, prior_distribution, progress=progress
1149+
mixture_prior = MixturePrior(
1150+
tree_sequence, approximate_priors, approx_prior_size, prior_distribution, allow_unary, progress
10921151
)
1152+
return mixture_prior.make_discretized_prior(population_size, timepoints)
10931153

1094-
base_priors.add(contmpr_ts.num_samples, approximate_priors)
1095-
for total_fixed in span_data.total_fixed_at_0_counts:
1096-
# For missing data: trees vary in total fixed node count => have different priors
1097-
if total_fixed > 0:
1098-
base_priors.add(total_fixed, approximate_priors)
1099-
1100-
if isinstance(timepoints, int):
1101-
if timepoints < 2:
1102-
raise ValueError("You must have at least 2 time points")
1103-
timepoints = create_timepoints(base_priors, timepoints + 1)
1104-
elif isinstance(timepoints, np.ndarray):
1105-
try:
1106-
timepoints = np.sort(timepoints.astype(base.FLOAT_DTYPE, casting="safe"))
1107-
except TypeError:
1108-
raise TypeError("Timepoints array cannot be converted to float dtype")
1109-
if len(timepoints) < 2:
1110-
raise ValueError("You must have at least 2 time points")
1111-
elif np.any(timepoints < 0):
1112-
raise ValueError("Timepoints cannot be negative")
1113-
elif np.any(np.unique(timepoints, return_counts=True)[1] > 1):
1114-
raise ValueError("Timepoints cannot have duplicate values")
1115-
else:
1116-
raise ValueError("time_slices must be an integer or a numpy array of floats")
1117-
1118-
prior_params_contmpr = base_priors.get_mixture_prior_params(span_data)
1119-
# Map the nodes in the prior params back to the node ids in the original ts
1120-
prior_params = prior_params_contmpr[node_map, :]
1121-
# Set all fixed nodes (i.e. samples) to have 0 variance
1122-
priors = fill_priors(
1123-
prior_params,
1124-
timepoints,
1125-
tree_sequence,
1126-
population_size,
1127-
prior_distr=prior_distribution,
1128-
progress=progress,
1129-
)
1130-
return priors

0 commit comments

Comments
 (0)