11import torch
22
3+ from torch_scatter import segment_cpu , gather_cpu
34from torch_scatter .helpers import min_value , max_value
45
56if torch .cuda .is_available ():
67 from torch_scatter import segment_cuda , gather_cuda
78
9+ seg = lambda is_cuda : segment_cuda if is_cuda else segment_cpu # noqa
10+ gat = lambda is_cuda : gather_cuda if is_cuda else gather_cpu # noqa
11+
812
913class SegmentCOO (torch .autograd .Function ):
1014 @staticmethod
@@ -28,7 +32,7 @@ def forward(ctx, src, index, out, dim_size, reduce):
2832
2933 out = src .new_full (size , fill_value )
3034
31- out , arg_out = segment_cuda .segment_coo (src , index , out , reduce )
35+ out , arg_out = seg ( src . is_cuda ) .segment_coo (src , index , out , reduce )
3236
3337 if fill_value != 0 :
3438 out .masked_fill_ (out == fill_value , 0 )
@@ -47,13 +51,13 @@ def backward(ctx, grad_out, *args):
4751 grad_src = None
4852 if ctx .needs_input_grad [0 ]:
4953 if ctx .reduce == 'add' :
50- grad_src = gather_cuda .gather_coo (grad_out , index ,
51- grad_out .new_empty (src_size ))
54+ grad_src = gat ( grad_out ) .gather_coo (
55+ grad_out , index , grad_out .new_empty (src_size ))
5256 elif ctx .reduce == 'mean' :
53- grad_src = gather_cuda .gather_coo (grad_out , index ,
54- grad_out .new_empty (src_size ))
57+ grad_src = gat ( grad_out ) .gather_coo (
58+ grad_out , index , grad_out .new_empty (src_size ))
5559 count = arg_out
56- count = gather_cuda .gather_coo (
60+ count = gat ( grad_out . is_cuda ) .gather_coo (
5761 count , index , count .new_empty (src_size [:index .dim ()]))
5862 for _ in range (grad_out .dim () - index .dim ()):
5963 count = count .unsqueeze (- 1 )
@@ -78,7 +82,7 @@ def forward(ctx, src, indptr, out, reduce):
7882 ctx .reduce = reduce
7983 ctx .src_size = list (src .size ())
8084
81- out , arg_out = segment_cuda .segment_csr (src , indptr , out , reduce )
85+ out , arg_out = seg ( src . is_cuda ) .segment_csr (src , indptr , out , reduce )
8286 ctx .save_for_backward (indptr , arg_out )
8387 return out if arg_out is None else (out , arg_out )
8488
@@ -89,15 +93,15 @@ def backward(ctx, grad_out, *args):
8993 grad_src = None
9094 if ctx .needs_input_grad [0 ]:
9195 if ctx .reduce == 'add' :
92- grad_src = gather_cuda . gather_csr (grad_out , indptr ,
93- grad_out .new_empty (src_size ))
96+ grad_src = gat ( grad_out . is_cuda ). gather_csr (
97+ grad_out , indptr , grad_out .new_empty (src_size ))
9498 elif ctx .reduce == 'mean' :
95- grad_src = gather_cuda . gather_csr (grad_out , indptr ,
96- grad_out .new_empty (src_size ))
99+ grad_src = gat ( grad_out . is_cuda ). gather_csr (
100+ grad_out , indptr , grad_out .new_empty (src_size ))
97101 indptr1 = indptr .narrow (- 1 , 0 , indptr .size (- 1 ) - 1 )
98102 indptr2 = indptr .narrow (- 1 , 1 , indptr .size (- 1 ) - 1 )
99103 count = (indptr2 - indptr1 ).to (grad_src .dtype )
100- count = gather_cuda .gather_csr (
104+ count = gat ( grad_out . is_cuda ) .gather_csr (
101105 count , indptr , count .new_empty (src_size [:indptr .dim ()]))
102106 for _ in range (grad_out .dim () - indptr .dim ()):
103107 count = count .unsqueeze (- 1 )
0 commit comments