Skip to content

Commit e419e15

Browse files
committed
no half tensor type
1 parent a8bbb42 commit e419e15

File tree

3 files changed

+14
-28
lines changed

3 files changed

+14
-28
lines changed

torch_scatter/kernel/THCAtomics.cuh

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,6 @@ static inline __device__ void atomicAdd( double *address, double val) { AtomicA
109109
#elif !defined(__CUDA_ARCH__) && (CUDA_VERSION < 8000)
110110
static inline __device__ void atomicAdd( double *address, double val) {}
111111
#endif
112-
#ifdef CUDA_HALF_TENSOR
113-
static inline __device__ void atomicAdd( half *address, half val) {}
114-
#endif
115112

116113
#define OP(X, Y) Y * X
117114
ATOMIC_(Mul)
@@ -123,9 +120,6 @@ static inline __device__ void atomicMul(int32_t *address, int32_t val) { AtomicM
123120
static inline __device__ void atomicMul(int64_t *address, int64_t val) { AtomicMulIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
124121
static inline __device__ void atomicMul( float *address, float val) { AtomicMulDecimalImpl< float, sizeof( float)>()(address, val); }
125122
static inline __device__ void atomicMul( double *address, double val) { AtomicMulDecimalImpl< double, sizeof( double)>()(address, val); }
126-
#ifdef CUDA_HALF_TENSOR
127-
static inline __device__ void atomicMul( half *address, half val) {}
128-
#endif
129123

130124
#define OP(X, Y) Y / X
131125
ATOMIC_(Div)
@@ -137,9 +131,6 @@ static inline __device__ void atomicDiv(int32_t *address, int32_t val) { AtomicD
137131
static inline __device__ void atomicDiv(int64_t *address, int64_t val) { AtomicDivIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
138132
static inline __device__ void atomicDiv( float *address, float val) { AtomicDivDecimalImpl< float, sizeof( float)>()(address, val); }
139133
static inline __device__ void atomicDiv( double *address, double val) { AtomicDivDecimalImpl< double, sizeof( double)>()(address, val); }
140-
#ifdef CUDA_HALF_TENSOR
141-
static inline __device__ void atomicDiv( half *address, half val) {}
142-
#endif
143134

144135
#define OP(X, Y) max(Y, X)
145136
ATOMIC_(Max)
@@ -150,9 +141,6 @@ static inline __device__ void atomicMax(int16_t *address, int16_t val) { AtomicM
150141
static inline __device__ void atomicMax(int64_t *address, int64_t val) { AtomicMaxIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
151142
static inline __device__ void atomicMax( float *address, float val) { AtomicMaxDecimalImpl< float, sizeof( float)>()(address, val); }
152143
static inline __device__ void atomicMax( double *address, double val) { AtomicMaxDecimalImpl< double, sizeof( double)>()(address, val); }
153-
#ifdef CUDA_HALF_TENSOR
154-
static inline __device__ void atomicMax( half *address, half val) {}
155-
#endif
156144

157145
#define OP(X, Y) min(Y, X)
158146
ATOMIC_(Min)
@@ -163,6 +151,3 @@ static inline __device__ void atomicMin(int16_t *address, int16_t val) { AtomicM
163151
static inline __device__ void atomicMin(int64_t *address, int64_t val) { AtomicMinIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
164152
static inline __device__ void atomicMin( float *address, float val) { AtomicMinDecimalImpl< float, sizeof( float)>()(address, val); }
165153
static inline __device__ void atomicMin( double *address, double val) { AtomicMinDecimalImpl< double, sizeof( double)>()(address, val); }
166-
#ifdef CUDA_HALF_TENSOR
167-
static inline __device__ void atomicMin( half *address, half val) {}
168-
#endif

torch_scatter/kernel/common.cuh

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,3 @@ struct TensorInfo {
3535
} \
3636
THCudaCheck(cudaGetLastError()); \
3737
}
38-
39-
static inline __device__ bool eq(uint8_t a, uint8_t b) { return a == b; }
40-
static inline __device__ bool eq( int8_t a, int8_t b) { return a == b; }
41-
static inline __device__ bool eq(int16_t a, int16_t b) { return a == b; }
42-
static inline __device__ bool eq(int32_t a, int32_t b) { return a == b; }
43-
static inline __device__ bool eq(int64_t a, int64_t b) { return a == b; }
44-
static inline __device__ bool eq( float a, float b) { return a == b; }
45-
static inline __device__ bool eq( double a, double b) { return a == b; }
46-
#ifdef CUDA_HALF_TENSOR
47-
static inline __device__ bool eq(half a, half b) { return __half2float(a) == __half2float(b); }
48-
#endif

torch_scatter/kernel/kernel.cu

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ __global__ void argKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, Te
6464
KERNEL_LOOP(i, n) {
6565
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0; int argOffset = 0;
6666
IndexToScatterOffsets4<Real, Real, int64_t, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset, arg, &argOffset);
67-
if (eq(input.data[inputOffset], output.data[outputOffset])) arg.data[argOffset] = inputOffset % input.size[dim];
67+
if (input.data[inputOffset] == output.data[outputOffset]) arg.data[argOffset] = inputOffset % input.size[dim];
6868
}
6969
}
7070

@@ -78,4 +78,16 @@ __global__ void indexBackwardKernel(TensorInfo<Real> output, TensorInfo<int64_t>
7878
}
7979

8080
#include "generic/kernel.cu"
81-
#include "THCGenerateAllTypes.h"
81+
#include "THCGenerateFloatType.h"
82+
#include "generic/kernel.cu"
83+
#include "THCGenerateDoubleType.h"
84+
#include "generic/kernel.cu"
85+
#include "THCGenerateByteType.h"
86+
#include "generic/kernel.cu"
87+
#include "THCGenerateCharType.h"
88+
#include "generic/kernel.cu"
89+
#include "THCGenerateShortType.h"
90+
#include "generic/kernel.cu"
91+
#include "THCGenerateIntType.h"
92+
#include "generic/kernel.cu"
93+
#include "THCGenerateLongType.h"

0 commit comments

Comments
 (0)