Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 6b6b77c

Browse files
committed
feat: Use multipliers conditionally in the model architecture
Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 06566d9 commit 6b6b77c

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

torchchat/model.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -728,13 +728,20 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int
728728
if self.tok_embeddings:
729729
x = self.tok_embeddings(x)
730730

731+
# For Granite architectures
732+
if self.config.embedding_multiplier:
733+
x = x * self.config.embedding_multiplier
734+
731735
for _, layer in self.layers.items():
732736
x = layer(x, input_pos, freqs_cis, mask, cache_lane=cache_lane)
733737

734738
if self.norm:
735739
x = self.norm(x)
736740
if self.output:
737741
x = self.output(x)
742+
# For granite architectures
743+
if self.config.logits_scaling:
744+
x = x / self.config.logits_scaling
738745
# print(f"output shape: {x.shape}")
739746
return x
740747

@@ -746,6 +753,12 @@ def __init__(self, config: TransformerArgs) -> None:
746753
self.feed_forward = FeedForward(config)
747754
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
748755
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
756+
# None for llama architecture, set for granite architectures
757+
self.residual_multiplier = (
758+
config.residual_multiplier
759+
if config.residual_multiplier is not None
760+
else 1.0
761+
)
749762

750763
def distribute(self, device_mesh: DeviceMesh):
751764
self.attention.distribute(device_mesh)
@@ -756,8 +769,8 @@ def forward(
756769
) -> Tensor:
757770
h = x + self.attention(
758771
self.attention_norm(x), freqs_cis, mask, input_pos, cache_lane=cache_lane
759-
)
760-
out = h + self.feed_forward(self.ffn_norm(h))
772+
) * self.residual_multiplier
773+
out = h + self.feed_forward(self.ffn_norm(h)) * self.residual_multiplier
761774
return out
762775

763776

@@ -784,6 +797,7 @@ def __init__(self, config: TransformerArgs):
784797
self.head_dim = config.head_dim
785798
self.n_local_heads = config.n_local_heads
786799
self.dim = config.dim
800+
self.attention_scale = config.attention_multiplier
787801
self._register_load_state_dict_pre_hook(self.load_hook)
788802

789803
def setup_cache(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
@@ -880,7 +894,16 @@ def forward(
880894

881895
k = k.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
882896
v = v.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
883-
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
897+
y = F.scaled_dot_product_attention(
898+
query=q,
899+
key=k,
900+
value=v,
901+
attn_mask=mask,
902+
dropout_p=0.0,
903+
# This is None (default) for llama architecture and set for granite
904+
# architectures
905+
scale=self.attention_scale,
906+
)
884907

885908
# -1 = self.dim
886909
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

0 commit comments

Comments
 (0)