Skip to content

Commit eda4158

Browse files
committed
Resolve conflicts.
2 parents 6bbeb91 + db6c376 commit eda4158

File tree

12 files changed

+186
-66
lines changed

12 files changed

+186
-66
lines changed

setup.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ def get_ext_modules():
3737
name="torch_points_kernels.points_cuda",
3838
sources=ext_sources,
3939
include_dirs=["{}/include".format(ext_src_root)],
40-
extra_compile_args={"cxx": extra_compile_args, "nvcc": extra_compile_args,},
40+
extra_compile_args={
41+
"cxx": extra_compile_args,
42+
"nvcc": extra_compile_args,
43+
},
4144
)
4245
)
4346

@@ -49,7 +52,9 @@ def get_ext_modules():
4952
name="torch_points_kernels.points_cpu",
5053
sources=cpu_ext_sources,
5154
include_dirs=["{}/include".format(cpu_ext_src_root)],
52-
extra_compile_args={"cxx": extra_compile_args,},
55+
extra_compile_args={
56+
"cxx": extra_compile_args,
57+
},
5358
)
5459
)
5560
return ext_modules
@@ -81,5 +86,8 @@ def get_cmdclass():
8186
cmdclass=get_cmdclass(),
8287
long_description=long_description,
8388
long_description_content_type="text/markdown",
84-
classifiers=["Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License",],
89+
classifiers=[
90+
"Programming Language :: Python :: 3",
91+
"License :: OSI Approved :: MIT License",
92+
],
8593
)

test/speed_radius.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,26 @@ def test_speed(self):
2323
R = 1
2424
samples = 50
2525

26-
idx, dist = ball_query(R, samples, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True,)
27-
idx1, dist = ball_query(R, samples, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True,)
26+
idx, dist = ball_query(
27+
R,
28+
samples,
29+
a,
30+
b,
31+
mode="PARTIAL_DENSE",
32+
batch_x=batch_a,
33+
batch_y=batch_b,
34+
sort=True,
35+
)
36+
idx1, dist = ball_query(
37+
R,
38+
samples,
39+
a,
40+
b,
41+
mode="PARTIAL_DENSE",
42+
batch_x=batch_a,
43+
batch_y=batch_b,
44+
sort=True,
45+
)
2846
print(time.time() - start)
2947
torch.testing.assert_allclose(idx1, idx)
3048

test/test_ballquerry.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,48 @@ def test_random_cpu(self, cuda=False):
112112
batch_b = torch.tensor([0 for i in range(b.shape[0] // 2)] + [1 for i in range(b.shape[0] // 2, b.shape[0])])
113113
R = 1
114114

115-
idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True,)
116-
idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True,)
115+
idx, dist = ball_query(
116+
R,
117+
15,
118+
a,
119+
b,
120+
mode="PARTIAL_DENSE",
121+
batch_x=batch_a,
122+
batch_y=batch_b,
123+
sort=True,
124+
)
125+
idx1, dist = ball_query(
126+
R,
127+
15,
128+
a,
129+
b,
130+
mode="PARTIAL_DENSE",
131+
batch_x=batch_a,
132+
batch_y=batch_b,
133+
sort=True,
134+
)
117135
torch.testing.assert_allclose(idx1, idx)
118136
with self.assertRaises(AssertionError):
119-
idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False,)
120-
idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False,)
137+
idx, dist = ball_query(
138+
R,
139+
15,
140+
a,
141+
b,
142+
mode="PARTIAL_DENSE",
143+
batch_x=batch_a,
144+
batch_y=batch_b,
145+
sort=False,
146+
)
147+
idx1, dist = ball_query(
148+
R,
149+
15,
150+
a,
151+
b,
152+
mode="PARTIAL_DENSE",
153+
batch_x=batch_a,
154+
batch_y=batch_b,
155+
sort=False,
156+
)
121157
torch.testing.assert_allclose(idx1, idx)
122158

