Skip to content

Commit e105424

Browse files
authored
[Optimization] Implement fused add rmsnorm (#1667)
1 parent 8d17774 commit e105424

File tree

9 files changed

+166
-61
lines changed

9 files changed

+166
-61
lines changed

csrc/layernorm.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,19 @@ void rms_norm(
66
torch::Tensor& weight,
77
float epsilon);
88

9+
void fused_add_rms_norm(
10+
torch::Tensor& input,
11+
torch::Tensor& residual,
12+
torch::Tensor& weight,
13+
float epsilon);
14+
915
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1016
m.def(
1117
"rms_norm",
1218
&rms_norm,
1319
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
20+
m.def(
21+
"fused_add_rms_norm",
22+
&fused_add_rms_norm,
23+
"In-place fused Add and RMS Normalization");
1424
}

csrc/layernorm_kernels.cu

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,36 @@ __global__ void rms_norm_kernel(
3434
}
3535
}
3636

37+
// TODO: Further optimize this kernel.
38+
template<typename scalar_t>
39+
__global__ void fused_add_rms_norm_kernel(
40+
scalar_t* __restrict__ input, // [..., hidden_size]
41+
scalar_t* __restrict__ residual, // [..., hidden_size]
42+
const scalar_t* __restrict__ weight, // [hidden_size]
43+
const float epsilon,
44+
const int num_tokens,
45+
const int hidden_size) {
46+
__shared__ float s_variance;
47+
float variance = 0.0f;
48+
49+
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
50+
float x = (float) input[blockIdx.x * hidden_size + idx];
51+
x += (float) residual[blockIdx.x * hidden_size + idx];
52+
variance += x * x;
53+
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x;
54+
}
55+
variance = blockReduceSum<float>(variance);
56+
if (threadIdx.x == 0) {
57+
s_variance = rsqrtf(variance / hidden_size + epsilon);
58+
}
59+
__syncthreads();
60+
61+
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
62+
float x = (float) residual[blockIdx.x * hidden_size + idx];
63+
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
64+
}
65+
}
66+
3767
} // namespace vllm
3868

3969
void rms_norm(
@@ -60,3 +90,28 @@ void rms_norm(
6090
hidden_size);
6191
});
6292
}
93+
94+
void fused_add_rms_norm(
95+
torch::Tensor& input, // [..., hidden_size]
96+
torch::Tensor& residual, // [..., hidden_size]
97+
torch::Tensor& weight, // [hidden_size]
98+
float epsilon) {
99+
int hidden_size = input.size(-1);
100+
int num_tokens = input.numel() / hidden_size;
101+
102+
dim3 grid(num_tokens);
103+
dim3 block(std::min(hidden_size, 1024));
104+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
105+
VLLM_DISPATCH_FLOATING_TYPES(
106+
input.scalar_type(),
107+
"fused_add_rms_norm_kernel",
108+
[&] {
109+
vllm::fused_add_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
110+
input.data_ptr<scalar_t>(),
111+
residual.data_ptr<scalar_t>(),
112+
weight.data_ptr<scalar_t>(),
113+
epsilon,
114+
num_tokens,
115+
hidden_size);
116+
});
117+
}

