Skip to content

Commit 80a7dc5

Browse files
committed
all tests on CPU+GPU
1 parent 5db0086 commit 80a7dc5

File tree

6 files changed

+138
-23
lines changed

6 files changed

+138
-23
lines changed

benchmark/gather.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from torch_scatter import gather_coo, gather_csr
99

10-
from scatter_segment import iters, sizes
1110
from scatter_segment import short_rows, long_rows, download, bold
1211

1312

@@ -125,6 +124,9 @@ def gat_csr(x):
125124
parser.add_argument('--with_backward', action='store_true')
126125
parser.add_argument('--device', type=str, default='cuda')
127126
args = parser.parse_args()
127+
iters = 1 if args.device == 'cpu' else 20
128+
sizes = [1, 16, 32, 64, 128, 256, 512]
129+
sizes = sizes[:3] if args.device == 'cpu' else sizes
128130

129131
for _ in range(10): # Warmup.
130132
torch.randn(100, 100, device=args.device).sum()

cpu/gather.cpp

Lines changed: 119 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#include <torch/extension.h>
22

3+
#include "compat.h"
4+
#include "index_info.h"
5+
36
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
47

58
at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
@@ -8,8 +11,59 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
811
CHECK_CPU(indptr);
912
if (out_opt.has_value())
1013
CHECK_CPU(out_opt.value());
11-
AT_ASSERTM(false, "Not yet implemented");
12-
return src;
14+
15+
AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
16+
for (int i = 0; i < indptr.dim() - 1; i++)
17+
AT_ASSERTM(src.size(i) == indptr.size(i), "Input mismatch");
18+
19+
src = src.contiguous();
20+
auto gather_dim = indptr.dim() - 1;
21+
AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1,
22+
"Input mismatch");
23+
24+
at::Tensor out;
25+
if (out_opt.has_value()) {
26+
out = out_opt.value().contiguous();
27+
for (int i = 0; i < out.dim(); i++)
28+
if (i != gather_dim)
29+
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
30+
} else {
31+
auto sizes = src.sizes().vec();
32+
sizes[gather_dim] = *indptr.flatten()[-1].DATA_PTR<int64_t>();
33+
out = at::empty(sizes, src.options());
34+
}
35+
36+
auto N = src.size(gather_dim) * (indptr.numel() / indptr.size(-1));
37+
auto K = src.numel() / N;
38+
auto E = out.size(gather_dim);
39+
40+
auto indptr_info = getTensorInfo<int64_t>(indptr);
41+
auto stride = indptr_info.strides[indptr_info.dims - 1];
42+
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr", [&] {
43+
auto src_data = src.DATA_PTR<scalar_t>();
44+
auto out_data = out.DATA_PTR<scalar_t>();
45+
46+
scalar_t vals[K];
47+
int64_t row_start, row_end;
48+
for (int n = 0; n < N; n++) {
49+
int offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
50+
row_start = indptr_info.data[offset];
51+
row_end = indptr_info.data[offset + stride];
52+
53+
for (int k = 0; k < K; k++) {
54+
vals[k] = src_data[n * K + k];
55+
}
56+
57+
offset = (n / (indptr.size(-1) - 1)) * E * K;
58+
for (int64_t e = row_start; e < row_end; e++) {
59+
for (int k = 0; k < K; k++) {
60+
out_data[offset + e * K + k] = vals[k];
61+
}
62+
}
63+
}
64+
});
65+
66+
return out;
1367
}
1468

1569
at::Tensor gather_coo(at::Tensor src, at::Tensor index,
@@ -18,8 +72,69 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index,
1872
CHECK_CPU(index);
1973
if (out_opt.has_value())
2074
CHECK_CPU(out_opt.value());
21-
AT_ASSERTM(false, "Not yet implemented");
22-
return src;
75+
76+
AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
77+
for (int i = 0; i < index.dim() - 1; i++)
78+
AT_ASSERTM(src.size(i) == index.size(i), "Input mismatch");
79+
80+
src = src.contiguous();
81+
auto gather_dim = index.dim() - 1;
82+
83+
at::Tensor out;
84+
if (out_opt.has_value()) {
85+
out = out_opt.value().contiguous();
86+
for (int i = 0; i < index.dim(); i++)
87+
AT_ASSERTM(out.size(i) == index.size(i), "Input mismatch");
88+
for (int i = index.dim() + 1; i < src.dim(); i++)
89+
AT_ASSERTM(out.size(i) == src.size(i), "Input mismatch");
90+
} else {
91+
auto sizes = src.sizes().vec();
92+
sizes[gather_dim] = index.size(gather_dim);
93+
out = at::empty(sizes, src.options());
94+
}
95+
96+
auto E_1 = index.numel() / out.size(gather_dim);
97+
auto E_2 = index.size(gather_dim);
98+
auto K = out.numel() / index.numel();
99+
auto N = src.size(gather_dim);
100+
101+
auto index_info = getTensorInfo<int64_t>(index);
102+
auto stride = index_info.strides[index_info.dims - 1];
103+
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo", [&] {
104+
auto src_data = src.DATA_PTR<scalar_t>();
105+
auto out_data = out.DATA_PTR<scalar_t>();
106+
107+
scalar_t vals[K];
108+
int64_t idx, next_idx;
109+
for (int e_1 = 0; e_1 < E_1; e_1++) {
110+
int offset = IndexToOffset<int64_t>::get(e_1 * E_2, index_info);
111+
idx = index_info.data[offset];
112+
113+
for (int k = 0; k < K; k++) {
114+
vals[k] = src_data[e_1 * N * K + idx * K + k];
115+
}
116+
117+
for (int e_2 = 0; e_2 < E_2; e_2++) {
118+
for (int k = 0; k < K; k++) {
119+
out_data[e_1 * E_2 * K + e_2 * K + k] = vals[k];
120+
}
121+
122+
if (e_2 < E_2 - 1) {
123+
next_idx = index_info.data[offset + (e_2 + 1) * stride];
124+
assert(idx <= next_idx);
125+
126+
if (idx != next_idx) {
127+
idx = next_idx;
128+
for (int k = 0; k < K; k++) {
129+
vals[k] = src_data[e_1 * N * K + idx * K + k];
130+
}
131+
}
132+
}
133+
}
134+
}
135+
});
136+
137+
return out;
23138
}
24139

