Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 78 additions & 49 deletions src/tracksdata/graph/filters/_spatial_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
self,
indices: pl.Series,
df: pl.DataFrame,
) -> None:
):
from spatial_graph import PointRTree

start_time = time.time()
Expand Down Expand Up @@ -194,6 +194,73 @@ def __getitem__(self, keys: tuple[slice, ...]) -> "BaseFilter":
return self._graph.filter(node_ids=node_ids)


class DataFrameBBoxFilter:
# TODO: ...
def __init__(
self,
indices: pl.Series,
bboxes: pl.Series,
frames: pl.DataFrame | None,
):
from spatial_graph import PointRTree

_indices = np.ascontiguousarray(indices.to_numpy(), dtype=np.int64).copy()

if bboxes.is_empty():
self._node_rtree = None
else:
_bboxes = bboxes.to_numpy()
if _bboxes.shape[1] % 2 != 0:
raise ValueError(
f"Bounding box coordinates must have even number of dimensions, got {_bboxes.shape[1]}"
)
num_dims = _bboxes.shape[1] // 2

if frames is None:
self._ndims = num_dims
positions_min = np.ascontiguousarray(_bboxes[:, :num_dims], dtype=np.float32)
positions_max = np.ascontiguousarray(_bboxes[:, num_dims:], dtype=np.float32)
else:
_frames = frames.to_numpy()
self._ndims = num_dims + 1 # +1 for the frame dimension
positions_min = np.ascontiguousarray(
np.hstack((_frames[:, np.newaxis], _bboxes[:, :num_dims])), dtype=np.float32
)
positions_max = np.ascontiguousarray(
np.hstack((_frames[:, np.newaxis], _bboxes[:, num_dims:])), dtype=np.float32
)

self._node_rtree = PointRTree(
item_dtype="int64",
coord_dtype="float32",
dims=self._ndims,
)
self._node_rtree.insert_bb_items(_indices, positions_min, positions_max)

def __getitem__(self, keys: tuple[slice, ...]) -> list[int]:
# TODO
if self._node_rtree is None:
return []

for key in keys:
if key.start is None or key.stop is None:
raise ValueError(f"Slice {key} must have start and stop")

if len(keys) != self._ndims:
raise ValueError(f"Expected {self._ndims} keys, got {len(keys)}")

node_ids = self._node_rtree.search(
*(
np.stack(
[[s.start, s.stop] for s in keys],
axis=1,
dtype=np.float32,
)
)
)
return node_ids.tolist()


