-
Notifications
You must be signed in to change notification settings - Fork 49
Description
如下所示:
seed_xyz = end_points['point_clouds'] # use all sampled point cloud, BNs3
B, point_num, _ = seed_xyz.shape # batch _size
coordinates_batch = end_points['coors']
features_batch = end_points['feats']
mink_input = ME.SparseTensor(features_batch, coordinates=coordinates_batch)
seed_features = self.backbone(mink_input).F
seed_features = seed_features[end_points['quantize2original']].view(B, point_num, -1).transpose(1, 2)
end_points = self.graspable(seed_features, end_points)
seed_features_flipped = seed_features.transpose(1, 2) # BNsfeat_dim
objectness_score = end_points['objectness_score']
graspness_score = end_points['graspness_score'].squeeze(1)
objectness_pred = torch.argmax(objectness_score, 1)
objectness_mask = (objectness_pred == 1)
graspness_mask = graspness_score > GRASPNESS_THRESHOLD
graspable_mask = objectness_mask & graspness_mask
seed_features_graspable = []
seed_xyz_graspable = []
graspable_num_batch = 0.
for i in range(B):
cur_mask = graspable_mask[i]
graspable_num_batch += cur_mask.sum()
cur_feat = seed_features_flipped[i][cur_mask] # Nsfeat_dim
cur_seed_xyz = seed_xyz[i][cur_mask] # Ns3
seed_xyz [i]大小是Ns3
cur_seed_xyz = seed_xyz[i][cur_mask]的大小还是Ns3?