123159
self.assertEqual(idx.shape[0], b.shape[0])
@@ -144,7 +180,16 @@ def test_random_gpu(self):
144180
).cuda()
145181
R = 1
146182

147-
idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False,)
183+
idx, dist = ball_query(
184+
R,
185+
15,
186+
a,
187+
b,
188+
mode="PARTIAL_DENSE",
189+
batch_x=batch_a,
190+
batch_y=batch_b,
191+
sort=False,
192+
)
148193

149194
# Comparison to see if we have the same result
150195
tree = KDTree(a.cpu().detach().numpy())

test/test_chamfer_dist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
1313
sys.path.insert(0, ROOT)
1414

15-
from torch_points_kernels import ChamferFunction, chamfer_dist
15+
from torch_points_kernels.chamfer_dist import ChamferFunction, chamfer_dist
1616

1717

1818
class TestChamferDistance(unittest.TestCase):

test/test_cluster.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,16 @@
1212
class TestGrow(unittest.TestCase):
1313
def setUp(self):
1414
self.pos = torch.tensor(
15-
[[0, 0, 0], [1, 0, 0], [2, 0, 0], [10, 0, 0], [0, 0, 0], [1, 0, 0], [2, 0, 0], [10, 0, 0],]
15+
[
16+
[0, 0, 0],
17+
[1, 0, 0],
18+
[2, 0, 0],
19+
[10, 0, 0],
20+
[0, 0, 0],
21+
[1, 0, 0],
22+
[2, 0, 0],
23+
[10, 0, 0],
24+
]
1625
)
1726
self.batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])
1827
self.labels = torch.tensor([0, 0, 1, 1, 0, 1, 1, 10])

test/test_grouping.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ def test_simple(self):
1313
features = torch.tensor(
1414
[
1515
[[0, 10, 0], [1, 11, 0], [2, 12, 0]],
16-
[[100, 110, 120], [101, 111, 121], [102, 112, 122],], # x-coordinates # y-coordinates # z-coordinates
16+
[
17+
[100, 110, 120],
18+
[101, 111, 121],
19+
[102, 112, 122],
20+
], # x-coordinates # y-coordinates # z-coordinates
1721
]
1822
).type(torch.float)
1923
idx = torch.tensor([[[1, 0], [0, 0]], [[0, 1], [1, 2]]]).type(torch.long)
@@ -38,7 +42,8 @@ def test_simple(self):
3842

3943
if torch.cuda.is_available():
4044
npt.assert_array_equal(
41-
grouping_operation(features.cuda(), idx.cuda()).detach().cpu().numpy(), expected,
45+
grouping_operation(features.cuda(), idx.cuda()).detach().cpu().numpy(),
46+
expected,
4247
)
4348

4449

test/test_metrics.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ def test_simple(self, cuda=False):
2525
gt_instances = gt_instances.cuda()
2626
ious = instance_iou(proposed_instances, gt_instances)
2727
torch.testing.assert_allclose(
28-
ious.cpu(), torch.tensor([[1, 0, 0], [0, 2 / 3.0, 0], [0, 1.0 / 4.0, 1.0 / 2.0]]),
28+
ious.cpu(),
29+
torch.tensor([[1, 0, 0], [0, 2 / 3.0, 0], [0, 1.0 / 4.0, 1.0 / 2.0]]),
2930
)
3031

3132
def test_batch(self, cuda=False):
@@ -42,7 +43,14 @@ def test_batch(self, cuda=False):
4243
batch = batch.cuda()
4344
ious = instance_iou(proposed_instances, gt_instances, batch=batch)
4445
torch.testing.assert_allclose(
45-
ious.cpu(), torch.tensor([[0.5, 0.5, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1],]),
46+
ious.cpu(),
47+
torch.tensor(
48+
[
49+
[0.5, 0.5, 0, 0, 0],
50+
[0, 0, 0, 1, 0],
51+
[0, 0, 0, 0, 1],
52+
]
53+
),
4654
)
4755

