Skip to content

Commit ea94e54

Browse files
committed
cpu boilerplate
1 parent d824c8b commit ea94e54

File tree

7 files changed

+121
-32
lines changed

7 files changed

+121
-32
lines changed

benchmark/gather.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,16 @@ def correctness(dataset):
3030

3131
assert torch.allclose(out1, out2, atol=1e-4)
3232
assert torch.allclose(out1, out3, atol=1e-4)
33-
except RuntimeError:
33+
except RuntimeError as e:
34+
if 'out of memory' not in str(e):
35+
raise RuntimeError(e)
3436
torch.cuda.empty_cache()
3537

3638

3739
def time_func(func, x):
3840
try:
39-
torch.cuda.synchronize()
41+
if torch.cuda.is_available():
42+
torch.cuda.synchronize()
4043
t = time.perf_counter()
4144

4245
if not args.with_backward:
@@ -49,9 +52,12 @@ def time_func(func, x):
4952
out = func(x)
5053
torch.autograd.grad(out, x, out, only_inputs=True)
5154

52-
torch.cuda.synchronize()
55+
if torch.cuda.is_available():
56+
torch.cuda.synchronize()
5357
return time.perf_counter() - t
54-
except RuntimeError:
58+
except RuntimeError as e:
59+
if 'out of memory' not in str(e):
60+
raise RuntimeError(e)
5561
torch.cuda.empty_cache()
5662
return float('inf')
5763

@@ -88,7 +94,9 @@ def gat_csr(x):
8894

8995
del x
9096

91-
except RuntimeError:
97+
except RuntimeError as e:
98+
if 'out of memory' not in str(e):
99+
raise RuntimeError(e)
92100
torch.cuda.empty_cache()
93101
for t in (t1, t2, t3, t4):
94102
t.append(float('inf'))

benchmark/scatter_segment.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,16 @@ def correctness(dataset):
8282
assert torch.allclose(out1, out2, atol=1e-4)
8383
assert torch.allclose(out1, out3, atol=1e-4)
8484

85-
except RuntimeError:
85+
except RuntimeError as e:
86+
if 'out of memory' not in str(e):
87+
raise RuntimeError(e)
8688
torch.cuda.empty_cache()
8789

8890

8991
def time_func(func, x):
9092
try:
91-
torch.cuda.synchronize()
93+
if torch.cuda.is_available():
94+
torch.cuda.synchronize()
9295
t = time.perf_counter()
9396

9497
if not args.with_backward:
@@ -102,9 +105,12 @@ def time_func(func, x):
102105
out = out[0] if isinstance(out, tuple) else out
103106
torch.autograd.grad(out, x, out, only_inputs=True)
104107

105-
torch.cuda.synchronize()
108+
if torch.cuda.is_available():
109+
torch.cuda.synchronize()
106110
return time.perf_counter() - t
107-
except RuntimeError:
111+
except RuntimeError as e:
112+
if 'out of memory' not in str(e):
113+
raise RuntimeError(e)
108114
torch.cuda.empty_cache()
109115
return float('inf')
110116

@@ -152,7 +158,9 @@ def dense2(x):
152158

153159
del x
154160

155-
except RuntimeError:
161+
except RuntimeError as e:
162+
if 'out of memory' not in str(e):
163+
raise RuntimeError(e)
156164
torch.cuda.empty_cache()
157165
for t in (t1, t2, t3, t4):
158166
t.append(float('inf'))
@@ -167,7 +175,9 @@ def dense2(x):
167175

168176
del x
169177

170-
except RuntimeError:
178+
except RuntimeError as e:
179+
if 'out of memory' not in str(e):
180+
raise RuntimeError(e)
171181
torch.cuda.empty_cache()
172182
for t in (t5, t6):
173183
t.append(float('inf'))
@@ -197,8 +207,11 @@ def dense2(x):
197207

198208
if __name__ == '__main__':
199209
parser = argparse.ArgumentParser()
200-
parser.add_argument('--reduce', type=str, required=True,
201-
choices=['add', 'mean', 'min', 'max'])
210+
parser.add_argument(
211+
'--reduce',
212+
type=str,
213+
required=True,
214+
choices=['add', 'mean', 'min', 'max'])
202215
parser.add_argument('--with_backward', action='store_true')
203216
parser.add_argument('--device', type=str, default='cuda')
204217
args = parser.parse_args()

