Skip to content

Commit 9ae191a

Browse files
fix: update mat_transpose_f32_row2col2d_kernel to make it actually row2col (#404)
refactor: clean up mat_transpose.cu by removing unused includes and correcting kernel comments for clarity
1 parent 10fd701 commit 9ae191a

File tree

1 file changed

+4
-10
lines changed

1 file changed

+4
-10
lines changed

kernels/mat-transpose/mat_transpose.cu

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include <algorithm>
2-
#include <cuda_bf16.h>
32
#include <cuda_fp16.h>
43
#include <cuda_fp8.h>
54
#include <cuda_runtime.h>
@@ -16,16 +15,10 @@
1615
#define INT4(value) (reinterpret_cast<int4 *>(&(value))[0])
1716
#define FLOAT4(value) (reinterpret_cast<float4 *>(&(value))[0])
1817
#define HALF2(value) (reinterpret_cast<half2 *>(&(value))[0])
19-
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162 *>(&(value))[0])
2018
#define LDST128BITS(value) (reinterpret_cast<float4 *>(&(value))[0])
21-
#define MAX_EXP_F32 88.3762626647949f
22-
#define MIN_EXP_F32 -88.3762626647949f
23-
#define MAX_EXP_F16 __float2half(11.089866488461016f)
24-
#define MIN_EXP_F16 __float2half(-9.704060527839234f)
2519

2620
// FP32
27-
// col2row means read x[row][col] and
28-
// write y[col][row] row2col means read x[col][row] and write y[row][col]
21+
// col2row means read x[row][col] and write y[col][row]
2922
__global__ void mat_transpose_f32_col2row_kernel(float *x, float *y,
3023
const int row, const int col) {
3124
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
@@ -36,6 +29,7 @@ __global__ void mat_transpose_f32_col2row_kernel(float *x, float *y,
3629
}
3730
}
3831

32+
// row2col means read x[col][row] and write y[row][col]
3933
__global__ void mat_transpose_f32_row2col_kernel(float *x, float *y,
4034
const int row, const int col) {
4135
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
@@ -104,8 +98,8 @@ __global__ void mat_transpose_f32_col2row2d_kernel(float *x, float *y,
10498
__global__ void mat_transpose_f32_row2col2d_kernel(float *x, float *y,
10599
const int row,
106100
const int col) {
107-
const int global_y = blockIdx.x * blockDim.x + threadIdx.x;
108-
const int global_x = blockIdx.y * blockDim.y + threadIdx.y;
101+
const int global_x = blockIdx.x * blockDim.x + threadIdx.x;
102+
const int global_y = blockIdx.y * blockDim.y + threadIdx.y;
109103
if (global_y < col && global_x < row) {
110104
y[global_y * row + global_x] = x[global_x * col + global_y];
111105
}

0 commit comments

Comments
 (0)