Skip to content

Commit 6da3a58

Browse files
authored
Merge pull request #356 from hyanwong/API-extras
Use direct memory access for checks
2 parents 9b6b275 + faae8e1 commit 6da3a58

File tree

3 files changed

+21
-27
lines changed

3 files changed

+21
-27
lines changed

tsdate/core.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,12 @@ def precalculate_mutation_likelihoods(self, num_threads=None, unique_method=0):
181181
if e.child not in self.fixednodes
182182
}
183183
else:
184-
edges = self.ts.tables.edges
185184
fixed_nodes = np.array(list(self.fixednodes))
186185
keys = np.unique(
187186
np.core.records.fromarrays(
188-
(self.mut_edges, edges.right - edges.left), names="muts,span"
189-
)[np.logical_not(np.isin(edges.child, fixed_nodes))]
187+
(self.mut_edges, self.ts.edges_right - self.ts.edges_left),
188+
names="muts,span",
189+
)[np.logical_not(np.isin(self.ts.edges_child, fixed_nodes))]
190190
)
191191
if unique_method == 1:
192192
self.unfixed_likelihood_cache = dict.fromkeys({tuple(t) for t in keys})
@@ -603,8 +603,8 @@ def __init__(self, priors, lik, *, progress=False):
603603
self.priors.force_probability_space(lik.probability_space)
604604

605605
self.spans = np.bincount(
606-
self.ts.tables.edges.child,
607-
weights=self.ts.tables.edges.right - self.ts.tables.edges.left,
606+
self.ts.edges_child,
607+
weights=self.ts.edges_right - self.ts.edges_left,
608608
)
609609
self.spans = np.pad(self.spans, (0, self.ts.num_nodes - len(self.spans)))
610610

