@@ -8,10 +8,21 @@ class TrackedObject(Filter):
88 Acts as a simple proxy to the actual filter provided in the initialization
99 """
1010
11- def __init__ (self , id , filter ):
11+ def __init__ (self , id , filter , max_time_to_live ):
1212 self .id = id
1313 self .filter = filter
14- self .time_to_live = 0
14+ self .time_to_live = 1
15+ self .max_time_to_live = max_time_to_live
16+
17+ def increase_time_to_live (self ):
18+ if self .time_to_live < self .max_time_to_live :
19+ self .time_to_live += 1
20+
21+ def decrease_time_to_live (self ):
22+ if self .time_to_live > 0 :
23+ self .time_to_live -= 1
24+ else :
25+ self .time_to_live = 0
1526
1627 def update (self , state ):
1728 self .filter .update (state )
@@ -31,32 +42,35 @@ class Tracker:
3142 """
3243
3344 def __init__ (self , filter_prototype ,
34- time_to_live = 0 ,
45+ max_time_to_live = 1 ,
46+ time_to_birth = 0 ,
3547 distance_threshold = 1.0 ,
3648 distance_function = lambda x1 , x2 : np .linalg .norm (x1 - x2 )):
3749 self .distance_threshold = distance_threshold
38- self .time_to_live = time_to_live
50+ self .max_time_to_live = max_time_to_live
51+ self .time_to_birth = time_to_birth
52+
3953 self .object_counter = 0
4054
4155 self .__distance_function = distance_function
4256 self .__filter_prototype = filter_prototype
4357 self .__tracked_objects = []
4458
4559 def get_tracked_objects (self ):
46- return self .__tracked_objects
60+ return list ( filter ( lambda x : x . time_to_live > self .time_to_birth , self . __tracked_objects ))
4761
4862 def to_numpy_array (self , raw = False ):
4963 """
5064 Returns the tracking id, plus the filtered object state if raw is False
5165 """
5266 m = []
53- for t in self .__tracked_objects :
67+ for t in self .get_tracked_objects () :
5468 if raw :
5569 state = t .raw ()
5670 else :
5771 state = t .eval ()
5872
59- m .append (np .insert (state , 0 , t . id ))
73+ m .append (np .insert (np . array ( t . id , dtype = np . float32 ), 0 , state ))
6074
6175 return np .array (m )
6276
@@ -118,6 +132,7 @@ def update(self, states):
118132 if t in objects_to_match and s in states_to_match :
119133 objects_to_match .remove (t )
120134 states_to_match .remove (s )
135+ self .__tracked_objects [t ].increase_time_to_live ()
121136 self .__tracked_objects [t ].update (states [s ])
122137
123138 ## Delete objects
@@ -127,9 +142,9 @@ def update(self, states):
127142 removals = []
128143 for i in objects_to_match :
129144 tracked_object = self .__tracked_objects [i ]
130- tracked_object .time_to_live += 1
145+ tracked_object .decrease_time_to_live ()
131146
132- if tracked_object .time_to_live > self . time_to_live :
147+ if tracked_object .time_to_live < 1 :
133148 removals .append (tracked_object )
134149 else :
135150 # update the object with the next predicted state
@@ -142,6 +157,10 @@ def update(self, states):
142157 # now go through all unmatched objects and create new objects
143158 for i in states_to_match :
144159 self .object_counter += 1
145- added_object = TrackedObject (self .object_counter , deepcopy (self .__filter_prototype ))
160+ added_object = TrackedObject (
161+ self .object_counter ,
162+ deepcopy (self .__filter_prototype ),
163+ max_time_to_live = self .max_time_to_live
164+ )
146165 added_object .update (states [i ])
147166 self .__tracked_objects .append (added_object )
0 commit comments