@@ -97,7 +97,7 @@ class SPMMSum : public torch::autograd::Function<SPMMSum> {
9797 if (torch::autograd::any_variable_requires_grad ({mat})) {
9898 torch::optional<torch::Tensor> opt_value = torch::nullopt ;
9999 if (has_value)
100- opt_value = value.index_select (0 , csr2csc);
100+ opt_value = value.view ({- 1 , 1 }). index_select (0 , csr2csc). view (- 1 );
101101
102102 grad_mat = std::get<0 >(spmm_fw (colptr, row.index_select (0 , csr2csc),
103103 opt_value, grad_out, " sum" ));
@@ -161,11 +161,12 @@ class SPMMMean : public torch::autograd::Function<SPMMMean> {
161161 auto grad_mat = Variable ();
162162 if (torch::autograd::any_variable_requires_grad ({mat})) {
163163 row = row.index_select (0 , csr2csc);
164- rowcount = rowcount.toType (mat.scalar_type ()). index_select ( 0 , row );
164+ rowcount = rowcount.index_select ( 0 , row). toType (mat.scalar_type ());
165165 rowcount.masked_fill_ (rowcount < 1 , 1 );
166166
167167 if (has_value > 0 )
168- rowcount = value.index_select (0 , csr2csc).div (rowcount);
168+ rowcount =
169+ value.view ({-1 , 1 }).index_select (0 , csr2csc).view (-1 ).div (rowcount);
169170 else
170171 rowcount.pow_ (-1 );
171172
@@ -219,8 +220,10 @@ class SPMMMin : public torch::autograd::Function<SPMMMin> {
219220 auto grad_mat = Variable ();
220221 if (torch::autograd::any_variable_requires_grad ({mat})) {
221222 if (has_value > 0 ) {
222- value = value.index_select (0 , arg_out.flatten ()).view_as (arg_out);
223- value.mul_ (grad_out);
223+ value = value.view ({-1 , 1 })
224+ .index_select (0 , arg_out.flatten ())
225+ .view_as (arg_out)
226+ .mul_ (grad_out);
224227 } else
225228 value = grad_out;
226229
@@ -277,8 +280,10 @@ class SPMMMax : public torch::autograd::Function<SPMMMax> {
277280 auto grad_mat = Variable ();
278281 if (torch::autograd::any_variable_requires_grad ({mat})) {
279282 if (has_value > 0 ) {
280- value = value.index_select (0 , arg_out.flatten ()).view_as (arg_out);
281- value.mul_ (grad_out);
283+ value = value.view ({-1 , 1 })
284+ .index_select (0 , arg_out.flatten ())
285+ .view_as (arg_out)
286+ .mul_ (grad_out);
282287 } else
283288 value = grad_out;
284289
0 commit comments