@@ -419,7 +419,7 @@ def __init__(self, tree_sequence, progress=False):
419
419
420
420
self .ts = tree_sequence
421
421
self .sample_node_set = set (self .ts .samples ())
422
- #if np.any(self.ts.tables.nodes.time[self.ts.samples()] != 0):
422
+ # if np.any(self.ts.tables.nodes.time[self.ts.samples()] != 0):
423
423
# raise ValueError(
424
424
# "The SpansBySamples class needs a tree seq with all samples at time 0"
425
425
# )
@@ -991,34 +991,36 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
991
991
return prior_times
992
992
993
993
994
- def truncate_priors (ts , sample_times , priors , nodes_to_date = None , progress = False ):
994
+ def truncate_priors (ts , priors , nodes_to_date = None , progress = False ):
995
995
"""
996
996
Truncate priors so they conform to the age of nodes in the tree sequence
997
997
"""
998
+ tables = ts .tables
999
+ if nodes_to_date is None :
1000
+ nodes_to_date = np .arange (ts .num_nodes , dtype = np .uint64 )
1001
+ nodes_to_date = nodes_to_date [~ np .isin (nodes_to_date , ts .samples ())]
1002
+ # ensure nodes_to_date is ordered by node time
1003
+ nodes_to_date = nodes_to_date [np .argsort (tables .nodes .time [nodes_to_date ])]
1004
+
998
1005
grid_data = np .copy (priors .grid_data [:])
999
1006
timepoints = priors .timepoints
1000
- if np .max (sample_times ) >= np .max (timepoints ):
1007
+ if np .max (tables . nodes . time [ ts . samples ()] ) >= np .max (timepoints ):
1001
1008
raise ValueError ("Sample times cannot be larger than the oldest timepoint" )
1002
1009
if priors .probability_space == "linear" :
1003
1010
zero_value = 0
1004
- one_value = 1
1005
1011
elif priors .probability_space == "logarithmic" :
1006
1012
zero_value = - np .inf
1007
- one_value = 0
1008
- constrained_min_times = np .copy (sample_times )
1009
- constrained_max_times = np .full (sample_times .shape [0 ], np .inf )
1010
- if nodes_to_date is None :
1011
- nodes_to_date = np .arange (ts .num_nodes , dtype = np .uint64 )
1012
- nodes_to_date = nodes_to_date [~ np .isin (nodes_to_date , ts .samples ())]
1013
+ constrained_min_times = np .copy (tables .nodes .time )
1014
+ constrained_min_times [ts .samples ()] = tables .nodes .time [ts .samples ()]
1015
+ constrained_max_times = np .full_like (constrained_min_times , np .inf )
1013
1016
1014
- tables = ts .tables
1015
1017
parents = tables .edges .parent
1016
1018
nd_children = tables .edges .child [np .argsort (parents )]
1017
1019
parents = sorted (parents )
1018
1020
parents_unique = np .unique (parents , return_index = True )
1019
1021
parent_indices = parents_unique [1 ][np .isin (parents_unique [0 ], nodes_to_date )]
1020
1022
for index , nd in tqdm (
1021
- enumerate (sorted ( nodes_to_date ) ), desc = "Constrain Ages" , disable = not progress
1023
+ enumerate (nodes_to_date ), desc = "Constrain Ages" , disable = not progress
1022
1024
):
1023
1025
if index + 1 != len (nodes_to_date ):
1024
1026
children_index = np .arange (parent_indices [index ], parent_indices [index + 1 ])
@@ -1033,17 +1035,17 @@ def truncate_priors(ts, sample_times, priors, nodes_to_date=None, progress=False
1033
1035
lookup_index = priors .row_lookup [int (nd )]
1034
1036
grid_data [lookup_index ][:nearest_time ] = zero_value
1035
1037
assert np .all (constrained_min_times < constrained_max_times )
1036
- all_zeros = np .where (np .all (grid_data == zero_value , axis = 1 ))[0 ]
1037
1038
1038
1039
rowmax = grid_data [:, 1 :].max (axis = 1 )
1039
1040
if priors .probability_space == "linear" :
1040
1041
grid_data = grid_data / rowmax [:, np .newaxis ]
1041
1042
elif priors .probability_space == "logarithmic" :
1042
1043
grid_data = grid_data - rowmax [:, np .newaxis ]
1043
-
1044
+
1044
1045
priors .grid_data [:] = grid_data
1045
1046
return constrained_min_times , constrained_max_times , priors
1046
1047
1048
+
1047
1049
def build_grid (
1048
1050
tree_sequence ,
1049
1051
Ne ,
@@ -1054,7 +1056,6 @@ def build_grid(
1054
1056
prior_distribution = "lognorm" ,
1055
1057
eps = 1e-6 ,
1056
1058
progress = False ,
1057
- sample_times = None
1058
1059
):
1059
1060
"""
1060
1061
Using the conditional coalescent, calculate the prior distribution for the age of
@@ -1085,7 +1086,7 @@ def build_grid(
1085
1086
inference and a discretised time grid
1086
1087
:rtype: base.NodeGridValues Object
1087
1088
"""
1088
- #tree_sequence = tree_sequence.simplify(tree_sequence.samples())
1089
+ # tree_sequence = tree_sequence.simplify(tree_sequence.samples())
1089
1090
1090
1091
if Ne <= 0 :
1091
1092
raise ValueError ("Parameter 'Ne' must be greater than 0" )
@@ -1139,6 +1140,5 @@ def build_grid(
1139
1140
progress = progress ,
1140
1141
)
1141
1142
if np .any (tree_sequence .tables .nodes .time [tree_sequence .samples ()] != 0 ):
1142
- if False :
1143
- priors = truncate_priors (tree_sequence , sample_times , priors , eps , progress = progress )
1143
+ _ , _ , priors = truncate_priors (tree_sequence , priors , progress = progress )
1144
1144
return priors
0 commit comments