55void scatter_mul (at::Tensor src, at::Tensor index, at::Tensor out,
66 int64_t dim) {
77 int64_t elems_per_row = index.size (dim), i, idx;
8- AT_DISPATCH_ALL_TYPES (src.type (), " scatter_mul" , [&] {
8+ AT_DISPATCH_ALL_TYPES (src.scalar_type (), " scatter_mul" , [&] {
99 DIM_APPLY3 (scalar_t , src, int64_t , index, scalar_t , out, dim, {
1010 for (i = 0 ; i < elems_per_row; i++) {
1111 idx = index_data[i * index_stride];
@@ -18,7 +18,7 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
1818void scatter_div (at::Tensor src, at::Tensor index, at::Tensor out,
1919 int64_t dim) {
2020 int64_t elems_per_row = index.size (dim), i, idx;
21- AT_DISPATCH_ALL_TYPES (src.type (), " scatter_div" , [&] {
21+ AT_DISPATCH_ALL_TYPES (src.scalar_type (), " scatter_div" , [&] {
2222 DIM_APPLY3 (scalar_t , src, int64_t , index, scalar_t , out, dim, {
2323 for (i = 0 ; i < elems_per_row; i++) {
2424 idx = index_data[i * index_stride];
@@ -31,7 +31,7 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
3131void scatter_max (at::Tensor src, at::Tensor index, at::Tensor out,
3232 at::Tensor arg, int64_t dim) {
3333 int64_t elems_per_row = index.size (dim), i, idx;
34- AT_DISPATCH_ALL_TYPES (src.type (), " scatter_max" , [&] {
34+ AT_DISPATCH_ALL_TYPES (src.scalar_type (), " scatter_max" , [&] {
3535 DIM_APPLY4 (scalar_t , src, int64_t , index, scalar_t , out, int64_t , arg, dim,
3636 {
3737 for (i = 0 ; i < elems_per_row; i++) {
@@ -48,7 +48,7 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
4848void scatter_min (at::Tensor src, at::Tensor index, at::Tensor out,
4949 at::Tensor arg, int64_t dim) {
5050 int64_t elems_per_row = index.size (dim), i, idx;
51- AT_DISPATCH_ALL_TYPES (src.type (), " scatter_min" , [&] {
51+ AT_DISPATCH_ALL_TYPES (src.scalar_type (), " scatter_min" , [&] {
5252 DIM_APPLY4 (scalar_t , src, int64_t , index, scalar_t , out, int64_t , arg, dim,
5353 {
5454 for (i = 0 ; i < elems_per_row; i++) {
@@ -65,7 +65,7 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
6565void index_backward (at::Tensor grad, at::Tensor index, at::Tensor arg,
6666 at::Tensor out, int64_t dim) {
6767 int64_t elems_per_row = index.size (dim), i, idx;
68- AT_DISPATCH_ALL_TYPES (grad.type (), " index_backward" , [&] {
68+ AT_DISPATCH_ALL_TYPES (grad.scalar_type (), " index_backward" , [&] {
6969 DIM_APPLY4 (scalar_t , grad, int64_t , index, int64_t , arg, scalar_t , out, dim,
7070 {
7171 for (i = 0 ; i < elems_per_row; i++) {
0 commit comments