@@ -70,6 +70,9 @@ def build_test_list():
7070 # Checkpointing tests
7171 OverrideDefinitions (
7272 [
73+ [
74+ "--checkpoint.enable_checkpoint" ,
75+ ],
7376 [
7477 "--checkpoint.enable_checkpoint" ,
7578 "--training.steps 20" ,
@@ -99,28 +102,30 @@ def build_test_list():
99102 "Checkpoint Integration Test - Save Model Weights Only bf16" ,
100103 "model_weights_only_bf16" ,
101104 ),
105+ # Parallelism tests. Note: Run DDP only will cause OOM
102106 OverrideDefinitions (
103107 [
104108 [
105- "--parallelism.data_parallel_shard_degree=1 " ,
106- "--parallelism.data_parallel_replicate_degree=4 " ,
109+ "--parallelism.data_parallel_shard_degree=8 " ,
110+ "--parallelism.data_parallel_replicate_degree=1 " ,
107111 ]
108112 ],
109- "DDP " ,
110- "ddp " ,
111- ngpu = 4 ,
113+ "FSDP " ,
114+ "fsdp " ,
115+ ngpu = 8 ,
112116 ),
113117 OverrideDefinitions (
114118 [
115119 [
116- "--parallelism.data_parallel_shard_degree=2 " ,
120+ "--parallelism.data_parallel_shard_degree=4 " ,
117121 "--parallelism.data_parallel_replicate_degree=2" ,
118122 ]
119123 ],
120124 "HSDP" ,
121125 "hsdp" ,
122- ngpu = 4 ,
126+ ngpu = 8 ,
123127 ),
128+ # Inference tests
124129 # OverrideDefinitions(
125130 # [
126131 # [
@@ -169,7 +174,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
169174 # "PROMPT='What is the meaning of life?' "
170175 # f"./scripts/generate/run_llama_generate.sh --out > {output_dir}/{test_name}/generated_output.json"
171176 # )
172- # TODO: migrate the generate image script
177+ # TODO: Add the generate image script
173178 cmd = None
174179
175180 result = _run_cmd (cmd )
0 commit comments