Skip to content

Commit edc46f3

Browse files
authored
fix comments (#346)
1 parent 0d3e128 commit edc46f3

File tree

11 files changed

+7
-32
lines changed

11 files changed

+7
-32
lines changed

kernels/elementwise/elementwise.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,6 @@ __global__ void elementwise_add_f16x8_pack_kernel(half *a, half *b, half *c,
120120
}
121121
}
122122

123-
// --------------------- PyTorch bindings for custom kernel
124-
// -----------------------
125123
#define STRINGFY(str) #str
126124
#define TORCH_BINDING_COMMON_EXTENSION(func) \
127125
m.def(STRINGFY(func), &func, STRINGFY(func));

kernels/elu/elu.cu

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,11 @@
3232

3333
// ELU 计算函数
3434
// FP32
35-
// --------------------------------------
3635
__device__ __forceinline__ float elu(float x) {
3736
return x > 0.f ? x : ALPHA * (expf(x) - 1.f);
3837
}
3938

4039
// FP16
41-
// --------------------------------------
4240
__device__ __forceinline__ half elu_half(half x) {
4341
return __hgt(x, __float2half(0.f))
4442
? x
@@ -47,7 +45,6 @@ __device__ __forceinline__ half elu_half(half x) {
4745

4846
// CUDA 核函数
4947
// FP32
50-
// --------------------------------------
5148
__global__ void elu_f32_kernel(float *x, float *y, int N) {
5249
int idx = blockIdx.x * blockDim.x + threadIdx.x;
5350
if (idx < N)
@@ -68,7 +65,6 @@ __global__ void elu_f32x4_kernel(float *x, float *y, int N) {
6865
}
6966

7067
// FP16
71-
// --------------------------------------
7268
__global__ void elu_f16_kernel(half *x, half *y, int N) {
7369
int idx = blockIdx.x * blockDim.x + threadIdx.x;
7470
if (idx < N)
@@ -129,7 +125,6 @@ __global__ void elu_f16x8_pack_kernel(half *x, half *y, int N) {
129125
}
130126
}
131127

132-
// PyTorch 绑定代码
133128
#define TORCH_BINDING_ELU(packed_type, th_type, element_type, n_elements) \
134129
void elu_##packed_type(torch::Tensor x, torch::Tensor y) { \
135130
CHECK_TORCH_TENSOR_DTYPE(x, (th_type)) \

kernels/embedding/embedding.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ __global__ void embedding_f16x8_pack_kernel(const int *idx, half *weight,
7878
LDST128BITS(weight[offset + 8 * tx]);
7979
}
8080

81-
// --------------------- PyTorch bindings for custom kernel
82-
// -----------------------
8381
#define STRINGFY(str) #str
8482
#define TORCH_BINDING_COMMON_EXTENSION(func) \
8583
m.def(STRINGFY(func), &func, STRINGFY(func));

kernels/gelu/gelu.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,6 @@ __global__ void gelu_f16x8_pack_kernel(half *x, half *y, int N) {
182182
}
183183
}
184184

185-
// --------------------- PyTorch bindings for custom kernel
186-
// -----------------------
187185
#define STRINGFY(str) #str
188186
#define TORCH_BINDING_COMMON_EXTENSION(func) \
189187
m.def(STRINGFY(func), &func, STRINGFY(func));

kernels/hardshrink/hardshrink.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ __global__ void hardshrink_f16x8_pack_kernel(half *x, half *y, int N) {
131131
}
132132
}
133133

134-
// PyTorch 绑定代码
135134
#define TORCH_BINDING_HARDSHRINK(packed_type, th_type, element_type, \
136135
n_elements) \
137136
void hardshrink_##packed_type(torch::Tensor x, torch::Tensor y) { \

kernels/hardswish/hardswish.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ __global__ void hardswish_f16x8_pack_kernel(half *x, half *y, int N) {
136136
}
137137
}
138138

139-
// PyTorch 绑定代码
140139
#define TORCH_BINDING_HARDSWISH(packed_type, th_type, element_type, \
141140
n_elements) \
142141
void hardswish_##packed_type(torch::Tensor x, torch::Tensor y) { \

kernels/histogram/histogram.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ __global__ void histogram_i32x4_kernel(int *a, int *y, int N) {
3535
}
3636
}
3737

38-
// PyTorch bindings for custom kernel
3938
#define STRINGFY(str) #str
4039
#define TORCH_BINDING_COMMON_EXTENSION(func) \
4140
m.def(STRINGFY(func), &func, STRINGFY(func));

kernels/layer-norm/layer_norm.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,8 +456,6 @@ __global__ void layer_norm_f16x8_pack_f32_kernel(half *x, half *y, float g,
456456
// TODO: support non 8-multiple K here
457457
}
458458

459-
// --------------------- PyTorch bindings for custom kernel
460-
// -----------------------
461459
#define STRINGFY(str) #str
462460
#define TORCH_BINDING_COMMON_EXTENSION(func) \
463461
m.def(STRINGFY(func), &func, STRINGFY(func));

kernels/nvidia-nsight/elementwise.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162 *>(&(value))[0])
1717
#define LDST128BITS(value) (reinterpret_cast<float4 *>(&(value))[0])
1818

19-
// -------------------------------------- FP32
20-
// -------------------------------------- ElementWise Add grid(N/256),
19+
// FP32
20+
// ElementWise Add grid(N/256),
2121
// block(256) a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b)
2222
__global__ void elementwise_add_f32_kernel(float *a, float *b, float *c,
2323
int N) {
@@ -44,8 +44,8 @@ __global__ void elementwise_add_f32x4_kernel(float *a, float *b, float *c,
4444
}
4545
}
4646

47-
// -------------------------------------- FP16
48-
// -------------------------------------- ElementWise Add grid(N/256),
47+
// FP16
48+
// ElementWise Add grid(N/256),
4949
// block(256) a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b)
5050
__global__ void elementwise_add_f16_kernel(half *a, half *b, half *c, int N) {
5151
int idx = blockIdx.x * blockDim.x + threadIdx.x;

kernels/nvidia-nsight/relu.cu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162 *>(&(value))[0])
1515
#define LDST128BITS(value) (reinterpret_cast<float4 *>(&(value))[0])
1616

17-
// -------------------------------------- FP32
18-
// -------------------------------------- Relu x: N, y: N y=max(0,x)
17+
// FP32
18+
// Relu x: N, y: N y=max(0,x)
1919
// grid(N/256), block(K=256)
2020
__global__ void relu_f32_kernel(float *x, float *y, int N) {
2121
int idx = blockIdx.x * blockDim.x + threadIdx.x;
@@ -38,8 +38,7 @@ __global__ void relu_f32x4_kernel(float *x, float *y, int N) {
3838
}
3939
}
4040

41-
// -------------------------------------- FP16
42-
// --------------------------------------
41+
// FP16
4342
__global__ void relu_f16_kernel(half *x, half *y, int N) {
4443
int idx = blockIdx.x * blockDim.x + threadIdx.x;
4544
if (idx < N)

0 commit comments

Comments
 (0)