Skip to content

Commit a18bae2

Browse files
committed
add periodical intergration test
1 parent 7400448 commit a18bae2

File tree

9 files changed

+277
-3
lines changed

9 files changed

+277
-3
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
name: 8 GPU Integration Test
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
paths:
7+
- 'torchtitan/experiments/flux/**'
8+
pull_request:
9+
paths:
10+
- 'torchtitan/experiments/flux/**'
11+
schedule:
12+
# Runs every 12 hours
13+
- cron: '0 */12 * * *'
14+
concurrency:
15+
group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
16+
cancel-in-progress: true
17+
18+
defaults:
19+
run:
20+
shell: bash -l -eo pipefail {0}
21+
22+
jobs:
23+
build-test:
24+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
25+
with:
26+
runner: linux.g5.48xlarge.nvidia.gpu
27+
gpu-arch-type: cuda
28+
gpu-arch-version: "12.6"
29+
# This image is faster to clone than the default, but it lacks CC needed by triton
30+
# (1m25s vs 2m37s).
31+
docker-image: torchtitan-ubuntu-20.04-clang12
32+
repository: pytorch/torchtitan
33+
upload-artifact: outputs
34+
script: |
35+
set -eux
36+
37+
# The generic Linux job chooses to use base env, not the one setup by the image
38+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
39+
conda activate "${CONDA_ENV}"
40+
41+
pip config --user set global.progress_bar off
42+
43+
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126
44+
45+
mkdir artifacts-to-be-uploaded
46+
python ./torchtitan/experiments/flux/tests/flux_integration_tests.py artifacts-to-be-uploaded --ngpu 8
-1.1 MB
Binary file not shown.

