Skip to content

Commit 94c870a

Browse files
Instance iou on GPU
1 parent a4c9048 commit 94c870a

File tree

13 files changed

+172
-145
lines changed

13 files changed

+172
-145
lines changed

cpu/include/utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22
#include <torch/extension.h>
33

4-
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be a CPU tensor")
4+
#define CHECK_CPU(x) AT_ASSERTM(!x.is_cuda(), #x " must be a CPU tensor")
55

6-
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be a contiguous tensor")
6+
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be a contiguous tensor")

cuda/include/metrics.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
#include <torch/extension.h>
33

44
at::Tensor instance_iou_cuda(at::Tensor instance_idx, at::Tensor instance_offsets,
5-
at::Tensor instance_gt);
5+
at::Tensor gt_instances, at::Tensor gt_instance_sizes,
6+
long num_gt_instances);

cuda/include/utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#define CHECK_CUDA(x) \
66
do \
77
{ \
8-
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \
8+
TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor"); \
99
} while (0)
1010

1111
#define CHECK_CONTIGUOUS(x) \

cuda/src/ball_query.cpp

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,18 @@ std::pair<at::Tensor, at::Tensor> ball_query_dense(at::Tensor new_xyz, at::Tenso
1919
CHECK_IS_FLOAT(new_xyz);
2020
CHECK_IS_FLOAT(xyz);
2121

22-
if (new_xyz.type().is_cuda())
23-
{
24-
CHECK_CUDA(xyz);
25-
}
22+
CHECK_CUDA(xyz);
23+
CHECK_CUDA(new_xyz);
2624

2725
at::Tensor idx = torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
2826
at::device(new_xyz.device()).dtype(at::ScalarType::Long));
2927
at::Tensor dist = torch::full({new_xyz.size(0), new_xyz.size(1), nsample}, -1,
3028
at::device(new_xyz.device()).dtype(at::ScalarType::Float));
3129

32-
if (new_xyz.type().is_cuda())
33-
{
34-
query_ball_point_kernel_dense_wrapper(
35-
xyz.size(0), xyz.size(1), new_xyz.size(1), radius, nsample, new_xyz.DATA_PTR<float>(),
36-
xyz.DATA_PTR<float>(), idx.DATA_PTR<long>(), dist.DATA_PTR<float>());
37-
}
38-
else
39-
{
40-
TORCH_CHECK(false, "CPU not supported");
41-
}
30+
query_ball_point_kernel_dense_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), radius,
31+
nsample, new_xyz.DATA_PTR<float>(), xyz.DATA_PTR<float>(),
32+
idx.DATA_PTR<long>(), dist.DATA_PTR<float>());
33+
4234
return std::make_pair(idx, dist);
4335
}
4436

