Skip to content

Commit 1f49a3a

Browse files
cmputerusty1s
andauthored
Add support for half (#137)
* compile with half * Fix * fix * rename * update * update * update * update * typo Co-authored-by: rusty1s <[email protected]>
1 parent eda8dc6 commit 1f49a3a

File tree

12 files changed

+77
-28
lines changed

12 files changed

+77
-28
lines changed

csrc/cpu/reducer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
7272
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
7373
*address = val;
7474
else if (REDUCE == MEAN)
75-
*address = val / (count > 0 ? count : (scalar_t)1);
75+
*address = val / (scalar_t)(count > 0 ? count : 1);
7676
else if (REDUCE == MIN || REDUCE == MAX) {
7777
if (count > 0) {
7878
*address = val;

csrc/cpu/scatter_cpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
5757
auto N = out.size(dim);
5858

5959
auto index_info = getTensorInfo<int64_t>(index);
60-
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] {
60+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
6161
auto src_data = src.data_ptr<scalar_t>();
6262
auto out_data = out.data_ptr<scalar_t>();
6363

csrc/cpu/segment_coo_cpu.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
6969
auto index_info = getTensorInfo<int64_t>(index);
7070
auto stride = index_info.strides[index_info.dims - 1];
7171
std::vector<int64_t> args(K);
72-
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] {
72+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
7373
auto src_data = src.data_ptr<scalar_t>();
7474
auto out_data = out.data_ptr<scalar_t>();
7575
scalar_t *count_data = nullptr;
@@ -130,7 +130,8 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
130130
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
131131

132132
if (REDUCE == MEAN)
133-
arg_out.value().clamp_(1);
133+
arg_out.value().masked_fill_(arg_out.value() < (scalar_t)1,
134+
(scalar_t)1);
134135
});
135136
});
136137

@@ -177,7 +178,7 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
177178

178179
auto index_info = getTensorInfo<int64_t>(index);
179180
auto stride = index_info.strides[index_info.dims - 1];
180-
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo", [&] {
181+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
181182
auto src_data = src.data_ptr<scalar_t>();
182183
auto out_data = out.data_ptr<scalar_t>();
183184

csrc/cpu/segment_csr_cpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
5757
auto indptr_info = getTensorInfo<int64_t>(indptr);
5858
auto stride = indptr_info.strides[indptr_info.dims - 1];
5959
std::vector<int64_t> args(K);
60-
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr", [&] {
60+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
6161
auto src_data = src.data_ptr<scalar_t>();
6262
auto out_data = out.data_ptr<scalar_t>();
6363

@@ -135,7 +135,7 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
135135

136136
auto indptr_info = getTensorInfo<int64_t>(indptr);
137137
auto stride = indptr_info.strides[indptr_info.dims - 1];
138-
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr", [&] {
138+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
139139
auto src_data = src.data_ptr<scalar_t>();
140140
auto out_data = out.data_ptr<scalar_t>();
141141

