2323#define MAX_EXP_F16 __float2half (11 .089866488461016f )
2424#define MIN_EXP_F16 __float2half (-9 .704060527839234f )
2525
26- // -------------------------------------- FP32
27- // -------------------------------------- col2row means read x[row][col] and
26+ // FP32
27+ // col2row means read x[row][col] and
2828// write y[col][row] row2col means read x[col][row] and write y[row][col]
2929__global__ void mat_transpose_f32_col2row_kernel(float *x, float *y,
3030 const int row, const int col) {
@@ -216,7 +216,6 @@ __global__ void mat_transpose_f32x4_shared_row2col2d_kernel(float *x, float *y,
216216 }
217217}
218218
219-
220219__global__ void mat_transpose_f32x4_shared_bcf_col2row2d_kernel (float *x,
221220 float *y,
222221 const int row,
@@ -298,11 +297,8 @@ __global__ void mat_transpose_f32x4_shared_bcf_row2col2d_kernel(float *x,
298297 }
299298}
300299
301-
302- __global__ void mat_transpose_f32x4_shared_bcf_merge_write_row2col2d_kernel (float *x,
303- float *y,
304- const int row,
305- const int col) {
300+ __global__ void mat_transpose_f32x4_shared_bcf_merge_write_row2col2d_kernel (
301+ float *x, float *y, const int row, const int col) {
306302 const int global_x = blockIdx .x * blockDim .x + threadIdx .x ;
307303 const int global_y = blockIdx .y * blockDim .y + threadIdx .y ;
308304 const int local_x = threadIdx .x ;
@@ -328,18 +324,13 @@ __global__ void mat_transpose_f32x4_shared_bcf_merge_write_row2col2d_kernel(floa
328324 smem_val.w = tile[local_x * 4 + 3 ][local_y];
329325
330326 const int gid_x = blockIdx .x * blockDim .x ;
331- const int gid_y = blockIdx .y * blockDim .y * 4 ;
327+ const int gid_y = blockIdx .y * blockDim .y * 4 ;
332328 const int out_y = gid_y + local_x * 4 ;
333329 const int out_x = gid_x + local_y;
334330 reinterpret_cast <float4 *>(y)[(out_x * row + out_y) / 4 ] = FLOAT4 (smem_val);
335331 }
336332}
337333
338- // TODO: may support double buffer pipeline mat transpose ?
339- // TODO: may support fp16 mat transpose ?
340-
341- // --------------------- PyTorch bindings for custom kernel
342- // -----------------------
343334#define STRINGFY (str ) #str
344335#define TORCH_BINDING_COMMON_EXTENSION (func ) \
345336 m.def(STRINGFY(func), &func, STRINGFY(func));
@@ -373,7 +364,7 @@ __global__ void mat_transpose_f32x4_shared_bcf_merge_write_row2col2d_kernel(floa
373364 dim3 block (WARP_SIZE_S, WARP_SIZE_S); \
374365 dim3 grid ((N + WARP_SIZE_S - 1 ) / (WARP_SIZE_S * n_element_col), \
375366 (M + WARP_SIZE_S - 1 ) / (WARP_SIZE_S * n_element_row)); \
376- mat_transpose_##tag##2d_kernel << < grid, \
367+ mat_transpose_##tag##2d_kernel < < < grid, \
377368 block >>> (reinterpret_cast <element_type *>(x.data_ptr ()), \
378369 reinterpret_cast <element_type *>(y.data_ptr ()), M, N); \
379370 }
@@ -400,11 +391,8 @@ TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_col2row, torch::kFloat32, float,
400391 1 , 4 )
401392TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_row2col, torch::kFloat32 , float ,
402393 4 , 1 )
403- TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_merge_write_row2col, torch::kFloat32 , float ,
404- 4 , 1 )
405-
406- // TODO: may support double buffer pipeline mat transpose ?
407- // TODO: may support fp16 mat transpose ?
394+ TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_merge_write_row2col,
395+ torch::kFloat32 , float , 4 , 1 )
408396
409397// CuTe implentations
410398extern void mat_transpose_cute_col2row_reg(torch::Tensor, torch::Tensor);
@@ -442,7 +430,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
442430 // shared memory optimize with bcf
443431 TORCH_BINDING_COMMON_EXTENSION (mat_transpose_f32x4_shared_bcf_col2row2d)
444432 TORCH_BINDING_COMMON_EXTENSION (mat_transpose_f32x4_shared_bcf_row2col2d)
445- TORCH_BINDING_COMMON_EXTENSION (mat_transpose_f32x4_shared_bcf_merge_write_row2col2d)
433+ TORCH_BINDING_COMMON_EXTENSION (
434+ mat_transpose_f32x4_shared_bcf_merge_write_row2col2d)
446435 // CuTe implentations
447436 TORCH_BINDING_COMMON_EXTENSION (mat_transpose_cute_col2row_reg)
448437 TORCH_BINDING_COMMON_EXTENSION (mat_transpose_cute_row2col_reg)
0 commit comments