@@ -57,14 +49,10 @@ std::pair<at::Tensor, at::Tensor> ball_query_partial_dense(at::Tensor x, at::Ten
5749
CHECK_CONTIGUOUS(y);
5850
CHECK_IS_FLOAT(x);
5951
CHECK_IS_FLOAT(y);
60-
61-
if (x.type().is_cuda())
62-
{
63-
CHECK_CUDA(x);
64-
CHECK_CUDA(y);
65-
CHECK_CUDA(batch_x);
66-
CHECK_CUDA(batch_y);
67-
}
52+
CHECK_CUDA(x);
53+
CHECK_CUDA(y);
54+
CHECK_CUDA(batch_x);
55+
CHECK_CUDA(batch_y);
6856

6957
at::Tensor idx =
7058
torch::full({y.size(0), nsample}, -1, at::device(y.device()).dtype(at::ScalarType::Long));
@@ -83,17 +71,10 @@ std::pair<at::Tensor, at::Tensor> ball_query_partial_dense(at::Tensor x, at::Ten
8371
batch_y = degree(batch_y, batch_size);
8472
batch_y = at::cat({at::zeros(1, batch_y.options()), batch_y.cumsum(0)}, 0);
8573

86-
if (x.type().is_cuda())
87-
{
88-
query_ball_point_kernel_partial_wrapper(batch_size, x.size(0), y.size(0), radius, nsample,
89-
x.DATA_PTR<float>(), y.DATA_PTR<float>(),
90-
batch_x.DATA_PTR<long>(), batch_y.DATA_PTR<long>(),
91-
idx.DATA_PTR<long>(), dist.DATA_PTR<float>());
92-
}
93-
else
94-
{
95-
TORCH_CHECK(false, "CPU not supported");
96-
}
74+
query_ball_point_kernel_partial_wrapper(batch_size, x.size(0), y.size(0), radius, nsample,
75+
x.DATA_PTR<float>(), y.DATA_PTR<float>(),
76+
batch_x.DATA_PTR<long>(), batch_y.DATA_PTR<long>(),
77+
idx.DATA_PTR<long>(), dist.DATA_PTR<float>());
9778

9879
return std::make_pair(idx, dist);
9980
}

cuda/src/bindings.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "ball_query.h"
22
#include "interpolate.h"
3+
#include "metrics.h"
34
#include "sampling.h"
45

56
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
@@ -12,4 +13,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
1213

1314
m.def("ball_query_dense", &ball_query_dense);
1415
m.def("ball_query_partial_dense", &ball_query_partial_dense);
16+
17+
m.def("instance_iou_cuda", &instance_iou_cuda);
1518
}

cuda/src/interpolate.cpp

Lines changed: 16 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,17 @@ std::vector<at::Tensor> three_nn(at::Tensor unknowns, at::Tensor knows)
1616
CHECK_IS_FLOAT(unknowns);
1717
CHECK_IS_FLOAT(knows);
1818

19-
if (unknowns.type().is_cuda())
20-
{
21-
CHECK_CUDA(knows);
22-
}
19+
CHECK_CUDA(knows);
20+
CHECK_CUDA(unknowns);
2321

2422
at::Tensor idx = torch::zeros({unknowns.size(0), unknowns.size(1), 3},
2523
at::device(unknowns.device()).dtype(at::ScalarType::Int));
2624
at::Tensor dist2 = torch::zeros({unknowns.size(0), unknowns.size(1), 3},
2725
at::device(unknowns.device()).dtype(at::ScalarType::Float));
2826

29-
if (unknowns.type().is_cuda())
30-
{
31-
three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1),
32-
unknowns.DATA_PTR<float>(), knows.DATA_PTR<float>(),
33-
dist2.DATA_PTR<float>(), idx.DATA_PTR<int>());
34-
}
35-
else
36-
{
37-
TORCH_CHECK(false, "CPU not supported");
38-
}
27+
three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1),
28+
unknowns.DATA_PTR<float>(), knows.DATA_PTR<float>(),
29+
dist2.DATA_PTR<float>(), idx.DATA_PTR<int>());
3930

4031
return {dist2, idx};
4132
}
@@ -49,25 +40,15 @@ at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, at::Tensor weigh
4940
CHECK_IS_INT(idx);
5041
CHECK_IS_FLOAT(weight);
5142

52-
if (points.type().is_cuda())
53-
{
54-
CHECK_CUDA(idx);
55-
CHECK_CUDA(weight);
56-
}
43+
CHECK_CUDA(idx);
44+
CHECK_CUDA(weight);
5745

5846
at::Tensor output = torch::zeros({points.size(0), points.size(1), idx.size(1)},
5947
at::device(points.device()).dtype(at::ScalarType::Float));
6048

61-
if (points.type().is_cuda())
62-
{
63-
three_interpolate_kernel_wrapper(points.size(0), points.size(1), points.size(2),
64-
idx.size(1), points.DATA_PTR<float>(), idx.DATA_PTR<int>(),
65-
weight.DATA_PTR<float>(), output.DATA_PTR<float>());
66-
}
67-
else
68-
{
69-
TORCH_CHECK(false, "CPU not supported");
70-
}
49+
three_interpolate_kernel_wrapper(points.size(0), points.size(1), points.size(2), idx.size(1),
50+
points.DATA_PTR<float>(), idx.DATA_PTR<int>(),
51+
weight.DATA_PTR<float>(), output.DATA_PTR<float>());
7152