class BBoxSpatialFilter:
"""
Spatial filter for bounding box queries on graph nodes.
Expand Down Expand Up @@ -223,44 +290,25 @@ def __init__(
frame_attr_key: str | None = DEFAULT_ATTR_KEYS.T,
bbox_attr_key: str = DEFAULT_ATTR_KEYS.BBOX,
) -> None:
from spatial_graph import PointRTree

self._graph = graph

if frame_attr_key is None:
attr_keys = [DEFAULT_ATTR_KEYS.NODE_ID, bbox_attr_key]
else:
attr_keys = [DEFAULT_ATTR_KEYS.NODE_ID, frame_attr_key, bbox_attr_key]

nodes_df = graph.node_attrs(attr_keys=attr_keys)
node_ids = np.ascontiguousarray(nodes_df[DEFAULT_ATTR_KEYS.NODE_ID].to_numpy(), dtype=np.int64).copy()

if nodes_df.is_empty():
self._node_rtree = None
if frame_attr_key is None:
frames = None
else:
bboxes = nodes_df[bbox_attr_key].to_numpy()
if bboxes.shape[1] % 2 != 0:
raise ValueError(f"Bounding box coordinates must have even number of dimensions, got {bboxes.shape[1]}")
num_dims = bboxes.shape[1] // 2
frames = nodes_df[frame_attr_key]

if frame_attr_key is None:
self._ndims = num_dims
positions_min = np.ascontiguousarray(bboxes[:, :num_dims], dtype=np.float32)
positions_max = np.ascontiguousarray(bboxes[:, num_dims:], dtype=np.float32)
else:
frames = nodes_df[frame_attr_key].to_numpy()
self._ndims = num_dims + 1 # +1 for the frame dimension
positions_min = np.ascontiguousarray(
np.hstack((frames[:, np.newaxis], bboxes[:, :num_dims])), dtype=np.float32
)
positions_max = np.ascontiguousarray(
np.hstack((frames[:, np.newaxis], bboxes[:, num_dims:])), dtype=np.float32
)
self._node_rtree = PointRTree(
item_dtype="int64",
coord_dtype="float32",
dims=self._ndims,
)
self._node_rtree.insert_bb_items(node_ids, positions_min, positions_max)
self._df_filter = DataFrameBBoxFilter(
indices=nodes_df[DEFAULT_ATTR_KEYS.NODE_ID],
bboxes=nodes_df[bbox_attr_key],
frames=frames,
)

def __getitem__(self, keys: tuple[slice, ...]) -> "BaseFilter":
"""
Expand Down Expand Up @@ -300,24 +348,5 @@ def __getitem__(self, keys: tuple[slice, ...]) -> "BaseFilter":
subgraph = spatial_filter[0:10, 0:5, 10:50, 20:60].subgraph()
```
"""

if self._node_rtree is None:
return self._graph.filter(node_ids=[])

for key in keys:
if key.start is None or key.stop is None:
raise ValueError(f"Slice {key} must have start and stop")

if len(keys) != self._ndims:
raise ValueError(f"Expected {self._ndims} keys, got {len(keys)}")

node_ids = self._node_rtree.search(
*(
np.stack(
[[s.start, s.stop] for s in keys],
axis=1,
dtype=np.float32,
)
)
)
node_ids = self._df_filter[keys]
return self._graph.filter(node_ids=node_ids)
4 changes: 2 additions & 2 deletions src/tracksdata/graph/filters/_test/test_spatial_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ def test_bbox_spatial_filter_with_edges() -> None:
def test_bbox_spatial_filter_initialization(sample_bbox_graph: RustWorkXGraph) -> None:
"""Test BoundingBoxSpatialFilter initialization with default and custom attributes."""
spatial_filter = BBoxSpatialFilter(sample_bbox_graph)
assert spatial_filter._node_rtree is not None
assert spatial_filter._df_filter._node_rtree is not None

spatial_filter = BBoxSpatialFilter(sample_bbox_graph, frame_attr_key="t", bbox_attr_key="bbox")
assert spatial_filter._node_rtree is not None
assert spatial_filter._df_filter._node_rtree is not None


def test_bbox_spatial_filter_querying(sample_bbox_graph: RustWorkXGraph) -> None:
Expand Down
33 changes: 27 additions & 6 deletions src/tracksdata/metrics/_ctc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from toolz import curry

from tracksdata.constants import DEFAULT_ATTR_KEYS
from tracksdata.graph.filters._spatial_filter import DataFrameBBoxFilter
from tracksdata.io._ctc import compressed_tracks_table
from tracksdata.options import get_options
from tracksdata.utils._dtypes import column_from_bytes, column_to_bytes
Expand Down Expand Up @@ -66,12 +67,30 @@ def _match_single_frame(
_rows = []
_cols = []

for i, (ref_id, ref_mask) in enumerate(
zip(ref_group[reference_graph_key], ref_group[DEFAULT_ATTR_KEYS.MASK], strict=True)
comp_indices = np.arange(len(comp_group))
bbox_filter = DataFrameBBoxFilter(
indices=pl.Series(comp_indices),
bboxes=comp_group[DEFAULT_ATTR_KEYS.BBOX],
frames=None,
)

for i, (ref_id, ref_mask, ref_bbox) in enumerate(
zip(
ref_group[reference_graph_key],
ref_group[DEFAULT_ATTR_KEYS.MASK],
ref_group[DEFAULT_ATTR_KEYS.BBOX].to_numpy(),
strict=True,
)
):
for j, (comp_id, comp_mask) in enumerate(
zip(comp_group[input_graph_key], comp_group[DEFAULT_ATTR_KEYS.MASK], strict=True)
):
ndim = ref_bbox.size // 2
overlap_comp_idx = bbox_filter[
tuple(slice(s, e) for s, e in zip(ref_bbox[:ndim], ref_bbox[ndim:], strict=False))
]
if len(overlap_comp_idx) == 0:
continue
comp_track_ids = comp_group[input_graph_key][overlap_comp_idx]
comp_masks = comp_group[DEFAULT_ATTR_KEYS.MASK][overlap_comp_idx]
for j, comp_id, comp_mask in zip(overlap_comp_idx, comp_track_ids, comp_masks, strict=False):
# intersection over reference is used to select the matches
inter = ref_mask.intersection(comp_mask)
ctc_score = inter / ref_mask.size
Expand Down Expand Up @@ -150,7 +169,9 @@ def _matching_data(
("ref", reference_graph, reference_graph_key),
("comp", input_graph, input_graph_key),
]:
nodes_df = graph.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.T, track_id_key, DEFAULT_ATTR_KEYS.MASK])
nodes_df = graph.node_attrs(
attr_keys=[DEFAULT_ATTR_KEYS.T, track_id_key, DEFAULT_ATTR_KEYS.BBOX, DEFAULT_ATTR_KEYS.MASK]
)
if n_workers > 1:
# required by multiprocessing
nodes_df = column_to_bytes(nodes_df, DEFAULT_ATTR_KEYS.MASK)
Expand Down
Loading