@@ -419,10 +419,10 @@ def __init__(self, tree_sequence, *, progress=False, allow_unary=False):
419419
420420 self .ts = tree_sequence
421421 self .sample_node_set = set (self .ts .samples ())
422- if np .any (self .ts .tables .nodes .time [self .ts .samples ()] != 0 ):
423- raise ValueError (
424- "The SpansBySamples class needs a tree seq with all samples at time 0"
425- )
422+ # if np.any(self.ts.tables.nodes.time[self.ts.samples()] != 0):
423+ # raise ValueError(
424+ # "The SpansBySamples class needs a tree seq with all samples at time 0"
425+ # )
426426 self .progress = progress
427427
428428 # We will store the spans in here, and normalize them at the end
@@ -996,6 +996,59 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
996996 return prior_times
997997
998998
999+ def truncate_priors (ts , sample_times , priors , nodes_to_date = None , progress = False ):
1000+ """
1001+ Truncate priors so they conform to the age of nodes in the tree sequence
1002+ """
1003+ grid_data = np .copy (priors .grid_data [:])
1004+ timepoints = priors .timepoints
1005+ if np .max (sample_times ) >= np .max (timepoints ):
1006+ raise ValueError ("Sample times cannot be larger than the oldest timepoint" )
1007+ if priors .probability_space == "linear" :
1008+ zero_value = 0
1009+ one_value = 1
1010+ elif priors .probability_space == "logarithmic" :
1011+ zero_value = - np .inf
1012+ one_value = 0
1013+ constrained_min_times = np .copy (sample_times )
1014+ constrained_max_times = np .full (sample_times .shape [0 ], np .inf )
1015+ if nodes_to_date is None :
1016+ nodes_to_date = np .arange (ts .num_nodes , dtype = np .uint64 )
1017+ nodes_to_date = nodes_to_date [~ np .isin (nodes_to_date , ts .samples ())]
1018+
1019+ tables = ts .tables
1020+ parents = tables .edges .parent
1021+ nd_children = tables .edges .child [np .argsort (parents )]
1022+ parents = sorted (parents )
1023+ parents_unique = np .unique (parents , return_index = True )
1024+ parent_indices = parents_unique [1 ][np .isin (parents_unique [0 ], nodes_to_date )]
1025+ for index , nd in tqdm (
1026+ enumerate (sorted (nodes_to_date )), desc = "Constrain Ages" , disable = not progress
1027+ ):
1028+ if index + 1 != len (nodes_to_date ):
1029+ children_index = np .arange (parent_indices [index ], parent_indices [index + 1 ])
1030+ else :
1031+ children_index = np .arange (parent_indices [index ], ts .num_edges )
1032+ children = nd_children [children_index ]
1033+ time = np .max (constrained_min_times [children ])
1034+ # The constrained time of the node should be the age of the oldest child
1035+ if constrained_min_times [nd ] <= time :
1036+ constrained_min_times [nd ] = time
1037+ nearest_time = np .argmin (np .abs (timepoints - time ))
1038+ lookup_index = priors .row_lookup [int (nd )]
1039+ grid_data [lookup_index ][:nearest_time ] = zero_value
1040+ assert np .all (constrained_min_times < constrained_max_times )
1041+ all_zeros = np .where (np .all (grid_data == zero_value , axis = 1 ))[0 ]
1042+
1043+ rowmax = grid_data [:, 1 :].max (axis = 1 )
1044+ if priors .probability_space == "linear" :
1045+ grid_data = grid_data / rowmax [:, np .newaxis ]
1046+ elif priors .probability_space == "logarithmic" :
1047+ grid_data = grid_data - rowmax [:, np .newaxis ]
1048+
1049+ priors .grid_data [:] = grid_data
1050+ return constrained_min_times , constrained_max_times , priors
1051+
9991052def build_grid (
10001053 tree_sequence ,
10011054 Ne ,
@@ -1007,7 +1060,7 @@ def build_grid(
10071060 eps = 1e-6 ,
10081061 # Parameters below undocumented
10091062 progress = False ,
1010- allow_unary = False ,
1063+ sample_times = None
10111064):
10121065 """
10131066 Using the conditional coalescent, calculate the prior distribution for the age of
@@ -1038,6 +1091,8 @@ def build_grid(
10381091 inference and a discretised time grid
10391092 :rtype: base.NodeGridValues Object
10401093 """
1094+ #tree_sequence = tree_sequence.simplify(tree_sequence.samples())
1095+
10411096 if Ne <= 0 :
10421097 raise ValueError ("Parameter 'Ne' must be greater than 0" )
10431098 if approximate_priors :
@@ -1049,19 +1104,13 @@ def build_grid(
10491104 "Can't set approx_prior_size if approximate_prior is False"
10501105 )
10511106
1052- contmpr_ts , node_map = util .reduce_to_contemporaneous (tree_sequence )
1053- if contmpr_ts .num_nodes != tree_sequence .num_nodes :
1054- raise ValueError (
1055- "Passed tree sequence is not simplified and/or contains "
1056- "noncontemporaneous samples"
1057- )
1058- span_data = SpansBySamples (contmpr_ts , progress = progress , allow_unary = allow_unary )
1107+ span_data = SpansBySamples (tree_sequence , progress = progress )
10591108
10601109 base_priors = ConditionalCoalescentTimes (
10611110 approx_prior_size , Ne , prior_distribution , progress = progress
10621111 )
10631112
1064- base_priors .add (contmpr_ts .num_samples , approximate_priors )
1113+ base_priors .add (tree_sequence .num_samples , approximate_priors )
10651114 for total_fixed in span_data .total_fixed_at_0_counts :
10661115 # For missing data: trees vary in total fixed node count => have different priors
10671116 if total_fixed > 0 :
@@ -1085,9 +1134,7 @@ def build_grid(
10851134 else :
10861135 raise ValueError ("time_slices must be an integer or a numpy array of floats" )
10871136
1088- prior_params_contmpr = base_priors .get_mixture_prior_params (span_data )
1089- # Map the nodes in the prior params back to the node ids in the original ts
1090- prior_params = prior_params_contmpr [node_map , :]
1137+ prior_params = base_priors .get_mixture_prior_params (span_data )
10911138 # Set all fixed nodes (i.e. samples) to have 0 variance
10921139 priors = fill_priors (
10931140 prior_params ,
@@ -1097,4 +1144,7 @@ def build_grid(
10971144 prior_distr = prior_distribution ,
10981145 progress = progress ,
10991146 )
1147+ if np .any (tree_sequence .tables .nodes .time [tree_sequence .samples ()] != 0 ):
1148+ if False :
1149+ priors = truncate_priors (tree_sequence , sample_times , priors , eps , progress = progress )
11001150 return priors
0 commit comments