11#include < torch/script.h>
22
33#include " cpu/segment_csr_cpu.h"
4+ #include " utils.h"
45
56#ifdef WITH_CUDA
67#include " cuda/segment_csr_cuda.h"
@@ -55,7 +56,7 @@ class SegmentSumCSR : public torch::autograd::Function<SegmentSumCSR> {
5556 auto grad_out = grad_outs[0 ];
5657 auto saved = ctx->get_saved_variables ();
5758 auto indptr = saved[0 ];
58- auto src_shape = ctx->saved_data [" src_shape" ].toIntVector ( );
59+ auto src_shape = list2vec ( ctx->saved_data [" src_shape" ].toIntList () );
5960 auto grad_in = torch::empty (src_shape, grad_out.options ());
6061 gather_csr_fw (grad_out, indptr, grad_in);
6162 return {grad_in, Variable (), Variable ()};
@@ -79,7 +80,7 @@ class SegmentMeanCSR : public torch::autograd::Function<SegmentMeanCSR> {
7980 auto grad_out = grad_outs[0 ];
8081 auto saved = ctx->get_saved_variables ();
8182 auto indptr = saved[0 ];
82- auto src_shape = ctx->saved_data [" src_shape" ].toIntVector ( );
83+ auto src_shape = list2vec ( ctx->saved_data [" src_shape" ].toIntList () );
8384 auto grad_in = torch::empty (src_shape, grad_out.options ());
8485 gather_csr_fw (grad_out, indptr, grad_in);
8586 auto indptr1 = indptr.narrow (-1 , 0 , indptr.size (-1 ) - 1 );
@@ -114,7 +115,7 @@ class SegmentMinCSR : public torch::autograd::Function<SegmentMinCSR> {
114115 auto saved = ctx->get_saved_variables ();
115116 auto indptr = saved[0 ];
116117 auto arg_out = saved[1 ];
117- auto src_shape = ctx->saved_data [" src_shape" ].toIntVector ( );
118+ auto src_shape = list2vec ( ctx->saved_data [" src_shape" ].toIntList () );
118119 src_shape[indptr.dim () - 1 ] += 1 ;
119120 auto grad_in = torch::zeros (src_shape, grad_out.options ());
120121 grad_in.scatter_ (indptr.dim () - 1 , arg_out, grad_out);
@@ -145,7 +146,7 @@ class SegmentMaxCSR : public torch::autograd::Function<SegmentMaxCSR> {
145146 auto saved = ctx->get_saved_variables ();
146147 auto indptr = saved[0 ];
147148 auto arg_out = saved[1 ];
148- auto src_shape = ctx->saved_data [" src_shape" ].toIntVector ( );
149+ auto src_shape = list2vec ( ctx->saved_data [" src_shape" ].toIntList () );
149150 src_shape[indptr.dim () - 1 ] += 1 ;
150151 auto grad_in = torch::zeros (src_shape, grad_out.options ());
151152 grad_in.scatter_ (indptr.dim () - 1 , arg_out, grad_out);
@@ -172,7 +173,7 @@ class GatherCSR : public torch::autograd::Function<GatherCSR> {
172173 auto grad_out = grad_outs[0 ];
173174 auto saved = ctx->get_saved_variables ();
174175 auto indptr = saved[0 ];
175- auto src_shape = ctx->saved_data [" src_shape" ].toIntVector ( );
176+ auto src_shape = list2vec ( ctx->saved_data [" src_shape" ].toIntList () );
176177
177178 auto grad_in = torch::empty (src_shape, grad_out.options ());
178179 segment_csr_fw (grad_out, indptr, grad_in, " sum" );
0 commit comments