cpu/gather.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#include <torch/extension.h>
2+
3+
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
4+
5+
at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
6+
at::optional<at::Tensor> out_opt) {
7+
CHECK_CPU(src);
8+
CHECK_CPU(indptr);
9+
if (out_opt.has_value())
10+
CHECK_CPU(out_opt.value());
11+
AT_ASSERTM(false, "Not yet implemented");
12+
return src;
13+
}
14+
15+
at::Tensor gather_coo(at::Tensor src, at::Tensor index,
16+
at::optional<at::Tensor> out_opt) {
17+
CHECK_CPU(src);
18+
CHECK_CPU(index);
19+
if (out_opt.has_value())
20+
CHECK_CPU(out_opt.value());
21+
AT_ASSERTM(false, "Not yet implemented");
22+
return src;
23+
}
24+
25+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
26+
m.def("gather_csr", &gather_csr, "Gather CSR (CPU)");
27+
m.def("gather_coo", &gather_coo, "Gather COO (CPU)");
28+
}

cpu/segment.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#include <torch/extension.h>
2+
3+
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
4+
5+
std::tuple<at::Tensor, at::optional<at::Tensor>>
6+
segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
7+
std::string reduce) {
8+
CHECK_CPU(src);
9+
CHECK_CPU(indptr);
10+
if (out_opt.has_value())
11+
CHECK_CPU(out_opt.value());
12+
AT_ASSERTM(false, "Not yet implemented");
13+
return std::make_tuple(src, at::nullopt);
14+
}
15+
16+
std::tuple<at::Tensor, at::optional<at::Tensor>>
17+
segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
18+
std::string reduce) {
19+
CHECK_CPU(src);
20+
CHECK_CPU(index);
21+
CHECK_CPU(out);
22+
AT_ASSERTM(false, "Not yet implemented");
23+
return std::make_tuple(src, at::nullopt);
24+
}
25+
26+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
27+
m.def("segment_csr", &segment_csr, "Segment CSR (CPU)");
28+
m.def("segment_coo", &segment_coo, "Segment COO (CPU)");
29+
}

setup.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,18 @@
2525
ext_modules = []
2626
exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cpu', '*.cpp'))]
2727
ext_modules += [
28-
CppExtension(f'torch_scatter.{ext}_cpu', [f'cpu/{ext}.cpp'],
29-
extra_compile_args=cxx_extra_compile_args) for ext in exts
28+
CppExtension(
29+
f'torch_scatter.{ext}_cpu', [f'cpu/{ext}.cpp'],
30+
extra_compile_args=cxx_extra_compile_args) for ext in exts
3031
]
3132

