Skip to content

Commit 6953b39

Browse files
authored
add out_f32x4_shared_bcf_merge_write_row2col(2d) (#339)
1 parent 0cf539d commit 6953b39

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

kernels/mat-transpose/mat_transpose.cu

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ __global__ void mat_transpose_f32x4_shared_row2col2d_kernel(float *x, float *y,
216216
}
217217
}
218218

219+
219220
__global__ void mat_transpose_f32x4_shared_bcf_col2row2d_kernel(float *x,
220221
float *y,
221222
const int row,
@@ -296,6 +297,44 @@ __global__ void mat_transpose_f32x4_shared_bcf_row2col2d_kernel(float *x,
296297
y[(out_y + 3) * row + out_x] = smem_val.w;
297298
}
298299
}
300+
301+
302+
__global__ void mat_transpose_f32x4_shared_bcf_merge_write_row2col2d_kernel(float *x,
303+
float *y,
304+
const int row,
305+
const int col) {
306+
const int global_x = blockIdx.x * blockDim.x + threadIdx.x;
307+
const int global_y = blockIdx.y * blockDim.y + threadIdx.y;
308+
const int local_x = threadIdx.x;
309+
const int local_y = threadIdx.y;
310+
__shared__ float tile[WARP_SIZE_S * 4][WARP_SIZE_S + PAD];
311+
if (global_y * 4 < row && global_x < col) {
312+
// load value from x to shared memory
313+
float4 x_val;
314+
x_val.x = x[(global_y * 4) * col + global_x];
315+
x_val.y = x[(global_y * 4 + 1) * col + global_x];
316+
x_val.z = x[(global_y * 4 + 2) * col + global_x];
317+
x_val.w = x[(global_y * 4 + 3) * col + global_x];
318+
tile[local_y * 4][local_x] = x_val.x;
319+
tile[local_y * 4 + 1][local_x] = x_val.y;
320+
tile[local_y * 4 + 2][local_x] = x_val.z;
321+
tile[local_y * 4 + 3][local_x] = x_val.w;
322+
__syncthreads();
323+
float4 smem_val;
324+
// load value from shared memory to y.
325+
smem_val.x = tile[local_x * 4][local_y];
326+
smem_val.y = tile[local_x * 4 + 1][local_y];
327+
smem_val.z = tile[local_x * 4 + 2][local_y];
328+
smem_val.w = tile[local_x * 4 + 3][local_y];
329+
330+
const int gid_x = blockIdx.x * blockDim.x;
331+
const int gid_y = blockIdx.y * blockDim.y * 4;
332+
const int out_y = gid_y + local_x * 4;
333+
const int out_x = gid_x + local_y;
334+
reinterpret_cast<float4 *>(y)[(out_x * row + out_y) / 4] = FLOAT4(smem_val);
335+
}
336+
}
337+
299338
// TODO: may support double buffer pipeline mat transpose ?
300339
// TODO: may support fp16 mat transpose ?
301340

@@ -361,6 +400,9 @@ TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_col2row, torch::kFloat32, float,
361400
1, 4)
362401
TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_row2col, torch::kFloat32, float,
363402
4, 1)
403+
TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_merge_write_row2col, torch::kFloat32, float,
404+
4, 1)
405+
364406
// TODO: may support double buffer pipeline mat transpose ?
365407
// TODO: may support fp16 mat transpose ?
366408

@@ -400,6 +442,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
400442
// shared memory optimize with bcf
401443
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_shared_bcf_col2row2d)
402444
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_shared_bcf_row2col2d)
445+
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_shared_bcf_merge_write_row2col2d)
403446
// CuTe implentations
404447
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_cute_col2row_reg)
405448
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_cute_row2col_reg)

kernels/mat-transpose/mat_transpose.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,12 @@ def transpose_copy_compiled(input: torch.Tensor, out: torch.Tensor):
127127
"f32x4_shared_bcf_row2col(2d)",
128128
y,
129129
)
130+
run_benchmark(
131+
lib.mat_transpose_f32x4_shared_bcf_merge_write_row2col2d,
132+
x,
133+
"f32x4_shared_bcf_merge_write_row2col(2d)",
134+
y,
135+
)
130136
run_benchmark(
131137
lib.mat_transpose_cute_col2row_reg,
132138
x,

0 commit comments

Comments
 (0)