1- import numpy as np
1+ import numpy as np
22import uuid
3+ import math
34from scipy .optimize import linear_sum_assignment
45
56from nav2_dynamic_msgs .msg import Obstacle , ObstacleArray
67from visualization_msgs .msg import Marker , MarkerArray
78
89import rclpy
10+ import copy
911from rclpy .node import Node
1012import colorsys
1113from kf_hungarian_tracker .obstacle_class import ObstacleClass
1618from tf2_geometry_msgs import do_transform_point , do_transform_vector3
1719from geometry_msgs .msg import PointStamped , Vector3Stamped
1820
21+
1922class KFHungarianTracker (Node ):
20- ''' Use Kalman Fiter and Hungarian algorithm to track multiple dynamic obstacles
23+ """ Use Kalman Fiter and Hungarian algorithm to track multiple dynamic obstacles
2124
2225 Use Hungarian algorithm to match presenting obstacles with new detection and maintain a kalman filter for each obstacle.
2326 spawn ObstacleClass when new obstacles come and delete when they disappear for certain number of frames
@@ -28,25 +31,27 @@ class KFHungarianTracker(Node):
2831 detection_sub: subscrib detection result from detection node
2932 tracker_obstacle_pub: publish tracking obstacles with ObstacleArray
3033 tracker_pose_pub: publish tracking obstacles with PoseArray, for rviz visualization
31- '''
34+ """
3235
3336 def __init__ (self ):
34- ''' initialize attributes and setup subscriber and publisher'''
37+ """ initialize attributes and setup subscriber and publisher"""
3538
36- super ().__init__ (' kf_hungarian_node' )
39+ super ().__init__ (" kf_hungarian_node" )
3740 self .declare_parameters (
38- namespace = '' ,
41+ namespace = "" ,
3942 parameters = [
40- ('global_frame' , "camera_link" ),
41- ('process_noise_cov' , [2. , 2. , 0.5 ]),
42- ('top_down' , False ),
43- ('death_threshold' , 3 ),
44- ('measurement_noise_cov' , [1. , 1. , 1. ]),
45- ('error_cov_post' , [1. , 1. , 1. , 10. , 10. , 10. ]),
46- ('vel_filter' , [0.1 , 2.0 ]),
47- ('height_filter' , [- 2.0 , 2.0 ]),
48- ('cost_filter' , 1.0 )
49- ])
43+ ("global_frame" , "camera_link" ),
44+ ("process_noise_cov" , [2.0 , 2.0 , 0.5 ]),
45+ ("top_down" , False ),
46+ ("death_threshold" , 3 ),
47+ ("measurement_noise_cov" , [1.0 , 1.0 , 1.0 ]),
48+ ("error_cov_post" , [1.0 , 1.0 , 1.0 , 10.0 , 10.0 , 10.0 ]),
49+ ("vel_filter" , [0.1 , 2.0 ]),
50+ ("height_filter" , [- 2.0 , 2.0 ]),
51+ ("cost_filter" , 1.0 ),
52+ ("transform_to_global_frame" , False ),
53+ ],
54+ )
5055 self .global_frame = self .get_parameter ("global_frame" )._value
5156 self .death_threshold = self .get_parameter ("death_threshold" )._value
5257 self .measurement_noise_cov = self .get_parameter ("measurement_noise_cov" )._value
@@ -56,31 +61,36 @@ def __init__(self):
5661 self .height_filter = self .get_parameter ("height_filter" )._value
5762 self .top_down = self .get_parameter ("top_down" )._value
5863 self .cost_filter = self .get_parameter ("cost_filter" )._value
64+ self .transform_to_global_frame = self .get_parameter (
65+ "transform_to_global_frame"
66+ )._value
5967
6068 self .obstacle_list = []
6169 self .sec = 0
6270 self .nanosec = 0
6371
64- # subscribe to detector
72+ # subscribe to detector
6573 self .detection_sub = self .create_subscription (
66- ObstacleArray ,
67- 'detection' ,
68- self .callback ,
69- 10 )
74+ ObstacleArray , "detection" , self .callback , 10
75+ )
7076
7177 # publisher for tracking result
72- self .tracker_obstacle_pub = self .create_publisher (ObstacleArray , 'tracking' , 10 )
73- self .tracker_marker_pub = self .create_publisher (MarkerArray , 'tracking_marker' , 10 )
78+ self .tracker_obstacle_pub = self .create_publisher (ObstacleArray , "tracking" , 10 )
79+ self .tracker_marker_pub = self .create_publisher (
80+ MarkerArray , "tracking_marker" , 10
81+ )
7482
7583 # setup tf related
7684 self .tf_buffer = Buffer ()
7785 self .tf_listener = TransformListener (self .tf_buffer , self )
7886
7987 def callback (self , msg ):
80- ''' callback function for detection result'''
88+ """ callback function for detection result"""
8189
8290 # update delta time
83- dt = (msg .header .stamp .sec - self .sec ) + (msg .header .stamp .nanosec - self .nanosec ) / 1e9
91+ dt = (msg .header .stamp .sec - self .sec ) + (
92+ msg .header .stamp .nanosec - self .nanosec
93+ ) / 1e9
8494 self .sec = msg .header .stamp .sec
8595 self .nanosec = msg .header .stamp .nanosec
8696
@@ -93,38 +103,43 @@ def callback(self, msg):
93103 for obj in self .obstacle_list :
94104 obj .predict (dt )
95105
96- # transform to global frame
97- if self .global_frame is not None :
98- try :
99- trans = self .tf_buffer .lookup_transform (self .global_frame , msg .header .frame_id , rclpy .time .Time ())
100- msg .header .frame_id = self .global_frame
101- # do_transform_vector3(vector, trans) resets trans.transform.translation
102- # values to 0.0, so we need to preserve them for future usage in the loop below
103- translation_backup_x = trans .transform .translation .x
104- translation_backup_y = trans .transform .translation .y
105- translation_backup_z = trans .transform .translation .z
106- for i in range (len (detections )):
107- trans .transform .translation .x = translation_backup_x
108- trans .transform .translation .y = translation_backup_y
109- trans .transform .translation .z = translation_backup_z
110- # transform position (point)
111- p = PointStamped ()
112- p .point = detections [i ].position
113- detections [i ].position = do_transform_point (p , trans ).point
114- # transform velocity (vector3)
115- v = Vector3Stamped ()
116- v .vector = detections [i ].velocity
117- detections [i ].velocity = do_transform_vector3 (v , trans ).vector
118- # transform size (vector3)
119- s = Vector3Stamped ()
120- s .vector = detections [i ].size
121- detections [i ].size = do_transform_vector3 (s , trans ).vector
122-
123- except TransformException as ex :
124- self .get_logger ().error (
125- 'fail to get tf from {} to {}: {}' .format (
126- msg .header .frame_id , self .global_frame , ex ))
127- return
106+ if self .transform_to_global_frame :
107+ # transform to global frame
108+ if self .global_frame is not None :
109+ try :
110+ trans = self .tf_buffer .lookup_transform (
111+ self .global_frame , msg .header .frame_id , rclpy .time .Time ()
112+ )
113+ msg .header .frame_id = self .global_frame
114+ # do_transform_vector3(vector, trans) resets trans.transform.translation
115+ # values to 0.0, so we need to preserve them for future usage in the loop below
116+ translation_backup_x = trans .transform .translation .x
117+ translation_backup_y = trans .transform .translation .y
118+ translation_backup_z = trans .transform .translation .z
119+ for i in range (len (detections )):
120+ trans .transform .translation .x = translation_backup_x
121+ trans .transform .translation .y = translation_backup_y
122+ trans .transform .translation .z = translation_backup_z
123+ # transform position (point)
124+ p = PointStamped ()
125+ p .point = detections [i ].position
126+ detections [i ].position = do_transform_point (p , trans ).point
127+ # transform velocity (vector3)
128+ v = Vector3Stamped ()
129+ v .vector = detections [i ].velocity
130+ detections [i ].velocity = do_transform_vector3 (v , trans ).vector
131+ # transform size (vector3)
132+ s = Vector3Stamped ()
133+ s .vector = detections [i ].size
134+ detections [i ].size = do_transform_vector3 (s , trans ).vector
135+
136+ except TransformException as ex :
137+ self .get_logger ().error (
138+ "fail to get tf from {} to {}: {}" .format (
139+ msg .header .frame_id , self .global_frame , ex
140+ )
141+ )
142+ return
128143
129144 # hungarian matching
130145 cost = np .zeros ((num_of_obstacle , num_of_detect ))
@@ -154,19 +169,27 @@ def callback(self, msg):
154169 # apply velocity and height filter
155170 filtered_obstacle_list = []
156171 for obs in self .obstacle_list :
157- obs_vel = np .linalg .norm (np .array ([obs .msg .velocity .x , obs .msg .velocity .y , obs .msg .velocity .z ]))
172+ obs_vel = np .linalg .norm (
173+ np .array ([obs .msg .velocity .x , obs .msg .velocity .y , obs .msg .velocity .z ])
174+ )
158175 obs_height = obs .msg .position .z
159- if obs_vel > self .vel_filter [0 ] and obs_vel < self .vel_filter [1 ] and obs_height > self .height_filter [0 ] and obs_height < self .height_filter [1 ]:
176+ if (
177+ obs_vel > self .vel_filter [0 ]
178+ and obs_vel < self .vel_filter [1 ]
179+ and obs_height > self .height_filter [0 ]
180+ and obs_height < self .height_filter [1 ]
181+ ):
160182 filtered_obstacle_list .append (obs )
161183
162184 # construct ObstacleArray
163185 if self .tracker_obstacle_pub .get_subscription_count () > 0 :
164186 obstacle_array = ObstacleArray ()
165187 obstacle_array .header = msg .header
166188 track_list = []
189+
167190 for obs in filtered_obstacle_list :
168- # do not publish obstacles with low speed
169191 track_list .append (obs .msg )
192+
170193 obstacle_array .obstacles = track_list
171194 self .tracker_obstacle_pub .publish (obstacle_array )
172195
@@ -177,49 +200,55 @@ def callback(self, msg):
177200 # add current active obstacles
178201 for obs in filtered_obstacle_list :
179202 obstacle_uuid = uuid .UUID (bytes = bytes (obs .msg .uuid .uuid ))
180- (r , g , b ) = colorsys .hsv_to_rgb (obstacle_uuid .int % 360 / 360. , 1. , 1. ) # encode id with rgb color
181- # make a cube
203+ (r , g , b ) = colorsys .hsv_to_rgb (
204+ obstacle_uuid .int % 360 / 360.0 , 1.0 , 1.0
205+ ) # encode id with rgb color
206+
207+ # make a cube
182208 marker = Marker ()
183209 marker .header = msg .header
184210 marker .ns = str (obstacle_uuid )
185211 marker .id = 0
186- marker .type = 1 # CUBE
212+ marker .type = 1 # CUBE
187213 marker .action = 0
188214 marker .color .a = 0.5
189215 marker .color .r = r
190216 marker .color .g = g
191217 marker .color .b = b
192218 marker .pose .position = obs .msg .position
193219 angle = np .arctan2 (obs .msg .velocity .y , obs .msg .velocity .x )
194- marker .pose .orientation .z = np . float (np .sin (angle / 2 ))
195- marker .pose .orientation .w = np . float (np .cos (angle / 2 ))
220+ marker .pose .orientation .z = 0.0 # float(np.sin(angle / 2))
221+ marker .pose .orientation .w = 1.0 # float(np.cos(angle / 2))
196222 marker .scale = obs .msg .size
197223 marker_list .append (marker )
198224 # make an arrow
199225 arrow = Marker ()
200226 arrow .header = msg .header
201227 arrow .ns = str (obstacle_uuid )
202- arrow .id = 1
228+ arrow .id = 1
203229 arrow .type = 0
204230 arrow .action = 0
205231 arrow .color .a = 1.0
206232 arrow .color .r = r
207233 arrow .color .g = g
208234 arrow .color .b = b
209235 arrow .pose .position = obs .msg .position
210- arrow .pose .orientation .z = np .float (np .sin (angle / 2 ))
211- arrow .pose .orientation .w = np .float (np .cos (angle / 2 ))
212- arrow .scale .x = np .linalg .norm ([obs .msg .velocity .x , obs .msg .velocity .y , obs .msg .velocity .z ])
236+ arrow .pose .orientation .z = float (np .sin (angle / 2 ))
237+ arrow .pose .orientation .w = float (np .cos (angle / 2 ))
238+ arrow .scale .x = np .linalg .norm (
239+ [obs .msg .velocity .x , obs .msg .velocity .y , obs .msg .velocity .z ]
240+ )
213241 arrow .scale .y = 0.05
214242 arrow .scale .z = 0.05
215243 marker_list .append (arrow )
244+
216245 # add dead obstacles to delete in rviz
217246 for dead_uuid in dead_object_list :
218247 marker = Marker ()
219248 marker .header = msg .header
220249 marker .ns = str (dead_uuid )
221250 marker .id = 0
222- marker .action = 2 # delete
251+ marker .action = 2 # delete
223252 arrow = Marker ()
224253 arrow .header = msg .header
225254 arrow .ns = str (dead_uuid )
@@ -231,14 +260,20 @@ def callback(self, msg):
231260 self .tracker_marker_pub .publish (marker_array )
232261
233262 def birth (self , det_ind , num_of_detect , detections ):
234- ''' generate new ObstacleClass for detections that do not match any in current obstacle list'''
263+ """ generate new ObstacleClass for detections that do not match any in current obstacle list"""
235264 for det in range (num_of_detect ):
236265 if det not in det_ind :
237- obstacle = ObstacleClass (detections [det ], self .top_down , self .measurement_noise_cov , self .error_cov_post , self .process_noise_cov )
266+ obstacle = ObstacleClass (
267+ detections [det ],
268+ self .top_down ,
269+ self .measurement_noise_cov ,
270+ self .error_cov_post ,
271+ self .process_noise_cov ,
272+ )
238273 self .obstacle_list .append (obstacle )
239274
240275 def death (self , obj_ind , num_of_obstacle ):
241- ''' count obstacles' missing frames and delete when reach threshold'''
276+ """ count obstacles' missing frames and delete when reach threshold"""
242277 new_object_list = []
243278 dead_object_list = []
244279 # for previous obstacles
@@ -251,16 +286,19 @@ def death(self, obj_ind, num_of_obstacle):
251286 if self .obstacle_list [obs ].dying < self .death_threshold :
252287 new_object_list .append (self .obstacle_list [obs ])
253288 else :
254- obstacle_uuid = uuid .UUID (bytes = bytes (self .obstacle_list [obs ].msg .uuid .uuid ))
289+ obstacle_uuid = uuid .UUID (
290+ bytes = bytes (self .obstacle_list [obs ].msg .uuid .uuid )
291+ )
255292 dead_object_list .append (obstacle_uuid )
256-
293+
257294 # add newly born obstacles
258295 for obs in range (num_of_obstacle , len (self .obstacle_list )):
259296 new_object_list .append (self .obstacle_list [obs ])
260297
261298 self .obstacle_list = new_object_list
262299 return dead_object_list
263300
301+
264302def main (args = None ):
265303 rclpy .init (args = args )
266304
@@ -271,5 +309,6 @@ def main(args=None):
271309
272310 rclpy .shutdown ()
273311
312+
274313if __name__ == "__main__" :
275314 main ()
0 commit comments