15
15
#define FLOAT4 (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
16
16
#define HALF2 (value ) (reinterpret_cast <half2*>(&(value))[0 ])
17
17
#define BFLOAT2 (value ) (reinterpret_cast <__nv_bfloat162*>(&(value))[0 ])
18
+ #define LDST128BITS (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
18
19
19
20
// -------------------------------------- FP32 --------------------------------------
20
21
// Warp Reduce Sum
@@ -123,7 +124,7 @@ __global__ void dot_prod_f16_f32_kernel(half* a, half* b, float* y, int N) {
123
124
if (tid == 0 ) atomicAdd (y, prod);
124
125
}
125
126
126
- template <const int NUM_THREADS = 256 >
127
+ template <const int NUM_THREADS = 256 / 2 >
127
128
__global__ void dot_prod_f16x2_f32_kernel (half* a, half* b, float * y, int N) {
128
129
int tid = threadIdx .x ;
129
130
int idx = (blockIdx .x * NUM_THREADS + tid) * 2 ; // 2 half elements per thread
@@ -148,6 +149,38 @@ __global__ void dot_prod_f16x2_f32_kernel(half* a, half* b, float* y, int N) {
148
149
if (tid == 0 ) atomicAdd (y, prod);
149
150
}
150
151
152
+ template <const int NUM_THREADS = 256 /8 >
153
+ __global__ void dot_prod_f16x8_pack_f32_kernel (half* a, half* b, float * y, int N) {
154
+ int tid = threadIdx .x ;
155
+ int idx = (blockIdx .x * NUM_THREADS + tid) * 8 ; // 8 half elements per thread
156
+ constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1 ) / WARP_SIZE;
157
+ __shared__ float reduce_smem[NUM_WARPS];
158
+ // temporary register(memory), .local space in ptx, addressable
159
+ half pack_a[8 ], pack_b[8 ]; // 8x16 bits=128 bits.
160
+ LDST128BITS (pack_a[0 ]) = LDST128BITS (a[idx]); // load 128 bits
161
+ LDST128BITS (pack_b[0 ]) = LDST128BITS (b[idx]); // load 128 bits
162
+ const half z = __float2half (0 .0f );
163
+
164
+ half prod_f16 = z;
165
+ #pragma unroll
166
+ for (int i = 0 ; i < 8 ; i += 2 ) {
167
+ half2 v = __hmul2 (HALF2 (pack_a[i]), HALF2 (pack_b[i]));
168
+ prod_f16 += (((idx + i ) < N) ? (v.x + v.y ) : z);
169
+ }
170
+
171
+ int warp = tid / WARP_SIZE;
172
+ int lane = tid % WARP_SIZE;
173
+ // perform warp sync reduce.
174
+ float prod = warp_reduce_sum_f16_f32<WARP_SIZE>(prod_f16);
175
+ // warp leaders store the data to shared memory.
176
+ if (lane == 0 ) reduce_smem[warp] = prod;
177
+ __syncthreads (); // make sure the data is in shared memory.
178
+ // the first warp compute the final sum.
179
+ prod = (lane < NUM_WARPS) ? reduce_smem[lane] : 0 .0f ;
180
+ if (warp == 0 ) prod = warp_reduce_sum_f32<NUM_WARPS>(prod);
181
+ if (tid == 0 ) atomicAdd (y, prod);
182
+ }
183
+
151
184
// --------------------- PyTorch bindings for custom kernel -----------------------
152
185
#define STRINGFY (str ) #str
153
186
#define TORCH_BINDING_COMMON_EXTENSION (func ) \
@@ -159,8 +192,42 @@ if(((T).options().dtype() != (th_type))) { \
159
192
throw std::runtime_error (" values must be " #th_type); \
160
193
}
161
194
162
- #define CHECK_TORCH_TENSOR_SHAPE (T, S0 ) \
163
- if (((T).size(0 ) != (S0))) { throw std::runtime_error (" Tensor size mismatch!" ); }
195
+ #define LANUCH_DOT_PROD_KERNEL (NT, packed_type, acc_type, element_type ) \
196
+ dot_prod_##packed_type##_##acc_type##_kernel<(NT)><<<grid, block>>> ( \
197
+ reinterpret_cast <element_type*>(a.data_ptr()), \
198
+ reinterpret_cast <element_type*>(b.data_ptr()), \
199
+ prod.data_ptr<float >(), N);
200
+
201
+ #define DISPATCH_DOT_PROD_KERNEL (K, packed_type, acc_type, element_type, n_elements ) \
202
+ const int NT = (K)/(n_elements); \
203
+ dim3 block (NT); \
204
+ dim3 grid ((S)); \
205
+ switch (NT) \
206
+ { \
207
+ case 32 : \
208
+ LANUCH_DOT_PROD_KERNEL (32 , packed_type, acc_type, element_type) \
209
+ break ; \
210
+ case 64 : \
211
+ LANUCH_DOT_PROD_KERNEL (64 , packed_type, acc_type, element_type) \
212
+ break ; \
213
+ case 128 : \
214
+ LANUCH_DOT_PROD_KERNEL (128 , packed_type, acc_type, element_type) \
215
+ break ; \
216
+ case 256 : \
217
+ LANUCH_DOT_PROD_KERNEL (256 , packed_type, acc_type, element_type) \
218
+ break ; \
219
+ case 512 : \
220
+ LANUCH_DOT_PROD_KERNEL (512 , packed_type, acc_type, element_type) \
221
+ break ; \
222
+ case 1024 : \
223
+ LANUCH_DOT_PROD_KERNEL (1024 , packed_type, acc_type, element_type) \
224
+ break ; \
225
+ default : \
226
+ throw std::runtime_error ( \
227
+ " only support (K)/(n_elements): 32/64/128/256/512/1024" ); \
228
+ break ; \
229
+ }
230
+
164
231
165
232
#define TORCH_BINDING_DOT_PROD (packed_type, acc_type, th_type, element_type, n_elements ) \
166
233
torch::Tensor dot_prod_##packed_type##_##acc_type(torch::Tensor a, torch::Tensor b) { \
@@ -169,30 +236,49 @@ torch::Tensor dot_prod_##packed_type##_##acc_type(torch::Tensor a, torch::Tensor
169
236
auto options = torch::TensorOptions ().dtype (torch::kFloat32 ).device ( \
170
237
torch::kCUDA , 0 ); \
171
238
auto prod = torch::zeros ({1 }, options); \
172
- const int N = a.size ( 0 ); \
173
- CHECK_TORCH_TENSOR_SHAPE (b, N) \
174
- static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
175
- const int NUM_BLOCKS = (N + 256 - 1 ) / 256 ; \
176
- dim3 block (NUM_THREADS_PER_BLOCK); \
177
- dim3 grid (NUM_BLOCKS); \
178
- dot_prod_##packed_type##_##acc_type##_kernel< \
179
- NUM_THREADS_PER_BLOCK ><<<grid, block>>> ( \
239
+ const int ndim = a.dim (); \
240
+ if (ndim != 2 ) { \
241
+ int N = 1 ; \
242
+ for ( int i = 0 ; i < ndim; ++i) { N *= a. size (i); } \
243
+ dim3 block (256 ); \
244
+ dim3 grid (((N + 256 - 1 ) / 256 ) / (n_elements)); \
245
+ dot_prod_##packed_type##_##acc_type##_kernel< \
246
+ 256 ><<<grid, block>>> ( \
180
247
reinterpret_cast <element_type*>(a.data_ptr ()), \
181
248
reinterpret_cast <element_type*>(b.data_ptr ()), \
182
249
prod.data_ptr <float >(), N); \
250
+ } else { \
251
+ const int S = a.size (0 ); \
252
+ const int K = a.size (1 ); \
253
+ const int N = S * K; \
254
+ if ((K/(n_elements)) <= 1024 ) { \
255
+ DISPATCH_DOT_PROD_KERNEL (K, packed_type, acc_type, element_type, n_elements) \
256
+ } else { \
257
+ int N = 1 ; \
258
+ for (int i = 0 ; i < ndim; ++i) { N *= a.size (i); } \
259
+ dim3 block (256 ); \
260
+ dim3 grid (((N + 256 - 1 ) / 256 ) / (n_elements)); \
261
+ dot_prod_##packed_type##_##acc_type##_kernel< \
262
+ 256 ><<<grid, block>>> ( \
263
+ reinterpret_cast <element_type*>(a.data_ptr ()), \
264
+ reinterpret_cast <element_type*>(b.data_ptr ()), \
265
+ prod.data_ptr <float >(), N); \
266
+ } \
267
+ } \
183
268
return prod; \
184
269
}
185
270
186
271
// packed_type, acc_type, th_type, element_type, n_elements_per_pack
187
- TORCH_BINDING_DOT_PROD (f32 , f32 , torch::kFloat32 , float , 1 )
188
- TORCH_BINDING_DOT_PROD(f32x4, f32 , torch::kFloat32 , float , 4 )
189
- TORCH_BINDING_DOT_PROD(f16 , f32 , torch::kHalf , half, 1 )
190
- TORCH_BINDING_DOT_PROD(f16x2, f32 , torch::kHalf , half, 2 )
191
-
272
+ TORCH_BINDING_DOT_PROD (f32 , f32 , torch::kFloat32 , float , 1 )
273
+ TORCH_BINDING_DOT_PROD(f32x4, f32 , torch::kFloat32 , float , 4 )
274
+ TORCH_BINDING_DOT_PROD(f16 , f32 , torch::kHalf , half, 1 )
275
+ TORCH_BINDING_DOT_PROD(f16x2, f32 , torch::kHalf , half, 2 )
276
+ TORCH_BINDING_DOT_PROD(f16x8_pack, f32 , torch:: kHalf , half, 8 )
192
277
193
278
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
194
279
TORCH_BINDING_COMMON_EXTENSION (dot_prod_f32_f32)
195
280
TORCH_BINDING_COMMON_EXTENSION (dot_prod_f32x4_f32)
196
281
TORCH_BINDING_COMMON_EXTENSION (dot_prod_f16_f32)
197
282
TORCH_BINDING_COMMON_EXTENSION (dot_prod_f16x2_f32)
283
+ TORCH_BINDING_COMMON_EXTENSION (dot_prod_f16x8_pack_f32)
198
284
}
0 commit comments