Add weight tying support for Llama3#2580
Open
dean-mccoppin wants to merge 2 commits intopytorch:mainfrom
Open
Conversation
Ties tok_embeddings.weight to output.weight via enable_weight_tying config flag. Follows the same pattern as Qwen3 (pytorch#1590). Closes pytorch#1524.
tianyu-l
requested changes
Mar 15, 2026
Contributor
tianyu-l
left a comment
There was a problem hiding this comment.
IIUC the existing model registry doesn't have llama3.2 1B / 3B models, which are the only variants which have weight-tying enabled. Please add those models to llama3/__init__.py. You can refer to the exact config in earlier attempt #1376
|
|
||
| HAS_TORCHTITAN_MODELS = True | ||
| except Exception: | ||
| HAS_TORCHTITAN_MODELS = False |
Contributor
Author
There was a problem hiding this comment.
my reasoning was because torchtitan/models/common/init.py re-exports from moe/, which imports triton at module level, and if triton isnt installed then the import chain fails. but i did just notice that many existing unit tests import from the same torchtitan.models.common.* submodules with no guard and they pass in CI, so I'll be removing this
Llama 3.2 1B and 3B are the only Llama variants with weight tying, so they belong in the registry. Without them the feature has no real entry point. Also dropped the try/except guard in test_weight_tying.py, which was inconsistent with every other unit test here and silently skips on broken imports.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Implements enable_weight_tying for Llama3, sharing tok_embeddings.weight with output.weight. It mirrors the Qwen3 implementation from #1590 (thanks!)
Changes cover model.py (config field, tying in init/init_weights, PP guard), parallelize.py (grouped FSDP unit for tied params), state_dict_adapter.py (skip/reconstruct output.weight for HF conversion), and a new unit test file
Closes #1524