@@ -216,6 +216,7 @@ __global__ void mat_transpose_f32x4_shared_row2col2d_kernel(float *x, float *y,
216
216
}
217
217
}
218
218
219
+
219
220
__global__ void mat_transpose_f32x4_shared_bcf_col2row2d_kernel (float *x,
220
221
float *y,
221
222
const int row,
@@ -296,6 +297,44 @@ __global__ void mat_transpose_f32x4_shared_bcf_row2col2d_kernel(float *x,
296
297
y[(out_y + 3 ) * row + out_x] = smem_val.w ;
297
298
}
298
299
}
300
+
301
+
302
+ __global__ void mat_transpose_f32x4_shared_bcf_merge_write_row2col2d_kernel (float *x,
303
+ float *y,
304
+ const int row,
305
+ const int col) {
306
+ const int global_x = blockIdx .x * blockDim .x + threadIdx .x ;
307
+ const int global_y = blockIdx .y * blockDim .y + threadIdx .y ;
308
+ const int local_x = threadIdx .x ;
309
+ const int local_y = threadIdx .y ;
310
+ __shared__ float tile[WARP_SIZE_S * 4 ][WARP_SIZE_S + PAD];
311
+ if (global_y * 4 < row && global_x < col) {
312
+ // load value from x to shared memory
313
+ float4 x_val;
314
+ x_val.x = x[(global_y * 4 ) * col + global_x];
315
+ x_val.y = x[(global_y * 4 + 1 ) * col + global_x];
316
+ x_val.z = x[(global_y * 4 + 2 ) * col + global_x];
317
+ x_val.w = x[(global_y * 4 + 3 ) * col + global_x];
318
+ tile[local_y * 4 ][local_x] = x_val.x ;
319
+ tile[local_y * 4 + 1 ][local_x] = x_val.y ;
320
+ tile[local_y * 4 + 2 ][local_x] = x_val.z ;
321
+ tile[local_y * 4 + 3 ][local_x] = x_val.w ;
322
+ __syncthreads ();
323
+ float4 smem_val;
324
+ // load value from shared memory to y.
325
+ smem_val.x = tile[local_x * 4 ][local_y];
326
+ smem_val.y = tile[local_x * 4 + 1 ][local_y];
327
+ smem_val.z = tile[local_x * 4 + 2 ][local_y];
328
+ smem_val.w = tile[local_x * 4 + 3 ][local_y];
329
+
330
+ const int gid_x = blockIdx .x * blockDim .x ;
331
+ const int gid_y = blockIdx .y * blockDim .y * 4 ;
332
+ const int out_y = gid_y + local_x * 4 ;
333
+ const int out_x = gid_x + local_y;
334
+ reinterpret_cast <float4 *>(y)[(out_x * row + out_y) / 4 ] = FLOAT4 (smem_val);
335
+ }
336
+ }
337
+
299
338
// TODO: may support double buffer pipeline mat transpose ?
300
339
// TODO: may support fp16 mat transpose ?
301
340
@@ -361,6 +400,9 @@ TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_col2row, torch::kFloat32, float,
361
400
1 , 4 )
362
401
TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_row2col, torch::kFloat32 , float ,
363
402
4 , 1 )
403
+ TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_merge_write_row2col, torch::kFloat32 , float ,
404
+ 4 , 1 )
405
+
364
406
// TODO: may support double buffer pipeline mat transpose ?
365
407
// TODO: may support fp16 mat transpose ?
366
408
@@ -400,6 +442,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
400
442
// shared memory optimize with bcf
401
443
TORCH_BINDING_COMMON_EXTENSION (mat_transpose_f32x4_shared_bcf_col2row2d)
402
444
TORCH_BINDING_COMMON_EXTENSION (mat_transpose_f32x4_shared_bcf_row2col2d)
445
+ TORCH_BINDING_COMMON_EXTENSION (mat_transpose_f32x4_shared_bcf_merge_write_row2col2d)
403
446
// CuTe implentations
404
447
TORCH_BINDING_COMMON_EXTENSION (mat_transpose_cute_col2row_reg)
405
448
TORCH_BINDING_COMMON_EXTENSION (mat_transpose_cute_row2col_reg)
0 commit comments