|
| 1 | +#include <stdio.h> |
| 2 | +#include <stdlib.h> |
| 3 | +#include <float.h> |
| 4 | +#include <vector> |
| 5 | +#include <algorithm> |
| 6 | +#include <cuda_runtime.h> |
| 7 | +#include <cuda_fp16.h> |
| 8 | +#include <cuda_bf16.h> |
| 9 | +#include <cuda_fp8.h> |
| 10 | +#include <torch/types.h> |
| 11 | +#include <torch/extension.h> |
| 12 | + |
| 13 | +#define WARP_SIZE 32 |
| 14 | +#define INT4(value) (reinterpret_cast<int4*>(&(value))[0]) |
| 15 | +#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0]) |
| 16 | +#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0]) |
| 17 | +#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) |
| 18 | + |
| 19 | +// -------------------------------------- FP16 -------------------------------------- |
| 20 | +// HGEMM: Block Tile + K Tile, with smem |
| 21 | +// Block Tile (BM, BN) + K Tile (BK=32) |
| 22 | +// grid((N + BN - 1) / BN, (M + BM - 1) / BM), block(BN, BM) |
| 23 | +// a: MxK, b: KxN, c: MxN, compute: c = a * b, all row major |
| 24 | +template<const int BM=32, const int BN=32, const int BK=32> |
| 25 | +__global__ void hgemm_sliced_k_f16_kernel(half* a, half* b, half* c, int M, int N, int K) { |
| 26 | + // [1] Block Tile: 32x32的block处理c上一块32x32的元素计算 |
| 27 | + // [2] K Tile: 使用共享内存,并将K分块为BK大小的块 |
| 28 | + __shared__ half s_a[BM][BK], s_b[BK][BN]; |
| 29 | + |
| 30 | + int bx = blockIdx.x; |
| 31 | + int by = blockIdx.y; |
| 32 | + int tx = threadIdx.x; |
| 33 | + int ty = threadIdx.y; |
| 34 | + int tid = threadIdx.y * blockDim.x + tx; // tid within the block |
| 35 | + // load values to shared memory, 32x32 threads working together |
| 36 | + // to fetch data along the row direction of a and b both for s_a |
| 37 | + // and s_b 32x32x4x2=8KB, we use 32x32 threads within block to |
| 38 | + // load 32x32 elements from global memory to shared memory, namely, |
| 39 | + // each thread will load 1 element. |
| 40 | + int load_smem_a_m = tid / 32; // 0~31, tid / 32, tid / BM, threadIdx.y |
| 41 | + int load_smem_a_k = tid % 32; // 0~31, tid % 32, tid % BK, threadIdx.x |
| 42 | + int load_smem_b_k = tid / 32; // 0~31, tid / 32, tid / BK, threadIdx.y |
| 43 | + int load_smem_b_n = tid % 32; // 0~31, tid % 32, tid % BN, threadIdx.x |
| 44 | + int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c |
| 45 | + int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c |
| 46 | + // if (load_gmem_a_m >= M || load_gmem_b_n >= N) return; |
| 47 | + |
| 48 | + half sum = __float2half(0.f); |
| 49 | + for (int bk = 0; bk < (K + BK - 1) / BK; ++bk) { |
| 50 | + int load_gmem_a_k = bk * BK + load_smem_a_k; |
| 51 | + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; |
| 52 | + s_a[load_smem_a_m][load_smem_a_k] = a[load_gmem_a_addr]; |
| 53 | + int load_gmem_b_k = bk * BK + load_smem_b_k; |
| 54 | + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; |
| 55 | + s_b[load_smem_b_k][load_smem_b_n] = b[load_gmem_b_addr]; |
| 56 | + __syncthreads(); |
| 57 | + #pragma unroll |
| 58 | + for (int k = 0; k < BK; ++k) { |
| 59 | + int comp_smem_a_m = load_smem_a_m; |
| 60 | + int comp_smem_b_n = load_smem_b_n; |
| 61 | + sum += s_a[comp_smem_a_m][k] * s_b[k][comp_smem_b_n]; |
| 62 | + } |
| 63 | + __syncthreads(); |
| 64 | + } |
| 65 | + int store_gmem_c_m = load_gmem_a_m; |
| 66 | + int store_gmem_c_n = load_gmem_b_n; |
| 67 | + int store_gmem_c_addr = store_gmem_c_m * N + store_gmem_c_n; |
| 68 | + c[store_gmem_c_addr] = sum; |
| 69 | +} |
| 70 | + |
| 71 | +// HGEMM: Block Tile + Thread Tile + K Tile + Vec4, with smem |
| 72 | +// BK:TILE_K=8 BM=BN=128 |
| 73 | +// TM=TN=8 增加计算密度 BM/TM=16 BN/TN=16 |
| 74 | +// dim3 blockDim(BN/TN, BM/TM); |
| 75 | +// dim3 gridDim((N + BN - 1) / BN, (M + BM - 1) / BM) |
| 76 | +template<const int BM=128, const int BN=128, const int BK=8, const int TM=8, const int TN=8> |
| 77 | +__global__ void hgemm_t_8x8_sliced_k_f16x4_kernel(half* a, half* b, half* c, int M, int N, int K) { |
| 78 | + // [1] Block Tile: 一个16x16的block处理C上大小为128X128的一个目标块 |
| 79 | + // [2] Thread Tile: 每个thread负责计算TM*TN(8*8)个元素,增加计算密度 |
| 80 | + // [3] K Tile: 将K分块,每块BK大小,迭代(K+BK-1/BK)次, |
| 81 | + // 每次计算TM*TN个元素各自的部分乘累加 |
| 82 | + // [4] Vectorize: 减少load和store指令,使用half2 |
| 83 | + |
| 84 | + // 线程总数16x16=256,每个线程负责计算8x8的元素 |
| 85 | + int bx = blockIdx.x; |
| 86 | + int by = blockIdx.y; |
| 87 | + int tx = threadIdx.x; |
| 88 | + int ty = threadIdx.y; |
| 89 | + int tid = threadIdx.y * blockDim.x + tx; // tid within the block |
| 90 | + __shared__ half s_a[BM][BK], s_b[BK][BN]; // 2*128*8*2=4KB |
| 91 | + |
| 92 | + // 0. 先计算shared memory中的索引 |
| 93 | + // tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=128 BK=8 按行读取 A行主序 |
| 94 | + // 对于s_a每行8个数据,每个线程读取4个,需要2个线程;总共128行,需要128x2刚好256线程 |
| 95 | + int load_smem_a_m = tid / 2; // tid/2 (128/8)*(128/8)=256 threads per block, tid/2->[0,128), BM=128 0~127 |
| 96 | + int load_smem_a_k = (tid % 2 == 0) ? 0 : 4; // (tid%2 == 0) ? 0 : 4, col of s_a 0,4 |
| 97 | + // tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=8 BN=128 按行读取 B行主序 |
| 98 | + // 对于s_b每行128个数据,每个线程读4个数据,需要32个线程;总共8行,需要32x8=256个线程 |
| 99 | + int load_smem_b_k = tid / 32; // tid/32, row of s_b 256/32=8 行 0~7 |
| 100 | + int load_smem_b_n = (tid % 32) * 4; // (tid % 32) * 4, col of s_b 0,4,...,124 |
| 101 | + // 1. 再计算全局内存中的索引 |
| 102 | + // 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块 |
| 103 | + int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c |
| 104 | + int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c |
| 105 | + |
| 106 | + half r_c[TM][TN] = {__float2half(0.0f)}; // 8x8 |
| 107 | + // 2. 先对K进行分块,每块BK大小 |
| 108 | + for (int bk = 0; bk < (K + BK - 1) / BK; ++bk) { |
| 109 | + // 加载数据到共享内存smem s_a BM*BK 128*8 vectorize float4 |
| 110 | + int load_gmem_a_k = bk * BK + load_smem_a_k; // global col of a |
| 111 | + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; |
| 112 | + HALF2(s_a[load_smem_a_m][load_smem_a_k + 0]) = HALF2(a[load_gmem_a_addr + 0]); |
| 113 | + HALF2(s_a[load_smem_a_m][load_smem_a_k + 2]) = HALF2(a[load_gmem_a_addr + 2]); |
| 114 | + // 加载数据到共享内存smem s_b BK*BN 8*128 vectorize float4 |
| 115 | + int load_gmem_b_k = bk * BK + load_smem_b_k; // global row of b |
| 116 | + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; |
| 117 | + HALF2(s_b[load_smem_b_k][load_smem_b_n + 0]) = HALF2(b[load_gmem_b_addr + 0]); |
| 118 | + HALF2(s_b[load_smem_b_k][load_smem_b_n + 2]) = HALF2(b[load_gmem_b_addr + 2]); |
| 119 | + __syncthreads(); |
| 120 | + #pragma unroll |
| 121 | + for (int k = 0; k < BK; k++) { |
| 122 | + // 3. 每个线程负责计算BM*BN(12x128)中的TM*TN(8x8)个元素 |
| 123 | + #pragma unroll |
| 124 | + for (int m = 0; m < TM; m++) { |
| 125 | + #pragma unroll |
| 126 | + for (int n = 0; n < TN; n++) { |
| 127 | + // k from 0~7,0 ~ BK, ty and tx range from 0 to 15, 16x8=128 |
| 128 | + int comp_smem_a_m = ty * TM + m; // 128*8 128/TM(8)=16 M方向 16线程 |
| 129 | + int comp_smem_b_n = tx * TN + n; // 8*128 128/TN(8)=16 N方向 16线程 |
| 130 | + r_c[m][n] += s_a[comp_smem_a_m][k] * s_b[k][comp_smem_b_n]; |
| 131 | + } |
| 132 | + } |
| 133 | + } |
| 134 | + __syncthreads(); |
| 135 | + } |
| 136 | + |
| 137 | + #pragma unroll |
| 138 | + for (int m = 0; m < TM; ++m) { |
| 139 | + int store_gmem_c_m = by * BM + ty * TM + m; |
| 140 | + #pragma unroll |
| 141 | + for (int n = 0; n < TN; n += 2) { |
| 142 | + int store_gmem_c_n = bx * BN + tx * TN + n; |
| 143 | + int store_gmem_c_addr = store_gmem_c_m * N + store_gmem_c_n; |
| 144 | + HALF2(c[store_gmem_c_addr]) = HALF2(r_c[m][n]); |
| 145 | + } |
| 146 | + } |
| 147 | +} |
| 148 | + |
| 149 | +// --------------------- PyTorch bindings for custom kernel ----------------------- |
| 150 | +#define STRINGFY(str) #str |
| 151 | +#define TORCH_BINDING_COMMON_EXTENSION(func) \ |
| 152 | + m.def(STRINGFY(func), &func, STRINGFY(func)); |
| 153 | + |
| 154 | +#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ |
| 155 | +if(((T).options().dtype() != (th_type))) { \ |
| 156 | + std::cout << "Tensor Info:" << (T).options() << std::endl; \ |
| 157 | + throw std::runtime_error("values must be "#th_type); \ |
| 158 | +} |
| 159 | + |
| 160 | +#define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \ |
| 161 | +if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \ |
| 162 | + throw std::runtime_error("Tensor size mismatch!"); \ |
| 163 | +} |
| 164 | + |
| 165 | +// HGEMM: Block Tile + K Tile, with smem |
| 166 | +// Block Tile (BM, BN) + K Tile (BK=32) |
| 167 | +// grid((N + BN - 1) / BN, (M + BM - 1) / BM), block(BN, BM) |
| 168 | +// a: MxK, b: KxN, c: MxN, compute: c = a * b, all row major |
| 169 | +void hgemm_sliced_k_f16(torch::Tensor a, torch::Tensor b, torch::Tensor c) { |
| 170 | + CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) |
| 171 | + CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) |
| 172 | + CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) |
| 173 | + const int M = a.size(0); |
| 174 | + const int K = a.size(1); |
| 175 | + const int N = b.size(1); |
| 176 | + CHECK_TORCH_TENSOR_SHAPE(a, M, K) |
| 177 | + CHECK_TORCH_TENSOR_SHAPE(b, K, N) |
| 178 | + CHECK_TORCH_TENSOR_SHAPE(c, M, N) |
| 179 | + constexpr int BM = 32; |
| 180 | + constexpr int BN = 32; |
| 181 | + constexpr int BK = 32; |
| 182 | + |
| 183 | + dim3 block(BN, BM); |
| 184 | + dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM); |
| 185 | + |
| 186 | + hgemm_sliced_k_f16_kernel<BM, BN, BK><<<grid, block>>>( |
| 187 | + reinterpret_cast<half*>(a.data_ptr()), |
| 188 | + reinterpret_cast<half*>(b.data_ptr()), |
| 189 | + reinterpret_cast<half*>(c.data_ptr()), |
| 190 | + M, N, K |
| 191 | + ); |
| 192 | +} |
| 193 | + |
| 194 | +// HGEMM: Block Tile + Thread Tile + K Tile + half2x2, with smem |
| 195 | +// BK:TILE_K=8 BM=BN=128 |
| 196 | +// TM=TN=8 增加计算密度 BM/TM=16 BN/TN=16 |
| 197 | +// dim3 blockDim(BN/TN, BM/TM); |
| 198 | +// dim3 gridDim((N + BN - 1) / BN, (M + BM - 1) / BM) |
| 199 | +void hgemm_t_8x8_sliced_k_f16x4(torch::Tensor a, torch::Tensor b, torch::Tensor c) { |
| 200 | + CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) |
| 201 | + CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) |
| 202 | + CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) |
| 203 | + const int M = a.size(0); |
| 204 | + const int K = a.size(1); |
| 205 | + const int N = b.size(1); |
| 206 | + CHECK_TORCH_TENSOR_SHAPE(a, M, K) |
| 207 | + CHECK_TORCH_TENSOR_SHAPE(b, K, N) |
| 208 | + CHECK_TORCH_TENSOR_SHAPE(c, M, N) |
| 209 | + constexpr int BM = 128; |
| 210 | + constexpr int BN = 128; |
| 211 | + constexpr int BK = 8; |
| 212 | + constexpr int TM = 8; |
| 213 | + constexpr int TN = 8; |
| 214 | + |
| 215 | + dim3 block(BN/TN, BM/TM); |
| 216 | + dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM); |
| 217 | + |
| 218 | + hgemm_t_8x8_sliced_k_f16x4_kernel<BM, BN, BK, TM, TN><<<grid, block>>>( |
| 219 | + reinterpret_cast<half*>(a.data_ptr()), |
| 220 | + reinterpret_cast<half*>(b.data_ptr()), |
| 221 | + reinterpret_cast<half*>(c.data_ptr()), |
| 222 | + M, N, K |
| 223 | + ); |
| 224 | +} |
| 225 | + |
| 226 | +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| 227 | + TORCH_BINDING_COMMON_EXTENSION(hgemm_sliced_k_f16) |
| 228 | + TORCH_BINDING_COMMON_EXTENSION(hgemm_t_8x8_sliced_k_f16x4) |
| 229 | +} |
0 commit comments