11from unittest import TestCase
22
3- from ..simple_filters import Tracker , TrackedObject , Filter , PolynomialFilterStrategy
3+ from ..simple_filters import Tracker , TrackedObject , Filter , PolynomialFilterStrategy , DummyFilterStrategy
44
55import pytest
66import numpy as np
77
88class TestSingleStepSingleObjectTracking (TestCase ):
99
10- def setUp (self ):
11- strategy = PolynomialFilterStrategy ()
12- filter_prototype = Filter (strategy , history_size = 1 )
13- self .tracker = Tracker (filter_prototype , distance_threshold = 1. )
14-
15- def test_main (self ):
10+ def test_new_obj (self ):
11+ strategy = DummyFilterStrategy ()
12+ filter_prototype = Filter (strategy , history_size = 5 )
13+ tracker = Tracker (filter_prototype , distance_threshold = 1. )
14+
1615 states = [
1716 np .array ([1.0 , 1.0 ]),
1817 np .array ([1.5 , 1.5 ]),
@@ -26,8 +25,33 @@ def test_main(self):
2625
2726 for i , (state , expected_tracking_id ) in enumerate (zip (states , expected_tracking_ids )):
2827 print ("timestep" , i )
29- self . tracker .update (state )
30- tracked_state = self . tracker .get_tracked_objects ()
28+ tracker .update (state )
29+ tracked_state = tracker .get_tracked_objects ()
3130
3231 self .assertEqual (1 , len (tracked_state ))
3332 self .assertEqual (tracked_state [0 ].id , expected_tracking_id )
33+
34+ def test_interpolate_object_with_ttl (self ):
35+ strategy = PolynomialFilterStrategy (poly_degree = 1 , reject_outliers = False )
36+ filter_prototype = Filter (strategy , history_size = 3 )
37+ tracker = Tracker (filter_prototype , distance_threshold = 1. , time_to_live = 1 )
38+
39+ states = [
40+ np .array ([[1.0 , 1.0 ]]),
41+ np .array ([[1.5 , 1.5 ]]),
42+ np .array ([[2.0 , 2.0 ]]),
43+ np .array ([]),
44+ np .array ([[3.0 , 3.0 ]]),
45+ np .array ([[3.5 , 3.5 ]]),
46+ np .array ([[4.0 , 4.0 ]])
47+ ]
48+
49+ expected_tracking_ids = [1 , 1 , 1 , 1 , 1 , 1 , 1 ]
50+
51+ for i , (state , expected_tracking_id ) in enumerate (zip (states , expected_tracking_ids )):
52+ print ("timestep" , i )
53+ tracker .update (state )
54+ tracked_state = tracker .get_tracked_objects ()
55+
56+ self .assertEqual (1 , len (tracked_state ))
57+ self .assertEqual (tracked_state [0 ].id , expected_tracking_id )
0 commit comments