Skip to content

Commit 458d879

Browse files
committed
integrition test v1
1 parent a18bae2 commit 458d879

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

torchtitan/components/checkpoint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,8 @@ def _save_last_step(self, curr_step: int) -> None:
562562
# For now, we will manually pop the freqs_cis buffer, as we made this permanent
563563
# temporarily and we don't want to include it in the exported state_dict.
564564
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py#L348
565-
self.states.pop("freqs_cis")
565+
if self.states and self.states.get("freqs_cis") is not None:
566+
self.states.pop("freqs_cis")
566567

567568
if self.export_dtype != torch.float32:
568569
self.states = {

torchtitan/experiments/flux/tests/flux_integration_tests.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ def build_test_list():
7070
# Checkpointing tests
7171
OverrideDefinitions(
7272
[
73+
[
74+
"--checkpoint.enable_checkpoint",
75+
],
7376
[
7477
"--checkpoint.enable_checkpoint",
7578
"--training.steps 20",
@@ -99,28 +102,30 @@ def build_test_list():
99102
"Checkpoint Integration Test - Save Model Weights Only bf16",
100103
"model_weights_only_bf16",
101104
),
105+
# Parallelism tests. Note: Run DDP only will cause OOM
102106
OverrideDefinitions(
103107
[
104108
[
105-
"--parallelism.data_parallel_shard_degree=1",
106-
"--parallelism.data_parallel_replicate_degree=4",
109+
"--parallelism.data_parallel_shard_degree=8",
110+
"--parallelism.data_parallel_replicate_degree=1",
107111
]
108112
],
109-
"DDP",
110-
"ddp",
111-
ngpu=4,
113+
"FSDP",
114+
"fsdp",
115+
ngpu=8,
112116
),
113117
OverrideDefinitions(
114118
[
115119
[
116-
"--parallelism.data_parallel_shard_degree=2",
120+
"--parallelism.data_parallel_shard_degree=4",
117121
"--parallelism.data_parallel_replicate_degree=2",
118122
]
119123
],
120124
"HSDP",
121125
"hsdp",
122-
ngpu=4,
126+
ngpu=8,
123127
),
128+
# Inference tests
124129
# OverrideDefinitions(
125130
# [
126131
# [
@@ -169,7 +174,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
169174
# "PROMPT='What is the meaning of life?' "
170175
# f"./scripts/generate/run_llama_generate.sh --out > {output_dir}/{test_name}/generated_output.json"
171176
# )
172-
# TODO: migrate the generate image script
177+
# TODO: Add the generate image script
173178
cmd = None
174179

175180
result = _run_cmd(cmd)

0 commit comments

Comments
 (0)