Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions tests/unit_tests/test_weight_tying.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

try:
from torchtitan.models.common.attention import GQAttention
from torchtitan.models.common.embedding import Embedding
from torchtitan.models.common.feed_forward import (
compute_ffn_hidden_dim,
FeedForward,
)
from torchtitan.models.common.linear import Linear
from torchtitan.models.common.rmsnorm import RMSNorm
from torchtitan.models.common.rope import RoPE
from torchtitan.models.llama3.model import Llama3Model, Llama3TransformerBlock

HAS_TORCHTITAN_MODELS = True
except Exception:
HAS_TORCHTITAN_MODELS = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my reasoning was because torchtitan/models/common/init.py re-exports from moe/, which imports triton at module level, and if triton isnt installed then the import chain fails. but i did just notice that many existing unit tests import from the same torchtitan.models.common.* submodules with no guard and they pass in CI, so I'll be removing this


_SKIP_MSG = "torchtitan model imports not available (missing triton or other deps)"


def _make_config(enable_weight_tying: bool = False):
return Llama3Model.Config(
dim=64,
n_layers=2,
vocab_size=256,
enable_weight_tying=enable_weight_tying,
tok_embeddings=Embedding.Config(),
norm=RMSNorm.Config(),
output=Linear.Config(),
layer=Llama3TransformerBlock.Config(
attention_norm=RMSNorm.Config(),
ffn_norm=RMSNorm.Config(),
feed_forward=FeedForward.Config(
hidden_dim=compute_ffn_hidden_dim(64, multiple_of=64),
),
attention=GQAttention.Config(
n_heads=4,
attn_backend="sdpa",
rope_backend="complex",
),
),
rope=RoPE.Config(
dim=64 // 4,
max_seq_len=512,
theta=500000,
backend="complex",
scaling="llama",
),
)


@unittest.skipUnless(HAS_TORCHTITAN_MODELS, _SKIP_MSG)
class TestLlama3WeightTying(unittest.TestCase):
def test_weights_are_shared_when_tying_enabled(self):
"""tok_embeddings.weight and output.weight should share the same storage."""
model = Llama3Model(_make_config(enable_weight_tying=True))
self.assertIs(
model.tok_embeddings.weight,
model.output.weight,
"tok_embeddings.weight and output.weight must be the same tensor object",
)

def test_weights_are_independent_when_tying_disabled(self):
"""Without weight tying, tok_embeddings and output have separate weights."""
model = Llama3Model(_make_config(enable_weight_tying=False))
self.assertIsNot(
model.tok_embeddings.weight,
model.output.weight,
"tok_embeddings.weight and output.weight must be distinct tensor objects",
)

def test_weights_remain_tied_after_init_weights(self):
"""Weights must still be shared after calling init_weights."""
model = Llama3Model(_make_config(enable_weight_tying=True))
model.init_weights()
self.assertIs(
model.tok_embeddings.weight,
model.output.weight,
"tok_embeddings.weight and output.weight must remain tied after init_weights",
)

def test_pp_guard_raises_when_weight_tying_and_pp_enabled(self):
"""update_from_config must raise NotImplementedError when PP > 1 and weight tying is on."""
from unittest.mock import MagicMock

config = _make_config(enable_weight_tying=True)

trainer_config = MagicMock()
trainer_config.training.seq_len = 512
trainer_config.parallelism.pipeline_parallel_degree = 2
trainer_config.parallelism.context_parallel_degree = 1

with self.assertRaises(NotImplementedError):
config.update_from_config(trainer_config=trainer_config)

def test_pp_guard_does_not_raise_without_weight_tying(self):
"""update_from_config must NOT raise when PP > 1 and weight tying is off."""
from unittest.mock import MagicMock

config = _make_config(enable_weight_tying=False)

trainer_config = MagicMock()
trainer_config.training.seq_len = 512
trainer_config.parallelism.pipeline_parallel_degree = 2
trainer_config.parallelism.context_parallel_degree = 1