3233
if CUDA_HOME is not None and USE_GPU:
3334
exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cuda', '*.cpp'))]
3435
ext_modules += [
3536
CUDAExtension(
3637
f'torch_scatter.{ext}_cuda',
37-
[f'cuda/{ext}.cpp', f'cuda/{ext}_kernel.cu'], extra_compile_args={
38+
[f'cuda/{ext}.cpp', f'cuda/{ext}_kernel.cu'],
39+
extra_compile_args={
3840
'cxx': cxx_extra_compile_args,
3941
'nvcc': nvcc_extra_compile_args,
4042
}) for ext in exts

torch_scatter/gather.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import torch
22

3+
from torch_scatter import segment_cpu, gather_cpu
4+
35
if torch.cuda.is_available():
46
from torch_scatter import gather_cuda, segment_cuda
57

8+
gat = lambda is_cuda: gather_cuda if is_cuda else gather_cpu # noqa
9+
seg = lambda is_cuda: segment_cuda if is_cuda else segment_cpu # noqa
10+
611

712
class GatherCOO(torch.autograd.Function):
813
@staticmethod
@@ -12,15 +17,15 @@ def forward(ctx, src, index, out):
1217
ctx.src_size = list(src.size())
1318
ctx.save_for_backward(index)
1419

15-
return gather_cuda.gather_coo(src, index, out)
20+
return gat(src.is_cuda).gather_coo(src, index, out)
1621

1722
@staticmethod
1823
def backward(ctx, grad_out):
1924
(index, ), src_size = ctx.saved_tensors, ctx.src_size
2025

2126
grad_src = None
2227
if ctx.needs_input_grad[0]:
23-
grad_src, _ = segment_cuda.segment_coo(
28+
grad_src, _ = seg(grad_out.is_cuda).segment_coo(
2429
grad_out, index, grad_out.new_zeros(src_size), 'add')
2530

2631
return grad_src, None, None
@@ -34,15 +39,15 @@ def forward(ctx, src, indptr, out):
3439
ctx.src_size = list(src.size())
3540
ctx.save_for_backward(indptr)
3641

37-
return gather_cuda.gather_csr(src, indptr, out)
42+
return gat(src.is_cuda).gather_csr(src, indptr, out)
3843

3944
@staticmethod
4045
def backward(ctx, grad_out):
4146
(indptr, ), src_size = ctx.saved_tensors, ctx.src_size
4247

4348
grad_src = None
4449
if ctx.needs_input_grad[0]:
45-
grad_src, _ = segment_cuda.segment_csr(
50+
grad_src, _ = seg(grad_out.is_cuda).segment_csr(
4651
grad_out, indptr, grad_out.new_empty(src_size), 'add')
4752

4853
return grad_src, None, None

torch_scatter/segment.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import torch
22

3+
from torch_scatter import segment_cpu, gather_cpu
34
from torch_scatter.helpers import min_value, max_value
45

56
if torch.cuda.is_available():
67
from torch_scatter import segment_cuda, gather_cuda
78

9+
seg = lambda is_cuda: segment_cuda if is_cuda else segment_cpu # noqa
10+
gat = lambda is_cuda: gather_cuda if is_cuda else gather_cpu # noqa
11+
812

913
class SegmentCOO(torch.autograd.Function):
1014
@staticmethod
@@ -28,7 +32,7 @@ def forward(ctx, src, index, out, dim_size, reduce):
2832

2933
out = src.new_full(size, fill_value)
3034

31-
out, arg_out = segment_cuda.segment_coo(src, index, out, reduce)
35+
out, arg_out = seg(src.is_cuda).segment_coo(src, index, out, reduce)
3236

3337
if fill_value != 0:
3438
out.masked_fill_(out == fill_value, 0)
@@ -47,13 +51,13 @@ def backward(ctx, grad_out, *args):
4751
grad_src = None
4852
if ctx.needs_input_grad[0]:
4953
if ctx.reduce == 'add':
50-
grad_src = gather_cuda.gather_coo(grad_out, index,
51-
grad_out.new_empty(src_size))
54+
grad_src = gat(grad_out).gather_coo(
55+
grad_out, index, grad_out.new_empty(src_size))
5256
elif ctx.reduce == 'mean':
53-
grad_src = gather_cuda.gather_coo(grad_out, index,
54-
grad_out.new_empty(src_size))
57+
grad_src = gat(grad_out).gather_coo(
58+
grad_out, index, grad_out.new_empty(src_size))
5559
count = arg_out
56-
count = gather_cuda.gather_coo(
60+
count = gat(grad_out.is_cuda).gather_coo(
5761
count, index, count.new_empty(src_size[:index.dim()]))
5862
for _ in range(grad_out.dim() - index.dim()):
5963
count = count.unsqueeze(-1)
@@ -78,7 +82,7 @@ def forward(ctx, src, indptr, out, reduce):
7882
ctx.reduce = reduce
7983
ctx.src_size = list(src.size())
8084

81-
out, arg_out = segment_cuda.segment_csr(src, indptr, out, reduce)
85+
out, arg_out = seg(src.is_cuda).segment_csr(src, indptr, out, reduce)
8286
ctx.save_for_backward(indptr, arg_out)
8387
return out if arg_out is None else (out, arg_out)
8488

@@ -89,15 +93,15 @@ def backward(ctx, grad_out, *args):
8993
grad_src = None
9094
if ctx.needs_input_grad[0]:
9195
if ctx.reduce == 'add':
92-
grad_src = gather_cuda.gather_csr(grad_out, indptr,
93-
grad_out.new_empty(src_size))
96+
grad_src = gat(grad_out.is_cuda).gather_csr(
97+
grad_out, indptr, grad_out.new_empty(src_size))
9498
elif ctx.reduce == 'mean':
95-
grad_src = gather_cuda.gather_csr(grad_out, indptr,
96-
grad_out.new_empty(src_size))
99+
grad_src = gat(grad_out.is_cuda).gather_csr(
100+
grad_out, indptr, grad_out.new_empty(src_size))
97101
indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1)
98102
indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1)
99103
count = (indptr2 - indptr1).to(grad_src.dtype)
100-
count = gather_cuda.gather_csr(
104+
count = gat(grad_out.is_cuda).gather_csr(
101105
count, indptr, count.new_empty(src_size[:indptr.dim()]))
102106
for _ in range(grad_out.dim() - indptr.dim()):
103107
count = count.unsqueeze(-1)

0 commit comments

Comments
 (0)