Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 88705cb

Browse files
committed
feat: Add support for attention and ff biases
Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 766bee9 commit 88705cb

File tree

2 files changed

+24
-16
lines changed

2 files changed

+24
-16
lines changed

torchchat/cli/convert_hf_checkpoint.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,17 @@ def convert_hf_checkpoint(
8181
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
8282
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
8383
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
84+
"model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias",
85+
"model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias",
86+
"model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias",
87+
"model.layers.{}.self_attn.o_proj.bias": "layers.{}.attention.wo.bias",
8488
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
8589
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
8690
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
8791
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
92+
"model.layers.{}.mlp.gate_proj.bias": "layers.{}.feed_forward.w1.bias",
93+
"model.layers.{}.mlp.up_proj.bias": "layers.{}.feed_forward.w3.bias",
94+
"model.layers.{}.mlp.down_proj.bias": "layers.{}.feed_forward.w2.bias",
8895
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
8996
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
9097
"model.norm.weight": "norm.weight",
@@ -135,17 +142,15 @@ def load_safetensors():
135142
if "layers" in key:
136143
abstract_key = re.sub(r"(\d+)", "{}", key)
137144
layer_num = re.search(r"\d+", key).group(0)
138-
new_key = weight_map[abstract_key]
139-
if new_key is None:
140-
continue
145+
new_key = weight_map.get(abstract_key, abstract_key)
141146
new_key = new_key.format(layer_num)
142147
else:
143-
new_key = weight_map[key]
148+
new_key = weight_map.get(key, key)
144149

145150
final_result[new_key] = value
146151

147152
for key in tuple(final_result.keys()):
148-
if "wq" in key:
153+
if "wq.weight" in key:
149154
q = final_result[key]
150155
k = final_result[key.replace("wq", "wk")]
151156
v = final_result[key.replace("wq", "wv")]

torchchat/model.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
try:
3535
# TODO: remove this after we figure out where in torchtune an `evaluate` module
3636
# is being imported, which is being confused with huggingface's `evaluate``.
37-
import lm_eval # noqa
37+
import lm_eval # noqa
3838
except Exception:
3939
pass
4040

@@ -278,6 +278,9 @@ class TransformerArgs:
278278
# For pipeline parallel
279279
n_stages: int = 1
280280
stage_idx: int = 0
281+
# Optional biases
282+
attention_bias: bool = False
283+
feed_forward_bias: bool = False
281284

282285
def __post_init__(self):
283286
if self.n_local_heads == -1:
@@ -394,7 +397,7 @@ def from_name(cls, name: str):
394397
config = [
395398
config
396399
for config in known_model_params
397-
if config in str(name).upper() or config in str(name)
400+
if config.upper() in str(name).upper() or config in str(name)
398401
]
399402

400403
# We may have two or more configs matched (e.g., "7B" and
@@ -471,7 +474,7 @@ def build_model(self) -> nn.Module:
471474
modules[name] = module_class(TransformerArgs.from_params(config_args))
472475
else:
473476
modules[name] = module_class(**config_args)
474-
477+
475478
# Temporary add extra params to the DeepFusionModel.
476479
# TODO: Remove it once we can make fusion model configurable in model_param.
477480
if recipe.fusion_class == DeepFusionModel:
@@ -730,16 +733,16 @@ def __init__(self, config: TransformerArgs):
730733

731734
# key, query, value projections for all heads, but in a batch
732735
# total_head_dim = (config.n_heads + 2 * config.n_local_heads) * config.head_dim
733-
# self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
734-
self.wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False)
736+
# self.wqkv = nn.Linear(config.dim, total_head_dim, bias=config.attention_bias)
737+
self.wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=config.attention_bias)
735738
self.wk = nn.Linear(
736-
config.dim, config.n_local_heads * config.head_dim, bias=False
739+
config.dim, config.n_local_heads * config.head_dim, bias=config.attention_bias
737740
)
738741
self.wv = nn.Linear(
739-
config.dim, config.n_local_heads * config.head_dim, bias=False
742+
config.dim, config.n_local_heads * config.head_dim, bias=config.attention_bias
740743
)
741744

742-
self.wo = nn.Linear(config.dim, config.dim, bias=False)
745+
self.wo = nn.Linear(config.dim, config.dim, bias=config.attention_bias)
743746
self.kv_cache = None
744747

745748
self.n_heads = config.n_heads
@@ -852,9 +855,9 @@ def forward(
852855
class FeedForward(nn.Module):
853856
def __init__(self, config: TransformerArgs) -> None:
854857
super().__init__()
855-
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
856-
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
857-
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
858+
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=config.feed_forward_bias)
859+
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=config.feed_forward_bias)
860+
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=config.feed_forward_bias)
858861

859862
def distribute(self, device_mesh: DeviceMesh):
860863
parallelize_module(self.w1, device_mesh, ColwiseParallel())

0 commit comments

Comments
 (0)