7253
return output;
7354
}
@@ -80,26 +61,16 @@ at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, at::Tenso
8061
CHECK_IS_FLOAT(grad_out);
8162
CHECK_IS_INT(idx);
8263
CHECK_IS_FLOAT(weight);
83-
84-
if (grad_out.type().is_cuda())
85-
{
86-
CHECK_CUDA(idx);
87-
CHECK_CUDA(weight);
88-
}
64+
CHECK_CUDA(idx);
65+
CHECK_CUDA(weight);
66+
CHECK_CUDA(grad_out);
8967

9068
at::Tensor output = torch::zeros({grad_out.size(0), grad_out.size(1), m},
9169
at::device(grad_out.device()).dtype(at::ScalarType::Float));
9270

93-
if (grad_out.type().is_cuda())
94-
{
95-
three_interpolate_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), grad_out.size(2),
96-
m, grad_out.DATA_PTR<float>(), idx.DATA_PTR<int>(),
97-
weight.DATA_PTR<float>(), output.DATA_PTR<float>());
98-
}
99-
else
100-
{
101-
TORCH_CHECK(false, "CPU not supported");
102-
}
71+
three_interpolate_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), grad_out.size(2), m,
72+
grad_out.DATA_PTR<float>(), idx.DATA_PTR<int>(),
73+
weight.DATA_PTR<float>(), output.DATA_PTR<float>());
10374

10475
return output;
10576
}

cuda/src/metrics.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,33 @@
22
#include "compat.h"
33
#include "utils.h"
44

5-
void instance_iou_kernel_wrapper(int b, int n, int m, const float* dataset, float* temp, int* idxs);
5+
void instance_iou_kernel_wrapper(int nInstance, int nProposal, long* proposals_idx,
6+
long* proposals_offset, long* instance_labels,
7+
long* instance_pointnum, float* proposals_iou);
68

79
at::Tensor instance_iou_cuda(at::Tensor instance_idx, at::Tensor instance_offsets,
8-
at::Tensor instance_gt)
10+
at::Tensor gt_instances, at::Tensor gt_instance_sizes,
11+
long num_gt_instances)
912
{
1013
CHECK_CONTIGUOUS(instance_idx);
1114
CHECK_CONTIGUOUS(instance_offsets);
12-
CHECK_CONTIGUOUS(instance_gt);
13-
CHECK_CUDA(instance_idx)
14-
CHECK_CUDA(instance_offsets)
15-
CHECK_CUDA(instance_gt)
15+
CHECK_CONTIGUOUS(gt_instances);
16+
CHECK_CONTIGUOUS(gt_instance_sizes);
1617

17-
auto num_gt_instances = instance_gt.max(0);
18-
auto num_proposed_instances = instance_offsets.size(0);
18+
CHECK_CUDA(instance_idx);
19+
CHECK_CUDA(instance_offsets);
20+
CHECK_CUDA(gt_instances);
21+
CHECK_CUDA(gt_instance_sizes);
22+
23+
long num_proposed_instances = instance_offsets.size(0) - 1;
1924
at::Tensor output =
2025
torch::zeros({num_proposed_instances, num_gt_instances},
21-
at::device(num_gt_instances.device()).dtype(at::ScalarType::Float));
26+
at::device(gt_instances.device()).dtype(at::ScalarType::Float));
2227

