Skip to content

Commit 808e965

Browse files
JoOkumayfukai
andauthored
Adding offset to assign_track_ids and new tracklet_graph for napari (#116)
* adding optional track_id_offset to assign_track_id * adding rustworkx tracklet_graph method * moving to tracklet_graph to BaseGraph * adding sql graph tracklet_graph * fixing edge filtering * adding digraph to napari dict helper function * fixing tests * Update src/tracksdata/graph/_base_graph.py Co-authored-by: Yohsuke T. Fukai <[email protected]> --------- Co-authored-by: Yohsuke T. Fukai <[email protected]>
1 parent be07bce commit 808e965

File tree

8 files changed

+287
-13
lines changed

8 files changed

+287
-13
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Functional utilities for graph operations."""
22

3-
from tracksdata.functional._napari import to_napari_format
3+
from tracksdata.functional._napari import rx_digraph_to_napari_dict, to_napari_format
44

5-
__all__ = ["to_napari_format"]
5+
__all__ = ["rx_digraph_to_napari_dict", "to_napari_format"]

src/tracksdata/functional/_napari.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import TYPE_CHECKING, overload
22

33
import polars as pl
4+
import rustworkx as rx
45

56
from tracksdata.attrs import EdgeAttr, NodeAttr
67
from tracksdata.constants import DEFAULT_ATTR_KEYS
@@ -119,3 +120,27 @@ def to_napari_format(
119120
return tracks_data, dict_graph, array_view
120121

121122
return tracks_data, dict_graph
123+
124+
125+
def rx_digraph_to_napari_dict(
126+
tracklet_graph: rx.PyDiGraph,
127+
) -> dict[int, list[int]]:
128+
"""
129+
Convert a tracklet graph to a napari-ready dictionary.
130+
The input is a (child -> parent) graph (forward in time) and it is converted
131+
to a (parent -> child) dictionary (backward in time).
132+
133+
Parameters
134+
----------
135+
tracklet_graph : rx.PyDiGraph
136+
The tracklet graph to convert.
137+
138+
Returns
139+
-------
140+
dict[int, list[int]]
141+
A dictionary of parent -> child relationships.
142+
"""
143+
dict_graph = {}
144+
for parent, child in tracklet_graph.edges():
145+
dict_graph.setdefault(child, []).append(parent)
146+
return dict_graph

src/tracksdata/functional/_rx.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def _fast_path_transverse(
4545
def _fast_dag_transverse(
4646
starts: np.ndarray,
4747
dag: dict[int, int],
48+
track_id_offset: int,
4849
) -> tuple[list[np.ndarray], np.ndarray, np.ndarray, dict[int, int], dict[int, int]]:
4950
"""
5051
Traverse the tracks DAG creating a distinct id to each linear path.
@@ -61,6 +62,8 @@ def _fast_dag_transverse(
6162
dag : dict[int, int]
6263
Directed acyclic graph mapping parent → child for linear paths only.
6364
Dividing edges are excluded and handled separately.
65+
track_id_offset : int
66+
The starting track id, useful when assigning track ids to a subgraph.
6467
6568
Returns
6669
-------
@@ -77,7 +80,7 @@ def _fast_dag_transverse(
7780
last_to_track_id = {}
7881
first_to_track_id = {}
7982

80-
track_id = 1
83+
track_id = track_id_offset
8184

8285
for start in starts:
8386
path = _fast_path_transverse(start, dag)
@@ -247,6 +250,7 @@ def _track_id_edges_from_long_edges(
247250

248251
def _assign_track_ids(
249252
graph: rx.PyDiGraph,
253+
track_id_offset: int,
250254
) -> tuple[np.ndarray, np.ndarray, rx.PyDiGraph]:
251255
"""
252256
Assigns an unique `track_id` to each simple path in the graph and
@@ -256,6 +260,8 @@ def _assign_track_ids(
256260
----------
257261
graph : rx.PyDiGraph
258262
Directed acyclic graph of tracks.
263+
track_id_offset : int
264+
The starting track id, useful when assigning track ids to a subgraph.
259265
260266
Returns
261267
-------
@@ -272,7 +278,9 @@ def _assign_track_ids(
272278
# was it better (faster) when using a numpy array for the digraph as in ultrack?
273279
linear_dag, starts, long_edges_df = _rx_graph_to_dict_dag(graph)
274280

275-
paths, track_ids, lengths, last_to_track_id, first_to_track_id = _fast_dag_transverse(starts, linear_dag)
281+
paths, track_ids, lengths, last_to_track_id, first_to_track_id = _fast_dag_transverse(
282+
starts, linear_dag, track_id_offset
283+
)
276284

277285
n_tracks = len(track_ids)
278286

src/tracksdata/functional/_test/test_rx.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def test_empty_graph() -> None:
1010
"""Test that empty graph raises ValueError."""
1111
graph = rx.PyDiGraph()
1212
with pytest.raises(ValueError, match="Graph is empty"):
13-
_assign_track_ids(graph)
13+
_assign_track_ids(graph, track_id_offset=1)
1414

1515

1616
def test_single_path() -> None:
@@ -22,7 +22,7 @@ def test_single_path() -> None:
2222
graph.add_edge(nodes[0], nodes[1], None)
2323
graph.add_edge(nodes[1], nodes[2], None)
2424

25-
node_ids, track_ids, tracks_graph = _assign_track_ids(graph)
25+
node_ids, track_ids, tracks_graph = _assign_track_ids(graph, track_id_offset=1)
2626

2727
assert np.array_equal(node_ids, [0, 1, 2])
2828
assert np.array_equal(track_ids, [1, 1, 1])
@@ -46,7 +46,7 @@ def test_symmetric_branching_path() -> None:
4646
graph.add_edge(nodes[0], nodes[1], None)
4747
graph.add_edge(nodes[0], nodes[2], None)
4848

49-
node_ids, track_ids, tracks_graph = _assign_track_ids(graph)
49+
node_ids, track_ids, tracks_graph = _assign_track_ids(graph, track_id_offset=1)
5050

5151
# Should create 2 tracks: one for each branch
5252
assert len(node_ids) == 3
@@ -76,7 +76,7 @@ def test_asymmetric_branching_path() -> None:
7676
graph.add_edge(nodes[1], nodes[2], None)
7777
graph.add_edge(nodes[0], nodes[3], None)
7878

79-
node_ids, track_ids, tracks_graph = _assign_track_ids(graph)
79+
node_ids, track_ids, tracks_graph = _assign_track_ids(graph, track_id_offset=1)
8080

8181
# Should create 2 tracks: one for each branch
8282
assert len(node_ids) == 4
@@ -103,7 +103,7 @@ def test_invalid_multiple_parents() -> None:
103103
graph.add_edge(nodes[1], nodes[2], None)
104104

105105
with pytest.raises(RuntimeError, match="Invalid graph structure"):
106-
_assign_track_ids(graph)
106+
_assign_track_ids(graph, track_id_offset=1)
107107

108108

109109
def test_complex_valid_branching() -> None:
@@ -132,7 +132,7 @@ def test_complex_valid_branching() -> None:
132132
graph.add_edge(nodes[3], nodes[4], None)
133133
graph.add_edge(nodes[2], nodes[5], None)
134134

135-
node_ids, track_ids, tracks_graph = _assign_track_ids(graph)
135+
node_ids, track_ids, tracks_graph = _assign_track_ids(graph, track_id_offset=1)
136136

137137
# this order is an implementation detail, it could change
138138
# then the track ids should change accordingly
@@ -168,7 +168,7 @@ def test_three_children() -> None:
168168
graph.add_edge(nodes[0], nodes[2], None)
169169
graph.add_edge(nodes[0], nodes[3], None)
170170

171-
_, track_ids, tracks_graph = _assign_track_ids(graph)
171+
_, track_ids, tracks_graph = _assign_track_ids(graph, track_id_offset=1)
172172
assert set(tracks_graph.successor_indices(track_ids[0])) == set(track_ids[1:])
173173

174174

@@ -182,7 +182,7 @@ def test_multiple_roots() -> None:
182182
graph.add_edge(nodes[0], nodes[1], None)
183183
graph.add_edge(nodes[2], nodes[3], None)
184184

185-
node_ids, track_ids, tracks_graph = _assign_track_ids(graph)
185+
node_ids, track_ids, tracks_graph = _assign_track_ids(graph, track_id_offset=1)
186186

187187
assert len(node_ids) == 4
188188
assert len(track_ids) == 4

src/tracksdata/graph/_base_graph.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
import polars as pl
10+
import rustworkx as rx
1011
from numpy.typing import ArrayLike
1112

1213
from tracksdata.attrs import AttrComparison, NodeAttr
@@ -948,3 +949,79 @@ def spatial_filter(self, attrs_keys: list[str] | None = None) -> "SpatialFilter"
948949
from tracksdata.graph.filters._spatial_filter import SpatialFilter
949950

950951
return SpatialFilter(self, attrs_keys=attrs_keys)
952+
953+
def tracklet_graph(
954+
self,
955+
track_id_key: str = DEFAULT_ATTR_KEYS.TRACK_ID,
956+
ignore_track_id: int | None = None,
957+
) -> rx.PyDiGraph:
958+
"""
959+
Create a compressed tracklet graph where each node is a tracklet
960+
and each edge is a transition between tracklets.
961+
962+
IMPORTANT:
963+
rx.PyDiGraph does not allow arbitrary indices, so we use the tracklet ids as node values.
964+
And edge values are the tuple of source and target tracklet ids.
965+
966+
Parameters
967+
----------
968+
track_id_key : str
969+
The key of the track id attribute.
970+
ignore_track_id : int | None
971+
The track id to ignore. If None, all track ids are used.
972+
973+
Returns
974+
-------
975+
rx.PyDiGraph
976+
A compressed tracklet graph.
977+
978+
See Also
979+
--------
980+
[rx_digraph_to_napari_dict][tracksdata.functional.rx_digraph_to_napari_dict]:
981+
Convert a tracklet graph to a napari-ready dictionary.
982+
"""
983+
984+
if track_id_key not in self.node_attr_keys:
985+
raise ValueError(f"Track id key '{track_id_key}' not found in graph. Expected '{self.node_attr_keys}'")
986+
987+
nodes_df = self.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, track_id_key])
988+
edges_df = self.edge_attrs(attr_keys=[])
989+
990+
if ignore_track_id is not None:
991+
nodes_df = nodes_df.filter(pl.col(track_id_key) != ignore_track_id)
992+
993+
nodes_df = nodes_df.unique(subset=[track_id_key])
994+
995+
graph = rx.PyDiGraph()
996+
nodes_df = nodes_df.with_columns(
997+
pl.Series(
998+
np.asarray(graph.add_nodes_from(nodes_df[track_id_key].to_list()), dtype=int),
999+
).alias("rx_id"),
1000+
)
1001+
1002+
edges_df = (
1003+
edges_df.join(
1004+
nodes_df.rename({track_id_key: "source_track_id", "rx_id": "source_rx_id"}),
1005+
left_on=DEFAULT_ATTR_KEYS.EDGE_SOURCE,
1006+
right_on=DEFAULT_ATTR_KEYS.NODE_ID,
1007+
how="right",
1008+
)
1009+
.join(
1010+
nodes_df.rename({track_id_key: "target_track_id", "rx_id": "target_rx_id"}),
1011+
left_on=DEFAULT_ATTR_KEYS.EDGE_TARGET,
1012+
right_on=DEFAULT_ATTR_KEYS.NODE_ID,
1013+
how="right",
1014+
)
1015+
.filter(~pl.col(DEFAULT_ATTR_KEYS.EDGE_ID).is_null())
1016+
)
1017+
1018+
graph.add_edges_from(
1019+
zip(
1020+
edges_df["source_rx_id"].to_list(),
1021+
edges_df["target_rx_id"].to_list(),
1022+
zip(edges_df["source_track_id"].to_list(), edges_df["target_track_id"].to_list(), strict=False),
1023+
strict=True,
1024+
)
1025+
)
1026+
1027+
return graph

src/tracksdata/graph/_graph_view.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ def assign_track_ids(
447447
self,
448448
output_key: str = DEFAULT_ATTR_KEYS.TRACK_ID,
449449
reset: bool = True,
450+
track_id_offset: int = 1,
450451
) -> rx.PyDiGraph:
451452
"""
452453
Compute and assign track ids to nodes.
@@ -457,14 +458,16 @@ def assign_track_ids(
457458
The key of the output track id attribute.
458459
reset : bool
459460
Whether to reset all track ids before assigning new ones.
461+
track_id_offset : int
462+
The starting track id, useful when assigning track ids to a subgraph.
460463
461464
Returns
462465
-------
463466
rx.PyDiGraph
464467
A compressed graph (parent -> child) with track ids lineage relationships.
465468
"""
466469
try:
467-
node_ids, track_ids, tracks_graph = _assign_track_ids(self.rx_graph)
470+
node_ids, track_ids, tracks_graph = _assign_track_ids(self.rx_graph, track_id_offset)
468471
except RuntimeError as e:
469472
raise RuntimeError(
470473
"Are you sure this graph is a valid lineage graph?\n"

src/tracksdata/graph/_sql_graph.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,3 +1341,92 @@ def __setstate__(self, state: dict) -> None:
13411341
# recreate deleted objects
13421342
self._engine = sa.create_engine(self._url, **self._engine_kwargs)
13431343
self._define_schema(overwrite=False)
1344+
1345+
def tracklet_graph(
1346+
self,
1347+
track_id_key: str = DEFAULT_ATTR_KEYS.TRACK_ID,
1348+
ignore_track_id: int | None = None,
1349+
) -> rx.PyDiGraph:
1350+
"""
1351+
Create a compressed tracklet graph where each node is a tracklet
1352+
and each edge is a transition between tracklets.
1353+
1354+
IMPORTANT:
1355+
rx.PyDiGraph does not allow arbitrary indices, so we use the tracklet ids as node values.
1356+
And edge values are the tuple of source and target tracklet ids.
1357+
1358+
Parameters
1359+
----------
1360+
track_id_key : str
1361+
The key of the track id attribute.
1362+
ignore_track_id : int | None
1363+
The track id to ignore. If None, all track ids are used.
1364+
1365+
Returns
1366+
-------
1367+
rx.PyDiGraph
1368+
A compressed tracklet graph.
1369+
"""
1370+
1371+
if track_id_key not in self.node_attr_keys:
1372+
raise ValueError(f"Track id key '{track_id_key}' not found in graph. Expected '{self.node_attr_keys}'")
1373+
1374+
with Session(self._engine) as session:
1375+
node_query = sa.select(getattr(self.Node, track_id_key)).distinct()
1376+
1377+
SourceNode = aliased(self.Node)
1378+
TargetNode = aliased(self.Node)
1379+
1380+
edge_query = (
1381+
sa.select(
1382+
getattr(self.Edge, DEFAULT_ATTR_KEYS.EDGE_SOURCE),
1383+
getattr(self.Edge, DEFAULT_ATTR_KEYS.EDGE_TARGET),
1384+
)
1385+
.join(
1386+
SourceNode,
1387+
SourceNode.node_id == self.Edge.source_id,
1388+
)
1389+
.join(
1390+
TargetNode,
1391+
TargetNode.node_id == self.Edge.target_id,
1392+
)
1393+
.filter(
1394+
getattr(SourceNode, track_id_key) != getattr(TargetNode, track_id_key),
1395+
)
1396+
)
1397+
1398+
if ignore_track_id is not None:
1399+
node_query = node_query.filter(getattr(self.Node, track_id_key) != ignore_track_id)
1400+
edge_query = edge_query.filter(
1401+
getattr(SourceNode, track_id_key) != ignore_track_id,
1402+
getattr(TargetNode, track_id_key) != ignore_track_id,
1403+
)
1404+
1405+
edge_query = edge_query.with_only_columns(
1406+
getattr(SourceNode, track_id_key).label("source_track_id"),
1407+
getattr(TargetNode, track_id_key).label("target_track_id"),
1408+
)
1409+
1410+
nodes_df = pl.read_database(
1411+
self._raw_query(node_query),
1412+
connection=session.connection(),
1413+
)
1414+
1415+
edges_df = pl.read_database(
1416+
self._raw_query(edge_query),
1417+
connection=session.connection(),
1418+
)
1419+
1420+
graph = rx.PyDiGraph()
1421+
tracklet_ids = nodes_df[track_id_key].to_list()
1422+
tracklet_id_to_rx = dict(zip(tracklet_ids, graph.add_nodes_from(tracklet_ids), strict=False))
1423+
graph.add_edges_from(
1424+
zip(
1425+
edges_df["source_track_id"].map_elements(tracklet_id_to_rx.__getitem__, return_dtype=int).to_list(),
1426+
edges_df["target_track_id"].map_elements(tracklet_id_to_rx.__getitem__, return_dtype=int).to_list(),
1427+
zip(edges_df["source_track_id"].to_list(), edges_df["target_track_id"].to_list(), strict=False),
1428+
strict=True,
1429+
)
1430+
)
1431+
1432+
return graph

0 commit comments

Comments
 (0)