-
Notifications
You must be signed in to change notification settings - Fork 749
Add weight tying support for Llama3 #2580
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
dean-mccoppin
wants to merge
2
commits into
pytorch:main
Choose a base branch
from
dean-mccoppin:feat/llama3-weight-tying
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+230
−12
Open
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import unittest | ||
|
|
||
| try: | ||
| from torchtitan.models.common.attention import GQAttention | ||
| from torchtitan.models.common.embedding import Embedding | ||
| from torchtitan.models.common.feed_forward import ( | ||
| compute_ffn_hidden_dim, | ||
| FeedForward, | ||
| ) | ||
| from torchtitan.models.common.linear import Linear | ||
| from torchtitan.models.common.rmsnorm import RMSNorm | ||
| from torchtitan.models.common.rope import RoPE | ||
| from torchtitan.models.llama3.model import Llama3Model, Llama3TransformerBlock | ||
|
|
||
| HAS_TORCHTITAN_MODELS = True | ||
| except Exception: | ||
| HAS_TORCHTITAN_MODELS = False | ||
|
|
||
| _SKIP_MSG = "torchtitan model imports not available (missing triton or other deps)" | ||
|
|
||
|
|
||
| def _make_config(enable_weight_tying: bool = False): | ||
| return Llama3Model.Config( | ||
| dim=64, | ||
| n_layers=2, | ||
| vocab_size=256, | ||
| enable_weight_tying=enable_weight_tying, | ||
| tok_embeddings=Embedding.Config(), | ||
| norm=RMSNorm.Config(), | ||
| output=Linear.Config(), | ||
| layer=Llama3TransformerBlock.Config( | ||
| attention_norm=RMSNorm.Config(), | ||
| ffn_norm=RMSNorm.Config(), | ||
| feed_forward=FeedForward.Config( | ||
| hidden_dim=compute_ffn_hidden_dim(64, multiple_of=64), | ||
| ), | ||
| attention=GQAttention.Config( | ||
| n_heads=4, | ||
| attn_backend="sdpa", | ||
| rope_backend="complex", | ||
| ), | ||
| ), | ||
| rope=RoPE.Config( | ||
| dim=64 // 4, | ||
| max_seq_len=512, | ||
| theta=500000, | ||
| backend="complex", | ||
| scaling="llama", | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| @unittest.skipUnless(HAS_TORCHTITAN_MODELS, _SKIP_MSG) | ||
| class TestLlama3WeightTying(unittest.TestCase): | ||
| def test_weights_are_shared_when_tying_enabled(self): | ||
| """tok_embeddings.weight and output.weight should share the same storage.""" | ||
| model = Llama3Model(_make_config(enable_weight_tying=True)) | ||
| self.assertIs( | ||
| model.tok_embeddings.weight, | ||
| model.output.weight, | ||
| "tok_embeddings.weight and output.weight must be the same tensor object", | ||
| ) | ||
|
|
||
| def test_weights_are_independent_when_tying_disabled(self): | ||
| """Without weight tying, tok_embeddings and output have separate weights.""" | ||
| model = Llama3Model(_make_config(enable_weight_tying=False)) | ||
| self.assertIsNot( | ||
| model.tok_embeddings.weight, | ||
| model.output.weight, | ||
| "tok_embeddings.weight and output.weight must be distinct tensor objects", | ||
| ) | ||
|
|
||
| def test_weights_remain_tied_after_init_weights(self): | ||
| """Weights must still be shared after calling init_weights.""" | ||
| model = Llama3Model(_make_config(enable_weight_tying=True)) | ||
| model.init_weights() | ||
| self.assertIs( | ||
| model.tok_embeddings.weight, | ||
| model.output.weight, | ||
| "tok_embeddings.weight and output.weight must remain tied after init_weights", | ||
| ) | ||
|
|
||
| def test_pp_guard_raises_when_weight_tying_and_pp_enabled(self): | ||
| """update_from_config must raise NotImplementedError when PP > 1 and weight tying is on.""" | ||
| from unittest.mock import MagicMock | ||
|
|
||
| config = _make_config(enable_weight_tying=True) | ||
|
|
||
| trainer_config = MagicMock() | ||
| trainer_config.training.seq_len = 512 | ||
| trainer_config.parallelism.pipeline_parallel_degree = 2 | ||
| trainer_config.parallelism.context_parallel_degree = 1 | ||
|
|
||
| with self.assertRaises(NotImplementedError): | ||
| config.update_from_config(trainer_config=trainer_config) | ||
|
|
||
| def test_pp_guard_does_not_raise_without_weight_tying(self): | ||
| """update_from_config must NOT raise when PP > 1 and weight tying is off.""" | ||
| from unittest.mock import MagicMock | ||
|
|
||
| config = _make_config(enable_weight_tying=False) | ||
|
|
||
| trainer_config = MagicMock() | ||
| trainer_config.training.seq_len = 512 | ||
| trainer_config.parallelism.pipeline_parallel_degree = 2 | ||
| trainer_config.parallelism.context_parallel_degree = 1 | ||
|
|
||
| # Should not raise | ||
| config.update_from_config(trainer_config=trainer_config) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my reasoning was because torchtitan/models/common/init.py re-exports from moe/, which imports triton at module level, and if triton isnt installed then the import chain fails. but i did just notice that many existing unit tests import from the same torchtitan.models.common.* submodules with no guard and they pass in CI, so I'll be removing this