Skip to content

Commit ceda986

Browse files
committed
Add Llama 3.2 1B and 3B to model registry, clean up test imports
Llama 3.2 1B and 3B are the only Llama variants with weight tying, so they belong in the registry. Without them the feature has no real entry point. Also dropped the try/except guard in test_weight_tying.py, which was inconsistent with every other unit test here and silently skips on broken imports.
1 parent 365773c commit ceda986

File tree

2 files changed

+67
-18
lines changed

2 files changed

+67
-18
lines changed

tests/unit_tests/test_weight_tying.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,13 @@
66

77
import unittest
88

9-
try:
10-
from torchtitan.models.common.attention import GQAttention
11-
from torchtitan.models.common.embedding import Embedding
12-
from torchtitan.models.common.feed_forward import (
13-
compute_ffn_hidden_dim,
14-
FeedForward,
15-
)
16-
from torchtitan.models.common.linear import Linear
17-
from torchtitan.models.common.rmsnorm import RMSNorm
18-
from torchtitan.models.common.rope import RoPE
19-
from torchtitan.models.llama3.model import Llama3Model, Llama3TransformerBlock
20-
21-
HAS_TORCHTITAN_MODELS = True
22-
except Exception:
23-
HAS_TORCHTITAN_MODELS = False
24-
25-
_SKIP_MSG = "torchtitan model imports not available (missing triton or other deps)"
9+
from torchtitan.models.common.attention import GQAttention
10+
from torchtitan.models.common.embedding import Embedding
11+
from torchtitan.models.common.feed_forward import compute_ffn_hidden_dim, FeedForward
12+
from torchtitan.models.common.linear import Linear
13+
from torchtitan.models.common.rmsnorm import RMSNorm
14+
from torchtitan.models.common.rope import RoPE
15+
from torchtitan.models.llama3.model import Llama3Model, Llama3TransformerBlock
2616

2717

2818
def _make_config(enable_weight_tying: bool = False):
@@ -56,7 +46,6 @@ def _make_config(enable_weight_tying: bool = False):
5646
)
5747

5848

59-
@unittest.skipUnless(HAS_TORCHTITAN_MODELS, _SKIP_MSG)
6049
class TestLlama3WeightTying(unittest.TestCase):
6150
def test_weights_are_shared_when_tying_enabled(self):
6251
"""tok_embeddings.weight and output.weight should share the same storage."""

torchtitan/models/llama3/__init__.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,66 @@
113113
scaling="llama",
114114
),
115115
),
116+
"1B": Llama3Model.Config(
117+
dim=2048,
118+
n_layers=16,
119+
enable_weight_tying=True,
120+
tok_embeddings=Embedding.Config(),
121+
norm=RMSNorm.Config(),
122+
output=Linear.Config(),
123+
layer=Llama3TransformerBlock.Config(
124+
attention_norm=RMSNorm.Config(),
125+
ffn_norm=RMSNorm.Config(),
126+
feed_forward=FeedForward.Config(
127+
hidden_dim=compute_ffn_hidden_dim(
128+
2048, multiple_of=1024, ffn_dim_multiplier=1.5
129+
),
130+
),
131+
attention=GQAttention.Config(
132+
n_heads=32,
133+
n_kv_heads=8,
134+
attn_backend="sdpa",
135+
rope_backend="complex",
136+
),
137+
),
138+
rope=RoPE.Config(
139+
dim=2048 // 32,
140+
max_seq_len=131072,
141+
theta=500000,
142+
backend="complex",
143+
scaling="llama",
144+
),
145+
),
146+
"3B": Llama3Model.Config(
147+
dim=3072,
148+
n_layers=28,
149+
enable_weight_tying=True,
150+
tok_embeddings=Embedding.Config(),
151+
norm=RMSNorm.Config(),
152+
output=Linear.Config(),
153+
layer=Llama3TransformerBlock.Config(
154+
attention_norm=RMSNorm.Config(),
155+
ffn_norm=RMSNorm.Config(),
156+
feed_forward=FeedForward.Config(
157+
hidden_dim=compute_ffn_hidden_dim(
158+
3072, multiple_of=1024, ffn_dim_multiplier=1.0
159+
),
160+
),
161+
attention=GQAttention.Config(
162+
n_heads=24,
163+
n_kv_heads=8,
164+
attn_backend="sdpa",
165+
rope_backend="complex",
166+
),
167+
),
168+
rope=RoPE.Config(
169+
dim=3072 // 24,
170+
max_seq_len=131072,
171+
theta=500000,
172+
backend="complex",
173+
scaling="llama",
174+
),
175+
),
116176
"8B": Llama3Model.Config(
117177
dim=4096,
118178
n_layers=32,

0 commit comments

Comments
 (0)