Skip to content

Commit 9b02e96

Browse files
major tracking improvements
1 parent 529e3eb commit 9b02e96

File tree

4 files changed

+51
-16
lines changed

4 files changed

+51
-16
lines changed

simple_filters/polynomial_filter_strategy.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@ class PolynomialFilterStrategy(FilterStrategy):
1111
to the median, multiplied by the outlier_rejection_ratio
1212
"""
1313

14-
def __init__(self, poly_degree=3, reject_outliers=True, outlier_rejection_ratio=2.0):
14+
def __init__(self, poly_degree=3, reject_outliers=True, outlier_rejection_ratio=2.0, filter_weight=1.0, max_items=None):
1515
super().__init__()
1616

1717
self.poly_degree = poly_degree
1818
self.reject_outliers = reject_outliers
1919
self.outlier_rejection_ratio = outlier_rejection_ratio
20+
self.max_items = max_items
2021
self.history = None
22+
self.filter_weight = filter_weight
2123

2224
self.__poly_fn = None
2325

@@ -33,16 +35,28 @@ def eval(self, time=0):
3335
history_size = self.history.shape[0]
3436
offset_time = history_size + time - 1
3537

38+
# for debugging purposes
39+
if self.poly_degree == 0:
40+
return self.history[history_size - 1]
41+
3642
# in the case that the equation is underdetermined, we cannot predict a polynomial
3743
# simply return the last state in the history
3844
if history_size < self.poly_degree + 1:
3945
return self.history[history_size - 1]
40-
46+
4147
# if the polynomial functions are not existent, calculate them
4248
if self.__poly_fn is None:
4349
self.__update_polynomials()
4450

45-
return self.__eval_polynomials(offset_time)
51+
predictions = self.__eval_polynomials(offset_time)
52+
53+
# finally applying a weight to the prediction
54+
if time <= 0:
55+
result = (self.filter_weight * predictions) + ((1 - self.filter_weight) * self.history[offset_time])
56+
else:
57+
result = predictions
58+
59+
return result
4660

4761
def __eval_polynomials(self, t):
4862
length = self.history.shape[1]
@@ -69,6 +83,7 @@ def __calc_polynomial(self, x):
6983
rel_delta = delta / np.median(delta)
7084

7185
mask = rel_delta < self.outlier_rejection_ratio
86+
7287
x = x[mask]
7388
y = y[mask]
7489

simple_filters/tracker.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,21 @@ class TrackedObject(Filter):
88
Acts as a simple proxy to the actual filter provided in the initialization
99
"""
1010

11-
def __init__(self, id, filter):
11+
def __init__(self, id, filter, max_time_to_live):
1212
self.id = id
1313
self.filter = filter
14-
self.time_to_live = 0
14+
self.time_to_live = 1
15+
self.max_time_to_live = max_time_to_live
16+
17+
def increase_time_to_live(self):
18+
if self.time_to_live < self.max_time_to_live:
19+
self.time_to_live += 1
20+
21+
def decrease_time_to_live(self):
22+
if self.time_to_live > 0:
23+
self.time_to_live -= 1
24+
else:
25+
self.time_to_live = 0
1526

1627
def update(self, state):
1728
self.filter.update(state)
@@ -31,32 +42,35 @@ class Tracker:
3142
"""
3243

3344
def __init__(self, filter_prototype,
34-
time_to_live=0,
45+
max_time_to_live=1,
46+
time_to_birth=0,
3547
distance_threshold=1.0,
3648
distance_function=lambda x1, x2: np.linalg.norm(x1 - x2)):
3749
self.distance_threshold = distance_threshold
38-
self.time_to_live = time_to_live
50+
self.max_time_to_live = max_time_to_live
51+
self.time_to_birth = time_to_birth
52+
3953
self.object_counter = 0
4054

4155
self.__distance_function = distance_function
4256
self.__filter_prototype = filter_prototype
4357
self.__tracked_objects = []
4458

4559
def get_tracked_objects(self):
46-
return self.__tracked_objects
60+
return list(filter(lambda x: x.time_to_live > self.time_to_birth, self.__tracked_objects))
4761

4862
def to_numpy_array(self, raw=False):
4963
"""
5064
Returns the tracking id, plus the filtered object state if raw is False
5165
"""
5266
m = []
53-
for t in self.__tracked_objects:
67+
for t in self.get_tracked_objects():
5468
if raw:
5569
state = t.raw()
5670
else:
5771
state = t.eval()
5872

59-
m.append(np.insert(state, 0, t.id))
73+
m.append(np.insert(np.array(t.id, dtype=np.float32), 0, state))
6074

6175
return np.array(m)
6276

@@ -118,6 +132,7 @@ def update(self, states):
118132
if t in objects_to_match and s in states_to_match:
119133
objects_to_match.remove(t)
120134
states_to_match.remove(s)
135+
self.__tracked_objects[t].increase_time_to_live()
121136
self.__tracked_objects[t].update(states[s])
122137

123138
## Delete objects
@@ -127,9 +142,9 @@ def update(self, states):
127142
removals = []
128143
for i in objects_to_match:
129144
tracked_object = self.__tracked_objects[i]
130-
tracked_object.time_to_live += 1
145+
tracked_object.decrease_time_to_live()
131146

132-
if tracked_object.time_to_live > self.time_to_live:
147+
if tracked_object.time_to_live < 1:
133148
removals.append(tracked_object)
134149
else:
135150
# update the object with the next predicted state
@@ -142,6 +157,10 @@ def update(self, states):
142157
# now go through all unmatched objects and create new objects
143158
for i in states_to_match:
144159
self.object_counter += 1
145-
added_object = TrackedObject(self.object_counter, deepcopy(self.__filter_prototype))
160+
added_object = TrackedObject(
161+
self.object_counter,
162+
deepcopy(self.__filter_prototype),
163+
max_time_to_live=self.max_time_to_live
164+
)
146165
added_object.update(states[i])
147166
self.__tracked_objects.append(added_object)

tests/test_single_object_tracking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_new_obj(self):
3434
def test_interpolate_object_with_ttl(self):
3535
strategy = PolynomialFilterStrategy(poly_degree=1, reject_outliers=False)
3636
filter_prototype = Filter(strategy, history_size=3)
37-
tracker = Tracker(filter_prototype, distance_threshold=1., time_to_live=1)
37+
tracker = Tracker(filter_prototype, distance_threshold=1., max_time_to_live=2)
3838

3939
states = [
4040
np.array([[1.0, 1.0]]),

tests/test_tracker.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,16 @@ def test_tracker_delete_objects(self):
2525
self.static_update_and_assert(1, 1, 2)
2626

2727
def test_tracker_delete_objects_time_to_live(self):
28-
self.tracker.time_to_live = 1
28+
self.tracker.max_time_to_live = 2
2929

3030
self.static_update_and_assert(2, 2, 2)
31+
self.static_update_and_assert(2, 2, 2) # increment ttl counter to 2
3132
self.static_update_and_assert(1, 2, 2) # object should be retained, even if it doesn't appear
3233
self.static_update_and_assert(1, 1, 2) # object should be removed after this update
3334

3435
def test_tracker_mapping(self):
3536
# TODO: This assumes that the order is retained, but makes it easier for testing
36-
reference_matrix = np.array([[1., 1., 2.], [2., 2., 3.]])
37+
reference_matrix = np.array([[1., 2., 1.], [2., 3., 2.]])
3738

3839
self.tracker.update(self.generate_static_states(2))
3940
self.assertTrue((self.tracker.to_numpy_array() == reference_matrix).all())

0 commit comments

Comments
 (0)