Skip to content

Commit 5953686

Browse files
JRosenkranzjoerunde
authored andcommitted
added mlp and attn bias option to flash and paged llama models (#85)
#### Motivation The `Calico` models currently set the mlp and attention bias to true, which was hard-coded to false in flash and paged llama implementations. This will use the config params set in huggingface/transformers#30031 to set those values properly. #### Modifications - added attention_bias, mlp_bias to config for Flash and Paged Llama implementations (default is False) - set bias in attention and mlp to the config value #### Result Models should be able to load properly if containing attention and mlp bias --------- Signed-off-by: Joshua Rosenkranz <[email protected]> Signed-off-by: Joe Runde <[email protected]> Co-authored-by: Joe Runde <[email protected]>
1 parent 6e68de5 commit 5953686

File tree

3 files changed

+23
-10
lines changed

3 files changed

+23
-10
lines changed

server/text_generation_server/inference_engine/tgis_native.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ def __init__(
9898
model_class = FlashRWForCausalLM
9999

100100
elif model_type == "llama":
101+
# See: https://github.com/ibm-granite/vllm_granite/blob/main/vllm/model_executor/models/llama.py#L353-L354
102+
if self._config.tie_word_embeddings:
103+
aliases = {
104+
"lm_head.weight": ["model.embed_tokens.weight"]
105+
}
101106
if PAGED_ATTENTION:
102107
from text_generation_server.models.custom_modeling.paged_llama_modeling import PagedLlamaForCausalLM
103108
model_class = PagedLlamaForCausalLM

server/text_generation_server/models/custom_modeling/flash_llama_modeling.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def __init__(
6363
tie_word_embeddings=False,
6464
rope_scaling=None,
6565
rope_theta=10000.0,
66+
attention_bias=False,
67+
mlp_bias=False,
6668
**kwargs,
6769
):
6870
self.vocab_size = vocab_size
@@ -84,6 +86,8 @@ def __init__(
8486
self.use_cache = use_cache
8587
self.rope_scaling = rope_scaling
8688
self.rope_theta = rope_theta
89+
self.attention_bias = attention_bias
90+
self.mlp_bias = mlp_bias
8791

8892
super().__init__(
8993
pad_token_id=pad_token_id,
@@ -168,7 +172,7 @@ def _load_gqa(config, prefix: str, weights):
168172
config.hidden_size,
169173
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
170174

171-
return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize))
175+
return TensorParallelColumnLinear(get_linear(weight, bias=config.attention_bias, quantize=config.quantize))
172176

173177

174178
class FlashLlamaAttention(torch.nn.Module):
@@ -209,13 +213,13 @@ def __init__(
209213
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
210214
dim=0,
211215
weights=weights,
212-
bias=False,
216+
bias=config.attention_bias,
213217
)
214218
self.o_proj = TensorParallelRowLinear.load(
215219
config,
216220
prefix=f"{prefix}.o_proj",
217221
weights=weights,
218-
bias=False,
222+
bias=config.attention_bias,
219223
)
220224

221225
def forward(
@@ -298,13 +302,13 @@ def __init__(self, prefix, config, weights):
298302
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
299303
weights=weights,
300304
dim=0,
301-
bias=False,
305+
bias=config.mlp_bias,
302306
)
303307
self.down_proj = TensorParallelRowLinear.load(
304308
config,
305309
prefix=f"{prefix}.down_proj",
306310
weights=weights,
307-
bias=False,
311+
bias=config.mlp_bias,
308312
)
309313
self.intermediate_size = (
310314
config.intermediate_size // weights.process_group.size()

server/text_generation_server/models/custom_modeling/paged_llama_modeling.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def __init__(
6464
tie_word_embeddings=False,
6565
rope_scaling=None,
6666
rope_theta=10000.0,
67+
attention_bias=False,
68+
mlp_bias=False,
6769
**kwargs,
6870
):
6971
self.vocab_size = vocab_size
@@ -85,6 +87,8 @@ def __init__(
8587
self.use_cache = use_cache
8688
self.rope_scaling = rope_scaling
8789
self.rope_theta = rope_theta
90+
self.attention_bias = attention_bias
91+
self.mlp_bias = mlp_bias
8892

8993
super().__init__(
9094
pad_token_id=pad_token_id,
@@ -169,7 +173,7 @@ def _load_gqa(config, prefix: str, weights):
169173
config.hidden_size,
170174
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
171175

172-
return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize))
176+
return TensorParallelColumnLinear(get_linear(weight, bias=config.attention_bias, quantize=config.quantize))
173177

174178

175179
class PagedLlamaAttention(torch.nn.Module):
@@ -207,13 +211,13 @@ def __init__(
207211
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
208212
dim=0,
209213
weights=weights,
210-
bias=False,
214+
bias=config.attention_bias,
211215
)
212216
self.o_proj = TensorParallelRowLinear.load(
213217
config,
214218
prefix=f"{prefix}.o_proj",
215219
weights=weights,
216-
bias=False,
220+
bias=config.attention_bias,
217221
)
218222

219223
def forward(
@@ -280,13 +284,13 @@ def __init__(self, prefix, config, weights):
280284
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
281285
weights=weights,
282286
dim=0,
283-
bias=False,
287+
bias=config.mlp_bias,
284288
)
285289
self.down_proj = TensorParallelRowLinear.load(
286290
config,
287291
prefix=f"{prefix}.down_proj",
288292
weights=weights,
289-
bias=False,
293+
bias=config.mlp_bias,
290294
)
291295
self.intermediate_size = (
292296
config.intermediate_size // weights.process_group.size()

0 commit comments

Comments
 (0)