@@ -654,15 +654,15 @@ def edges_by_child_then_parent_desc(self, grouped=True):
654654
"""
655655
wtype = np.dtype(
656656
[
657-
("child_age", self.ts.tables.nodes.time.dtype),
658-
("child_node", self.ts.tables.edges.child.dtype),
659-
("parent_age", self.ts.tables.nodes.time.dtype),
657+
("child_age", self.ts.nodes_time.dtype),
658+
("child_node", self.ts.edges_child.dtype),
659+
("parent_age", self.ts.nodes_time.dtype),
660660
]
661661
)
662662
w = np.empty(self.ts.num_edges, dtype=wtype)
663-
w["child_age"] = self.ts.tables.nodes.time[self.ts.tables.edges.child]
664-
w["child_node"] = self.ts.tables.edges.child
665-
w["parent_age"] = -self.ts.tables.nodes.time[self.ts.tables.edges.parent]
663+
w["child_age"] = self.ts.nodes_time[self.ts.edges_child]
664+
w["child_node"] = self.ts.edges_child
665+
w["parent_age"] = -self.ts.nodes_time[self.ts.edges_parent]
666666
sorted_child_parent = (
667667
self.ts.edge(i)
668668
for i in reversed(
@@ -741,9 +741,7 @@ def inside_pass(self, *, standardize=True, cache_inside=False, progress=None):
741741
if standardize:
742742
marginal_lik = self.lik.combine(marginal_lik, denominator[parent])
743743
if cache_inside:
744-
self.g_i = self.lik.ratio(
745-
g_i, denominator[self.ts.tables.edges.child, None]
746-
)
744+
self.g_i = self.lik.ratio(g_i, denominator[self.ts.edges_child, None])
747745
# Keep the results in this object
748746
self.inside = inside
749747
self.denominator = denominator
@@ -792,7 +790,7 @@ def outside_pass(
792790
for child, edges in tqdm(
793791
self.edges_by_child_desc(),
794792
desc="Outside",
795-
total=len(np.unique(self.ts.tables.edges.child)),
793+
total=len(np.unique(self.ts.edges_child)),
796794
disable=not progress,
797795
):
798796
if child in self.fixednodes:
@@ -860,9 +858,7 @@ def outside_maximization(self, *, eps, progress=None):
860858

861859
mut_edges = self.lik.mut_edges
862860
mrcas = np.where(
863-
np.isin(
864-
np.arange(self.ts.num_nodes), self.ts.tables.edges.child, invert=True
865-
)
861+
np.isin(np.arange(self.ts.num_nodes), self.ts.edges_child, invert=True)
866862
)[0]
867863
for i in mrcas:
868864
if i not in self.fixednodes:
@@ -871,7 +867,7 @@ def outside_maximization(self, *, eps, progress=None):
871867
for child, edges in tqdm(
872868
self.edges_by_child_then_parent_desc(),
873869
desc="Maximization",
874-
total=len(np.unique(self.ts.tables.edges.child)),
870+
total=len(np.unique(self.ts.edges_child)),
875871
disable=not progress,
876872
):
877873
if child in self.fixednodes:

tsdate/prior.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def __init__(self, tree_sequence, *, progress=False, allow_unary=False):
469469

470470
self.ts = tree_sequence
471471
self.sample_node_set = set(self.ts.samples())
472-
if np.any(self.ts.tables.nodes.time[self.ts.samples()] != 0):
472+
if np.any(self.ts.nodes_time[self.ts.samples()] != 0):
473473
raise ValueError(
474474
"The SpansBySamples class needs a tree seq with all samples at time 0"
475475
)
@@ -1034,7 +1034,7 @@ def fill_priors(
10341034
# convert timepoints to generational timescale
10351035
prior_times = base.NodeGridValues(
10361036
ts.num_nodes,
1037-
datable_nodes[np.argsort(ts.tables.nodes.time[datable_nodes])].astype(np.int32),
1037+
datable_nodes[np.argsort(ts.nodes_time[datable_nodes])].astype(np.int32),
10381038
population_size.to_natural_timescale(timepoints),
10391039
)
10401040

@@ -1169,9 +1169,7 @@ def make_parameter_grid(self, population_size, progress=False):
11691169

11701170
prior_pars = base.NodeGridValues(
11711171
self.tree_sequence.num_nodes,
1172-
datable_nodes[np.argsort(ts.tables.nodes.time[datable_nodes])].astype(
1173-
np.int32
1174-
),
1172+
datable_nodes[np.argsort(ts.nodes_time[datable_nodes])].astype(np.int32),
11751173
np.array([0, np.inf]),
11761174
)
11771175
prior_pars.probability_space = base.GAMMA_PAR

tsdate/util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def reduce_to_contemporaneous(ts):
3939
Simplify the ts to only the contemporaneous samples, and return the new ts + node map
4040
"""
4141
samples = ts.samples()
42-
contmpr_samples = samples[ts.tables.nodes.time[samples] == 0]
42+
contmpr_samples = samples[ts.nodes_time[samples] == 0]
4343
return ts.simplify(
4444
contmpr_samples,
4545
map_nodes=True,
@@ -187,7 +187,7 @@ def nodes_time_unconstrained(tree_sequence):
187187
stored in the node metadata). Will produce an error if the tree sequence does
188188
not contain this information.
189189
"""
190-
nodes_time = tree_sequence.tables.nodes.time.copy()
190+
nodes_time = tree_sequence.nodes_time.copy()
191191
metadata = tree_sequence.tables.nodes.metadata
192192
metadata_offset = tree_sequence.tables.nodes.metadata_offset
193193
for index, met in enumerate(tskit.unpack_bytes(metadata, metadata_offset)):
@@ -270,7 +270,7 @@ def sites_time_from_ts(
270270
e.args += "Try calling sites_time_from_ts() with unconstrained=False."
271271
raise
272272
else:
273-
nodes_time = tree_sequence.tables.nodes.time
273+
nodes_time = tree_sequence.nodes_time
274274
sites_time = np.full(tree_sequence.num_sites, np.nan)
275275

276276
for tree in tree_sequence.trees():

0 commit comments

Comments
 (0)