Skip to content

Commit faae8e1

Browse files
committed
Use direct memory access for checks
1 parent c7a348a commit faae8e1

File tree

3 files changed

+21
-27
lines changed

3 files changed

+21
-27
lines changed

tsdate/core.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,12 @@ def precalculate_mutation_likelihoods(self, num_threads=None, unique_method=0):
180180
if e.child not in self.fixednodes
181181
}
182182
else:
183-
edges = self.ts.tables.edges
184183
fixed_nodes = np.array(list(self.fixednodes))
185184
keys = np.unique(
186185
np.core.records.fromarrays(
187-
(self.mut_edges, edges.right - edges.left), names="muts,span"
188-
)[np.logical_not(np.isin(edges.child, fixed_nodes))]
186+
(self.mut_edges, self.ts.edges_right - self.ts.edges_left),
187+
names="muts,span",
188+
)[np.logical_not(np.isin(self.ts.edges_child, fixed_nodes))]
189189
)
190190
if unique_method == 1:
191191
self.unfixed_likelihood_cache = dict.fromkeys({tuple(t) for t in keys})
@@ -602,8 +602,8 @@ def __init__(self, priors, lik, *, progress=False):
602602
self.priors.force_probability_space(lik.probability_space)
603603

604604
self.spans = np.bincount(
605-
self.ts.tables.edges.child,
606-
weights=self.ts.tables.edges.right - self.ts.tables.edges.left,
605+
self.ts.edges_child,
606+
weights=self.ts.edges_right - self.ts.edges_left,
607607
)
608608
self.spans = np.pad(self.spans, (0, self.ts.num_nodes - len(self.spans)))
609609

