diff --git a/utils.py b/utils.py index 4b76ccd..b3c431a 100644 --- a/utils.py +++ b/utils.py @@ -669,11 +669,67 @@ def get_bbox_minmax(point_cloud): return (min_point, max_point) -def joint_optimize(surf_ncs, edge_ncs, surfPos, unique_vertices, EdgeVertexAdj, FaceEdgeAdj, num_edge, num_surf): +def ChamferDistance_batch(X, Y, bidirectional=False, reverse=False, batch_reduction: Optional[str] = "mean", + point_reduction: Optional[str] = "sum")->torch.Tensor: + """ + Args: + - X: Tensor of shape (B, N, d) representing a batch of point clouds. + - Y: Tensor of shape (B, M, d) representing a batch of point clouds. + - bidirectional: If True, compute the Chamfer distance in both directions and average the results. + - reverse: If False, the Chamfer distance is computed based on the nearest neighbor point in y \in Y for each point in x \in X; and vice versa if True. + - batch_reduction: Method used to reduce the distance between points in a batch. Can be "sum" or "mean". + - point_reduction: Method used to reduce the distance between points in a point cloud. Can be "sum" or "mean". + """ + xx = torch.bmm(X, X.transpose(2, 1)) # [b, N, N] + yy = torch.bmm(Y, Y.transpose(2, 1)) # [b, M, M] + zz = torch.bmm(X, Y.transpose(2, 1)) # [b, N, M] + diag_ind = torch.arange(0, X.size()[1]).to(X).long() + diag_ind_2 = torch.arange(0, Y.size()[1]).to(X).long() + rx = xx[:, diag_ind, diag_ind].unsqueeze(2).expand_as(zz) # [b, N] -> [b, N, 1] -> [b, N, M] + ry = yy[:, diag_ind_2, diag_ind_2].unsqueeze(1).expand_as(zz) # [b, N] -> [b, 1, N] -> [b, N, M] + P = (rx + ry - 2 * zz) # [b, N, M] + + if reverse: + P_back = P.transpose(1, 2) + + P = P.min(2)[0] # [b, N] + if reverse: + P_back = P_back.min(2)[0] # [b, M] + if point_reduction == "sum": + P = P.sum(1) # [b] + if reverse: + P_back = P_back.sum(1) # [b] + elif point_reduction == "mean": + P = P.mean(1) + if reverse: + P_back = P_back.mean(1) + else: + raise ValueError("Invalid point reduction") + + if batch_reduction == "sum": + P = P.sum() + if reverse: + P_back = P_back.sum() + elif batch_reduction == "mean": + P = P.mean() + if reverse: + P_back = P_back.mean() + else: + raise ValueError("Invalid batch reduction") + + if bidirectional: + return P + P_back + elif reverse: + return P_back + else: + return P + + +def joint_optimize(surf_ncs, edge_ncs, surfPos, unique_vertices, EdgeVertexAdj, FaceEdgeAdj, num_edge, num_surf, use_local_cd=False): """ Jointly optimize the face/edge/vertex based on topology """ - loss_func = ChamferDistance() + loss_func = ChamferDistance() if not use_local_cd else ChamferDistance_batch model = STModel(num_edge, num_surf) model = model.cuda().train()