1
+ #include < stdio.h>
2
+ #include < stdlib.h>
3
+ #include < float.h>
4
+ #include < vector>
5
+ #include < algorithm>
6
+ #include < cuda_runtime.h>
7
+ #include < cuda_fp16.h>
8
+ #include < cuda_bf16.h>
9
+ #include < cuda_fp8.h>
10
+ #include < torch/types.h>
11
+ #include < torch/extension.h>
12
+
13
+ #define WARP_SIZE 32
14
+ #define INT4 (value ) (reinterpret_cast <int4 *>(&(value))[0 ])
15
+ #define FLOAT4 (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
16
+ #define HALF2 (value ) (reinterpret_cast <half2*>(&(value))[0 ])
17
+ #define BFLOAT2 (value ) (reinterpret_cast <__nv_bfloat162*>(&(value))[0 ])
18
+
19
+ // -------------------------------------- FP16 --------------------------------------
20
+ // Warp Reduce Sum
21
+ template <const int kWarpSize = WARP_SIZE>
22
+ __device__ __forceinline__ half warp_reduce_sum_f16 (half val) {
23
+ #pragma unroll
24
+ for (int mask = kWarpSize >> 1 ; mask >= 1 ; mask >>= 1 ) {
25
+ val += __shfl_xor_sync (0xffffffff , val, mask);
26
+ }
27
+ return val;
28
+ }
29
+
30
+ // HGEMV: Warp HGEMV K32
31
+ // 假设K为32的倍数,每个warp负责一行
32
+ // grid(M/4), block(32,4) blockDim.x=32=K, blockDim.y=4
33
+ // a: MxK, x: Kx1, y: Mx1, compute: y = a * x
34
+ __global__ void hgemv_k32_f16_kernel (half* a, half* x, half* y, int M, int K) {
35
+ int tx = threadIdx .x ; // 0~31
36
+ int ty = threadIdx .y ; // 0~4
37
+ int bx = blockIdx .x ; // 0~M/4
38
+ int lane = tx % WARP_SIZE; // 0~31
39
+ int m = bx * blockDim .y + ty; // (0~M/4) * 4 + (0~3)
40
+ if (m < M) {
41
+ half sum = 0 .0f ;
42
+ int NUM_WARPS = (K + WARP_SIZE - 1 ) / WARP_SIZE;
43
+ #pragma unroll
44
+ for (int w = 0 ; w < NUM_WARPS; ++w) {
45
+ // 若NUM_WARPS>=2,先将当前行的数据累加到第一个warp中
46
+ int k = w * WARP_SIZE + lane;
47
+ sum += a[m * K + k] * x[k];
48
+ }
49
+ sum = warp_reduce_sum_f16<WARP_SIZE>(sum);
50
+ if (lane == 0 ) y[m] = sum;
51
+ }
52
+ }
53
+
54
+ // HGEMV: Warp HGEMV K128 + half2x2
55
+ // 假设K为128的倍数 float4
56
+ // grid(M/4), block(32,4) blockDim.x=32=K, blockDim.y=4
57
+ // a: MxK, x: Kx1, y: Mx1, compute: y = a * x
58
+ __global__ void hgemv_k128_f16x4_kernel (half* a, half* x, half* y, int M, int K) {
59
+ // 每个线程负责4个元素,一个warp覆盖128个元素
60
+ int tx = threadIdx .x ; // 0~31
61
+ int ty = threadIdx .y ; // 0~3
62
+ int bx = blockIdx .x ; // 0~M/4
63
+ int lane = tx % WARP_SIZE; // 0~31
64
+ int m = blockDim .y * bx + ty; // (0~M/4) * 4 + (0~3)
65
+
66
+ if (m < M) {
67
+ half sum = 0 .0f ;
68
+ // process 4*WARP_SIZE elements per warp.
69
+ int NUM_WARPS = (((K + WARP_SIZE - 1 ) / WARP_SIZE) + 4 - 1 ) / 4 ;
70
+ #pragma unroll
71
+ for (int w = 0 ; w < NUM_WARPS; ++w) {
72
+ int k = (w * WARP_SIZE + lane) * 4 ;
73
+ half2 reg_x_0 = HALF2 (x[k + 0 ]);
74
+ half2 reg_x_1 = HALF2 (x[k + 2 ]);
75
+ half2 reg_a_0 = HALF2 (a[m * K + k + 0 ]);
76
+ half2 reg_a_1 = HALF2 (a[m * K + k + 2 ]);
77
+ sum += (reg_x_0.x * reg_a_0.x + reg_x_0.y * reg_a_0.y
78
+ + reg_x_1.x * reg_a_1.x + reg_x_1.y * reg_a_1.y );
79
+ }
80
+ sum = warp_reduce_sum_f16<WARP_SIZE>(sum);
81
+ if (lane == 0 ) y[m] = sum;
82
+ }
83
+ }
84
+
85
+ // HGEMV: Warp HGEMV K16
86
+ // 假设K为16 < 32,每个warp负责2行,每行有16个元素
87
+ // NUM_THREADS=128, NUM_WARPS=NUM_THREADS/WARP_SIZE;
88
+ // NUM_ROWS=NUM_WARPS * ROW_PER_WARP, grid(M/NUM_ROWS), block(32,NUM_WARPS)
89
+ // a: MxK, x: Kx1, y: Mx1, compute: y = a * x
90
+ template <const int ROW_PER_WARP = 2 >
91
+ __global__ void hgemv_k16_f16_kernel (half* A, half* x, half* y, int M, int K) {
92
+ constexpr int K_WARP_SIZE = (WARP_SIZE + ROW_PER_WARP - 1 ) / ROW_PER_WARP;
93
+ int tx = threadIdx .x ; // 0~31
94
+ int ty = threadIdx .y ; // 0~NUM_WARPS
95
+ int bx = blockIdx .x ; // 0~M/NUM_ROWS (NUM_ROWS=NUM_WARPS * ROW_PER_WARP)
96
+ int lane = tx % WARP_SIZE; // 0~31
97
+ int k = lane % K_WARP_SIZE; // 0~15
98
+ // gloabl row of a: MxK and y:Mx1, blockDim.y=NUM_WARPS
99
+ int m = (blockDim .y * bx + ty) * ROW_PER_WARP + lane / K_WARP_SIZE;
100
+ if (m < M) {
101
+ half sum = A[m * K + k] * x[k];
102
+ sum = warp_reduce_sum_f16<K_WARP_SIZE>(sum);
103
+ // 注意是k == 0,而不是lane == 0
104
+ if (k == 0 ) y[m] = sum;
105
+ }
106
+ }
107
+
108
+ // --------------------- PyTorch bindings for custom kernel -----------------------
109
+ #define STRINGFY (str ) #str
110
+ #define TORCH_BINDING_COMMON_EXTENSION (func ) \
111
+ m.def(STRINGFY(func), &func, STRINGFY(func));
112
+
113
+ #define CHECK_TORCH_TENSOR_DTYPE (T, th_type ) \
114
+ if (((T).options().dtype() != (th_type))) { \
115
+ std::cout << " Tensor Info:" << (T).options () << std::endl; \
116
+ throw std::runtime_error (" values must be " #th_type); \
117
+ }
118
+
119
+ #define CHECK_TORCH_TENSOR_SHAPE (T, S0, S1 ) \
120
+ if (((T).size(0 ) != (S0)) || ((T).size(1 ) != (S1))) { \
121
+ throw std::runtime_error (" Tensor size mismatch!" ); \
122
+ }
123
+
124
+ #define ASSERT_K_IS_MULTIBLE_OF (V ) \
125
+ if (K % (V) != 0 ) { throw std::runtime_error (" K must be multiple of " #V); }
126
+
127
+ #define ASSERT_K_IS_EQUAL_OF (V ) \
128
+ if (K != (V)) { throw std::runtime_error (" K must be " #V);}
129
+
130
+ void hgemv_k32_f16 (torch::Tensor a, torch::Tensor x, torch::Tensor y) {
131
+ CHECK_TORCH_TENSOR_DTYPE (a, torch::kHalf )
132
+ CHECK_TORCH_TENSOR_DTYPE (x, torch::kHalf )
133
+ CHECK_TORCH_TENSOR_DTYPE (y, torch::kHalf )
134
+ const int M = a.size (0 );
135
+ const int K = a.size (1 );
136
+ CHECK_TORCH_TENSOR_SHAPE (a, M, K)
137
+ CHECK_TORCH_TENSOR_SHAPE (x, K, 1 )
138
+ CHECK_TORCH_TENSOR_SHAPE (y, M, 1 )
139
+ ASSERT_K_IS_MULTIBLE_OF (32 )
140
+
141
+ dim3 block (32 , 4 );
142
+ dim3 grid ((M + 4 - 1 ) / 4 );
143
+
144
+ hgemv_k32_f16_kernel<<<grid, block>>> (
145
+ reinterpret_cast <half*>(a.data_ptr ()),
146
+ reinterpret_cast <half*>(x.data_ptr ()),
147
+ reinterpret_cast <half*>(y.data_ptr ()),
148
+ M, K
149
+ );
150
+ }
151
+
152
+ void hgemv_k128_f16x4 (torch::Tensor a, torch::Tensor x, torch::Tensor y) {
153
+ CHECK_TORCH_TENSOR_DTYPE (a, torch::kHalf )
154
+ CHECK_TORCH_TENSOR_DTYPE (x, torch::kHalf )
155
+ CHECK_TORCH_TENSOR_DTYPE (y, torch::kHalf )
156
+ const int M = a.size (0 );
157
+ const int K = a.size (1 );
158
+ CHECK_TORCH_TENSOR_SHAPE (a, M, K)
159
+ CHECK_TORCH_TENSOR_SHAPE (x, K, 1 )
160
+ CHECK_TORCH_TENSOR_SHAPE (y, M, 1 )
161
+ ASSERT_K_IS_MULTIBLE_OF (128 )
162
+
163
+ dim3 block (32 , 4 );
164
+ dim3 grid ((M + 4 - 1 ) / 4 );
165
+
166
+ hgemv_k128_f16x4_kernel<<<grid, block>>> (
167
+ reinterpret_cast <half*>(a.data_ptr ()),
168
+ reinterpret_cast <half*>(x.data_ptr ()),
169
+ reinterpret_cast <half*>(y.data_ptr ()),
170
+ M, K
171
+ );
172
+ }
173
+
174
+ void hgemv_k16_f16 (torch::Tensor a, torch::Tensor x, torch::Tensor y) {
175
+ CHECK_TORCH_TENSOR_DTYPE (a, torch::kHalf )
176
+ CHECK_TORCH_TENSOR_DTYPE (x, torch::kHalf )
177
+ CHECK_TORCH_TENSOR_DTYPE (y, torch::kHalf )
178
+ const int M = a.size (0 );
179
+ const int K = a.size (1 );
180
+ CHECK_TORCH_TENSOR_SHAPE (a, M, K)
181
+ CHECK_TORCH_TENSOR_SHAPE (x, K, 1 )
182
+ CHECK_TORCH_TENSOR_SHAPE (y, M, 1 )
183
+ ASSERT_K_IS_EQUAL_OF (16 )
184
+
185
+ constexpr int NUM_THREADS = 128 ;
186
+ constexpr int ROW_PER_WARP = 2 ;
187
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; // 4
188
+ constexpr int NUM_ROWS = NUM_WARPS * ROW_PER_WARP; // 4 * 2 = 8
189
+
190
+ dim3 block (32 , NUM_WARPS);
191
+ dim3 grid ((M + NUM_ROWS - 1 ) / NUM_ROWS);
192
+
193
+ hgemv_k16_f16_kernel<ROW_PER_WARP><<<grid, block>>> (
194
+ reinterpret_cast <half*>(a.data_ptr ()),
195
+ reinterpret_cast <half*>(x.data_ptr ()),
196
+ reinterpret_cast <half*>(y.data_ptr ()),
197
+ M, K
198
+ );
199
+ }
200
+
201
+ PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
202
+ TORCH_BINDING_COMMON_EXTENSION (hgemv_k32_f16)
203
+ TORCH_BINDING_COMMON_EXTENSION (hgemv_k128_f16x4)
204
+ TORCH_BINDING_COMMON_EXTENSION (hgemv_k16_f16)
205
+ }
0 commit comments