Skip to content

Commit 290d36e

Browse files
added filter + tracking tests
1 parent ef3d7d7 commit 290d36e

File tree

2 files changed

+40
-13
lines changed

2 files changed

+40
-13
lines changed

simple_filters/tracker.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,17 @@ def update(self, states):
6464

6565
states = np.array(states)
6666

67-
# check if array is 2d, otherwise make it so
67+
# check if the states array is 2d, otherwise make it so
6868
if len(states.shape) == 1:
6969
states = np.array([states])
7070

71-
# set initial properties
72-
number_of_states = states.shape[0]
71+
# check if the states array is empty
72+
if states.size == 0:
73+
number_of_states = 0
74+
else:
75+
number_of_states = states.shape[0]
76+
7377
number_of_tracked_objects = len(self.__tracked_objects)
74-
7578
objects_to_match = [i for i in range(0, number_of_tracked_objects)]
7679
states_to_match = [i for i in range(0, number_of_states)]
7780

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
from unittest import TestCase
22

3-
from ..simple_filters import Tracker, TrackedObject, Filter, PolynomialFilterStrategy
3+
from ..simple_filters import Tracker, TrackedObject, Filter, PolynomialFilterStrategy, DummyFilterStrategy
44

55
import pytest
66
import numpy as np
77

88
class TestSingleStepSingleObjectTracking(TestCase):
99

10-
def setUp(self):
11-
strategy = PolynomialFilterStrategy()
12-
filter_prototype = Filter(strategy, history_size=1)
13-
self.tracker = Tracker(filter_prototype, distance_threshold=1.)
14-
15-
def test_main(self):
10+
def test_new_obj(self):
11+
strategy = DummyFilterStrategy()
12+
filter_prototype = Filter(strategy, history_size=5)
13+
tracker = Tracker(filter_prototype, distance_threshold=1.)
14+
1615
states = [
1716
np.array([1.0, 1.0]),
1817
np.array([1.5, 1.5]),
@@ -26,8 +25,33 @@ def test_main(self):
2625

2726
for i, (state, expected_tracking_id) in enumerate(zip(states, expected_tracking_ids)):
2827
print("timestep", i)
29-
self.tracker.update(state)
30-
tracked_state = self.tracker.get_tracked_objects()
28+
tracker.update(state)
29+
tracked_state = tracker.get_tracked_objects()
3130

3231
self.assertEqual(1, len(tracked_state))
3332
self.assertEqual(tracked_state[0].id, expected_tracking_id)
33+
34+
def test_interpolate_object_with_ttl(self):
35+
strategy = PolynomialFilterStrategy(poly_degree=1, reject_outliers=False)
36+
filter_prototype = Filter(strategy, history_size=3)
37+
tracker = Tracker(filter_prototype, distance_threshold=1., time_to_live=1)
38+
39+
states = [
40+
np.array([[1.0, 1.0]]),
41+
np.array([[1.5, 1.5]]),
42+
np.array([[2.0, 2.0]]),
43+
np.array([]),
44+
np.array([[3.0, 3.0]]),
45+
np.array([[3.5, 3.5]]),
46+
np.array([[4.0, 4.0]])
47+
]
48+
49+
expected_tracking_ids = [1, 1, 1, 1, 1, 1, 1]
50+
51+
for i, (state, expected_tracking_id) in enumerate(zip(states, expected_tracking_ids)):
52+
print("timestep", i)
53+
tracker.update(state)
54+
tracked_state = tracker.get_tracked_objects()
55+
56+
self.assertEqual(1, len(tracked_state))
57+
self.assertEqual(tracked_state[0].id, expected_tracking_id)

0 commit comments

Comments
 (0)