Skip to content

Commit 6463f4b

Browse files
authored
Merge pull request #260 from hyanwong/edge-iteration-drop-in
Some simpler functions
2 parents 9588521 + 4ba21a9 commit 6463f4b

File tree

1 file changed

+8
-30
lines changed

1 file changed

+8
-30
lines changed

tsdate/core.py

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -127,24 +127,9 @@ def get_mut_edges(ts):
127127
"""
128128
Get the number of mutations on each edge in the tree sequence.
129129
"""
130-
edge_diff_iter = ts.edge_diffs()
131-
right = 0
132-
edges_by_child = {} # contains {child_node:edge_id}
133130
mut_edges = np.zeros(ts.num_edges, dtype=np.int64)
134-
for site in ts.sites():
135-
while right <= site.position:
136-
(left, right), edges_out, edges_in = next(edge_diff_iter)
137-
for e in edges_out:
138-
del edges_by_child[e.child]
139-
for e in edges_in:
140-
assert e.child not in edges_by_child
141-
edges_by_child[e.child] = e.id
142-
for m in site.mutations:
143-
# In some cases, mutations occur above the root
144-
# These don't provide any information for the inside step
145-
if m.node in edges_by_child:
146-
edge_id = edges_by_child[m.node]
147-
mut_edges[edge_id] += 1
131+
for m in ts.mutations():
132+
mut_edges[m.edge] += 1
148133
return mut_edges
149134

150135
@staticmethod
@@ -585,20 +570,13 @@ def edges_by_child_desc(self):
585570
Return an itertools.groupby object of edges grouped by child in descending order
586571
of the time of the child.
587572
"""
588-
wtype = np.dtype(
589-
[
590-
("child_age", self.ts.tables.nodes.time.dtype),
591-
("child_node", self.ts.tables.edges.child.dtype),
592-
]
593-
)
594-
w = np.empty(self.ts.num_edges, dtype=wtype)
595-
w["child_age"] = self.ts.tables.nodes.time[self.ts.tables.edges.child]
596-
w["child_node"] = self.ts.tables.edges.child
597-
sorted_child_parent = (
598-
self.ts.edge(i)
599-
for i in reversed(np.argsort(w, order=("child_age", "child_node")))
573+
it = (
574+
self.ts.edge(u)
575+
for u in np.lexsort(
576+
(self.ts.edges_child, -self.ts.nodes_time[self.ts.edges_child])
577+
)
600578
)
601-
return itertools.groupby(sorted_child_parent, operator.attrgetter("child"))
579+
return itertools.groupby(it, operator.attrgetter("child"))
602580

603581
def edges_by_child_then_parent_desc(self):
604582
"""

0 commit comments

Comments
 (0)