Skip to content

Commit 0d3e128

Browse files
authored
fix comments (#345)
1 parent 955557b commit 0d3e128

File tree

18 files changed

+64
-97
lines changed

18 files changed

+64
-97
lines changed

kernels/elementwise/elementwise.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162 *>(&(value))[0])
1818
#define LDST128BITS(value) (reinterpret_cast<float4 *>(&(value))[0])
1919

20-
// -------------------------------------- FP32
21-
// -------------------------------------- ElementWise Add grid(N/256),
20+
// FP32
21+
// ElementWise Add grid(N/256),
2222
// block(256) a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b)
2323
__global__ void elementwise_add_f32_kernel(float *a, float *b, float *c,
2424
int N) {
@@ -45,8 +45,8 @@ __global__ void elementwise_add_f32x4_kernel(float *a, float *b, float *c,
4545
}
4646
}
4747

48-
// -------------------------------------- FP16
49-
// -------------------------------------- ElementWise Add grid(N/256),
48+
// FP16
49+
// ElementWise Add grid(N/256),
5050
// block(256) a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b)
5151
__global__ void elementwise_add_f16_kernel(half *a, half *b, half *c, int N) {
5252
int idx = blockIdx.x * blockDim.x + threadIdx.x;

kernels/elu/elu.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@
3131
m.def(STRINGFY(func), &func, STRINGFY(func));
3232

3333
// ELU 计算函数
34-
// -------------------------------------- FP32
34+
// FP32
3535
// --------------------------------------
3636
__device__ __forceinline__ float elu(float x) {
3737
return x > 0.f ? x : ALPHA * (expf(x) - 1.f);
3838
}
3939

40-
// -------------------------------------- FP16
40+
// FP16
4141
// --------------------------------------
4242
__device__ __forceinline__ half elu_half(half x) {
4343
return __hgt(x, __float2half(0.f))
@@ -46,7 +46,7 @@ __device__ __forceinline__ half elu_half(half x) {
4646
}
4747

4848
// CUDA 核函数
49-
// -------------------------------------- FP32
49+
// FP32
5050
// --------------------------------------
5151
__global__ void elu_f32_kernel(float *x, float *y, int N) {
5252
int idx = blockIdx.x * blockDim.x + threadIdx.x;
@@ -67,7 +67,7 @@ __global__ void elu_f32x4_kernel(float *x, float *y, int N) {
6767
}
6868
}
6969

70-
// -------------------------------------- FP16
70+
// FP16
7171
// --------------------------------------
7272
__global__ void elu_f16_kernel(half *x, half *y, int N) {
7373
int idx = blockIdx.x * blockDim.x + threadIdx.x;

kernels/gelu/gelu.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ __inline__ __device__ float gelu_none_approximate(float x) {
5757
return x * 0.5 * (1 + erff(x * M_SQRT1_2));
5858
}
5959

60-
// -------------------------------------- FP32
61-
// -------------------------------------- GELU tanh approximate: x, y:x 0.5 * x
60+
// FP32
61+
// GELU tanh approximate: x, y:x 0.5 * x
6262
// * (1.0 + tanh(0.7978845608 * x * (1.0 + 0.044715 * x * x))) grid(N/256),
6363
// block(K=256)
6464
__global__ void gelu_f32_kernel(float *x, float *y, int N) {
@@ -91,8 +91,8 @@ __global__ void gelu_f32x4_kernel(float *x, float *y, int N) {
9191
}
9292
}
9393

94-
// -------------------------------------- FP16
95-
// -------------------------------------- GELU approximate: x, y:x 0.5 * x *
94+
// FP16
95+
// GELU approximate: x, y:x 0.5 * x *
9696
// (1.0 + tanh(0.7978845608 (x + 0.044715 * x * x * x))) Vec4
9797
__global__ void gelu_f16_kernel(half *x, half *y, int N) {
9898
int idx = blockIdx.x * blockDim.x + threadIdx.x;

kernels/hardshrink/hardshrink.cu

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@
3131
m.def(STRINGFY(func), &func, STRINGFY(func));
3232

3333
// HARDSHRINK 计算函数
34-
// -------------------------------------- FP32
35-
// --------------------------------------
34+
// FP32
3635
__device__ __forceinline__ float hardshrink(float x) {
3736
if (x > LAMBD || x < -LAMBD) {
3837
return x;
@@ -41,8 +40,7 @@ __device__ __forceinline__ float hardshrink(float x) {
4140
}
4241
}
4342

44-
// -------------------------------------- FP16
45-
// --------------------------------------
43+
// FP16
4644
__device__ __forceinline__ half hardshrink_half(half x) {
4745
if (x > __float2half(LAMBD) || x < __float2half(-LAMBD)) {
4846
return x;
@@ -52,8 +50,7 @@ __device__ __forceinline__ half hardshrink_half(half x) {
5250
}
5351

5452
// CUDA 核函数
55-
// -------------------------------------- FP32
56-
// --------------------------------------
53+
// FP32
5754
__global__ void hardshrink_f32_kernel(float *x, float *y, int N) {
5855
int idx = blockIdx.x * blockDim.x + threadIdx.x;
5956
if (idx < N)
@@ -73,8 +70,7 @@ __global__ void hardshrink_f32x4_kernel(float *x, float *y, int N) {
7370
}
7471
}
7572

76-
// -------------------------------------- FP16
77-
// --------------------------------------
73+
// FP16
7874
__global__ void hardshrink_f16_kernel(half *x, half *y, int N) {
7975
int idx = blockIdx.x * blockDim.x + threadIdx.x;
8076
if (idx < N)

kernels/hardswish/hardswish.cu

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@
3232
m.def(STRINGFY(func), &func, STRINGFY(func));
3333

3434
// HARDSWISH 计算函数
35-
// -------------------------------------- FP32
36-
// --------------------------------------
35+
// FP32
3736
__device__ __forceinline__ float hardswish(float x) {
3837
if (x >= THRESHOLD_A) {
3938
return x;
@@ -44,8 +43,7 @@ __device__ __forceinline__ float hardswish(float x) {
4443
}
4544
}
4645

47-
// -------------------------------------- FP16
48-
// --------------------------------------
46+
// FP16
4947
__device__ __forceinline__ half hardswish_half(half x) {
5048
if (x > __float2half(THRESHOLD_A)) {
5149
return x;
@@ -57,8 +55,7 @@ __device__ __forceinline__ half hardswish_half(half x) {
5755
}
5856

5957
// CUDA 核函数
60-
// -------------------------------------- FP32
61-
// --------------------------------------
58+
// FP32
6259
__global__ void hardswish_f32_kernel(float *x, float *y, int N) {
6360
int idx = blockIdx.x * blockDim.x + threadIdx.x;
6461
if (idx < N)
@@ -78,8 +75,7 @@ __global__ void hardswish_f32x4_kernel(float *x, float *y, int N) {
7875
}
7976
}
8077

81-
// -------------------------------------- FP16
82-
// --------------------------------------
78+
// FP16
8379
__global__ void hardswish_f16_kernel(half *x, half *y, int N) {
8480
int idx = blockIdx.x * blockDim.x + threadIdx.x;
8581
if (idx < N)

kernels/histogram/histogram.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ __global__ void histogram_i32x4_kernel(int *a, int *y, int N) {
3535
}
3636
}
3737

38-
// --------------------- PyTorch bindings for custom kernel
39-
// -----------------------
38+
// PyTorch bindings for custom kernel
4039
#define STRINGFY(str) #str
4140
#define TORCH_BINDING_COMMON_EXTENSION(func) \
4241
m.def(STRINGFY(func), &func, STRINGFY(func));

kernels/layer-norm/layer_norm.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162 *>(&(value))[0])
1818
#define LDST128BITS(value) (reinterpret_cast<float4 *>(&(value))[0])
1919

20-
// -------------------------------------- FP32
21-
// -------------------------------------- Warp Reduce Sum
20+
// FP32
21+
// Warp Reduce Sum
2222
template <const int kWarpSize = WARP_SIZE>
2323
__device__ __forceinline__ float warp_reduce_sum_f32(float val) {
2424
#pragma unroll
@@ -119,8 +119,8 @@ __global__ void layer_norm_f32x4_kernel(float *x, float *y, float g, float b,
119119
FLOAT4(y[idx]) = reg_y;
120120
}
121121

122-
// -------------------------------------- FP16
123-
// -------------------------------------- Warp Reduce Sum: Half
122+
// FP16
123+
// Warp Reduce Sum: Half
124124
template <const int kWarpSize = WARP_SIZE>
125125
__device__ __forceinline__ half warp_reduce_sum_f16_f16(half val) {
126126
#pragma unroll

kernels/mat-transpose/mat_transpose.cu

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
#define MAX_EXP_F16 __float2half(11.089866488461016f)
2424
#define MIN_EXP_F16 __float2half(-9.704060527839234f)
2525

26-
// -------------------------------------- FP32
27-
// -------------------------------------- col2row means read x[row][col] and
26+
// FP32
27+
// col2row means read x[row][col] and
2828
// write y[col][row] row2col means read x[col][row] and write y[row][col]
2929
__global__ void mat_transpose_f32_col2row_kernel(float *x, float *y,
3030
const int row, const int col) {
@@ -216,7 +216,6 @@ __global__ void mat_transpose_f32x4_shared_row2col2d_kernel(float *x, float *y,
216216
}
217217
}
218218

219-
220219
__global__ void mat_transpose_f32x4_shared_bcf_col2row2d_kernel(float *x,
221220
float *y,
222221
const int row,
@@ -298,11 +297,8 @@ __global__ void mat_transpose_f32x4_shared_bcf_row2col2d_kernel(float *x,
298297
}
299298
}
300299

301-
302-
__global__ void mat_transpose_f32x4_shared_bcf_merge_write_row2col2d_kernel(float *x,
303-
float *y,
304-
const int row,
305-
const int col) {
300+
__global__ void mat_transpose_f32x4_shared_bcf_merge_write_row2col2d_kernel(
301+
float *x, float *y, const int row, const int col) {
306302
const int global_x = blockIdx.x * blockDim.x + threadIdx.x;
307303
const int global_y = blockIdx.y * blockDim.y + threadIdx.y;
308304
const int local_x = threadIdx.x;
@@ -328,18 +324,13 @@ __global__ void mat_transpose_f32x4_shared_bcf_merge_write_row2col2d_kernel(floa
328324
smem_val.w = tile[local_x * 4 + 3][local_y];
329325

330326
const int gid_x = blockIdx.x * blockDim.x;
331-
const int gid_y = blockIdx.y * blockDim.y * 4;
327+
const int gid_y = blockIdx.y * blockDim.y * 4;
332328
const int out_y = gid_y + local_x * 4;
333329
const int out_x = gid_x + local_y;
334330
reinterpret_cast<float4 *>(y)[(out_x * row + out_y) / 4] = FLOAT4(smem_val);
335331
}
336332
}
337333

338-
// TODO: may support double buffer pipeline mat transpose ?
339-
// TODO: may support fp16 mat transpose ?
340-
341-
// --------------------- PyTorch bindings for custom kernel
342-
// -----------------------
343334
#define STRINGFY(str) #str
344335
#define TORCH_BINDING_COMMON_EXTENSION(func) \
345336
m.def(STRINGFY(func), &func, STRINGFY(func));
@@ -373,7 +364,7 @@ __global__ void mat_transpose_f32x4_shared_bcf_merge_write_row2col2d_kernel(floa
373364
dim3 block(WARP_SIZE_S, WARP_SIZE_S); \
374365
dim3 grid((N + WARP_SIZE_S - 1) / (WARP_SIZE_S * n_element_col), \
375366
(M + WARP_SIZE_S - 1) / (WARP_SIZE_S * n_element_row)); \
376-
mat_transpose_##tag##2d_kernel <<< grid, \
367+
mat_transpose_##tag##2d_kernel < < < grid, \
377368
block >>> (reinterpret_cast<element_type *>(x.data_ptr()), \
378369
reinterpret_cast<element_type *>(y.data_ptr()), M, N); \
379370
}
@@ -400,11 +391,8 @@ TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_col2row, torch::kFloat32, float,
400391
1, 4)
401392
TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_row2col, torch::kFloat32, float,
402393
4, 1)
403-
TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_merge_write_row2col, torch::kFloat32, float,
404-
4, 1)
405-
406-
// TODO: may support double buffer pipeline mat transpose ?
407-
// TODO: may support fp16 mat transpose ?
394+
TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_merge_write_row2col,
395+
torch::kFloat32, float, 4, 1)
408396

409397
// CuTe implentations
410398
extern void mat_transpose_cute_col2row_reg(torch::Tensor, torch::Tensor);
@@ -442,7 +430,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
442430
// shared memory optimize with bcf
443431
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_shared_bcf_col2row2d)
444432
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_shared_bcf_row2col2d)
445-
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_shared_bcf_merge_write_row2col2d)
433+
TORCH_BINDING_COMMON_EXTENSION(
434+
mat_transpose_f32x4_shared_bcf_merge_write_row2col2d)
446435
// CuTe implentations
447436
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_cute_col2row_reg)
448437
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_cute_row2col_reg)

kernels/nms/nms.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@ __global__ void nms_kernel(const float *boxes, const float *scores, int *keep,
5858
return;
5959
}
6060

61-
// --------------------- PyTorch bindings for custom kernel
62-
// -----------------------
6361
#define STRINGFY(str) #str
6462
#define TORCH_BINDING_COMMON_EXTENSION(func) \
6563
m.def(STRINGFY(func), &func, STRINGFY(func));

kernels/reduce/block_all_reduce.cu

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
// MatMul FP8 -> Tensor Cores
2626

2727
// CUDA温故(0x00): 一步步学习block all reduce: 从FP32到FP16/BF16,再到FP8
28-
// -------------------------------------- FP32
29-
// -------------------------------------- Warp Reduce Sum
28+
// FP32
29+
// Warp Reduce Sum
3030
template <const int kWarpSize = WARP_SIZE>
3131
__device__ __forceinline__ float warp_reduce_sum_f32(float val) {
3232
#pragma unroll
@@ -93,8 +93,8 @@ __global__ void block_all_reduce_sum_f32x4_f32_kernel(float *a, float *y,
9393
atomicAdd(y, sum);
9494
}
9595

96-
// -------------------------------------- FP16
97-
// -------------------------------------- Warp Reduce Sum: Half
96+
// FP16
97+
// Warp Reduce Sum: Half
9898
template <const int kWarpSize = WARP_SIZE>
9999
__device__ __forceinline__ half warp_reduce_sum_f16_f16(half val) {
100100
#pragma unroll
@@ -301,8 +301,8 @@ __global__ void block_all_reduce_sum_f16x8_pack_f32_kernel(half *a, float *y,
301301
atomicAdd(y, sum);
302302
}
303303

304-
// -------------------------------------- BF16
305-
// -------------------------------------- Warp Reduce Sum: Half
304+
// BF16
305+
// Warp Reduce Sum: Half
306306
template <const int kWarpSize = WARP_SIZE>
307307
__device__ __forceinline__ __nv_bfloat16
308308
warp_reduce_sum_bf16_bf16(__nv_bfloat16 val) {
@@ -520,8 +520,8 @@ __global__ void block_all_reduce_sum_bf16x8_pack_f32_kernel(__nv_bfloat16 *a,
520520
atomicAdd(y, sum);
521521
}
522522

523-
// -------------------------------------- FP8
524-
// --------------------------------------
523+
// FP8
524+
//
525525
template <const int kWarpSize = WARP_SIZE>
526526
__device__ __forceinline__ half
527527
warp_reduce_sum_fp8_e4m3_f16(__nv_fp8_storage_t val) {
@@ -680,8 +680,8 @@ block_all_reduce_sum_fp8_e5m2x16_pack_f16_kernel(__nv_fp8_storage_t *a,
680680
atomicAdd(y, __half2float(sum));
681681
}
682682

683-
// -------------------------------------- INT8
684-
// --------------------------------------
683+
// INT8
684+
//
685685
template <const int kWarpSize = WARP_SIZE>
686686
__device__ __forceinline__ int32_t warp_reduce_sum_i8_i32(int8_t val) {
687687
int32_t val_i32 = static_cast<int32_t>(val);

0 commit comments

Comments
 (0)