@@ -653,15 +653,15 @@ def edges_by_child_then_parent_desc(self, grouped=True):
653653
"""
654654
wtype = np.dtype(
655655
[
656-
("child_age", self.ts.tables.nodes.time.dtype),
657-
("child_node", self.ts.tables.edges.child.dtype),
658-
("parent_age", self.ts.tables.nodes.time.dtype),
656+
("child_age", self.ts.nodes_time.dtype),
657+
("child_node", self.ts.edges_child.dtype),
658+
("parent_age", self.ts.nodes_time.dtype),
659659
]
660660
)
661661
w = np.empty(self.ts.num_edges, dtype=wtype)
662-
w["child_age"] = self.ts.tables.nodes.time[self.ts.tables.edges.child]
663-
w["child_node"] = self.ts.tables.edges.child
664-
w["parent_age"] = -self.ts.tables.nodes.time[self.ts.tables.edges.parent]
662+
w["child_age"] = self.ts.nodes_time[self.ts.edges_child]
663+
w["child_node"] = self.ts.edges_child
664+
w["parent_age"] = -self.ts.nodes_time[self.ts.edges_parent]
665665
sorted_child_parent = (
666666
self.ts.edge(i)
667667
for i in reversed(
@@ -740,9 +740,7 @@ def inside_pass(self, *, standardize=True, cache_inside=False, progress=None):
740740
if standardize:
741741
marginal_lik = self.lik.combine(marginal_lik, denominator[parent])
742742
if cache_inside:
743-
self.g_i = self.lik.ratio(
744-
g_i, denominator[self.ts.tables.edges.child, None]
745-
)
743+
self.g_i = self.lik.ratio(g_i, denominator[self.ts.edges_child, None])
746744
# Keep the results in this object
747745
self.inside = inside
748746
self.denominator = denominator
@@ -791,7 +789,7 @@ def outside_pass(
791789
for child, edges in tqdm(
792790
self.edges_by_child_desc(),
793791
desc="Outside",
794-
total=len(np.unique(self.ts.tables.edges.child)),
792+
total=len(np.unique(self.ts.edges_child)),
795793
disable=not progress,
796794
):
797795
if child in self.fixednodes:
@@ -859,9 +857,7 @@ def outside_maximization(self, *, eps, progress=None):
859857

860858
mut_edges = self.lik.mut_edges
861859
mrcas = np.where(
862-
np.isin(
863-
np.arange(self.ts.num_nodes), self.ts.tables.edges.child, invert=True
864-
)
860+
np.isin(np.arange(self.ts.num_nodes), self.ts.edges_child, invert=True)
865861
)[0]
866862
for i in mrcas:
867863
if i not in self.fixednodes:
@@ -870,7 +866,7 @@ def outside_maximization(self, *, eps, progress=None):
870866
for child, edges in tqdm(
871867
self.edges_by_child_then_parent_desc(),
872868
desc="Maximization",
873-
total=len(np.unique(self.ts.tables.edges.child)),
869+
total=len(np.unique(self.ts.edges_child)),
874870
disable=not progress,
875871
):
876872
if child in self.fixednodes:

tsdate/prior.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def __init__(self, tree_sequence, *, progress=False, allow_unary=False):
467467

468468
self.ts = tree_sequence
469469
self.sample_node_set = set(self.ts.samples())
470-
if np.any(self.ts.tables.nodes.time[self.ts.samples()] != 0):
470+
if np.any(self.ts.nodes_time[self.ts.samples()] != 0):
471471
raise ValueError(
472472
"The SpansBySamples class needs a tree seq with all samples at time 0"
473473
)
@@ -1032,7 +1032,7 @@ def fill_priors(
10321032
# convert timepoints to generational timescale
10331033
prior_times = base.NodeGridValues(
10341034
ts.num_nodes,
1035-
datable_nodes[np.argsort(ts.tables.nodes.time[datable_nodes])].astype(np.int32),
1035+
datable_nodes[np.argsort(ts.nodes_time[datable_nodes])].astype(np.int32),
10361036
population_size.to_natural_timescale(timepoints),
10371037
)
10381038

@@ -1167,9 +1167,7 @@ def make_parameter_grid(self, population_size, progress=False):
11671167

11681168
prior_pars = base.NodeGridValues(
11691169
self.tree_sequence.num_nodes,
1170-
datable_nodes[np.argsort(ts.tables.nodes.time[datable_nodes])].astype(
1171-
np.int32
1172-
),
1170+
datable_nodes[np.argsort(ts.nodes_time[datable_nodes])].astype(np.int32),
11731171
np.array([0, np.inf]),
11741172
)
11751173
prior_pars.probability_space = base.GAMMA_PAR

tsdate/util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def reduce_to_contemporaneous(ts):
3939
Simplify the ts to only the contemporaneous samples, and return the new ts + node map
4040
"""
4141
samples = ts.samples()
42-
contmpr_samples = samples[ts.tables.nodes.time[samples] == 0]
42+
contmpr_samples = samples[ts.nodes_time[samples] == 0]
4343
return ts.simplify(
4444
contmpr_samples,
4545
map_nodes=True,
@@ -187,7 +187,7 @@ def nodes_time_unconstrained(tree_sequence):
187187
stored in the node metadata). Will produce an error if the tree sequence does
188188
not contain this information.
189189
"""
190-
nodes_time = tree_sequence.tables.nodes.time.copy()
190+
nodes_time = tree_sequence.nodes_time.copy()
191191
metadata = tree_sequence.tables.nodes.metadata
192192
metadata_offset = tree_sequence.tables.nodes.metadata_offset
193193
for index, met in enumerate(tskit.unpack_bytes(metadata, metadata_offset)):
@@ -270,7 +270,7 @@ def sites_time_from_ts(
270270
e.args += "Try calling sites_time_from_ts() with unconstrained=False."
271271
raise
272272
else:
273-
nodes_time = tree_sequence.tables.nodes.time
273+
nodes_time = tree_sequence.nodes_time
274274
sites_time = np.full(tree_sequence.num_sites, np.nan)
275275

276276
for tree in tree_sequence.trees():

0 commit comments

Comments
 (0)