tests/integration_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ def main():
558558
parser = argparse.ArgumentParser()
559559
parser.add_argument("output_dir")
560560
parser.add_argument(
561-
"--config_dir", default="./torchtitan/models/llama3/train_configs"
561+
"--config_dir", default="./torchtitan/experiments/flux/train_configs"
562562
)
563563
parser.add_argument(
564564
"--test",

torchtitan/experiments/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import torchtitan.experiments.flux # noqa: F401
78
import torchtitan.experiments.llama4 # noqa: F401
89
import torchtitan.experiments.simple_fsdp # noqa: F401

torchtitan/experiments/flux/dataset/flux_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class TextToImageDatasetConfig:
116116
data_processor=_cc12m_wds_data_processor,
117117
),
118118
"cc12m-test": TextToImageDatasetConfig(
119-
path="tests/assets/cc12m_test",
119+
path="torchtitan/experiments/flux/tests/assets/cc12m_test",
120120
loader=lambda path: load_dataset(
121121
path, split="train", data_files={"train": "*.tar"}, streaming=True
122122
),

torchtitan/experiments/flux/tests/__init__.py

Whitespace-only changes.
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import logging
9+
import os
10+
import subprocess
11+
from collections import defaultdict
12+
from dataclasses import dataclass
13+
from typing import Sequence
14+
15+
16+
logging.basicConfig(level=logging.INFO)
17+
logger = logging.getLogger(__name__)
18+
19+
try:
20+
import tomllib
21+
except ModuleNotFoundError:
22+
import tomli as tomllib
23+
24+
25+
@dataclass
26+
class OverrideDefinitions:
27+
"""
28+
This class is used to define the override definitions for the integration tests.
29+
"""
30+
31+
override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
32+
test_descr: str = "default"
33+
test_name: str = "default"
34+
ngpu: int = 4
35+
model_flavor: str = "flux-debug"
36+
37+
def __repr__(self):
38+
return self.test_descr
39+
40+
41+
def build_test_list():
42+
"""
43+
key is the config file name and value is a list of OverrideDefinitions
44+
that is used to generate variations of integration tests based on the
45+
same root config file.
46+
"""
47+
integration_tests_flavors = defaultdict(list)
48+
integration_tests_flavors["debug_model.toml"] = [
49+
# basic tests
50+
OverrideDefinitions(
51+
[
52+
[
53+
"--profiling.enable_profiling",
54+
"--metrics.enable_tensorboard",
55+
],
56+
],
57+
"default",
58+
"default",
59+
),
60+
# Compile tests
61+
OverrideDefinitions(
62+
[
63+
[
64+
"--training.compile",
65+
],
66+
],
67+
"1D compile",
68+
"1d_compile",
69+
),
70+
# Checkpointing tests
71+
OverrideDefinitions(
72+
[
73+
[
74+
"--checkpoint.enable_checkpoint",
75+
"--training.steps 20",
76+
],
77+
],
78+
"Checkpoint Integration Test - Save Load Full Checkpoint",
79+
"full_checkpoint",
80+
),
81+
OverrideDefinitions(
82+
[
83+
[
84+
"--checkpoint.enable_checkpoint",
85+
"--checkpoint.model_weights_only",
86+
],
87+
],
88+
"Checkpoint Integration Test - Save Model Weights Only fp32",
89+
"model_weights_only_fp32",
90+
),
91+
OverrideDefinitions(
92+
[
93+
[
94+
"--checkpoint.enable_checkpoint",
95+
"--checkpoint.model_weights_only",
96+
"--checkpoint.export_dtype bfloat16",
97+
],
98+
],
99+
"Checkpoint Integration Test - Save Model Weights Only bf16",
100+
"model_weights_only_bf16",
101+
),
102+
OverrideDefinitions(
103+
[
104+
[
105+
"--parallelism.data_parallel_shard_degree=1",
106+
"--parallelism.data_parallel_replicate_degree=4",
107+
]
108+
],
109+
"DDP",
110+
"ddp",
111+
ngpu=4,
112+
),
113+
OverrideDefinitions(
114+
[
115+
[
116+
"--parallelism.data_parallel_shard_degree=2",
117+
"--parallelism.data_parallel_replicate_degree=2",
118+
]
119+
],
120+
"HSDP",
121+
"hsdp",
122+
ngpu=4,
123+
),
124+
# OverrideDefinitions(
125+
# [
126+
# [
127+
# "--checkpoint.enable_checkpoint",
128+
# ],
129+
# [
130+
# # placeholder for the generation script's generate step
131+
# ],
132+
# ],
133+
# "Generation script test",
134+
# "test_generate",
135+
# ngpu=2,
136+
# ),
137+
]
138+
return integration_tests_flavors
139+
140+
141+
def _run_cmd(cmd):
142+
return subprocess.run([cmd], text=True, shell=True)
143+
144+
145+
def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
146+
# run_test supports sequence of tests.
147+
test_name = test_flavor.test_name
148+
dump_folder_arg = f"--job.dump_folder {output_dir}/{test_name}"
149+
model_flavor_arg = f"--model.flavor {test_flavor.model_flavor}"
150+
all_ranks = ",".join(map(str, range(test_flavor.ngpu)))
151+
152+
for idx, override_arg in enumerate(test_flavor.override_args):
153+
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./torchtitan/experiments/flux/run_train.sh"
154+
# dump compile trace for debugging purpose
155+
cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd
156+
cmd += " " + dump_folder_arg
157+
cmd += " " + model_flavor_arg
158+
if override_arg:
159+
cmd += " " + " ".join(override_arg)
160+
logger.info(
161+
f"=====Flux Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
162+
)
163+
164+
# save checkpoint (idx == 0) and load it for generation (idx == 1)
165+
if test_name == "test_generate_image" and idx == 1:
166+
# cmd = (
167+
# f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} "
168+
# f"CHECKPOINT_DIR={output_dir}/{test_name}/checkpoint/step-10 "
169+
# "PROMPT='What is the meaning of life?' "
170+
# f"./scripts/generate/run_llama_generate.sh --out > {output_dir}/{test_name}/generated_output.json"
171+
# )
172+
# TODO: migrate the generate image script
173+
cmd = None
174+
175+
result = _run_cmd(cmd)
176+
logger.info(result.stdout)
177+
if result.returncode != 0:
178+
raise Exception(
179+
f"Flux Integration test failed, flavor : {test_flavor.test_descr}, command : {cmd}"
180+
)
181+
182+
183+
def run_tests(args):
184+
integration_tests_flavors = build_test_list()
185+
for config_file in os.listdir(args.config_dir):
186+
if config_file.endswith(".toml"):
187+
full_path = os.path.join(args.config_dir, config_file)
188+
with open(full_path, "rb") as f:
189+
config = tomllib.load(f)
190+
is_integration_test = config["job"].get(
191+
"use_for_integration_test", False
192+
)
193+
if is_integration_test:
194+
for test_flavor in integration_tests_flavors[config_file]:
195+
if args.test == "all" or test_flavor.test_name == args.test:
196+
if args.ngpu < test_flavor.ngpu:
197+
logger.info(
198+
f"Skipping test {test_flavor.test_name} that requires {test_flavor.ngpu} gpus,"
199+
f" because --ngpu arg is {args.ngpu}"
200+
)
201+
else:
202+
run_test(test_flavor, full_path, args.output_dir)
203+
204+
205+
def main():
206+
parser = argparse.ArgumentParser()
207+
parser.add_argument("output_dir")
208+
parser.add_argument(
209+
"--config_dir", default="./torchtitan/experiments/flux/train_configs"
210+
)
211+
parser.add_argument(
212+
"--test",
213+
default="all",
214+
help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)",
215+
)
216+
parser.add_argument("--ngpu", default=8, type=int)
217+
args = parser.parse_args()
218+
219+
if not os.path.exists(args.output_dir):
220+
os.makedirs(args.output_dir)
221+
if os.listdir(args.output_dir):
222+
raise RuntimeError("Please provide an empty output directory.")
223+
run_tests(args)
224+
225+
226+
if __name__ == "__main__":
227+
main()

torchtitan/experiments/flux/tests/unit_tests/__init__.py

Whitespace-only changes.

torchtitan/experiments/flux/train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ seq_len = 512
3838
max_norm = 2.0 # grad norm clipping
3939
steps = 10
4040
compile = false
41-
dataset = "cc12m-wds"
41+
dataset = "cc12m-test"
4242
classifer_free_guidance_prob = 0.1
4343
img_size = 256
4444

0 commit comments

Comments
 (0)