22
33#include " dim_apply.h"
44
5+ #define CHECK_CPU (x ) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor" )
6+
57void scatter_mul (at::Tensor src, at::Tensor index, at::Tensor out,
68 int64_t dim) {
9+ CHECK_CPU (src);
10+ CHECK_CPU (index);
11+ CHECK_CPU (out);
712 int64_t elems_per_row = index.size (dim), i, idx;
813 AT_DISPATCH_ALL_TYPES (src.scalar_type (), " scatter_mul" , [&] {
914 DIM_APPLY3 (scalar_t , src, int64_t , index, scalar_t , out, dim, {
@@ -17,6 +22,9 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
1722
1823void scatter_div (at::Tensor src, at::Tensor index, at::Tensor out,
1924 int64_t dim) {
25+ CHECK_CPU (src);
26+ CHECK_CPU (index);
27+ CHECK_CPU (out);
2028 int64_t elems_per_row = index.size (dim), i, idx;
2129 AT_DISPATCH_ALL_TYPES (src.scalar_type (), " scatter_div" , [&] {
2230 DIM_APPLY3 (scalar_t , src, int64_t , index, scalar_t , out, dim, {
@@ -30,6 +38,9 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
3038
3139void scatter_max (at::Tensor src, at::Tensor index, at::Tensor out,
3240 at::Tensor arg, int64_t dim) {
41+ CHECK_CPU (src);
42+ CHECK_CPU (index);
43+ CHECK_CPU (out);
3344 int64_t elems_per_row = index.size (dim), i, idx;
3445 AT_DISPATCH_ALL_TYPES (src.scalar_type (), " scatter_max" , [&] {
3546 DIM_APPLY4 (scalar_t , src, int64_t , index, scalar_t , out, int64_t , arg, dim,
@@ -47,6 +58,10 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
4758
4859void scatter_min (at::Tensor src, at::Tensor index, at::Tensor out,
4960 at::Tensor arg, int64_t dim) {
61+ CHECK_CPU (src);
62+ CHECK_CPU (index);
63+ CHECK_CPU (out);
64+ CHECK_CPU (arg);
5065 int64_t elems_per_row = index.size (dim), i, idx;
5166 AT_DISPATCH_ALL_TYPES (src.scalar_type (), " scatter_min" , [&] {
5267 DIM_APPLY4 (scalar_t , src, int64_t , index, scalar_t , out, int64_t , arg, dim,
0 commit comments