4856
@run_if_cuda
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import torch
2+
3+
if torch.cuda.is_available():
4+
import torch_points_kernels.points_cuda as tpcuda
5+
6+
7+
class ChamferFunction(torch.autograd.Function):
8+
@staticmethod
9+
def forward(ctx, xyz1, xyz2):
10+
if not torch.cuda.is_available():
11+
raise NotImplementedError(
12+
"CPU version is not available for Chamfer Distance"
13+
)
14+
15+
dist1, dist2, idx1, idx2 = tpcuda.chamfer_dist(xyz1, xyz2)
16+
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
17+
18+
return dist1, dist2
19+
20+
@staticmethod
21+
def backward(ctx, grad_dist1, grad_dist2):
22+
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
23+
grad_xyz1, grad_xyz2 = tpcuda.chamfer_dist_grad(
24+
xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2
25+
)
26+
return grad_xyz1, grad_xyz2
27+
28+
29+
def chamfer_dist(xyz1, xyz2, ignore_zeros=False):
30+
r"""
31+
Calcuates the distance between B pairs of point clouds
32+
33+
Parameters
34+
----------
35+
xyz1 : torch.Tensor (dtype=torch.float32)
36+
(B, n1, 3) B point clouds containing n1 points
37+
xyz2 : torch.Tensor (dtype=torch.float32)
38+
(B, n2, 3) B point clouds containing n2 points
39+
ignore_zeros : bool
40+
ignore the point whose coordinate is (0, 0, 0) or not
41+
42+
Returns
43+
-------
44+
dist: torch.Tensor
45+
(B, ): the distances between B pairs of point clouds
46+
"""
47+
if len(xyz1.shape) != 3 or xyz1.size(2) != 3 or len(xyz2.shape) != 3 or xyz2.size(2) != 3:
48+
raise ValueError('The input point cloud should be of size (B, n_pts, 3)')
49+
50+
batch_size = xyz1.size(0)
51+
if batch_size == 1 and ignore_zeros:
52+
non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
53+
non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
54+
xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
55+
xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)
56+
57+
dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
58+
return torch.mean(dist1) + torch.mean(dist2)
59+

torch_points_kernels/cluster.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _grow_proximity_core(neighbours, min_cluster_size):
3838

3939

4040
def grow_proximity(pos, batch, nsample=16, radius=0.02, min_cluster_size=32):
41-
""" Grow based on proximity only
41+
"""Grow based on proximity only
4242
Neighbour search is done on device while the cluster assignement is done on cpu"""
4343
assert pos.shape[0] == batch.shape[0]
4444
neighbours = ball_query_partial_dense(radius, nsample, pos, pos, batch, batch)[0].cpu().numpy()
@@ -48,7 +48,7 @@ def grow_proximity(pos, batch, nsample=16, radius=0.02, min_cluster_size=32):
4848
def region_grow(
4949
pos, labels, batch, ignore_labels=[], nsample=16, radius=0.02, min_cluster_size=32
5050
) -> List[torch.Tensor]:
51-
""" Region growing clustering algorithm proposed in
51+
"""Region growing clustering algorithm proposed in
5252
PointGroup: Dual-Set Point Grouping for 3D Instance Segmentation
5353
https://arxiv.org/pdf/2004.01658.pdf
5454
for instance segmentation
@@ -93,7 +93,11 @@ def region_grow(
9393

9494
# Cluster
9595
label_clusters = grow_proximity(
96-
pos[label_mask, :], remaped_batch, nsample=nsample, radius=radius, min_cluster_size=min_cluster_size,
96+
pos[label_mask, :],
97+
remaped_batch,
98+
nsample=nsample,
99+
radius=radius,
100+
min_cluster_size=min_cluster_size,
97101
)
98102

99103
# Remap indices to original coordinates

torch_points_kernels/knn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
def knn(pos_support, pos, k):
5-
""" Dense knn serach
5+
"""Dense knn serach
66
Arguments:
77
pos_support - [B,N,3] support points
88
pos - [B,M,3] centre of queries

0 commit comments

Comments
 (0)