Skip to content

Commit 6cfecbe

Browse files
committed
Switch parent_index to a count sort
1 parent 64b1ece commit 6cfecbe

File tree

2 files changed

+33
-44
lines changed

2 files changed

+33
-44
lines changed

python/tests/test_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def test_parent_index_correctness(ts):
152152
expected_parents.append(edge_id)
153153

154154
if len(expected_parents) == 0:
155-
assert start == -1 and stop == -1
155+
assert start == stop
156156
else:
157157
assert stop > start
158158
actual_parent_edge_ids = []

python/tskit/jit/numba.py

Lines changed: 32 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -479,58 +479,47 @@ def child_index(self):
479479

480480
def parent_index(self):
481481
"""
482-
Create a :class:`ParentIndex` for finding parent edges of nodes. This
483-
operation requires sorting the edges by child ID and left coordinate,
484-
and therefore requires O(E log E) time complexity.
482+
Create a :class:`ParentIndex` for finding parent edges of nodes.
483+
484+
Edges within each child's group are not guaranteed to be in any
485+
specific order. This operation uses a two-pass algorithm with
486+
O(N + E) time complexity and O(N) auxiliary space.
485487
486488
:return: A new parent index container that can be used to
487489
efficiently find all edges where a given node is the child.
488490
:rtype: ParentIndex
489491
"""
490-
index_range = np.full((self.num_nodes, 2), -1, dtype=np.int32)
491-
edge_index = np.zeros(self.num_edges, dtype=np.int32)
492-
if self.num_edges == 0:
493-
return ParentIndex(edge_index, index_range)
494-
495-
# Create array of edge IDs
496-
edge_index[:] = np.arange(self.num_edges, dtype=np.int32)
497-
498-
# Sort edge IDs by child node (and by left coordinate as secondary sort)
499-
# We need to implement our own sorting since numba doesn't support lexsort
500-
# Use a stable sort to maintain order for secondary key
501-
# First sort by left coordinate (secondary key) using a stable sort
502-
edges_left = self.edges_left
492+
num_nodes = self.num_nodes
493+
num_edges = self.num_edges
503494
edges_child = self.edges_child
504495

505-
left_coords = np.zeros(self.num_edges, dtype=np.float64)
506-
for i in range(self.num_edges):
507-
left_coords[i] = edges_left[edge_index[i]]
508-
509-
# Stable sort by left coordinate
510-
sort_indices = np.argsort(left_coords, kind="mergesort")
511-
edge_index[:] = edge_index[sort_indices]
512-
513-
# Stable sort by child node
514-
child_nodes = np.zeros(self.num_edges, dtype=np.int32)
515-
for i in range(self.num_edges):
516-
child_nodes[i] = edges_child[edge_index[i]]
517-
sort_indices = np.argsort(child_nodes, kind="mergesort")
518-
edge_index[:] = edge_index[sort_indices]
519-
520-
# Find ranges
521-
last_child = -1
522-
for j in range(self.num_edges):
523-
edge_id = edge_index[j]
524-
child = edges_child[edge_id]
496+
child_counts = np.zeros(num_nodes, dtype=np.int32)
497+
edge_index = np.zeros(num_edges, dtype=np.int32)
498+
index_range = np.zeros((num_nodes, 2), dtype=np.int32)
525499

526-
if child != last_child:
527-
index_range[child, 0] = j
528-
if last_child != -1:
529-
index_range[last_child, 1] = j
530-
last_child = child
500+
if num_edges == 0:
501+
return ParentIndex(edge_index, index_range)
531502

532-
if last_child != -1:
533-
index_range[last_child, 1] = self.num_edges
503+
# Count how many children each node has
504+
for child_node in edges_child:
505+
child_counts[child_node] += 1
506+
507+
# From the counts build the index ranges, we set both the start and the
508+
# end index to the start - this lets us use the end index as a tracker
509+
# for where we should insert the next edge for that node - when all
510+
# edges are done these values will be the correct end values!
511+
current_start = 0
512+
for i in range(num_nodes):
513+
index_range[i, :] = current_start
514+
current_start += child_counts[i]
515+
516+
# Now go over the edges, inserting them at the index pointed to
517+
# by the node's current end value, then increment.
518+
for edge_id in range(num_edges):
519+
child = edges_child[edge_id]
520+
pos = index_range[child, 1]
521+
edge_index[pos] = edge_id
522+
index_range[child, 1] += 1
534523

535524
return ParentIndex(edge_index, index_range)
536525

0 commit comments

Comments
 (0)