33#include " index_info.h"
44#include " reducer.h"
55#include " utils.h"
6+ #include < ATen/OpMathType.h>
67
78std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
89segment_coo_cpu (torch::Tensor src, torch::Tensor index,
@@ -70,11 +71,12 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
7071 auto stride = index_info.strides [index_info.dims - 1 ];
7172 std::vector<int64_t > args (K);
7273 AT_DISPATCH_ALL_TYPES_AND2 (at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type (), " segment_coo_cpu" , [&] {
74+ using opmath_t = at::opmath_type<scalar_t >;
7375 auto src_data = src.data_ptr <scalar_t >();
7476 auto out_data = out.data_ptr <scalar_t >();
7577 scalar_t *count_data = nullptr ;
7678
77- std::vector<scalar_t > vals (K);
79+ std::vector<opmath_t > vals (K);
7880 int64_t idx, next_idx, row_start;
7981 AT_DISPATCH_REDUCTION_TYPES (reduce, [&] {
8082 if (!optional_out.has_value ())
@@ -87,19 +89,19 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
8789 idx = index_info.data [offset];
8890
8991 for (auto k = 0 ; k < K; k++)
90- vals[k] = out_data[b * N * K + k];
92+ vals[k] = static_cast < opmath_t >( out_data[b * N * K + k]) ;
9193
9294 row_start = 0 ;
9395 for (auto e = 0 ; e < E; e++) {
9496
9597 for (auto k = 0 ; k < K; k++)
96- Reducer<scalar_t , REDUCE>::update (
97- &vals[k], src_data[b * E * K + e * K + k], &args[k], e);
98+ Reducer<opmath_t , REDUCE>::update (
99+ &vals[k], static_cast < opmath_t >( src_data[b * E * K + e * K + k]) , &args[k], e);
98100
99101 if (e == E - 1 ) {
100102 for (auto k = 0 ; k < K; k++)
101103 Reducer<scalar_t , REDUCE>::write (
102- out_data + b * N * K + idx * K + k, vals[k],
104+ out_data + b * N * K + idx * K + k, static_cast < scalar_t >( vals[k]) ,
103105 arg_out_data + b * N * K + idx * K + k, args[k],
104106 e + 1 - row_start);
105107 if (REDUCE == MEAN)
@@ -111,11 +113,11 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
111113 if (idx != next_idx) {
112114 for (auto k = 0 ; k < K; k++) {
113115 Reducer<scalar_t , REDUCE>::write (
114- out_data + b * N * K + idx * K + k, vals[k],
116+ out_data + b * N * K + idx * K + k, static_cast < scalar_t >( vals[k]) ,
115117 arg_out_data + b * N * K + idx * K + k, args[k],
116118 e + 1 - row_start);
117119
118- vals[k] = out_data[b * N * K + next_idx * K + k];
120+ vals[k] = static_cast < opmath_t >( out_data[b * N * K + next_idx * K + k]) ;
119121 }
120122 if (REDUCE == MEAN)
121123 count_data[b * N + idx] = (scalar_t )(e + 1 - row_start);
0 commit comments