diff --git a/src/tracksdata/graph/filters/_spatial_filter.py b/src/tracksdata/graph/filters/_spatial_filter.py index 38052323..bc516c60 100644 --- a/src/tracksdata/graph/filters/_spatial_filter.py +++ b/src/tracksdata/graph/filters/_spatial_filter.py @@ -33,7 +33,7 @@ def __init__( self, indices: pl.Series, df: pl.DataFrame, - ) -> None: + ): from spatial_graph import PointRTree start_time = time.time() @@ -194,6 +194,73 @@ def __getitem__(self, keys: tuple[slice, ...]) -> "BaseFilter": return self._graph.filter(node_ids=node_ids) +class DataFrameBBoxFilter: + # TODO: ... + def __init__( + self, + indices: pl.Series, + bboxes: pl.Series, + frames: pl.DataFrame | None, + ): + from spatial_graph import PointRTree + + _indices = np.ascontiguousarray(indices.to_numpy(), dtype=np.int64).copy() + + if bboxes.is_empty(): + self._node_rtree = None + else: + _bboxes = bboxes.to_numpy() + if _bboxes.shape[1] % 2 != 0: + raise ValueError( + f"Bounding box coordinates must have even number of dimensions, got {_bboxes.shape[1]}" + ) + num_dims = _bboxes.shape[1] // 2 + + if frames is None: + self._ndims = num_dims + positions_min = np.ascontiguousarray(_bboxes[:, :num_dims], dtype=np.float32) + positions_max = np.ascontiguousarray(_bboxes[:, num_dims:], dtype=np.float32) + else: + _frames = frames.to_numpy() + self._ndims = num_dims + 1 # +1 for the frame dimension + positions_min = np.ascontiguousarray( + np.hstack((_frames[:, np.newaxis], _bboxes[:, :num_dims])), dtype=np.float32 + ) + positions_max = np.ascontiguousarray( + np.hstack((_frames[:, np.newaxis], _bboxes[:, num_dims:])), dtype=np.float32 + ) + + self._node_rtree = PointRTree( + item_dtype="int64", + coord_dtype="float32", + dims=self._ndims, + ) + self._node_rtree.insert_bb_items(_indices, positions_min, positions_max) + + def __getitem__(self, keys: tuple[slice, ...]) -> list[int]: + # TODO + if self._node_rtree is None: + return [] + + for key in keys: + if key.start is None or key.stop is None: + raise ValueError(f"Slice {key} must have start and stop") + + if len(keys) != self._ndims: + raise ValueError(f"Expected {self._ndims} keys, got {len(keys)}") + + node_ids = self._node_rtree.search( + *( + np.stack( + [[s.start, s.stop] for s in keys], + axis=1, + dtype=np.float32, + ) + ) + ) + return node_ids.tolist() + + class BBoxSpatialFilter: """ Spatial filter for bounding box queries on graph nodes. @@ -223,44 +290,25 @@ def __init__( frame_attr_key: str | None = DEFAULT_ATTR_KEYS.T, bbox_attr_key: str = DEFAULT_ATTR_KEYS.BBOX, ) -> None: - from spatial_graph import PointRTree - self._graph = graph if frame_attr_key is None: attr_keys = [DEFAULT_ATTR_KEYS.NODE_ID, bbox_attr_key] else: attr_keys = [DEFAULT_ATTR_KEYS.NODE_ID, frame_attr_key, bbox_attr_key] + nodes_df = graph.node_attrs(attr_keys=attr_keys) - node_ids = np.ascontiguousarray(nodes_df[DEFAULT_ATTR_KEYS.NODE_ID].to_numpy(), dtype=np.int64).copy() - if nodes_df.is_empty(): - self._node_rtree = None + if frame_attr_key is None: + frames = None else: - bboxes = nodes_df[bbox_attr_key].to_numpy() - if bboxes.shape[1] % 2 != 0: - raise ValueError(f"Bounding box coordinates must have even number of dimensions, got {bboxes.shape[1]}") - num_dims = bboxes.shape[1] // 2 + frames = nodes_df[frame_attr_key] - if frame_attr_key is None: - self._ndims = num_dims - positions_min = np.ascontiguousarray(bboxes[:, :num_dims], dtype=np.float32) - positions_max = np.ascontiguousarray(bboxes[:, num_dims:], dtype=np.float32) - else: - frames = nodes_df[frame_attr_key].to_numpy() - self._ndims = num_dims + 1 # +1 for the frame dimension - positions_min = np.ascontiguousarray( - np.hstack((frames[:, np.newaxis], bboxes[:, :num_dims])), dtype=np.float32 - ) - positions_max = np.ascontiguousarray( - np.hstack((frames[:, np.newaxis], bboxes[:, num_dims:])), dtype=np.float32 - ) - self._node_rtree = PointRTree( - item_dtype="int64", - coord_dtype="float32", - dims=self._ndims, - ) - self._node_rtree.insert_bb_items(node_ids, positions_min, positions_max) + self._df_filter = DataFrameBBoxFilter( + indices=nodes_df[DEFAULT_ATTR_KEYS.NODE_ID], + bboxes=nodes_df[bbox_attr_key], + frames=frames, + ) def __getitem__(self, keys: tuple[slice, ...]) -> "BaseFilter": """ @@ -300,24 +348,5 @@ def __getitem__(self, keys: tuple[slice, ...]) -> "BaseFilter": subgraph = spatial_filter[0:10, 0:5, 10:50, 20:60].subgraph() ``` """ - - if self._node_rtree is None: - return self._graph.filter(node_ids=[]) - - for key in keys: - if key.start is None or key.stop is None: - raise ValueError(f"Slice {key} must have start and stop") - - if len(keys) != self._ndims: - raise ValueError(f"Expected {self._ndims} keys, got {len(keys)}") - - node_ids = self._node_rtree.search( - *( - np.stack( - [[s.start, s.stop] for s in keys], - axis=1, - dtype=np.float32, - ) - ) - ) + node_ids = self._df_filter[keys] return self._graph.filter(node_ids=node_ids) diff --git a/src/tracksdata/graph/filters/_test/test_spatial_filter.py b/src/tracksdata/graph/filters/_test/test_spatial_filter.py index d31310b3..0c62916d 100644 --- a/src/tracksdata/graph/filters/_test/test_spatial_filter.py +++ b/src/tracksdata/graph/filters/_test/test_spatial_filter.py @@ -173,10 +173,10 @@ def test_bbox_spatial_filter_with_edges() -> None: def test_bbox_spatial_filter_initialization(sample_bbox_graph: RustWorkXGraph) -> None: """Test BoundingBoxSpatialFilter initialization with default and custom attributes.""" spatial_filter = BBoxSpatialFilter(sample_bbox_graph) - assert spatial_filter._node_rtree is not None + assert spatial_filter._df_filter._node_rtree is not None spatial_filter = BBoxSpatialFilter(sample_bbox_graph, frame_attr_key="t", bbox_attr_key="bbox") - assert spatial_filter._node_rtree is not None + assert spatial_filter._df_filter._node_rtree is not None def test_bbox_spatial_filter_querying(sample_bbox_graph: RustWorkXGraph) -> None: diff --git a/src/tracksdata/metrics/_ctc_metrics.py b/src/tracksdata/metrics/_ctc_metrics.py index 5d48b91b..b93c8a65 100644 --- a/src/tracksdata/metrics/_ctc_metrics.py +++ b/src/tracksdata/metrics/_ctc_metrics.py @@ -6,6 +6,7 @@ from toolz import curry from tracksdata.constants import DEFAULT_ATTR_KEYS +from tracksdata.graph.filters._spatial_filter import DataFrameBBoxFilter from tracksdata.io._ctc import compressed_tracks_table from tracksdata.options import get_options from tracksdata.utils._dtypes import column_from_bytes, column_to_bytes @@ -66,12 +67,30 @@ def _match_single_frame( _rows = [] _cols = [] - for i, (ref_id, ref_mask) in enumerate( - zip(ref_group[reference_graph_key], ref_group[DEFAULT_ATTR_KEYS.MASK], strict=True) + comp_indices = np.arange(len(comp_group)) + bbox_filter = DataFrameBBoxFilter( + indices=pl.Series(comp_indices), + bboxes=comp_group[DEFAULT_ATTR_KEYS.BBOX], + frames=None, + ) + + for i, (ref_id, ref_mask, ref_bbox) in enumerate( + zip( + ref_group[reference_graph_key], + ref_group[DEFAULT_ATTR_KEYS.MASK], + ref_group[DEFAULT_ATTR_KEYS.BBOX].to_numpy(), + strict=True, + ) ): - for j, (comp_id, comp_mask) in enumerate( - zip(comp_group[input_graph_key], comp_group[DEFAULT_ATTR_KEYS.MASK], strict=True) - ): + ndim = ref_bbox.size // 2 + overlap_comp_idx = bbox_filter[ + tuple(slice(s, e) for s, e in zip(ref_bbox[:ndim], ref_bbox[ndim:], strict=False)) + ] + if len(overlap_comp_idx) == 0: + continue + comp_track_ids = comp_group[input_graph_key][overlap_comp_idx] + comp_masks = comp_group[DEFAULT_ATTR_KEYS.MASK][overlap_comp_idx] + for j, comp_id, comp_mask in zip(overlap_comp_idx, comp_track_ids, comp_masks, strict=False): # intersection over reference is used to select the matches inter = ref_mask.intersection(comp_mask) ctc_score = inter / ref_mask.size @@ -150,7 +169,9 @@ def _matching_data( ("ref", reference_graph, reference_graph_key), ("comp", input_graph, input_graph_key), ]: - nodes_df = graph.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.T, track_id_key, DEFAULT_ATTR_KEYS.MASK]) + nodes_df = graph.node_attrs( + attr_keys=[DEFAULT_ATTR_KEYS.T, track_id_key, DEFAULT_ATTR_KEYS.BBOX, DEFAULT_ATTR_KEYS.MASK] + ) if n_workers > 1: # required by multiprocessing nodes_df = column_to_bytes(nodes_df, DEFAULT_ATTR_KEYS.MASK)