@@ -419,10 +419,10 @@ def __init__(self, tree_sequence, *, progress=False, allow_unary=False):
419
419
420
420
self .ts = tree_sequence
421
421
self .sample_node_set = set (self .ts .samples ())
422
- if np .any (self .ts .tables .nodes .time [self .ts .samples ()] != 0 ):
423
- raise ValueError (
424
- "The SpansBySamples class needs a tree seq with all samples at time 0"
425
- )
422
+ # if np.any(self.ts.tables.nodes.time[self.ts.samples()] != 0):
423
+ # raise ValueError(
424
+ # "The SpansBySamples class needs a tree seq with all samples at time 0"
425
+ # )
426
426
self .progress = progress
427
427
428
428
# We will store the spans in here, and normalize them at the end
@@ -996,6 +996,59 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
996
996
return prior_times
997
997
998
998
999
+ def truncate_priors (ts , sample_times , priors , nodes_to_date = None , progress = False ):
1000
+ """
1001
+ Truncate priors so they conform to the age of nodes in the tree sequence
1002
+ """
1003
+ grid_data = np .copy (priors .grid_data [:])
1004
+ timepoints = priors .timepoints
1005
+ if np .max (sample_times ) >= np .max (timepoints ):
1006
+ raise ValueError ("Sample times cannot be larger than the oldest timepoint" )
1007
+ if priors .probability_space == "linear" :
1008
+ zero_value = 0
1009
+ one_value = 1
1010
+ elif priors .probability_space == "logarithmic" :
1011
+ zero_value = - np .inf
1012
+ one_value = 0
1013
+ constrained_min_times = np .copy (sample_times )
1014
+ constrained_max_times = np .full (sample_times .shape [0 ], np .inf )
1015
+ if nodes_to_date is None :
1016
+ nodes_to_date = np .arange (ts .num_nodes , dtype = np .uint64 )
1017
+ nodes_to_date = nodes_to_date [~ np .isin (nodes_to_date , ts .samples ())]
1018
+
1019
+ tables = ts .tables
1020
+ parents = tables .edges .parent
1021
+ nd_children = tables .edges .child [np .argsort (parents )]
1022
+ parents = sorted (parents )
1023
+ parents_unique = np .unique (parents , return_index = True )
1024
+ parent_indices = parents_unique [1 ][np .isin (parents_unique [0 ], nodes_to_date )]
1025
+ for index , nd in tqdm (
1026
+ enumerate (sorted (nodes_to_date )), desc = "Constrain Ages" , disable = not progress
1027
+ ):
1028
+ if index + 1 != len (nodes_to_date ):
1029
+ children_index = np .arange (parent_indices [index ], parent_indices [index + 1 ])
1030
+ else :
1031
+ children_index = np .arange (parent_indices [index ], ts .num_edges )
1032
+ children = nd_children [children_index ]
1033
+ time = np .max (constrained_min_times [children ])
1034
+ # The constrained time of the node should be the age of the oldest child
1035
+ if constrained_min_times [nd ] <= time :
1036
+ constrained_min_times [nd ] = time
1037
+ nearest_time = np .argmin (np .abs (timepoints - time ))
1038
+ lookup_index = priors .row_lookup [int (nd )]
1039
+ grid_data [lookup_index ][:nearest_time ] = zero_value
1040
+ assert np .all (constrained_min_times < constrained_max_times )
1041
+ all_zeros = np .where (np .all (grid_data == zero_value , axis = 1 ))[0 ]
1042
+
1043
+ rowmax = grid_data [:, 1 :].max (axis = 1 )
1044
+ if priors .probability_space == "linear" :
1045
+ grid_data = grid_data / rowmax [:, np .newaxis ]
1046
+ elif priors .probability_space == "logarithmic" :
1047
+ grid_data = grid_data - rowmax [:, np .newaxis ]
1048
+
1049
+ priors .grid_data [:] = grid_data
1050
+ return constrained_min_times , constrained_max_times , priors
1051
+
999
1052
def build_grid (
1000
1053
tree_sequence ,
1001
1054
Ne ,
@@ -1007,7 +1060,7 @@ def build_grid(
1007
1060
eps = 1e-6 ,
1008
1061
# Parameters below undocumented
1009
1062
progress = False ,
1010
- allow_unary = False ,
1063
+ sample_times = None
1011
1064
):
1012
1065
"""
1013
1066
Using the conditional coalescent, calculate the prior distribution for the age of
@@ -1038,6 +1091,8 @@ def build_grid(
1038
1091
inference and a discretised time grid
1039
1092
:rtype: base.NodeGridValues Object
1040
1093
"""
1094
+ #tree_sequence = tree_sequence.simplify(tree_sequence.samples())
1095
+
1041
1096
if Ne <= 0 :
1042
1097
raise ValueError ("Parameter 'Ne' must be greater than 0" )
1043
1098
if approximate_priors :
@@ -1049,19 +1104,13 @@ def build_grid(
1049
1104
"Can't set approx_prior_size if approximate_prior is False"
1050
1105
)
1051
1106
1052
- contmpr_ts , node_map = util .reduce_to_contemporaneous (tree_sequence )
1053
- if contmpr_ts .num_nodes != tree_sequence .num_nodes :
1054
- raise ValueError (
1055
- "Passed tree sequence is not simplified and/or contains "
1056
- "noncontemporaneous samples"
1057
- )
1058
- span_data = SpansBySamples (contmpr_ts , progress = progress , allow_unary = allow_unary )
1107
+ span_data = SpansBySamples (tree_sequence , progress = progress )
1059
1108
1060
1109
base_priors = ConditionalCoalescentTimes (
1061
1110
approx_prior_size , Ne , prior_distribution , progress = progress
1062
1111
)
1063
1112
1064
- base_priors .add (contmpr_ts .num_samples , approximate_priors )
1113
+ base_priors .add (tree_sequence .num_samples , approximate_priors )
1065
1114
for total_fixed in span_data .total_fixed_at_0_counts :
1066
1115
# For missing data: trees vary in total fixed node count => have different priors
1067
1116
if total_fixed > 0 :
@@ -1085,9 +1134,7 @@ def build_grid(
1085
1134
else :
1086
1135
raise ValueError ("time_slices must be an integer or a numpy array of floats" )
1087
1136
1088
- prior_params_contmpr = base_priors .get_mixture_prior_params (span_data )
1089
- # Map the nodes in the prior params back to the node ids in the original ts
1090
- prior_params = prior_params_contmpr [node_map , :]
1137
+ prior_params = base_priors .get_mixture_prior_params (span_data )
1091
1138
# Set all fixed nodes (i.e. samples) to have 0 variance
1092
1139
priors = fill_priors (
1093
1140
prior_params ,
@@ -1097,4 +1144,7 @@ def build_grid(
1097
1144
prior_distr = prior_distribution ,
1098
1145
progress = progress ,
1099
1146
)
1147
+ if np .any (tree_sequence .tables .nodes .time [tree_sequence .samples ()] != 0 ):
1148
+ if False :
1149
+ priors = truncate_priors (tree_sequence , sample_times , priors , eps , progress = progress )
1100
1150
return priors
0 commit comments