Skip to content

Add weight tying support for Llama3#2580

Open
dean-mccoppin wants to merge 2 commits intopytorch:mainfrom
dean-mccoppin:feat/llama3-weight-tying
Open

Add weight tying support for Llama3#2580
dean-mccoppin wants to merge 2 commits intopytorch:mainfrom
dean-mccoppin:feat/llama3-weight-tying

Conversation

@dean-mccoppin
Copy link
Contributor

Implements enable_weight_tying for Llama3, sharing tok_embeddings.weight with output.weight. It mirrors the Qwen3 implementation from #1590 (thanks!)

Changes cover model.py (config field, tying in init/init_weights, PP guard), parallelize.py (grouped FSDP unit for tied params), state_dict_adapter.py (skip/reconstruct output.weight for HF conversion), and a new unit test file

Closes #1524

Ties tok_embeddings.weight to output.weight via enable_weight_tying config flag.
Follows the same pattern as Qwen3 (pytorch#1590).

Closes pytorch#1524.
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 15, 2026
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC the existing model registry doesn't have llama3.2 1B / 3B models, which are the only variants which have weight-tying enabled. Please add those models to llama3/__init__.py. You can refer to the exact config in earlier attempt #1376


HAS_TORCHTITAN_MODELS = True
except Exception:
HAS_TORCHTITAN_MODELS = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my reasoning was because torchtitan/models/common/init.py re-exports from moe/, which imports triton at module level, and if triton isnt installed then the import chain fails. but i did just notice that many existing unit tests import from the same torchtitan.models.common.* submodules with no guard and they pass in CI, so I'll be removing this

Llama 3.2 1B and 3B are the only Llama variants with weight tying, so
they belong in the registry. Without them the feature has no real entry
point.

Also dropped the try/except guard in test_weight_tying.py, which was
inconsistent with every other unit test here and silently skips on
broken imports.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Weight tying between embedding and LM head layer

2 participants