|
| 1 | +import numpy as np |
| 2 | +from scipy.optimize import linear_sum_assignment |
| 3 | + |
| 4 | +from nav2_dynamic_msgs.msg import Obstacle, ObstacleArray |
| 5 | +from visualization_msgs.msg import Marker, MarkerArray |
| 6 | + |
| 7 | +import rclpy |
| 8 | +from rclpy.node import Node |
| 9 | +import colorsys |
| 10 | +from kf_hungarian_tracker.obstacle_class import ObstacleClass |
| 11 | + |
| 12 | +from tf2_ros import LookupException |
| 13 | +from tf2_ros.buffer import Buffer |
| 14 | +from tf2_ros.transform_listener import TransformListener |
| 15 | +from tf2_geometry_msgs import do_transform_point, do_transform_vector3 |
| 16 | +from geometry_msgs.msg import PointStamped, Vector3Stamped |
| 17 | + |
| 18 | +class KFHungarianTracker(Node): |
| 19 | + '''Use Kalman Fiter and Hungarian algorithm to track multiple dynamic obstacles |
| 20 | +
|
| 21 | + Use Hungarian algorithm to match presenting obstacles with new detection and maintain a kalman filter for each obstacle. |
| 22 | + spawn ObstacleClass when new obstacles come and delete when they disappear for certain number of frames |
| 23 | +
|
| 24 | + Attributes: |
| 25 | + obstacle_list: a list of ObstacleClass that currently present in the scene |
| 26 | + max_id: the maximum id assigned |
| 27 | + sec, nanosec: timing from sensor msg |
| 28 | + detection_sub: subscrib detection result from detection node |
| 29 | + tracker_obstacle_pub: publish tracking obstacles with ObstacleArray |
| 30 | + tracker_pose_pub: publish tracking obstacles with PoseArray, for rviz visualization |
| 31 | + ''' |
| 32 | + |
| 33 | + def __init__(self): |
| 34 | + '''initialize attributes and setup subscriber and publisher''' |
| 35 | + |
| 36 | + super().__init__('kf_hungarian_node') |
| 37 | + self.declare_parameters( |
| 38 | + namespace='', |
| 39 | + parameters=[ |
| 40 | + ('global_frame', "camera_link"), |
| 41 | + ('process_noise_cov', [2., 2., 0.5]), |
| 42 | + ('top_down', False), |
| 43 | + ('death_threshold', 3), |
| 44 | + ('measurement_noise_cov', [1., 1., 1.]), |
| 45 | + ('error_cov_post', [1., 1., 1., 10., 10., 10.]), |
| 46 | + ('vel_filter', [0.1, 2.0]), |
| 47 | + ('height_filter', [-2.0, 2.0]), |
| 48 | + ('cost_filter', 1.0) |
| 49 | + ]) |
| 50 | + self.global_frame = self.get_parameter("global_frame")._value |
| 51 | + self.death_threshold = self.get_parameter("death_threshold")._value |
| 52 | + self.measurement_noise_cov = self.get_parameter("measurement_noise_cov")._value |
| 53 | + self.error_cov_post = self.get_parameter("error_cov_post")._value |
| 54 | + self.process_noise_cov = self.get_parameter("process_noise_cov")._value |
| 55 | + self.vel_filter = self.get_parameter("vel_filter")._value |
| 56 | + self.height_filter = self.get_parameter("height_filter")._value |
| 57 | + self.top_down = self.get_parameter("top_down")._value |
| 58 | + self.cost_filter = self.get_parameter("cost_filter")._value |
| 59 | + |
| 60 | + self.obstacle_list = [] |
| 61 | + self.max_id = 0 |
| 62 | + self.sec = 0 |
| 63 | + self.nanosec = 0 |
| 64 | + |
| 65 | + # subscribe to detector |
| 66 | + self.detection_sub = self.create_subscription( |
| 67 | + ObstacleArray, |
| 68 | + 'detection', |
| 69 | + self.callback, |
| 70 | + 10) |
| 71 | + |
| 72 | + # publisher for tracking result |
| 73 | + self.tracker_obstacle_pub = self.create_publisher(ObstacleArray, 'tracking', 10) |
| 74 | + self.tracker_marker_pub = self.create_publisher(MarkerArray, 'marker', 10) |
| 75 | + |
| 76 | + # setup tf related |
| 77 | + self.tf_buffer = Buffer() |
| 78 | + self.tf_listener = TransformListener(self.tf_buffer, self) |
| 79 | + |
| 80 | + def callback(self, msg): |
| 81 | + '''callback function for detection result''' |
| 82 | + |
| 83 | + # update delta time |
| 84 | + dt = (msg.header.stamp.sec - self.sec) + (msg.header.stamp.nanosec - self.nanosec) / 1e9 |
| 85 | + self.sec = msg.header.stamp.sec |
| 86 | + self.nanosec = msg.header.stamp.nanosec |
| 87 | + |
| 88 | + # get detection |
| 89 | + detections = msg.obstacles |
| 90 | + num_of_detect = len(detections) |
| 91 | + num_of_obstacle = len(self.obstacle_list) |
| 92 | + |
| 93 | + # kalman predict |
| 94 | + for obj in self.obstacle_list: |
| 95 | + obj.predict(dt) |
| 96 | + |
| 97 | + # transform to global frame |
| 98 | + if self.global_frame is not None: |
| 99 | + try: |
| 100 | + trans = self.tf_buffer.lookup_transform(self.global_frame, msg.header.frame_id, rclpy.time.Time()) |
| 101 | + msg.header.frame_id = self.global_frame |
| 102 | + for i in range(len(detections)): |
| 103 | + # transform position (point) |
| 104 | + p = PointStamped() |
| 105 | + p.point = detections[i].position |
| 106 | + detections[i].position = do_transform_point(p, trans).point |
| 107 | + # transform velocity (vector3) |
| 108 | + v = Vector3Stamped() |
| 109 | + v.vector = detections[i].velocity |
| 110 | + detections[i].velocity = do_transform_vector3(v, trans).vector |
| 111 | + # transform size (vector3) |
| 112 | + s = Vector3Stamped() |
| 113 | + s.vector = detections[i].size |
| 114 | + detections[i].size = do_transform_vector3(s, trans).vector |
| 115 | + |
| 116 | + except LookupException: |
| 117 | + self.get_logger().info('fail to get tf from {} to {}'.format(msg.header.frame_id, self.global_frame)) |
| 118 | + return |
| 119 | + |
| 120 | + # hungarian matching |
| 121 | + cost = np.zeros((num_of_obstacle, num_of_detect)) |
| 122 | + for i in range(num_of_obstacle): |
| 123 | + for j in range(num_of_detect): |
| 124 | + cost[i, j] = self.obstacle_list[i].distance(detections[j]) |
| 125 | + obs_ind, det_ind = linear_sum_assignment(cost) |
| 126 | + |
| 127 | + # filter assignment according to cost |
| 128 | + new_obs_ind = [] |
| 129 | + new_det_ind = [] |
| 130 | + for o, d in zip(obs_ind, det_ind): |
| 131 | + if cost[o, d] < self.cost_filter: |
| 132 | + new_obs_ind.append(o) |
| 133 | + new_det_ind.append(d) |
| 134 | + obs_ind = new_obs_ind |
| 135 | + det_ind = new_det_ind |
| 136 | + |
| 137 | + # kalman update |
| 138 | + for o, d in zip(obs_ind, det_ind): |
| 139 | + self.obstacle_list[o].correct(detections[d]) |
| 140 | + |
| 141 | + # birth of new detection obstacles and death of disappear obstacle |
| 142 | + self.birth(det_ind, num_of_detect, detections) |
| 143 | + dead_object_list = self.death(obs_ind, num_of_obstacle) |
| 144 | + |
| 145 | + # apply velocity and height filter |
| 146 | + filtered_obstacle_list = [] |
| 147 | + for obs in self.obstacle_list: |
| 148 | + obs_vel = np.linalg.norm(np.array([obs.msg.velocity.x, obs.msg.velocity.y, obs.msg.velocity.z])) |
| 149 | + obs_height = obs.msg.position.z |
| 150 | + if obs_vel > self.vel_filter[0] and obs_vel < self.vel_filter[1] and obs_height > self.height_filter[0] and obs_height < self.height_filter[1]: |
| 151 | + filtered_obstacle_list.append(obs) |
| 152 | + |
| 153 | + # construct ObstacleArray |
| 154 | + if self.tracker_obstacle_pub.get_subscription_count() > 0: |
| 155 | + obstacle_array = ObstacleArray() |
| 156 | + obstacle_array.header = msg.header |
| 157 | + track_list = [] |
| 158 | + for obs in filtered_obstacle_list: |
| 159 | + # do not publish obstacles with low speed |
| 160 | + track_list.append(obs.msg) |
| 161 | + obstacle_array.obstacles = track_list |
| 162 | + self.tracker_obstacle_pub.publish(obstacle_array) |
| 163 | + |
| 164 | + # rviz visualization |
| 165 | + if self.tracker_marker_pub.get_subscription_count() > 0: |
| 166 | + marker_array = MarkerArray() |
| 167 | + marker_list = [] |
| 168 | + # add current active obstacles |
| 169 | + for obs in filtered_obstacle_list: |
| 170 | + (r, g, b) = colorsys.hsv_to_rgb(obs.msg.id * 30. % 360 / 360., 1., 1.) # encode id with rgb color |
| 171 | + # make a cube |
| 172 | + marker = Marker() |
| 173 | + marker.header = msg.header |
| 174 | + marker.id = obs.msg.id |
| 175 | + marker.type = 1 # CUBE |
| 176 | + marker.action = 0 |
| 177 | + marker.color.a = 0.5 |
| 178 | + marker.color.r = r |
| 179 | + marker.color.g = g |
| 180 | + marker.color.b = b |
| 181 | + marker.pose.position = obs.msg.position |
| 182 | + angle = np.arctan2(obs.msg.velocity.y, obs.msg.velocity.x) |
| 183 | + marker.pose.orientation.z = np.float(np.sin(angle / 2)) |
| 184 | + marker.pose.orientation.w = np.float(np.cos(angle / 2)) |
| 185 | + marker.scale = obs.msg.size |
| 186 | + marker_list.append(marker) |
| 187 | + # make an arrow |
| 188 | + arrow = Marker() |
| 189 | + arrow.header = msg.header |
| 190 | + arrow.id = 255 - obs.msg.id |
| 191 | + arrow.type = 0 |
| 192 | + arrow.action = 0 |
| 193 | + arrow.color.a = 1.0 |
| 194 | + arrow.color.r = r |
| 195 | + arrow.color.g = g |
| 196 | + arrow.color.b = b |
| 197 | + arrow.pose.position = obs.msg.position |
| 198 | + arrow.pose.orientation.z = np.float(np.sin(angle / 2)) |
| 199 | + arrow.pose.orientation.w = np.float(np.cos(angle / 2)) |
| 200 | + arrow.scale.x = np.linalg.norm([obs.msg.velocity.x, obs.msg.velocity.y, obs.msg.velocity.z]) |
| 201 | + arrow.scale.y = 0.05 |
| 202 | + arrow.scale.z = 0.05 |
| 203 | + marker_list.append(arrow) |
| 204 | + # add dead obstacles to delete in rviz |
| 205 | + for idx in dead_object_list: |
| 206 | + marker = Marker() |
| 207 | + marker.header = msg.header |
| 208 | + marker.id = idx |
| 209 | + marker.action = 2 # delete |
| 210 | + arrow = Marker() |
| 211 | + arrow.header = msg.header |
| 212 | + arrow.id = 255 - idx |
| 213 | + arrow.action = 2 |
| 214 | + marker_list.append(marker) |
| 215 | + marker_list.append(arrow) |
| 216 | + |
| 217 | + marker_array.markers = marker_list |
| 218 | + self.tracker_marker_pub.publish(marker_array) |
| 219 | + |
| 220 | + def birth(self, det_ind, num_of_detect, detections): |
| 221 | + '''generate new ObstacleClass for detections that do not match any in current obstacle list''' |
| 222 | + for det in range(num_of_detect): |
| 223 | + if det not in det_ind: |
| 224 | + self.obstacle_list.append(ObstacleClass(detections[det], self.max_id, self.top_down, self.measurement_noise_cov, self.error_cov_post, self.process_noise_cov)) |
| 225 | + self.max_id = self.max_id + 1 |
| 226 | + |
| 227 | + def death(self, obj_ind, num_of_obstacle): |
| 228 | + '''count obstacles' missing frames and delete when reach threshold''' |
| 229 | + new_object_list = [] |
| 230 | + dead_object_list = [] |
| 231 | + # for previous obstacles |
| 232 | + for obs in range(num_of_obstacle): |
| 233 | + if obs not in obj_ind: |
| 234 | + self.obstacle_list[obs].dying += 1 |
| 235 | + else: |
| 236 | + self.obstacle_list[obs].dying = 0 |
| 237 | + |
| 238 | + if self.obstacle_list[obs].dying < self.death_threshold: |
| 239 | + new_object_list.append(self.obstacle_list[obs]) |
| 240 | + else: |
| 241 | + dead_object_list.append(self.obstacle_list[obs].msg.id) |
| 242 | + |
| 243 | + # add newly born obstacles |
| 244 | + for obs in range(num_of_obstacle, len(self.obstacle_list)): |
| 245 | + new_object_list.append(self.obstacle_list[obs]) |
| 246 | + |
| 247 | + self.obstacle_list = new_object_list |
| 248 | + return dead_object_list |
| 249 | + |
| 250 | +def main(args=None): |
| 251 | + rclpy.init(args=args) |
| 252 | + |
| 253 | + node = KFHungarianTracker() |
| 254 | + node.get_logger().info("start spining tracker node...") |
| 255 | + |
| 256 | + rclpy.spin(node) |
| 257 | + |
| 258 | + rclpy.shutdown() |
| 259 | + |
| 260 | +if __name__ == "__main__": |
| 261 | + main() |
0 commit comments