Skip to content

Commit 78a5549

Browse files
committed
added cpu checks
1 parent cd114fd commit 78a5549

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

cpu/scatter.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@
22

33
#include "dim_apply.h"
44

5+
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
6+
57
void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
68
int64_t dim) {
9+
CHECK_CPU(src);
10+
CHECK_CPU(index);
11+
CHECK_CPU(out);
712
int64_t elems_per_row = index.size(dim), i, idx;
813
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul", [&] {
914
DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
@@ -17,6 +22,9 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
1722

1823
void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
1924
int64_t dim) {
25+
CHECK_CPU(src);
26+
CHECK_CPU(index);
27+
CHECK_CPU(out);
2028
int64_t elems_per_row = index.size(dim), i, idx;
2129
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div", [&] {
2230
DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
@@ -30,6 +38,9 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
3038

3139
void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
3240
at::Tensor arg, int64_t dim) {
41+
CHECK_CPU(src);
42+
CHECK_CPU(index);
43+
CHECK_CPU(out);
3344
int64_t elems_per_row = index.size(dim), i, idx;
3445
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max", [&] {
3546
DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim,
@@ -47,6 +58,10 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
4758

4859
void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
4960
at::Tensor arg, int64_t dim) {
61+
CHECK_CPU(src);
62+
CHECK_CPU(index);
63+
CHECK_CPU(out);
64+
CHECK_CPU(arg);
5065
int64_t elems_per_row = index.size(dim), i, idx;
5166
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min", [&] {
5267
DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim,

0 commit comments

Comments
 (0)