|
| 1 | +#include <stdio.h> |
| 2 | +#include <stdlib.h> |
| 3 | +#include <float.h> |
| 4 | +#include <vector> |
| 5 | +#include <algorithm> |
| 6 | +#include <cuda_runtime.h> |
| 7 | +#include <cuda_fp16.h> |
| 8 | +#include <cuda_bf16.h> |
| 9 | +#include <cuda_fp8.h> |
| 10 | + |
| 11 | +#define WARP_SIZE 32 |
| 12 | +#define INT4(value) (reinterpret_cast<int4*>(&(value))[0]) |
| 13 | +#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0]) |
| 14 | +#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0]) |
| 15 | +#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) |
| 16 | +#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0]) |
| 17 | + |
| 18 | +// -------------------------------------- FP32 -------------------------------------- |
| 19 | +// ElementWise Add |
| 20 | +// grid(N/256), block(256) |
| 21 | +// a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b) |
| 22 | +__global__ void elementwise_add_f32_kernel(float* a, float* b, float* c, int N) { |
| 23 | + int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 24 | + if (idx < N) c[idx] = a[idx] + b[idx]; |
| 25 | +} |
| 26 | + |
| 27 | +// ElementWise Add + Vec4 |
| 28 | +// grid(N/256), block(256/4) |
| 29 | +// a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b) |
| 30 | +__global__ void elementwise_add_f32x4_kernel(float* a, float* b, float* c, int N) { |
| 31 | + int idx = 4 * (blockIdx.x * blockDim.x + threadIdx.x); |
| 32 | + if (idx < N) { |
| 33 | + float4 reg_a = FLOAT4(a[idx]); |
| 34 | + float4 reg_b = FLOAT4(b[idx]); |
| 35 | + float4 reg_c; |
| 36 | + reg_c.x = reg_a.x + reg_b.x; |
| 37 | + reg_c.y = reg_a.y + reg_b.y; |
| 38 | + reg_c.z = reg_a.z + reg_b.z; |
| 39 | + reg_c.w = reg_a.w + reg_b.w; |
| 40 | + FLOAT4(c[idx]) = reg_c; |
| 41 | + } |
| 42 | +} |
| 43 | + |
| 44 | +// -------------------------------------- FP16 -------------------------------------- |
| 45 | +// ElementWise Add |
| 46 | +// grid(N/256), block(256) |
| 47 | +// a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b) |
| 48 | +__global__ void elementwise_add_f16_kernel(half* a, half* b, half* c, int N) { |
| 49 | + int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 50 | + if (idx < N) c[idx] = __hadd(a[idx], b[idx]); |
| 51 | +} |
| 52 | + |
| 53 | +// a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b) |
| 54 | +__global__ void elementwise_add_f16x2_kernel(half* a, half* b, half* c, int N) { |
| 55 | + int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x); |
| 56 | + if (idx < N) { |
| 57 | + half2 reg_a = HALF2(a[idx]); |
| 58 | + half2 reg_b = HALF2(b[idx]); |
| 59 | + half2 reg_c; |
| 60 | + reg_c.x = __hadd(reg_a.x, reg_b.x); |
| 61 | + reg_c.y = __hadd(reg_a.y, reg_b.y); |
| 62 | + HALF2(c[idx]) = reg_c; |
| 63 | + } |
| 64 | +} |
| 65 | + |
| 66 | +__global__ void elementwise_add_f16x8_kernel(half* a, half* b, half* c, int N) { |
| 67 | + int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x); |
| 68 | + // manual unroll and improve L2 cache hit rate. |
| 69 | + // Only L2 cache: load 32 bytes in 1 memory issue (default) |
| 70 | + // Enable L1 cache: load 128 bytes in 1 memory issue (-Xptxas -dlcm=ca) |
| 71 | + // why try fp16x8 within 1 threads? ref: https://zhuanlan.zhihu.com/p/641639133 |
| 72 | + // 0. first, tid_0 load 32 bytes in 1 memory issue and cache data into L2 cache. |
| 73 | + // 1. then, tid_1,...,tid_3 hit L2 cache and load data from L2 cache directly. |
| 74 | + half2 reg_a_0 = HALF2(a[idx + 0]); |
| 75 | + half2 reg_a_1 = HALF2(a[idx + 2]); |
| 76 | + half2 reg_a_2 = HALF2(a[idx + 4]); |
| 77 | + half2 reg_a_3 = HALF2(a[idx + 6]); |
| 78 | + half2 reg_b_0 = HALF2(b[idx + 0]); |
| 79 | + half2 reg_b_1 = HALF2(b[idx + 2]); |
| 80 | + half2 reg_b_2 = HALF2(b[idx + 4]); |
| 81 | + half2 reg_b_3 = HALF2(b[idx + 6]); |
| 82 | + half2 reg_c_0, reg_c_1, reg_c_2, reg_c_3; |
| 83 | + reg_c_0.x = __hadd(reg_a_0.x, reg_b_0.x); |
| 84 | + reg_c_0.y = __hadd(reg_a_0.y, reg_b_0.y); |
| 85 | + reg_c_1.x = __hadd(reg_a_1.x, reg_b_1.x); |
| 86 | + reg_c_1.y = __hadd(reg_a_1.y, reg_b_1.y); |
| 87 | + reg_c_2.x = __hadd(reg_a_2.x, reg_b_2.x); |
| 88 | + reg_c_2.y = __hadd(reg_a_2.y, reg_b_2.y); |
| 89 | + reg_c_3.x = __hadd(reg_a_3.x, reg_b_3.x); |
| 90 | + reg_c_3.y = __hadd(reg_a_3.y, reg_b_3.y); |
| 91 | + if ((idx + 0) < N) { HALF2(c[idx + 0]) = reg_c_0; } |
| 92 | + if ((idx + 2) < N) { HALF2(c[idx + 2]) = reg_c_1; } |
| 93 | + if ((idx + 4) < N) { HALF2(c[idx + 4]) = reg_c_2; } |
| 94 | + if ((idx + 6) < N) { HALF2(c[idx + 6]) = reg_c_3; } |
| 95 | +} |
| 96 | + |
| 97 | +__global__ void elementwise_add_f16x8_pack_kernel(half* a, half* b, half* c, int N) { |
| 98 | + int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x); |
| 99 | + // temporary register(memory), .local space in ptx, addressable |
| 100 | + half pack_a[8], pack_b[8], pack_c[8]; // 8x16 bits=128 bits. |
| 101 | + // reinterpret as float4 and load 128 bits in 1 memory issue. |
| 102 | + LDST128BITS(pack_a[0]) = LDST128BITS(a[idx]); // load 128 bits |
| 103 | + LDST128BITS(pack_b[0]) = LDST128BITS(b[idx]); // load 128 bits |
| 104 | + |
| 105 | + #pragma unroll |
| 106 | + for (int i = 0; i < 8; i += 2) { |
| 107 | + // __hadd2 for half2 x 4 |
| 108 | + HALF2(pack_c[i]) = __hadd2(HALF2(pack_a[i]), HALF2(pack_b[i])); |
| 109 | + } |
| 110 | + // reinterpret as float4 and store 128 bits in 1 memory issue. |
| 111 | + if ((idx + 7) < N) { LDST128BITS(c[idx]) = LDST128BITS(pack_c[0]); } |
| 112 | +} |
| 113 | + |
| 114 | + |
0 commit comments