# Should not raise
config.update_from_config(trainer_config=trainer_config)


if __name__ == "__main__":
unittest.main()
28 changes: 28 additions & 0 deletions torchtitan/models/llama3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class Config(Decoder.Config):
dim: int = 4096
n_layers: int = 32
vocab_size: int = 128256
enable_weight_tying: bool = False
layer: TransformerBlock.Config

def update_from_config(
Expand Down Expand Up @@ -108,6 +109,11 @@ def update_from_config(
f"Varlen attention is not supported with CP."
)

if self.enable_weight_tying and parallelism.pipeline_parallel_degree > 1:
raise NotImplementedError(
"Weight tying is not supported with Pipeline Parallel."
)

def get_nparams_and_flops(
self, model: nn.Module, seq_len: int
) -> tuple[int, int]:
Expand All @@ -118,3 +124,25 @@ def get_nparams_and_flops(
2 * (self.dim // self.layer.attention.n_heads),
seq_len,
)

def __init__(self, config: Config):
super().__init__(config)
self.enable_weight_tying = config.enable_weight_tying

if self.enable_weight_tying:
self.tok_embeddings.weight = self.output.weight

def init_weights(
self,
*,
buffer_device: torch.device | None = None,
**kwargs,
):
if self.enable_weight_tying:
# since when the model is initialized on meta device,
# the tying in the __init__ may not have worked correctly
# we ensure the weights are tied here
assert self.tok_embeddings is not None and self.output is not None
self.tok_embeddings.weight = self.output.weight

super().init_weights(buffer_device=buffer_device, **kwargs)
37 changes: 25 additions & 12 deletions torchtitan/models/llama3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,29 +337,42 @@ def apply_fsdp(
reshard_after_forward_policy, pp_enabled
)

if model.tok_embeddings is not None:
if getattr(model, "enable_weight_tying", False):
# When weights are tied, tok_embeddings and output share the same parameter.
# Group them together in one FSDP unit to avoid duplicate all-gathers.
modules = [
m for m in (model.tok_embeddings, model.norm, model.output) if m is not None
]
# pyrefly: ignore [no-matching-overload]
fully_shard(
model.tok_embeddings,
modules,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
reshard_after_forward=reshard_after_forward_policy == "always",
)
else:
if model.tok_embeddings is not None:
# pyrefly: ignore [no-matching-overload]
fully_shard(
model.tok_embeddings,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
# As an optimization, do not reshard_after_forward the last layers by default
# since FSDP would prefetch them immediately after the forward pass
if model.norm is not None and model.output is not None:
# pyrefly: ignore [no-matching-overload]
fully_shard(
[model.norm, model.output],
**fsdp_config,
reshard_after_forward=reshard_after_forward_policy == "always",
)
# pyrefly: ignore [missing-attribute]
for layer_id, transformer_block in model.layers.items():
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
# As an optimization, do not reshard_after_forward the last layers by default
# since FSDP would prefetch them immediately after the forward pass
if model.norm is not None and model.output is not None:
# pyrefly: ignore [no-matching-overload]
fully_shard(
[model.norm, model.output],
**fsdp_config,
reshard_after_forward=reshard_after_forward_policy == "always",
)

fully_shard(model, **fsdp_config)

Expand Down
9 changes: 9 additions & 0 deletions torchtitan/models/llama3/state_dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,22 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
continue
new_key = new_key.format(layer_num)
else:
if self.model_config.enable_weight_tying and key == "output.weight":
continue
new_key = to_hf_map[key]

hf_state_dict[new_key] = value

return hf_state_dict

def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
if (
self.model_config.enable_weight_tying
and "lm_head.weight" not in hf_state_dict
):
assert "model.embed_tokens.weight" in hf_state_dict
hf_state_dict["lm_head.weight"] = hf_state_dict["model.embed_tokens.weight"]

n_heads = self.model_config.layer.attention.n_heads
n_kv_heads = (
self.model_config.layer.attention.n_kv_heads
Expand Down