Skip to content

Commit 6bbeb91

Browse files
committed
Reformat code with black.
1 parent be5d95e commit 6bbeb91

File tree

2 files changed

+14
-42
lines changed

2 files changed

+14
-42
lines changed

torch_points_kernels/cubic_feature_sampling.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,12 @@ class CubicFeatureSamplingFunction(torch.autograd.Function):
99
def forward(ctx, ptcloud, cubic_features, neighborhood_size=1):
1010
scale = cubic_features.size(2)
1111
if not torch.cuda.is_available():
12-
raise NotImplementedError(
13-
"CPU version is not available for Cubic Feature Sampling"
14-
)
12+
raise NotImplementedError("CPU version is not available for Cubic Feature Sampling")
1513

1614
point_features, grid_pt_indexes = tpcuda.cubic_feature_sampling(
1715
scale, neighborhood_size, ptcloud, cubic_features
1816
)
19-
ctx.save_for_backward(
20-
torch.Tensor([scale]), torch.Tensor([neighborhood_size]), grid_pt_indexes
21-
)
17+
ctx.save_for_backward(torch.Tensor([scale]), torch.Tensor([neighborhood_size]), grid_pt_indexes)
2218
return point_features
2319

2420
@staticmethod
@@ -56,10 +52,8 @@ def cubic_feature_sampling(ptcloud, cubic_features, neighborhood_size=1):
5652
(B, n_pts, n_vertices, c), where n_vertices = (neighborhood_size * 2)^3
5753
"""
5854
if len(ptcloud.shape) != 3 or ptcloud.shape[2] != 3:
59-
raise ValueError('The input point cloud should be of size (B, n_pts, 3).')
55+
raise ValueError("The input point cloud should be of size (B, n_pts, 3).")
6056

6157
h_scale = cubic_features.size(2) / 2
6258
ptcloud = ptcloud * h_scale + h_scale
63-
return CubicFeatureSamplingFunction.apply(
64-
ptcloud, cubic_features, neighborhood_size
65-
)
59+
return CubicFeatureSamplingFunction.apply(ptcloud, cubic_features, neighborhood_size)

torch_points_kernels/torchpoints.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,7 @@ def furthest_point_sample(xyz, npoint):
3030
(B, npoint) tensor containing the set
3131
"""
3232
if npoint > xyz.shape[1]:
33-
raise ValueError(
34-
"caanot sample %i points from an input set of %i points"
35-
% (npoint, xyz.shape[1])
36-
)
33+
raise ValueError("caanot sample %i points from an input set of %i points" % (npoint, xyz.shape[1]))
3734
if xyz.is_cuda:
3835
return tpcuda.furthest_point_sampling(xyz, npoint)
3936
else:
@@ -102,13 +99,9 @@ def backward(ctx, grad_out):
10299
idx, weight, m = ctx.three_interpolate_for_backward
103100

104101
if grad_out.is_cuda:
105-
grad_features = tpcuda.three_interpolate_grad(
106-
grad_out.contiguous(), idx, weight, m
107-
)
102+
grad_features = tpcuda.three_interpolate_grad(grad_out.contiguous(), idx, weight, m)
108103
else:
109-
grad_features = tpcpu.knn_interpolate_grad(
110-
grad_out.contiguous(), idx, weight, m
111-
)
104+
grad_features = tpcpu.knn_interpolate_grad(grad_out.contiguous(), idx, weight, m)
112105

113106
return grad_features, None, None
114107

@@ -150,23 +143,17 @@ def grouping_operation(features, idx):
150143
all_idx = idx.reshape(idx.shape[0], -1)
151144
all_idx = all_idx.unsqueeze(1).repeat(1, features.shape[1], 1)
152145
grouped_features = features.gather(2, all_idx)
153-
return grouped_features.reshape(
154-
idx.shape[0], features.shape[1], idx.shape[1], idx.shape[2]
155-
)
146+
return grouped_features.reshape(idx.shape[0], features.shape[1], idx.shape[1], idx.shape[2])
156147

157148

158-
def ball_query_dense(
159-
radius, nsample, xyz, new_xyz, batch_xyz=None, batch_new_xyz=None, sort=False
160-
):
149+
def ball_query_dense(radius, nsample, xyz, new_xyz, batch_xyz=None, batch_new_xyz=None, sort=False):
161150
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
162151
if new_xyz.is_cuda:
163152
if sort:
164153
raise NotImplementedError("CUDA version does not sort the neighbors")
165154
ind, dist = tpcuda.ball_query_dense(new_xyz, xyz, radius, nsample)
166155
else:
167-
ind, dist = tpcpu.dense_ball_query(
168-
new_xyz, xyz, radius, nsample, mode=0, sorted=sort
169-
)
156+
ind, dist = tpcpu.dense_ball_query(new_xyz, xyz, radius, nsample, mode=0, sorted=sort)
170157
return ind, dist
171158

172159

@@ -175,13 +162,9 @@ def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y, sort=False
175162
if x.is_cuda:
176163
if sort:
177164
raise NotImplementedError("CUDA version does not sort the neighbors")
178-
ind, dist = tpcuda.ball_query_partial_dense(
179-
x, y, batch_x, batch_y, radius, nsample
180-
)
165+
ind, dist = tpcuda.ball_query_partial_dense(x, y, batch_x, batch_y, radius, nsample)
181166
else:
182-
ind, dist = tpcpu.batch_ball_query(
183-
x, y, batch_x, batch_y, radius, nsample, mode=0, sorted=sort
184-
)
167+
ind, dist = tpcpu.batch_ball_query(x, y, batch_x, batch_y, radius, nsample, mode=0, sorted=sort)
185168
return ind, dist
186169

187170

@@ -224,9 +207,7 @@ def ball_query(
224207
assert x.size(0) == batch_x.size(0)
225208
assert y.size(0) == batch_y.size(0)
226209
assert x.dim() == 2
227-
return ball_query_partial_dense(
228-
radius, nsample, x, y, batch_x, batch_y, sort=sort
229-
)
210+
return ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y, sort=sort)
230211

231212
elif mode.lower() == "dense":
232213
if (batch_x is not None) or (batch_y is not None):
@@ -248,9 +229,7 @@ def forward(ctx, xyz1, xyz2):
248229
@staticmethod
249230
def backward(ctx, grad_dist1, grad_dist2):
250231
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
251-
grad_xyz1, grad_xyz2 = tpcuda.chamfer_dist_grad(
252-
xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2
253-
)
232+
grad_xyz1, grad_xyz2 = tpcuda.chamfer_dist_grad(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2)
254233
return grad_xyz1, grad_xyz2
255234

256235

@@ -281,4 +260,3 @@ def chamfer_dist(xyz1, xyz2, ignore_zeros=False):
281260

282261
dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
283262
return torch.mean(dist1) + torch.mean(dist2)
284-

0 commit comments

Comments
 (0)