23-
instance_iou_kernel_wrapper(points.size(0), points.size(1), nsamples, points.DATA_PTR<float>(),
24-
tmp.DATA_PTR<float>(), output.DATA_PTR<float>());
28+
instance_iou_kernel_wrapper(num_gt_instances, num_proposed_instances,
29+
instance_idx.DATA_PTR<long>(), instance_offsets.DATA_PTR<long>(),
30+
gt_instances.DATA_PTR<long>(), gt_instance_sizes.DATA_PTR<long>(),
31+
output.DATA_PTR<float>());
2532

2633
return output;
2734
}

cuda/src/metrics.cu

Lines changed: 0 additions & 8 deletions
This file was deleted.

cuda/src/metrics_gpu.cu

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#include <math.h>
2+
#include <stdio.h>
3+
#include <stdlib.h>
4+
5+
#include "cuda_utils.h"
6+
7+
#define THREADS 512
8+
9+
__global__ void instance_iou_cuda_kernel(int nInstance, int nProposal, long* proposals_idx,
10+
long* proposals_offset, long* instance_labels,
11+
long* instance_pointnum, float* proposals_iou)
12+
{
13+
for (int proposal_id = blockIdx.x; proposal_id < nProposal; proposal_id += gridDim.x)
14+
{
15+
int start = proposals_offset[proposal_id];
16+
int end = proposals_offset[proposal_id + 1];
17+
int proposal_total = end - start;
18+
for (int instance_id = threadIdx.x; instance_id < nInstance; instance_id += blockDim.x)
19+
{
20+
int instance_total = instance_pointnum[instance_id];
21+
int intersection = 0;
22+
for (int i = start; i < end; i++)
23+
{
24+
int idx = proposals_idx[i];
25+
if ((int)instance_labels[idx] == instance_id + 1)
26+
{ // 0 is reserved for "no instance"
27+
intersection += 1;
28+
}
29+
}
30+
proposals_iou[proposal_id * nInstance + instance_id] =
31+
(float)intersection /
32+
((float)(proposal_total + instance_total - intersection) + 1e-5);
33+
}
34+
}
35+
}
36+
37+
// input: proposals_idx (sumNPoint), int
38+
// input: proposals_offset (nProposal + 1), int
39+
// input: instance_labels (N), long, 0~total_nInst-1, -100
40+
// input: instance_pointnum (total_nInst), int
41+
// output: proposals_iou (nProposal, total_nInst), float
42+
void instance_iou_kernel_wrapper(int nInstance, int nProposal, long* proposals_idx,
43+
long* proposals_offset, long* instance_labels,
44+
long* instance_pointnum, float* proposals_iou)
45+
{
46+
instance_iou_cuda_kernel<<<std::min(nProposal, THREADS * THREADS),
47+
std::min(nInstance, THREADS)>>>(nInstance, nProposal, proposals_idx,
48+
proposals_offset, instance_labels,
49+
instance_pointnum, proposals_iou);
50+
}

cuda/src/sampling.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,17 @@ at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples)
99
{
1010
CHECK_CONTIGUOUS(points);
1111
CHECK_IS_FLOAT(points);
12+
CHECK_CUDA(points);
1213

1314
at::Tensor output = torch::zeros({points.size(0), nsamples},
1415
at::device(points.device()).dtype(at::ScalarType::Int));
1516

1617
at::Tensor tmp = torch::full({points.size(0), points.size(1)}, 1e10,
1718
at::device(points.device()).dtype(at::ScalarType::Float));
1819

19-
if (points.type().is_cuda())
20-
{
21-
furthest_point_sampling_kernel_wrapper(points.size(0), points.size(1), nsamples,
22-
points.DATA_PTR<float>(), tmp.DATA_PTR<float>(),
23-
output.DATA_PTR<int>());
24-
}
25-
else
26-
{
27-
TORCH_CHECK(false, "CPU not supported");
28-
}
20+
furthest_point_sampling_kernel_wrapper(points.size(0), points.size(1), nsamples,
21+
points.DATA_PTR<float>(), tmp.DATA_PTR<float>(),
22+
output.DATA_PTR<int>());
2923

3024
return output;
3125
}

0 commit comments

Comments
 (0)