22
22
"""
23
23
Routines and classes for creating priors and timeslices for use in tsdate
24
24
"""
25
+ import itertools
25
26
import logging
26
27
import os
27
28
from collections import defaultdict
@@ -1035,10 +1036,8 @@ def _truncate_priors(ts, priors, progress=False):
1035
1036
Truncate priors for all nonfixed nodes
1036
1037
so they conform to the age of fixed nodes in the tree sequence
1037
1038
"""
1038
- tables = ts .tables
1039
-
1040
1039
fixed_nodes = priors .fixed_node_ids ()
1041
- fixed_times = tables . nodes . time [fixed_nodes ]
1040
+ fixed_times = ts . nodes_time [fixed_nodes ]
1042
1041
1043
1042
grid_data = np .copy (priors .grid_data [:])
1044
1043
timepoints = priors .timepoints
@@ -1048,24 +1047,25 @@ def _truncate_priors(ts, priors, progress=False):
1048
1047
zero_value = 0
1049
1048
elif priors .probability_space == "logarithmic" :
1050
1049
zero_value = - np .inf
1051
- constrained_min_times = np .zeros_like (tables . nodes . time )
1050
+ constrained_min_times = np .zeros_like (ts . nodes_time )
1052
1051
# Set the min times of fixed nodes to those in the tree sequence
1053
1052
constrained_min_times [fixed_nodes ] = fixed_times
1054
1053
1055
1054
# Traverse through the ARG, ensuring children come before parents.
1056
1055
# This can be done by iterating over groups of edges with the same parent
1057
- new_parent_edge_idx = np .concatenate (
1058
- (
1059
- [0 ],
1060
- np .where (np .diff (tables .edges .parent ) != 0 )[0 ] + 1 ,
1061
- [tables .edges .num_rows ],
1062
- )
1063
- )
1064
- for edges_start , edges_end in zip (
1065
- new_parent_edge_idx [:- 1 ], new_parent_edge_idx [1 :]
1056
+ edges_parent = ts .edges_parent
1057
+ edges_child = ts .edges_child
1058
+ new_parent_edge_idx = np .where (np .diff (edges_parent ) != 0 )[0 ] + 1
1059
+ for edges_start , edges_end in tqdm (
1060
+ zip (
1061
+ itertools .chain ([0 ], new_parent_edge_idx ),
1062
+ itertools .chain (new_parent_edge_idx , [len (edges_parent )]),
1063
+ ),
1064
+ desc = "Trunc priors" ,
1065
+ disable = not progress ,
1066
1066
):
1067
- parent = tables . edges . parent [edges_start ]
1068
- child_ids = tables . edges . child [edges_start :edges_end ] # May contain dups
1067
+ parent = edges_parent [edges_start ]
1068
+ child_ids = edges_child [edges_start :edges_end ] # May contain dups
1069
1069
oldest_child_time = np .max (constrained_min_times [child_ids ])
1070
1070
if oldest_child_time > constrained_min_times [parent ]:
1071
1071
if priors .is_fixed (parent ):
@@ -1204,15 +1204,17 @@ def build_grid(
1204
1204
node_var_override = node_var_override ,
1205
1205
progress = progress ,
1206
1206
)
1207
- tables = tree_sequence .tables
1208
- if np .any (tables .nodes .time [tree_sequence .samples ()] > 0 ):
1207
+ if np .any (tree_sequence .nodes_time [tree_sequence .samples ()] > 0 ):
1209
1208
if not allow_historical_samples :
1210
1209
raise ValueError (
1211
1210
"There are samples at non-zero times, invalidating the conditional "
1212
1211
"coalescent prior. You can set allow_historical_samples=True to carry "
1213
1212
"on regardless, calculating a prior as if all samples were "
1214
1213
"contemporaneous (reasonable if you only have a few ancient samples)"
1215
1214
)
1216
- if np .any (tables .nodes .time [priors .fixed_node_ids ()] > 0 ) and truncate_priors :
1215
+ if (
1216
+ np .any (tree_sequence .nodes_time [priors .fixed_node_ids ()] > 0 )
1217
+ and truncate_priors
1218
+ ):
1217
1219
priors = _truncate_priors (tree_sequence , priors , progress = progress )
1218
1220
return priors
0 commit comments