csrc/cuda/atomics.cuh

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,25 @@
6868
\
6969
template <typename scalar, size_t size> struct Atomic##NAME##DecimalImpl; \
7070
\
71+
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 2> { \
72+
inline __device__ void operator()(scalar *address, scalar val) { \
73+
unsigned int *address_as_ui = \
74+
(unsigned int *)((char *)address - ((size_t)address & 2)); \
75+
unsigned int old = *address_as_ui; \
76+
unsigned int assumed; \
77+
\
78+
do { \
79+
assumed = old; \
80+
at::Half hsum; \
81+
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); \
82+
hsum = OP(hsum, val); \
83+
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) \
84+
: (old & 0xffff0000) | hsum.x; \
85+
old = atomicCAS(address_as_ui, assumed, old); \
86+
} while (assumed != old); \
87+
} \
88+
}; \
89+
\
7190
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 4> { \
7291
inline __device__ void operator()(scalar *address, scalar val) { \
7392
int *address_as_i = (int *)address; \
@@ -116,6 +135,15 @@ static inline __device__ void atomAdd(int32_t *address, int32_t val) {
116135
static inline __device__ void atomAdd(int64_t *address, int64_t val) {
117136
AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
118137
}
138+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700 || CUDA_VERSION < 10000)
139+
static inline __device__ void atomAdd(at::Half *address, at::Half val) {
140+
AtomicAddDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
141+
}
142+
#else
143+
static inline __device__ void atomAdd(at::Half *address, at::Half val) {
144+
atomicAdd(reinterpret_cast<__half *>(address), val);
145+
}
146+
#endif
119147
static inline __device__ void atomAdd(float *address, float val) {
120148
atomicAdd(address, val);
121149
}
@@ -150,6 +178,9 @@ static inline __device__ void atomMul(int64_t *address, int64_t val) {
150178
static inline __device__ void atomMul(float *address, float val) {
151179
AtomicMulDecimalImpl<float, sizeof(float)>()(address, val);
152180
}
181+
static inline __device__ void atomMul(at::Half *address, at::Half val) {
182+
AtomicMulDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
183+
}
153184
static inline __device__ void atomMul(double *address, double val) {
154185
AtomicMulDecimalImpl<double, sizeof(double)>()(address, val);
155186
}
@@ -172,6 +203,9 @@ static inline __device__ void atomDiv(int32_t *address, int32_t val) {
172203
static inline __device__ void atomDiv(int64_t *address, int64_t val) {
173204
AtomicDivIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
174205
}
206+
static inline __device__ void atomDiv(at::Half *address, at::Half val) {
207+
AtomicDivDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
208+
}
175209
static inline __device__ void atomDiv(float *address, float val) {
176210
AtomicDivDecimalImpl<float, sizeof(float)>()(address, val);
177211
}
@@ -197,6 +231,9 @@ static inline __device__ void atomMax(int32_t *address, int32_t val) {
197231
static inline __device__ void atomMax(int64_t *address, int64_t val) {
198232
AtomicMaxIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
199233
}
234+
static inline __device__ void atomMax(at::Half *address, at::Half val) {
235+
AtomicMaxDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
236+
}
200237
static inline __device__ void atomMax(float *address, float val) {
201238
AtomicMaxDecimalImpl<float, sizeof(float)>()(address, val);
202239
}
@@ -222,6 +259,9 @@ static inline __device__ void atomMin(int32_t *address, int32_t val) {
222259
static inline __device__ void atomMin(int64_t *address, int64_t val) {
223260
AtomicMinIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
224261
}
262+
static inline __device__ void atomMin(at::Half *address, at::Half val) {
263+
AtomicMinDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
264+
}
225265
static inline __device__ void atomMin(float *address, float val) {
226266
AtomicMinDecimalImpl<float, sizeof(float)>()(address, val);
227267
}

csrc/cuda/reducer.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
8989
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
9090
*address = val;
9191
else if (REDUCE == MEAN)
92-
*address = val / (count > 0 ? count : (scalar_t)1);
92+
*address = val / (scalar_t)(count > 0 ? count : 1);
9393
else if (REDUCE == MIN || REDUCE == MAX) {
9494
if (count > 0) {
9595
*address = val;

csrc/cuda/scatter_cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
111111

112112
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
113113
auto stream = at::cuda::getCurrentCUDAStream();
114-
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] {
114+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
115115
auto src_data = src.data_ptr<scalar_t>();
116116
auto out_data = out.data_ptr<scalar_t>();
117117

csrc/cuda/segment_coo_cuda.cu

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <ATen/cuda/CUDAContext.h>
44
#include <ATen/cuda/detail/IndexUtils.cuh>
55
#include <ATen/cuda/detail/TensorInfo.cuh>
6+
#include <type_traits>
67

78
#include "reducer.cuh"
89
#include "utils.cuh"
@@ -25,6 +26,10 @@ segment_coo_kernel(const scalar_t *src_data,
2526
int lane_idx = row_idx & (32 - 1);
2627
int D = index_info.sizes[index_info.dims - 1];
2728

29+
using cuda_scalar_t =
30+
typename std::conditional<std::is_same<scalar_t, at::Half>::value, __half,
31+
scalar_t>::type;
32+
2833
if (row_idx < E) {
2934
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
3035
row_idx, index_info);
@@ -36,7 +41,7 @@ segment_coo_kernel(const scalar_t *src_data,
3641
#pragma unroll
3742
for (int i = 1; i < 32; i *= 2) {
3843
// Parallel reduction inside a single warp.
39-
tmp = __shfl_up_sync(FULL_MASK, val, i);
44+
tmp = __shfl_up_sync(FULL_MASK, (cuda_scalar_t)val, i);
4045
next_idx = __shfl_up_sync(FULL_MASK, idx, i);
4146
if (lane_idx >= i && row_idx / D == (row_idx - i) / D) {
4247
assert(idx >= next_idx);
@@ -214,7 +219,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
214219

215220
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
216221
auto stream = at::cuda::getCurrentCUDAStream();
217-
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] {
222+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
218223
auto src_data = src.data_ptr<scalar_t>();
219224
auto out_data = out.data_ptr<scalar_t>();
220225

@@ -266,14 +271,15 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
266271
segment_coo_kernel<scalar_t, SUM, false>
267272
<<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
268273
count_data, E, N);
269-
arg_out.value().clamp_(1);
274+
arg_out.value().masked_fill_(arg_out.value() < (scalar_t)1,
275+
(scalar_t)1);
270276
auto count = arg_out.value();
271277
for (int i = dim + 1; i < out.dim(); i++)
272278
count = count.unsqueeze(-1);
273279
if (out.is_floating_point())
274-
out.true_divide_(count);
280+
out.div_(count);
275281
else
276-
out.floor_divide_(count);
282+
out.div_(count, "floor");
277283
}
278284
});
279285
});
@@ -364,7 +370,7 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
364370
365371
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
366372
auto stream = at::cuda::getCurrentCUDAStream();
367-
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo_kernel", [&] {
373+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
368374
auto src_data = src.data_ptr<scalar_t>();
369375
auto out_data = out.data_ptr<scalar_t>();
370376

csrc/cuda/segment_csr_cuda.cu

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ segment_csr_kernel(const scalar_t *src_data,
2626
int row_idx = thread_idx / TB;
2727
int lane_idx = thread_idx & (TB - 1);
2828

29+
using cuda_scalar_t =
30+
typename std::conditional<std::is_same<scalar_t, at::Half>::value, __half,
31+
scalar_t>::type;
32+
2933
if (row_idx < N) {
3034
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
3135
int64_t row_start = __ldg(indptr_info.data + offset);
@@ -48,7 +52,8 @@ segment_csr_kernel(const scalar_t *src_data,
4852
if (REDUCE == MIN || REDUCE == MAX)
4953
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
5054
Reducer<scalar_t, REDUCE>::update(
51-
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
55+
&val, __shfl_down_sync(FULL_MASK, (cuda_scalar_t)val, i), &arg,
56+
arg_tmp);
5257
}
5358

5459
if (lane_idx == 0) {
@@ -147,7 +152,7 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
147152

148153
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
149154
auto stream = at::cuda::getCurrentCUDAStream();
150-
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr_kernel", [&] {
155+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
151156
auto src_data = src.data_ptr<scalar_t>();
152157
auto out_data = out.data_ptr<scalar_t>();
153158

@@ -264,7 +269,7 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
264269

265270
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
266271
auto stream = at::cuda::getCurrentCUDAStream();
267-
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr_kernel", [&] {
272+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
268273
auto src_data = src.data_ptr<scalar_t>();
269274
auto out_data = out.data_ptr<scalar_t>();
270275

csrc/scatter.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,12 @@ class ScatterMean : public torch::autograd::Function<ScatterMean> {
127127
old_index.dim() <= dim ? old_index.dim() - 1 : dim,
128128
torch::nullopt, out.size(dim), "sum");
129129
auto count = std::get<0>(result);
130-
count.clamp_(1);
130+
count.masked_fill_(count < 1, 1);
131131
count = broadcast(count, out, dim);
132-
133132
if (out.is_floating_point())
134-
out.true_divide_(count);
133+
out.div_(count);
135134
else
136-
out.floor_divide_(count);
135+
out.div_(count, "floor");
137136

138137
ctx->save_for_backward({index, count});
139138
if (optional_out.has_value())

0 commit comments

Comments
 (0)