vllm/model_executor/layers/layernorm.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""Custom normalization layers."""
2+
from typing import Optional, Tuple, Union
3+
24
import torch
35
import torch.nn as nn
46

@@ -21,7 +23,19 @@ def __init__(
2123
self.weight = nn.Parameter(torch.ones(hidden_size))
2224
self.variance_epsilon = eps
2325

24-
def forward(self, x: torch.Tensor) -> torch.Tensor:
26+
def forward(
27+
self,
28+
x: torch.Tensor,
29+
residual: Optional[torch.Tensor] = None,
30+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
31+
if residual is not None:
32+
layernorm_ops.fused_add_rms_norm(
33+
x,
34+
residual,
35+
self.weight.data,
36+
self.variance_epsilon,
37+
)
38+
return x, residual
2539
out = torch.empty_like(x)
2640
layernorm_ops.rms_norm(
2741
out,

vllm/model_executor/models/baichuan.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -225,25 +225,28 @@ def forward(
225225
kv_cache: KVCache,
226226
input_metadata: InputMetadata,
227227
cache_event: Optional[torch.cuda.Event],
228-
) -> torch.Tensor:
228+
residual: Optional[torch.Tensor],
229+
) -> Tuple[torch.Tensor, torch.Tensor]:
229230
# Self Attention
230-
residual = hidden_states
231-
hidden_states = self.input_layernorm(hidden_states)
231+
if residual is None:
232+
residual = hidden_states
233+
hidden_states = self.input_layernorm(hidden_states)
234+
else:
235+
hidden_states, residual = self.input_layernorm(
236+
hidden_states, residual)
232237
hidden_states = self.self_attn(
233238
positions=positions,
234239
hidden_states=hidden_states,
235240
kv_cache=kv_cache,
236241
input_metadata=input_metadata,
237242
cache_event=cache_event,
238243
)
239-
hidden_states = residual + hidden_states
240244

241245
# Fully Connected
242-
residual = hidden_states
243-
hidden_states = self.post_attention_layernorm(hidden_states)
246+
hidden_states, residual = self.post_attention_layernorm(
247+
hidden_states, residual)
244248
hidden_states = self.mlp(hidden_states)
245-
hidden_states = residual + hidden_states
246-
return hidden_states
249+
return hidden_states, residual
247250

248251

249252
class BaiChuanModel(nn.Module):
@@ -276,20 +279,22 @@ def forward(
276279
cache_events: Optional[List[torch.cuda.Event]],
277280
) -> torch.Tensor:
278281
hidden_states = self.embed_tokens(input_ids)
282+
residual = None
279283
for i in range(len(self.layers)):
280284
if cache_events is None:
281285
cache_event = None
282286
else:
283287
cache_event = cache_events[i]
284288
layer = self.layers[i]
285-
hidden_states = layer(
289+
hidden_states, residual = layer(
286290
positions,
287291
hidden_states,
288292
kv_caches[i],
289293
input_metadata,
290294
cache_event,
295+
residual,
291296
)
292-
hidden_states = self.norm(hidden_states)
297+
hidden_states, _ = self.norm(hidden_states, residual)
293298
return hidden_states
294299

295300

vllm/model_executor/models/internlm.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -155,25 +155,28 @@ def forward(
155155
kv_cache: KVCache,
156156
input_metadata: InputMetadata,
157157
cache_event: Optional[torch.cuda.Event],
158-
) -> torch.Tensor:
158+
residual: Optional[torch.Tensor],
159+
) -> Tuple[torch.Tensor, torch.Tensor]:
159160
# Self Attention
160-
residual = hidden_states
161-
hidden_states = self.input_layernorm(hidden_states)
161+
if residual is None:
162+
residual = hidden_states
163+
hidden_states = self.input_layernorm(hidden_states)
164+
else:
165+
hidden_states, residual = self.input_layernorm(
166+
hidden_states, residual)
162167
hidden_states = self.self_attn(
163168
positions=positions,
164169
hidden_states=hidden_states,
165170
kv_cache=kv_cache,
166171
input_metadata=input_metadata,
167172
cache_event=cache_event,
168173
)
169-
hidden_states = residual + hidden_states
170174

171175
# Fully Connected
172-
residual = hidden_states
173-
hidden_states = self.post_attention_layernorm(hidden_states)
176+
hidden_states, residual = self.post_attention_layernorm(
177+
hidden_states, residual)
174178
hidden_states = self.mlp(hidden_states)
175-
hidden_states = residual + hidden_states
176-
return hidden_states
179+
return hidden_states, residual
177180

178181

179182
class InternLMModel(nn.Module):
@@ -208,20 +211,22 @@ def forward(
208211
cache_events: Optional[List[torch.cuda.Event]],
209212
) -> torch.Tensor:
210213
hidden_states = self.embed_tokens(input_ids)
214+
residual = None
211215
for i in range(len(self.layers)):
212216
if cache_events is None:
213217
cache_event = None
214218
else:
215219
cache_event = cache_events[i]
216220
layer = self.layers[i]
217-
hidden_states = layer(
221+
hidden_states, residual = layer(
218222
positions,
219223
hidden_states,
220224
kv_caches[i],
221225
input_metadata,
222226
cache_event,
227+
residual,
223228
)
224-
hidden_states = self.norm(hidden_states)
229+
hidden_states, _ = self.norm(hidden_states, residual)
225230
return hidden_states
226231

227232

vllm/model_executor/models/llama.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -197,25 +197,28 @@ def forward(
197197
kv_cache: KVCache,
198198
input_metadata: InputMetadata,
199199
cache_event: Optional[torch.cuda.Event],
200-
) -> torch.Tensor:
200+
residual: Optional[torch.Tensor],
201+
) -> Tuple[torch.Tensor, torch.Tensor]:
201202
# Self Attention
202-
residual = hidden_states
203-
hidden_states = self.input_layernorm(hidden_states)
203+
if residual is None:
204+
residual = hidden_states
205+
hidden_states = self.input_layernorm(hidden_states)
206+
else:
207+
hidden_states, residual = self.input_layernorm(
208+
hidden_states, residual)
204209
hidden_states = self.self_attn(
205210
positions=positions,
206211
hidden_states=hidden_states,
207212
kv_cache=kv_cache,
208213
input_metadata=input_metadata,
209214
cache_event=cache_event,
210215
)
211-
hidden_states = residual + hidden_states
212216

213217
# Fully Connected
214-
residual = hidden_states
215-
hidden_states = self.post_attention_layernorm(hidden_states)
218+
hidden_states, residual = self.post_attention_layernorm(
219+
hidden_states, residual)
216220
hidden_states = self.mlp(hidden_states)
217-
hidden_states = residual + hidden_states
218-
return hidden_states
221+
return hidden_states, residual
219222

220223

221224
class LlamaModel(nn.Module):
@@ -248,20 +251,22 @@ def forward(
248251
cache_events: Optional[List[torch.cuda.Event]],
249252
) -> torch.Tensor:
250253
hidden_states = self.embed_tokens(input_ids)
254+
residual = None
251255
for i in range(len(self.layers)):
252256
if cache_events is None:
253257
cache_event = None
254258
else:
255259
cache_event = cache_events[i]
256260
layer = self.layers[i]
257-
hidden_states = layer(
261+
hidden_states, residual = layer(
258262
positions,
259263
hidden_states,
260264
kv_caches[i],
261265
input_metadata,
262266
cache_event,
267+
residual,
263268
)
264-
hidden_states = self.norm(hidden_states)
269+
hidden_states, _ = self.norm(hidden_states, residual)
265270
return hidden_states
266271

267272

vllm/model_executor/models/mistral.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -191,25 +191,28 @@ def forward(
191191
kv_cache: KVCache,
192192
input_metadata: InputMetadata,
193193
cache_event: Optional[torch.cuda.Event],
194-
) -> torch.Tensor:
194+
residual: Optional[torch.Tensor],
195+
) -> Tuple[torch.Tensor, torch.Tensor]:
195196
# Self Attention
196-
residual = hidden_states
197-
hidden_states = self.input_layernorm(hidden_states)
197+
if residual is None:
198+
residual = hidden_states
199+
hidden_states = self.input_layernorm(hidden_states)
200+
else:
201+
hidden_states, residual = self.input_layernorm(
202+
hidden_states, residual)
198203
hidden_states = self.self_attn(
199204
positions=positions,
200205
hidden_states=hidden_states,
201206
kv_cache=kv_cache,
202207
input_metadata=input_metadata,
203208
cache_event=cache_event,
204209
)
205-
hidden_states = residual + hidden_states
206210

207211
# Fully Connected
208-
residual = hidden_states
209-
hidden_states = self.post_attention_layernorm(hidden_states)
212+
hidden_states, residual = self.post_attention_layernorm(
213+
hidden_states, residual)
210214
hidden_states = self.mlp(hidden_states)
211-
hidden_states = residual + hidden_states
212-
return hidden_states
215+
return hidden_states, residual
213216

214217

215218
class MistralModel(nn.Module):
@@ -243,20 +246,22 @@ def forward(
243246
cache_events: Optional[List[torch.cuda.Event]],
244247
) -> torch.Tensor:
245248
hidden_states = self.embed_tokens(input_ids)
249+
residual = None
246250
for i in range(len(self.layers)):
247251
if cache_events is None:
248252
cache_event = None
249253
else:
250254
cache_event = cache_events[i]
251255
layer = self.layers[i]
252-
hidden_states = layer(
256+
hidden_states, residual = layer(
253257
positions,
254258
hidden_states,
255259
kv_caches[i],
256260
input_metadata,
257261
cache_event,
262+
residual,
258263
)
259-
hidden_states = self.norm(hidden_states)
264+
hidden_states, _ = self.norm(hidden_states, residual)
260265
return hidden_states
261266

262267

0 commit comments

Comments
 (0)