Skip to content

Commit 4401a4f

Browse files
committed
fixed cuda 8 bug
1 parent e419e15 commit 4401a4f

File tree

2 files changed

+43
-39
lines changed

2 files changed

+43
-39
lines changed

torch_scatter/kernel/THCAtomics.cuh

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -100,54 +100,58 @@ struct TH_CONCAT_3(Atomic, NAME, DecimalImpl)<T, 8> { \
100100
#define OP(X, Y) Y + X
101101
ATOMIC_(Add)
102102
#undef OP
103-
static inline __device__ void atomicAdd(uint8_t *address, uint8_t val) { AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val); }
104-
static inline __device__ void atomicAdd( int8_t *address, int8_t val) { AtomicAddIntegerImpl< int8_t, sizeof( int8_t)>()(address, val); }
105-
static inline __device__ void atomicAdd(int16_t *address, int16_t val) { AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address, val); }
106-
static inline __device__ void atomicAdd(int64_t *address, int64_t val) { AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
103+
static inline __device__ void atomAdd(uint8_t *address, uint8_t val) { AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val); }
104+
static inline __device__ void atomAdd( int8_t *address, int8_t val) { AtomicAddIntegerImpl< int8_t, sizeof( int8_t)>()(address, val); }
105+
static inline __device__ void atomAdd(int16_t *address, int16_t val) { AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address, val); }
106+
static inline __device__ void atomAdd(int32_t *address, int32_t val) { atomicAdd(address, val); }
107+
static inline __device__ void atomAdd(int64_t *address, int64_t val) { AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
108+
static inline __device__ void atomAdd( float *address, float val) { atomicAdd(address, val); }
107109
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
108-
static inline __device__ void atomicAdd( double *address, double val) { AtomicAddDecimalImpl< double, sizeof( double)>()(address, val); }
109-
#elif !defined(__CUDA_ARCH__) && (CUDA_VERSION < 8000)
110-
static inline __device__ void atomicAdd( double *address, double val) {}
110+
static inline __device__ void atomAdd( double *address, double val) { AtomicAddDecimalImpl< double, sizeof( double)>()(address, val); }
111+
#else
112+
static inline __device__ void atomAdd( double *address, double val) { atomicAdd(address, val); }
111113
#endif
112114

113115
#define OP(X, Y) Y * X
114116
ATOMIC_(Mul)
115117
#undef OP
116-
static inline __device__ void atomicMul(uint8_t *address, uint8_t val) { AtomicMulIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val); }
117-
static inline __device__ void atomicMul( int8_t *address, int8_t val) { AtomicMulIntegerImpl< int8_t, sizeof( int8_t)>()(address, val); }
118-
static inline __device__ void atomicMul(int16_t *address, int16_t val) { AtomicMulIntegerImpl<int16_t, sizeof(int16_t)>()(address, val); }
119-
static inline __device__ void atomicMul(int32_t *address, int32_t val) { AtomicMulIntegerImpl<int32_t, sizeof(int32_t)>()(address, val); }
120-
static inline __device__ void atomicMul(int64_t *address, int64_t val) { AtomicMulIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
121-
static inline __device__ void atomicMul( float *address, float val) { AtomicMulDecimalImpl< float, sizeof( float)>()(address, val); }
122-
static inline __device__ void atomicMul( double *address, double val) { AtomicMulDecimalImpl< double, sizeof( double)>()(address, val); }
118+
static inline __device__ void atomMul(uint8_t *address, uint8_t val) { AtomicMulIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val); }
119+
static inline __device__ void atomMul( int8_t *address, int8_t val) { AtomicMulIntegerImpl< int8_t, sizeof( int8_t)>()(address, val); }
120+
static inline __device__ void atomMul(int16_t *address, int16_t val) { AtomicMulIntegerImpl<int16_t, sizeof(int16_t)>()(address, val); }
121+
static inline __device__ void atomMul(int32_t *address, int32_t val) { AtomicMulIntegerImpl<int32_t, sizeof(int32_t)>()(address, val); }
122+
static inline __device__ void atomMul(int64_t *address, int64_t val) { AtomicMulIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
123+
static inline __device__ void atomMul( float *address, float val) { AtomicMulDecimalImpl< float, sizeof( float)>()(address, val); }
124+
static inline __device__ void atomMul( double *address, double val) { AtomicMulDecimalImpl< double, sizeof( double)>()(address, val); }
123125

124126
#define OP(X, Y) Y / X
125127
ATOMIC_(Div)
126128
#undef OP
127-
static inline __device__ void atomicDiv(uint8_t *address, uint8_t val) { AtomicDivIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val); }
128-
static inline __device__ void atomicDiv( int8_t *address, int8_t val) { AtomicDivIntegerImpl< int8_t, sizeof( int8_t)>()(address, val); }
129-
static inline __device__ void atomicDiv(int16_t *address, int16_t val) { AtomicDivIntegerImpl<int16_t, sizeof(int16_t)>()(address, val); }
130-
static inline __device__ void atomicDiv(int32_t *address, int32_t val) { AtomicDivIntegerImpl<int32_t, sizeof(int32_t)>()(address, val); }
131-
static inline __device__ void atomicDiv(int64_t *address, int64_t val) { AtomicDivIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
132-
static inline __device__ void atomicDiv( float *address, float val) { AtomicDivDecimalImpl< float, sizeof( float)>()(address, val); }
133-
static inline __device__ void atomicDiv( double *address, double val) { AtomicDivDecimalImpl< double, sizeof( double)>()(address, val); }
129+
static inline __device__ void atomDiv(uint8_t *address, uint8_t val) { AtomicDivIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val); }
130+
static inline __device__ void atomDiv( int8_t *address, int8_t val) { AtomicDivIntegerImpl< int8_t, sizeof( int8_t)>()(address, val); }
131+
static inline __device__ void atomDiv(int16_t *address, int16_t val) { AtomicDivIntegerImpl<int16_t, sizeof(int16_t)>()(address, val); }
132+
static inline __device__ void atomDiv(int32_t *address, int32_t val) { AtomicDivIntegerImpl<int32_t, sizeof(int32_t)>()(address, val); }
133+
static inline __device__ void atomDiv(int64_t *address, int64_t val) { AtomicDivIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
134+
static inline __device__ void atomDiv( float *address, float val) { AtomicDivDecimalImpl< float, sizeof( float)>()(address, val); }
135+
static inline __device__ void atomDiv( double *address, double val) { AtomicDivDecimalImpl< double, sizeof( double)>()(address, val); }
134136

135137
#define OP(X, Y) max(Y, X)
136138
ATOMIC_(Max)
137139
#undef OP
138-
static inline __device__ void atomicMax(uint8_t *address, uint8_t val) { AtomicMaxIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val); }
139-
static inline __device__ void atomicMax( int8_t *address, int8_t val) { AtomicMaxIntegerImpl< int8_t, sizeof( int8_t)>()(address, val); }
140-
static inline __device__ void atomicMax(int16_t *address, int16_t val) { AtomicMaxIntegerImpl<int16_t, sizeof(int16_t)>()(address, val); }
141-
static inline __device__ void atomicMax(int64_t *address, int64_t val) { AtomicMaxIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
142-
static inline __device__ void atomicMax( float *address, float val) { AtomicMaxDecimalImpl< float, sizeof( float)>()(address, val); }
143-
static inline __device__ void atomicMax( double *address, double val) { AtomicMaxDecimalImpl< double, sizeof( double)>()(address, val); }
140+
static inline __device__ void atomMax(uint8_t *address, uint8_t val) { AtomicMaxIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val); }
141+
static inline __device__ void atomMax( int8_t *address, int8_t val) { AtomicMaxIntegerImpl< int8_t, sizeof( int8_t)>()(address, val); }
142+
static inline __device__ void atomMax(int16_t *address, int16_t val) { AtomicMaxIntegerImpl<int16_t, sizeof(int16_t)>()(address, val); }
143+
static inline __device__ void atomMax(int32_t *address, int32_t val) { atomicMax(address, val); }
144+
static inline __device__ void atomMax(int64_t *address, int64_t val) { AtomicMaxIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
145+
static inline __device__ void atomMax( float *address, float val) { AtomicMaxDecimalImpl< float, sizeof( float)>()(address, val); }
146+
static inline __device__ void atomMax( double *address, double val) { AtomicMaxDecimalImpl< double, sizeof( double)>()(address, val); }
144147

145148
#define OP(X, Y) min(Y, X)
146149
ATOMIC_(Min)
147150
#undef OP
148-
static inline __device__ void atomicMin(uint8_t *address, uint8_t val) { AtomicMinIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val); }
149-
static inline __device__ void atomicMin( int8_t *address, int8_t val) { AtomicMinIntegerImpl< int8_t, sizeof( int8_t)>()(address, val); }
150-
static inline __device__ void atomicMin(int16_t *address, int16_t val) { AtomicMinIntegerImpl<int16_t, sizeof(int16_t)>()(address, val); }
151-
static inline __device__ void atomicMin(int64_t *address, int64_t val) { AtomicMinIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
152-
static inline __device__ void atomicMin( float *address, float val) { AtomicMinDecimalImpl< float, sizeof( float)>()(address, val); }
153-
static inline __device__ void atomicMin( double *address, double val) { AtomicMinDecimalImpl< double, sizeof( double)>()(address, val); }
151+
static inline __device__ void atomMin(uint8_t *address, uint8_t val) { AtomicMinIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val); }
152+
static inline __device__ void atomMin( int8_t *address, int8_t val) { AtomicMinIntegerImpl< int8_t, sizeof( int8_t)>()(address, val); }
153+
static inline __device__ void atomMin(int16_t *address, int16_t val) { AtomicMinIntegerImpl<int16_t, sizeof(int16_t)>()(address, val); }
154+
static inline __device__ void atomMin(int32_t *address, int32_t val) { atomicMin(address, val); }
155+
static inline __device__ void atomMin(int64_t *address, int64_t val) { AtomicMinIntegerImpl<int64_t, sizeof(int64_t)>()(address, val); }
156+
static inline __device__ void atomMin( float *address, float val) { AtomicMinDecimalImpl< float, sizeof( float)>()(address, val); }
157+
static inline __device__ void atomMin( double *address, double val) { AtomicMinDecimalImpl< double, sizeof( double)>()(address, val); }

