Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,11 +669,67 @@ def get_bbox_minmax(point_cloud):
return (min_point, max_point)


def joint_optimize(surf_ncs, edge_ncs, surfPos, unique_vertices, EdgeVertexAdj, FaceEdgeAdj, num_edge, num_surf):
def ChamferDistance_batch(X, Y, bidirectional=False, reverse=False, batch_reduction: Optional[str] = "mean",
point_reduction: Optional[str] = "sum")->torch.Tensor:
"""
Args:
- X: Tensor of shape (B, N, d) representing a batch of point clouds.
- Y: Tensor of shape (B, M, d) representing a batch of point clouds.
- bidirectional: If True, compute the Chamfer distance in both directions and average the results.
- reverse: If False, the Chamfer distance is computed based on the nearest neighbor point in y \in Y for each point in x \in X; and vice versa if True.
- batch_reduction: Method used to reduce the distance between points in a batch. Can be "sum" or "mean".
- point_reduction: Method used to reduce the distance between points in a point cloud. Can be "sum" or "mean".
"""
xx = torch.bmm(X, X.transpose(2, 1)) # [b, N, N]
yy = torch.bmm(Y, Y.transpose(2, 1)) # [b, M, M]
zz = torch.bmm(X, Y.transpose(2, 1)) # [b, N, M]
diag_ind = torch.arange(0, X.size()[1]).to(X).long()
diag_ind_2 = torch.arange(0, Y.size()[1]).to(X).long()
rx = xx[:, diag_ind, diag_ind].unsqueeze(2).expand_as(zz) # [b, N] -> [b, N, 1] -> [b, N, M]
ry = yy[:, diag_ind_2, diag_ind_2].unsqueeze(1).expand_as(zz) # [b, N] -> [b, 1, N] -> [b, N, M]
P = (rx + ry - 2 * zz) # [b, N, M]

if reverse:
P_back = P.transpose(1, 2)

P = P.min(2)[0] # [b, N]
if reverse:
P_back = P_back.min(2)[0] # [b, M]
if point_reduction == "sum":
P = P.sum(1) # [b]
if reverse:
P_back = P_back.sum(1) # [b]
elif point_reduction == "mean":
P = P.mean(1)
if reverse:
P_back = P_back.mean(1)
else:
raise ValueError("Invalid point reduction")

if batch_reduction == "sum":
P = P.sum()
if reverse:
P_back = P_back.sum()
elif batch_reduction == "mean":
P = P.mean()
if reverse:
P_back = P_back.mean()
else:
raise ValueError("Invalid batch reduction")

if bidirectional:
return P + P_back
elif reverse:
return P_back
else:
return P


def joint_optimize(surf_ncs, edge_ncs, surfPos, unique_vertices, EdgeVertexAdj, FaceEdgeAdj, num_edge, num_surf, use_local_cd=False):
"""
Jointly optimize the face/edge/vertex based on topology
"""
loss_func = ChamferDistance()
loss_func = ChamferDistance() if not use_local_cd else ChamferDistance_batch

model = STModel(num_edge, num_surf)
model = model.cuda().train()
Expand Down