Skip to content

Commit 9874e84

Browse files
authored
Change freq_cis from persistent buffer to non-persistent buffer (#1600)
## Context As PP didn't need persistent buffer, and `torch.compile` works with non-persistent buffer now, change freq_cis from persistent buffer to non-persistent buffer . In this way, checkpointer doesn't need to explicitly exclude freq_cis when loading. ## Test 1. llama3 model with torch.compile ✅ 2. llama4 model with torch.compile ✅ 3. deepseek-v3 model with torch.compile ✅
1 parent 084d307 commit 9874e84

File tree

7 files changed

+15
-32
lines changed

7 files changed

+15
-32
lines changed

scripts/generate/test_generate.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
parallelize_module,
2525
RowwiseParallel,
2626
)
27-
from torchtitan.components.checkpoint import excluded_parameters_for_model_only
2827
from torchtitan.components.metrics import build_device_memory_monitor
2928
from torchtitan.config import ConfigManager
3029
from torchtitan.distributed import ParallelDims, utils as dist_utils
@@ -143,8 +142,6 @@ def test_generate(
143142
model.eval()
144143

145144
state_dict = model.state_dict()
146-
for k in excluded_parameters_for_model_only:
147-
state_dict.pop(k, None)
148145

149146
# Checkpoint Loading
150147
begin = time.monotonic()

tests/unit_tests/test_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ def test_enable_first_step_checkpoint(self, mock_save, mock_rank):
562562

563563
@mock.patch("torch.distributed.get_rank", return_value=0)
564564
@mock.patch("torchtitan.components.checkpoint.dcp.save")
565-
def test_excluded_parameters_not_saved(self, mock_save, mock_rank):
565+
def test_non_persist_buffer_not_saved(self, mock_save, mock_rank):
566566
"""Test that freqs_cis is not saved"""
567567

568568
# Create a fake model with freqs_cis and other parameters
@@ -572,7 +572,7 @@ def __init__(self):
572572
self.weight = nn.Parameter(torch.randn(2, 2))
573573
self.bias = nn.Parameter(torch.randn(2))
574574
# Register freqs_cis as a buffer (common pattern in transformer models)
575-
self.register_buffer("freqs_cis", torch.randn(10, 5))
575+
self.register_buffer("freqs_cis", torch.randn(10, 5), persistent=False)
576576
self.other_param = nn.Parameter(torch.randn(3, 3))
577577

578578
fake_model = FakeModelWithFreqsCis()

torchtitan/components/checkpoint.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,6 @@ class AsyncMode(str, enum.Enum):
5555
ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem"
5656

5757

58-
# For now, we will manually pop the freqs_cis buffer, as we made this permanent
59-
# temporarily and we don't want to include it in the exported state_dict.
60-
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
61-
excluded_parameters_for_model_only = {"freqs_cis"}
62-
63-
6458
class ModelWrapper(Stateful):
6559
def __init__(self, model: nn.Module | list[nn.Module]) -> None:
6660
self.model = [model] if isinstance(model, nn.Module) else model
@@ -70,9 +64,6 @@ def _get_state_dict(self) -> dict[str, Any]:
7064
state_dict = {
7165
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
7266
}
73-
# Exclude parameters that should not be saved
74-
for excluded_key in excluded_parameters_for_model_only:
75-
state_dict.pop(excluded_key, None)
7667
return state_dict
7768

7869
def state_dict(self) -> dict[str, Any]:

torchtitan/experiments/llama4/model/model.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -391,14 +391,9 @@ def __init__(self, model_args: TransformerModelArgs):
391391

392392
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
393393

394-
# TODO persistent should be set to false, since this buffer can be recomputed.
395-
# however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411,
396-
# compile or pipeline-tracer will not correctly handle non-persistent buffers,
397-
# so we need to fix that. (2) if we initialize pipeline-parallel models from
398-
# a seed checkpoint rather than calling init_weights, we need freqs_cis to be
399-
# initialized by the checkpoint, or we need to add a separate initializer for
400-
# just the non-persistent buffers that is called after loading checkpoints.
401-
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
394+
self.register_buffer(
395+
"freqs_cis", self._precompute_freqs_cis(), persistent=False
396+
)
402397

403398
self.layers = torch.nn.ModuleDict()
404399
for layer_id in range(model_args.n_layers):

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import torch
78
import torch.nn as nn
89
from torch.distributed.device_mesh import DeviceMesh
910
from torch.distributed.tensor import Replicate, Shard
@@ -18,7 +19,11 @@
1819
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
1920
from torchtitan.distributed import ParallelDims
2021
from torchtitan.distributed.expert_parallel import NoParallel
21-
from torchtitan.experiments.llama4.infra.parallelize import apply_fsdp, apply_moe_ep_tp
22+
from torchtitan.experiments.llama4.infra.parallelize import (
23+
apply_compile,
24+
apply_fsdp,
25+
apply_moe_ep_tp,
26+
)
2227
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
2328
from torchtitan.tools.logging import logger
2429

torchtitan/models/deepseek_v3/model/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
322322
self.max_seq_len = model_args.max_seq_len
323323
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
324324
self.register_buffer(
325-
"freqs_cis", precompute_freqs_cis(model_args), persistent=True
325+
"freqs_cis", precompute_freqs_cis(model_args), persistent=False
326326
)
327327

328328
self.layers = torch.nn.ModuleDict()

torchtitan/models/llama3/model/model.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -335,14 +335,9 @@ def __init__(self, model_args: TransformerModelArgs):
335335

336336
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
337337

338-
# TODO persistent should be set to false, since this buffer can be recomputed.
339-
# however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411,
340-
# compile or pipeline-tracer will not correctly handle non-persistent buffers,
341-
# so we need to fix that. (2) if we initialize pipeline-parallel models from
342-
# a seed checkpoint rather than calling init_weights, we need freqs_cis to be
343-
# initialized by the checkpoint, or we need to add a separate initializer for
344-
# just the non-persistent buffers that is called after loading checkpoints.
345-
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
338+
self.register_buffer(
339+
"freqs_cis", self._precompute_freqs_cis(), persistent=False
340+
)
346341

347342
self.layers = torch.nn.ModuleDict()
348343
for layer_id in range(model_args.n_layers):

0 commit comments

Comments
 (0)