torch_scatter/kernel/kernel.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ __global__ void mulKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, Te
1818
KERNEL_LOOP(i, n) {
1919
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0;
2020
IndexToScatterOffsets3<Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset);
21-
atomicMul(&output.data[outputOffset], input.data[inputOffset]);
21+
atomMul(&output.data[outputOffset], input.data[inputOffset]);
2222
}
2323
}
2424

@@ -27,7 +27,7 @@ __global__ void divKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, Te
2727
KERNEL_LOOP(i, n) {
2828
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0;
2929
IndexToScatterOffsets3<Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset);
30-
atomicDiv(&output.data[outputOffset], input.data[inputOffset]);
30+
atomDiv(&output.data[outputOffset], input.data[inputOffset]);
3131
}
3232
}
3333

@@ -36,8 +36,8 @@ __global__ void meanKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, T
3636
KERNEL_LOOP(i, n) {
3737
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0; int countOffset = 0;
3838
IndexToScatterOffsets4<Real, Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset, count, &countOffset);
39-
atomicAdd(&output.data[outputOffset], input.data[inputOffset]);
40-
atomicAdd(&count.data[countOffset], 1);
39+
atomAdd(&output.data[outputOffset], input.data[inputOffset]);
40+
atomAdd(&count.data[countOffset], 1);
4141
}
4242
}
4343

@@ -46,7 +46,7 @@ __global__ void maxKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, Te
4646
KERNEL_LOOP(i, n) {
4747
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0;
4848
IndexToScatterOffsets3<Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset);
49-
atomicMax(&output.data[outputOffset], input.data[inputOffset]);
49+
atomMax(&output.data[outputOffset], input.data[inputOffset]);
5050
}
5151
}
5252

@@ -55,7 +55,7 @@ __global__ void minKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, Te
5555
KERNEL_LOOP(i, n) {
5656
int outputOffset = 0; int indexOffset = 0; int inputOffset = 0;
5757
IndexToScatterOffsets3<Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset);
58-
atomicMin(&output.data[outputOffset], input.data[inputOffset]);
58+
atomMin(&output.data[outputOffset], input.data[inputOffset]);
5959
}
6060
}
6161

0 commit comments

Comments
 (0)