25140
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

cpu/segment.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
184184
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
185185
}
186186

187-
auto E = index.numel();
188187
auto E_1 = index.numel() / src.size(reduce_dim);
189188
auto E_2 = src.size(reduce_dim);
190189
auto K = src.numel() / index.numel();
@@ -202,12 +201,12 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
202201
for (int e_1 = 0; e_1 < E_1; e_1++) {
203202
int offset = IndexToOffset<int64_t>::get(e_1 * E_2, index_info);
204203
idx = index_info.data[offset];
205-
row_start = 0;
206204

207205
for (int k = 0; k < K; k++) {
208206
vals[k] = out_data[e_1 * N * K + k];
209207
}
210208

209+
row_start = 0;
211210
for (int e_2 = 0; e_2 < E_2; e_2++) {
212211

213212
for (int k = 0; k < K; k++) {
@@ -224,6 +223,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
224223
}
225224
} else {
226225
next_idx = index_info.data[offset + (e_2 + 1) * stride];
226+
assert(idx <= next_idx);
227227

228228
if (idx != next_idx) {
229229
for (int k = 0; k < K; k++) {

test/test_gather.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55
from torch.autograd import gradcheck
66
from torch_scatter import gather_coo, gather_csr
77

8-
from .utils import tensor
9-
10-
dtypes = [torch.float]
11-
devices = [torch.device('cuda')]
8+
from .utils import tensor, dtypes, devices
129

1310
tests = [
1411
{
@@ -50,7 +47,6 @@
5047
]
5148

5249

53-
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
5450
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
5551
def test_forward(test, dtype, device):
5652
src = tensor(test['src'], dtype, device)
@@ -65,7 +61,6 @@ def test_forward(test, dtype, device):
6561
assert torch.all(out == expected)
6662

6763

68-
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
6964
@pytest.mark.parametrize('test,device', product(tests, devices))
7065
def test_backward(test, device):
7166
src = tensor(test['src'], torch.double, device)
@@ -77,9 +72,8 @@ def test_backward(test, device):
7772
assert gradcheck(gather_csr, (src, indptr, None)) is True
7873

7974

80-
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
8175
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
82-
def test_segment_out(test, dtype, device):
76+
def test_gather_out(test, dtype, device):
8377
src = tensor(test['src'], dtype, device)
8478
index = tensor(test['index'], torch.long, device)
8579
indptr = tensor(test['indptr'], torch.long, device)
@@ -98,7 +92,6 @@ def test_segment_out(test, dtype, device):
9892
assert torch.all(out == expected)
9993

10094

101-
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
10295
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
10396
def test_non_contiguous_segment(test, dtype, device):
10497
src = tensor(test['src'], dtype, device)

test/test_segment.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55
from torch.autograd import gradcheck
66
from torch_scatter import segment_coo, segment_csr
77

8-
from .utils import tensor, dtypes
8+
from .utils import tensor, dtypes, devices
99

1010
reductions = ['add', 'mean', 'min', 'max']
1111
grad_reductions = ['add', 'mean']
1212

13-
devices = [torch.device('cpu')]
14-
1513
tests = [
1614
{
1715
'src': [1, 2, 3, 4, 5, 6],
@@ -105,7 +103,6 @@ def test_forward(test, reduce, dtype, device):
105103
assert torch.all(out == expected)
106104

107105

108-
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
109106
@pytest.mark.parametrize('test,reduce,device',
110107
product(tests, grad_reductions, devices))
111108
def test_backward(test, reduce, device):

torch_scatter/segment.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,20 @@ def backward(ctx, grad_out, *args):
5656
grad_src = None
5757
if ctx.needs_input_grad[0]:
5858
if ctx.reduce == 'add':
59-
grad_src = gat(grad_out).gather_coo(
59+
grad_src = gat(grad_out.is_cuda).gather_coo(
6060
grad_out, index, grad_out.new_empty(src_size))
6161
elif ctx.reduce == 'mean':
62-
grad_src = gat(grad_out).gather_coo(
62+
grad_src = gat(grad_out.is_cuda).gather_coo(
6363
grad_out, index, grad_out.new_empty(src_size))
64-
count = arg_out
64+
65+
count = arg_out # Gets pre-computed on GPU but not on CPU.
66+
if count is None:
67+
size = list(index.size())
68+
size[-1] = grad_out.size(index.dim() - 1)
69+
count = segment_cpu.segment_coo(
70+
torch.ones_like(index, dtype=grad_out.dtype), index,
71+
grad_out.new_zeros(size), 'add')[0].clamp_(min=1)
72+
6573
count = gat(grad_out.is_cuda).gather_coo(
6674
count, index, count.new_empty(src_size[:index.dim()]))
6775
for _ in range(grad_out.dim() - index.dim()):

0 commit comments

Comments
 (0)