Skip to content

Commit a976f9a

Browse files
Added position and velocity covariance to Obstacle message
Position and velocity covariance added in detection step of obstacle_class.py Allowed the ability to toggle transforming obstacle into global frame or keeping it sensor frame
1 parent 6eae17e commit a976f9a

File tree

7 files changed

+264
-129
lines changed

7 files changed

+264
-129
lines changed
Lines changed: 114 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
import numpy as np
1+
import numpy as np
22
import uuid
3+
import math
34
from scipy.optimize import linear_sum_assignment
45

56
from nav2_dynamic_msgs.msg import Obstacle, ObstacleArray
67
from visualization_msgs.msg import Marker, MarkerArray
78

89
import rclpy
10+
import copy
911
from rclpy.node import Node
1012
import colorsys
1113
from kf_hungarian_tracker.obstacle_class import ObstacleClass
@@ -16,8 +18,9 @@
1618
from tf2_geometry_msgs import do_transform_point, do_transform_vector3
1719
from geometry_msgs.msg import PointStamped, Vector3Stamped
1820

21+
1922
class KFHungarianTracker(Node):
20-
'''Use Kalman Fiter and Hungarian algorithm to track multiple dynamic obstacles
23+
"""Use Kalman Fiter and Hungarian algorithm to track multiple dynamic obstacles
2124
2225
Use Hungarian algorithm to match presenting obstacles with new detection and maintain a kalman filter for each obstacle.
2326
spawn ObstacleClass when new obstacles come and delete when they disappear for certain number of frames
@@ -28,25 +31,27 @@ class KFHungarianTracker(Node):
2831
detection_sub: subscrib detection result from detection node
2932
tracker_obstacle_pub: publish tracking obstacles with ObstacleArray
3033
tracker_pose_pub: publish tracking obstacles with PoseArray, for rviz visualization
31-
'''
34+
"""
3235

3336
def __init__(self):
34-
'''initialize attributes and setup subscriber and publisher'''
37+
"""initialize attributes and setup subscriber and publisher"""
3538

36-
super().__init__('kf_hungarian_node')
39+
super().__init__("kf_hungarian_node")
3740
self.declare_parameters(
38-
namespace='',
41+
namespace="",
3942
parameters=[
40-
('global_frame', "camera_link"),
41-
('process_noise_cov', [2., 2., 0.5]),
42-
('top_down', False),
43-
('death_threshold', 3),
44-
('measurement_noise_cov', [1., 1., 1.]),
45-
('error_cov_post', [1., 1., 1., 10., 10., 10.]),
46-
('vel_filter', [0.1, 2.0]),
47-
('height_filter', [-2.0, 2.0]),
48-
('cost_filter', 1.0)
49-
])
43+
("global_frame", "camera_link"),
44+
("process_noise_cov", [2.0, 2.0, 0.5]),
45+
("top_down", False),
46+
("death_threshold", 3),
47+
("measurement_noise_cov", [1.0, 1.0, 1.0]),
48+
("error_cov_post", [1.0, 1.0, 1.0, 10.0, 10.0, 10.0]),
49+
("vel_filter", [0.1, 2.0]),
50+
("height_filter", [-2.0, 2.0]),
51+
("cost_filter", 1.0),
52+
("transform_to_global_frame", False),
53+
],
54+
)
5055
self.global_frame = self.get_parameter("global_frame")._value
5156
self.death_threshold = self.get_parameter("death_threshold")._value
5257
self.measurement_noise_cov = self.get_parameter("measurement_noise_cov")._value
@@ -56,31 +61,36 @@ def __init__(self):
5661
self.height_filter = self.get_parameter("height_filter")._value
5762
self.top_down = self.get_parameter("top_down")._value
5863
self.cost_filter = self.get_parameter("cost_filter")._value
64+
self.transform_to_global_frame = self.get_parameter(
65+
"transform_to_global_frame"
66+
)._value
5967

6068
self.obstacle_list = []
6169
self.sec = 0
6270
self.nanosec = 0
6371

64-
# subscribe to detector
72+
# subscribe to detector
6573
self.detection_sub = self.create_subscription(
66-
ObstacleArray,
67-
'detection',
68-
self.callback,
69-
10)
74+
ObstacleArray, "detection", self.callback, 10
75+
)
7076

7177
# publisher for tracking result
72-
self.tracker_obstacle_pub = self.create_publisher(ObstacleArray, 'tracking', 10)
73-
self.tracker_marker_pub = self.create_publisher(MarkerArray, 'tracking_marker', 10)
78+
self.tracker_obstacle_pub = self.create_publisher(ObstacleArray, "tracking", 10)
79+
self.tracker_marker_pub = self.create_publisher(
80+
MarkerArray, "tracking_marker", 10
81+
)
7482

7583
# setup tf related
7684
self.tf_buffer = Buffer()
7785
self.tf_listener = TransformListener(self.tf_buffer, self)
7886

7987
def callback(self, msg):
80-
'''callback function for detection result'''
88+
"""callback function for detection result"""
8189

8290
# update delta time
83-
dt = (msg.header.stamp.sec - self.sec) + (msg.header.stamp.nanosec - self.nanosec) / 1e9
91+
dt = (msg.header.stamp.sec - self.sec) + (
92+
msg.header.stamp.nanosec - self.nanosec
93+
) / 1e9
8494
self.sec = msg.header.stamp.sec
8595
self.nanosec = msg.header.stamp.nanosec
8696

@@ -93,38 +103,43 @@ def callback(self, msg):
93103
for obj in self.obstacle_list:
94104
obj.predict(dt)
95105

96-
# transform to global frame
97-
if self.global_frame is not None:
98-
try:
99-
trans = self.tf_buffer.lookup_transform(self.global_frame, msg.header.frame_id, rclpy.time.Time())
100-
msg.header.frame_id = self.global_frame
101-
# do_transform_vector3(vector, trans) resets trans.transform.translation
102-
# values to 0.0, so we need to preserve them for future usage in the loop below
103-
translation_backup_x = trans.transform.translation.x
104-
translation_backup_y = trans.transform.translation.y
105-
translation_backup_z = trans.transform.translation.z
106-
for i in range(len(detections)):
107-
trans.transform.translation.x = translation_backup_x
108-
trans.transform.translation.y = translation_backup_y
109-
trans.transform.translation.z = translation_backup_z
110-
# transform position (point)
111-
p = PointStamped()
112-
p.point = detections[i].position
113-
detections[i].position = do_transform_point(p, trans).point
114-
# transform velocity (vector3)
115-
v = Vector3Stamped()
116-
v.vector = detections[i].velocity
117-
detections[i].velocity = do_transform_vector3(v, trans).vector
118-
# transform size (vector3)
119-
s = Vector3Stamped()
120-
s.vector = detections[i].size
121-
detections[i].size = do_transform_vector3(s, trans).vector
122-
123-
except TransformException as ex:
124-
self.get_logger().error(
125-
'fail to get tf from {} to {}: {}'.format(
126-
msg.header.frame_id, self.global_frame, ex))
127-
return
106+
if self.transform_to_global_frame:
107+
# transform to global frame
108+
if self.global_frame is not None:
109+
try:
110+
trans = self.tf_buffer.lookup_transform(
111+
self.global_frame, msg.header.frame_id, rclpy.time.Time()
112+
)
113+
msg.header.frame_id = self.global_frame
114+
# do_transform_vector3(vector, trans) resets trans.transform.translation
115+
# values to 0.0, so we need to preserve them for future usage in the loop below
116+
translation_backup_x = trans.transform.translation.x
117+
translation_backup_y = trans.transform.translation.y
118+
translation_backup_z = trans.transform.translation.z
119+
for i in range(len(detections)):
120+
trans.transform.translation.x = translation_backup_x
121+
trans.transform.translation.y = translation_backup_y
122+
trans.transform.translation.z = translation_backup_z
123+
# transform position (point)
124+
p = PointStamped()
125+
p.point = detections[i].position
126+
detections[i].position = do_transform_point(p, trans).point
127+
# transform velocity (vector3)
128+
v = Vector3Stamped()
129+
v.vector = detections[i].velocity
130+
detections[i].velocity = do_transform_vector3(v, trans).vector
131+
# transform size (vector3)
132+
s = Vector3Stamped()
133+
s.vector = detections[i].size
134+
detections[i].size = do_transform_vector3(s, trans).vector
135+
136+
except TransformException as ex:
137+
self.get_logger().error(
138+
"fail to get tf from {} to {}: {}".format(
139+
msg.header.frame_id, self.global_frame, ex
140+
)
141+
)
142+
return
128143

129144
# hungarian matching
130145
cost = np.zeros((num_of_obstacle, num_of_detect))
@@ -154,19 +169,27 @@ def callback(self, msg):
154169
# apply velocity and height filter
155170
filtered_obstacle_list = []
156171
for obs in self.obstacle_list:
157-
obs_vel = np.linalg.norm(np.array([obs.msg.velocity.x, obs.msg.velocity.y, obs.msg.velocity.z]))
172+
obs_vel = np.linalg.norm(
173+
np.array([obs.msg.velocity.x, obs.msg.velocity.y, obs.msg.velocity.z])
174+
)
158175
obs_height = obs.msg.position.z
159-
if obs_vel > self.vel_filter[0] and obs_vel < self.vel_filter[1] and obs_height > self.height_filter[0] and obs_height < self.height_filter[1]:
176+
if (
177+
obs_vel > self.vel_filter[0]
178+
and obs_vel < self.vel_filter[1]
179+
and obs_height > self.height_filter[0]
180+
and obs_height < self.height_filter[1]
181+
):
160182
filtered_obstacle_list.append(obs)
161183

162184
# construct ObstacleArray
163185
if self.tracker_obstacle_pub.get_subscription_count() > 0:
164186
obstacle_array = ObstacleArray()
165187
obstacle_array.header = msg.header
166188
track_list = []
189+
167190
for obs in filtered_obstacle_list:
168-
# do not publish obstacles with low speed
169191
track_list.append(obs.msg)
192+
170193
obstacle_array.obstacles = track_list
171194
self.tracker_obstacle_pub.publish(obstacle_array)
172195

@@ -177,49 +200,55 @@ def callback(self, msg):
177200
# add current active obstacles
178201
for obs in filtered_obstacle_list:
179202
obstacle_uuid = uuid.UUID(bytes=bytes(obs.msg.uuid.uuid))
180-
(r, g, b) = colorsys.hsv_to_rgb(obstacle_uuid.int % 360 / 360., 1., 1.) # encode id with rgb color
181-
# make a cube
203+
(r, g, b) = colorsys.hsv_to_rgb(
204+
obstacle_uuid.int % 360 / 360.0, 1.0, 1.0
205+
) # encode id with rgb color
206+
207+
# make a cube
182208
marker = Marker()
183209
marker.header = msg.header
184210
marker.ns = str(obstacle_uuid)
185211
marker.id = 0
186-
marker.type = 1 # CUBE
212+
marker.type = 1 # CUBE
187213
marker.action = 0
188214
marker.color.a = 0.5
189215
marker.color.r = r
190216
marker.color.g = g
191217
marker.color.b = b
192218
marker.pose.position = obs.msg.position
193219
angle = np.arctan2(obs.msg.velocity.y, obs.msg.velocity.x)
194-
marker.pose.orientation.z = np.float(np.sin(angle / 2))
195-
marker.pose.orientation.w = np.float(np.cos(angle / 2))
220+
marker.pose.orientation.z = 0.0 # float(np.sin(angle / 2))
221+
marker.pose.orientation.w = 1.0 # float(np.cos(angle / 2))
196222
marker.scale = obs.msg.size
197223
marker_list.append(marker)
198224
# make an arrow
199225
arrow = Marker()
200226
arrow.header = msg.header
201227
arrow.ns = str(obstacle_uuid)
202-
arrow.id = 1
228+
arrow.id = 1
203229
arrow.type = 0
204230
arrow.action = 0
205231
arrow.color.a = 1.0
206232
arrow.color.r = r
207233
arrow.color.g = g
208234
arrow.color.b = b
209235
arrow.pose.position = obs.msg.position
210-
arrow.pose.orientation.z = np.float(np.sin(angle / 2))
211-
arrow.pose.orientation.w = np.float(np.cos(angle / 2))
212-
arrow.scale.x = np.linalg.norm([obs.msg.velocity.x, obs.msg.velocity.y, obs.msg.velocity.z])
236+
arrow.pose.orientation.z = float(np.sin(angle / 2))
237+
arrow.pose.orientation.w = float(np.cos(angle / 2))
238+
arrow.scale.x = np.linalg.norm(
239+
[obs.msg.velocity.x, obs.msg.velocity.y, obs.msg.velocity.z]
240+
)
213241
arrow.scale.y = 0.05
214242
arrow.scale.z = 0.05
215243
marker_list.append(arrow)
244+
216245
# add dead obstacles to delete in rviz
217246
for dead_uuid in dead_object_list:
218247
marker = Marker()
219248
marker.header = msg.header
220249
marker.ns = str(dead_uuid)
221250
marker.id = 0
222-
marker.action = 2 # delete
251+
marker.action = 2 # delete
223252
arrow = Marker()
224253
arrow.header = msg.header
225254
arrow.ns = str(dead_uuid)
@@ -231,14 +260,20 @@ def callback(self, msg):
231260
self.tracker_marker_pub.publish(marker_array)
232261

233262
def birth(self, det_ind, num_of_detect, detections):
234-
'''generate new ObstacleClass for detections that do not match any in current obstacle list'''
263+
"""generate new ObstacleClass for detections that do not match any in current obstacle list"""
235264
for det in range(num_of_detect):
236265
if det not in det_ind:
237-
obstacle = ObstacleClass(detections[det], self.top_down, self.measurement_noise_cov, self.error_cov_post, self.process_noise_cov)
266+
obstacle = ObstacleClass(
267+
detections[det],
268+
self.top_down,
269+
self.measurement_noise_cov,
270+
self.error_cov_post,
271+
self.process_noise_cov,
272+
)
238273
self.obstacle_list.append(obstacle)
239274

240275
def death(self, obj_ind, num_of_obstacle):
241-
'''count obstacles' missing frames and delete when reach threshold'''
276+
"""count obstacles' missing frames and delete when reach threshold"""
242277
new_object_list = []
243278
dead_object_list = []
244279
# for previous obstacles
@@ -251,16 +286,19 @@ def death(self, obj_ind, num_of_obstacle):
251286
if self.obstacle_list[obs].dying < self.death_threshold:
252287
new_object_list.append(self.obstacle_list[obs])
253288
else:
254-
obstacle_uuid = uuid.UUID(bytes=bytes(self.obstacle_list[obs].msg.uuid.uuid))
289+
obstacle_uuid = uuid.UUID(
290+
bytes=bytes(self.obstacle_list[obs].msg.uuid.uuid)
291+
)
255292
dead_object_list.append(obstacle_uuid)
256-
293+
257294
# add newly born obstacles
258295
for obs in range(num_of_obstacle, len(self.obstacle_list)):
259296
new_object_list.append(self.obstacle_list[obs])
260297

261298
self.obstacle_list = new_object_list
262299
return dead_object_list
263300

301+
264302
def main(args=None):
265303
rclpy.init(args=args)
266304

@@ -271,5 +309,6 @@ def main(args=None):
271309

272310
rclpy.shutdown()
273311

312+
274313
if __name__ == "__main__":
275314
main()

0 commit comments

Comments
 (0)