@@ -30,10 +30,7 @@ def furthest_point_sample(xyz, npoint):
3030 (B, npoint) tensor containing the set
3131 """
3232 if npoint > xyz .shape [1 ]:
33- raise ValueError (
34- "caanot sample %i points from an input set of %i points"
35- % (npoint , xyz .shape [1 ])
36- )
33+ raise ValueError ("caanot sample %i points from an input set of %i points" % (npoint , xyz .shape [1 ]))
3734 if xyz .is_cuda :
3835 return tpcuda .furthest_point_sampling (xyz , npoint )
3936 else :
@@ -102,13 +99,9 @@ def backward(ctx, grad_out):
10299 idx , weight , m = ctx .three_interpolate_for_backward
103100
104101 if grad_out .is_cuda :
105- grad_features = tpcuda .three_interpolate_grad (
106- grad_out .contiguous (), idx , weight , m
107- )
102+ grad_features = tpcuda .three_interpolate_grad (grad_out .contiguous (), idx , weight , m )
108103 else :
109- grad_features = tpcpu .knn_interpolate_grad (
110- grad_out .contiguous (), idx , weight , m
111- )
104+ grad_features = tpcpu .knn_interpolate_grad (grad_out .contiguous (), idx , weight , m )
112105
113106 return grad_features , None , None
114107
@@ -150,23 +143,17 @@ def grouping_operation(features, idx):
150143 all_idx = idx .reshape (idx .shape [0 ], - 1 )
151144 all_idx = all_idx .unsqueeze (1 ).repeat (1 , features .shape [1 ], 1 )
152145 grouped_features = features .gather (2 , all_idx )
153- return grouped_features .reshape (
154- idx .shape [0 ], features .shape [1 ], idx .shape [1 ], idx .shape [2 ]
155- )
146+ return grouped_features .reshape (idx .shape [0 ], features .shape [1 ], idx .shape [1 ], idx .shape [2 ])
156147
157148
158- def ball_query_dense (
159- radius , nsample , xyz , new_xyz , batch_xyz = None , batch_new_xyz = None , sort = False
160- ):
149+ def ball_query_dense (radius , nsample , xyz , new_xyz , batch_xyz = None , batch_new_xyz = None , sort = False ):
161150 # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
162151 if new_xyz .is_cuda :
163152 if sort :
164153 raise NotImplementedError ("CUDA version does not sort the neighbors" )
165154 ind , dist = tpcuda .ball_query_dense (new_xyz , xyz , radius , nsample )
166155 else :
167- ind , dist = tpcpu .dense_ball_query (
168- new_xyz , xyz , radius , nsample , mode = 0 , sorted = sort
169- )
156+ ind , dist = tpcpu .dense_ball_query (new_xyz , xyz , radius , nsample , mode = 0 , sorted = sort )
170157 return ind , dist
171158
172159
@@ -175,13 +162,9 @@ def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y, sort=False
175162 if x .is_cuda :
176163 if sort :
177164 raise NotImplementedError ("CUDA version does not sort the neighbors" )
178- ind , dist = tpcuda .ball_query_partial_dense (
179- x , y , batch_x , batch_y , radius , nsample
180- )
165+ ind , dist = tpcuda .ball_query_partial_dense (x , y , batch_x , batch_y , radius , nsample )
181166 else :
182- ind , dist = tpcpu .batch_ball_query (
183- x , y , batch_x , batch_y , radius , nsample , mode = 0 , sorted = sort
184- )
167+ ind , dist = tpcpu .batch_ball_query (x , y , batch_x , batch_y , radius , nsample , mode = 0 , sorted = sort )
185168 return ind , dist
186169
187170
@@ -224,9 +207,7 @@ def ball_query(
224207 assert x .size (0 ) == batch_x .size (0 )
225208 assert y .size (0 ) == batch_y .size (0 )
226209 assert x .dim () == 2
227- return ball_query_partial_dense (
228- radius , nsample , x , y , batch_x , batch_y , sort = sort
229- )
210+ return ball_query_partial_dense (radius , nsample , x , y , batch_x , batch_y , sort = sort )
230211
231212 elif mode .lower () == "dense" :
232213 if (batch_x is not None ) or (batch_y is not None ):
@@ -248,9 +229,7 @@ def forward(ctx, xyz1, xyz2):
248229 @staticmethod
249230 def backward (ctx , grad_dist1 , grad_dist2 ):
250231 xyz1 , xyz2 , idx1 , idx2 = ctx .saved_tensors
251- grad_xyz1 , grad_xyz2 = tpcuda .chamfer_dist_grad (
252- xyz1 , xyz2 , idx1 , idx2 , grad_dist1 , grad_dist2
253- )
232+ grad_xyz1 , grad_xyz2 = tpcuda .chamfer_dist_grad (xyz1 , xyz2 , idx1 , idx2 , grad_dist1 , grad_dist2 )
254233 return grad_xyz1 , grad_xyz2
255234
256235
@@ -281,4 +260,3 @@ def chamfer_dist(xyz1, xyz2, ignore_zeros=False):
281260
282261 dist1 , dist2 = ChamferFunction .apply (xyz1 , xyz2 )
283262 return torch .mean (dist1 ) + torch .mean (dist2 )
284-
0 commit comments