-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfused_layernorm_v3.cu
More file actions
157 lines (136 loc) · 5.62 KB
/
fused_layernorm_v3.cu
File metadata and controls
157 lines (136 loc) · 5.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
// fused_layernorm_v3.cu — two-level warp shuffle, float4 loads, multi-row blocks
#include <torch/extension.h>
#include <cuda_runtime.h>
__device__ __forceinline__ void welford_merge(
float& mean_a, float& M2_a, int& cnt_a,
float b_mean, float b_M2, int b_cnt)
{
int new_cnt = cnt_a + b_cnt;
if (new_cnt == 0) return;
float delta = b_mean - mean_a;
mean_a = mean_a + delta * b_cnt / (float)new_cnt;
M2_a = M2_a + b_M2 + delta * delta * cnt_a * b_cnt / (float)new_cnt;
cnt_a = new_cnt;
}
__global__ void fused_layernorm_v3_kernel(
const float* __restrict__ x,
const float* __restrict__ gamma,
const float* __restrict__ beta,
float* __restrict__ out,
int N, float eps)
{
// Block handles ROWS_PER_BLOCK rows simultaneously → fixes occupancy
// threadIdx.y = which row within block, threadIdx.x = position within row
int row = blockIdx.x * blockDim.y + threadIdx.y;
int tid = threadIdx.x;
int lane = tid % 32;
int warp = tid / 32;
int nwarps = blockDim.x / 32;
// Phase 1: Vectorized float4 load + online Welford (single pass)
float mean = 0.f, M2 = 0.f;
int cnt = 0;
// float4: process 4 elements per thread per iteration → 4× bandwidth efficiency
const float4* x4 = reinterpret_cast<const float4*>(x + row * N);
int vec_iters = N / 4;
for (int i = tid; i < vec_iters; i += blockDim.x) {
float4 v = x4[i];
// Welford update for each of the 4 elements
auto update = [&](float val) {
cnt++;
float delta = val - mean;
mean += delta / cnt;
M2 += delta * (val - mean);
};
update(v.x); update(v.y); update(v.z); update(v.w);
}
// Handle remainder (when N % 4 != 0)
for (int i = vec_iters * 4 + tid; i < N; i += blockDim.x) {
cnt++;
float val = x[row * N + i];
float delta = val - mean;
mean += delta / cnt;
M2 += delta * (val - mean);
}
// Phase 2: Warp-level reduction (fully parallel, no smem)
for (int offset = 16; offset > 0; offset >>= 1) {
float b_mean = __shfl_down_sync(0xffffffff, mean, offset);
float b_M2 = __shfl_down_sync(0xffffffff, M2, offset);
int b_cnt = __shfl_down_sync(0xffffffff, cnt, offset);
if (lane < offset)
welford_merge(mean, M2, cnt, b_mean, b_M2, b_cnt);
}
// Phase 3: Cross-warp — warp leaders write to smem
// smem layout: [nwarps means | nwarps M2s | nwarps cnts] per row-slot
extern __shared__ float smem[];
int row_slot = threadIdx.y * nwarps * 3;
float* s_mean = smem + row_slot;
float* s_M2 = s_mean + nwarps;
float* s_cnt = s_M2 + nwarps;
if (lane == 0) {
s_mean[warp] = mean;
s_M2[warp] = M2;
s_cnt[warp] = (float)cnt;
}
__syncthreads(); // barrier 1: wait for all warp leaders
// Phase 4: First warp reduces across all warp leaders — PARALLEL shuffle
// (fixes V2's serial thread-0 loop bottleneck)
if (warp == 0) {
mean = (lane < nwarps) ? s_mean[lane] : 0.f;
M2 = (lane < nwarps) ? s_M2[lane] : 0.f;
cnt = (lane < nwarps) ? (int)s_cnt[lane] : 0;
for (int offset = 16; offset > 0; offset >>= 1) {
float b_mean = __shfl_down_sync(0xffffffff, mean, offset);
float b_M2 = __shfl_down_sync(0xffffffff, M2, offset);
int b_cnt = __shfl_down_sync(0xffffffff, cnt, offset);
if (lane < offset)
welford_merge(mean, M2, cnt, b_mean, b_M2, b_cnt);
}
if (lane == 0) {
s_mean[0] = mean;
s_M2[0] = M2 / N; // store variance
}
}
__syncthreads(); // barrier 2: broadcast final stats
float final_mean = s_mean[0];
float final_inv_std = rsqrtf(s_M2[0] + eps);
// Phase 5: Vectorized float4 write (re-reads x via x4 — already in L1/L2)
float4* out4 = reinterpret_cast<float4*>(out + row * N);
const float4* g4 = reinterpret_cast<const float4*>(gamma);
const float4* b4 = reinterpret_cast<const float4*>(beta);
for (int i = tid; i < vec_iters; i += blockDim.x) {
float4 xv = x4[i];
float4 gv = g4[i];
float4 bv = b4[i];
float4 ov;
ov.x = gv.x * (xv.x - final_mean) * final_inv_std + bv.x;
ov.y = gv.y * (xv.y - final_mean) * final_inv_std + bv.y;
ov.z = gv.z * (xv.z - final_mean) * final_inv_std + bv.z;
ov.w = gv.w * (xv.w - final_mean) * final_inv_std + bv.w;
out4[i] = ov;
}
// Handle remainder
for (int i = vec_iters * 4 + tid; i < N; i += blockDim.x)
out[row * N + i] = gamma[i] * (x[row * N + i] - final_mean) * final_inv_std + beta[i];
}
torch::Tensor fused_layernorm_v3_cuda(
torch::Tensor x, torch::Tensor gamma, torch::Tensor beta, float eps = 1e-5)
{
int B = x.size(0);
int N = x.size(1);
auto out = torch::empty_like(x);
const int ROWS_PER_BLOCK = 4;
const int THREADS_X = 256;
int nwarps = THREADS_X / 32; // 8
dim3 block(THREADS_X, ROWS_PER_BLOCK);
dim3 grid((B + ROWS_PER_BLOCK - 1) / ROWS_PER_BLOCK);
// smem: [mean + M2 + cnt] × nwarps × ROWS_PER_BLOCK
size_t smem = ROWS_PER_BLOCK * nwarps * 3 * sizeof(float);
fused_layernorm_v3_kernel<<<grid, block, smem>>>(
x.data_ptr<float>(), gamma.data_ptr<float>(),
beta.data_ptr<float>(), out.data_ptr<float>(), N, eps);
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_layernorm_v3", &fused_layernorm_v3_cuda,
"Fused LayerNorm v3 — float4 + two-level warp shuffle + multi-row blocks");
}