Skip to content

Commit 0369053

Browse files
rhmukundansowmen
authored andcommitted
Updating Configs for LLAMA3 70B LoRa (NVIDIA-NeMo#2292)
Signed-off-by: Raghav Hrishikeshan Mukundan <rmukundan@nvidia.com> Signed-off-by: sowmen <sowmendipta@gmail.com>
1 parent 84c46b0 commit 0369053

File tree

2 files changed

+33
-13
lines changed

2 files changed

+33
-13
lines changed

scripts/performance/configs/llama/llama3_llm_finetune.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def llama3_70b_lora_config_gb300(precision: str = "bf16", config_variant: str =
205205
peft="lora",
206206
precision_config=precision_config,
207207
packed_sequence=True,
208-
seq_length=2048,
208+
seq_length=4096,
209209
)
210210
set_llama3_common_peft_configs(cfg)
211211
set_workload_base_configs(cfg, base_cfg)
@@ -234,11 +234,14 @@ def llama3_70b_lora_config_gb200(precision: str = "bf16", config_variant: str =
234234
)
235235
precision_config = get_precision_config(precision)
236236

237+
# BF16 uses seq_length=2048, FP8 variants use seq_length=4096
238+
seq_length = 2048 if precision.lower() == "bf16" else 4096
239+
237240
cfg = llama3_70b_finetune_config(
238241
peft="lora",
239242
precision_config=precision_config,
240243
packed_sequence=True,
241-
seq_length=2048,
244+
seq_length=seq_length,
242245
)
243246
set_llama3_common_peft_configs(cfg)
244247
set_workload_base_configs(cfg, base_cfg)

scripts/performance/configs/llama/llama3_workload_base_configs.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -554,55 +554,72 @@
554554
BASE_LLAMA3_70B_CONFIG,
555555
num_gpus=8,
556556
peft="lora",
557-
# pipeline_model_parallel_size=4,
558-
# virtual_pipeline_model_parallel_size=20,
557+
tensor_model_parallel_size=1,
558+
pipeline_model_parallel_size=1,
559+
context_parallel_size=1,
559560
micro_batch_size=1,
560-
global_batch_size=64,
561+
global_batch_size=32,
561562
cuda_graph_impl="transformer_engine",
562563
cuda_graph_scope="mlp",
563564
)
564565

565566
LLAMA3_70B_LORA_CONFIG_GB300_BF16_V1 = _LLAMA3_70B_LORA_CONFIG_GB300
566567
LLAMA3_70B_LORA_CONFIG_GB300_FP8_CS_V1 = _LLAMA3_70B_LORA_CONFIG_GB300
567568
LLAMA3_70B_LORA_CONFIG_GB300_FP8_MX_V1 = replace(
568-
LLAMA3_70B_LORA_CONFIG_GB300_FP8_CS_V1,
569-
pipeline_model_parallel_size=2, # PP=1 is OOM
569+
_LLAMA3_70B_LORA_CONFIG_GB300,
570+
pipeline_model_parallel_size=2,
570571
)
571572

572573

573574
_LLAMA3_70B_LORA_CONFIG_GB200 = replace(
574575
BASE_LLAMA3_70B_CONFIG,
575576
num_gpus=8,
576577
peft="lora",
577-
pipeline_model_parallel_size=4,
578-
virtual_pipeline_model_parallel_size=20,
578+
tensor_model_parallel_size=1,
579+
pipeline_model_parallel_size=1,
580+
context_parallel_size=1,
579581
micro_batch_size=1,
580582
global_batch_size=64,
581583
cuda_graph_impl="transformer_engine",
582584
cuda_graph_scope="mlp",
583585
)
584586

585587
LLAMA3_70B_LORA_CONFIG_GB200_BF16_V1 = _LLAMA3_70B_LORA_CONFIG_GB200
586-
LLAMA3_70B_LORA_CONFIG_GB200_FP8_CS_V1 = _LLAMA3_70B_LORA_CONFIG_GB200
588+
LLAMA3_70B_LORA_CONFIG_GB200_FP8_CS_V1 = replace(
589+
BASE_LLAMA3_70B_CONFIG,
590+
num_gpus=8,
591+
peft="lora",
592+
tensor_model_parallel_size=1,
593+
pipeline_model_parallel_size=2,
594+
context_parallel_size=1,
595+
micro_batch_size=1,
596+
global_batch_size=32,
597+
cuda_graph_impl="transformer_engine",
598+
cuda_graph_scope="mlp",
599+
)
587600
LLAMA3_70B_LORA_CONFIG_GB200_FP8_MX_V1 = LLAMA3_70B_LORA_CONFIG_GB200_FP8_CS_V1
588601

589602

590603
_LLAMA3_70B_LORA_CONFIG_H100 = replace(
591604
BASE_LLAMA3_70B_CONFIG,
592605
num_gpus=8,
593606
peft="lora",
594-
tensor_model_parallel_size=2,
607+
tensor_model_parallel_size=1,
595608
pipeline_model_parallel_size=4,
609+
context_parallel_size=1,
596610
virtual_pipeline_model_parallel_size=20,
597611
micro_batch_size=1,
598612
global_batch_size=32,
599613
)
600614

601615
LLAMA3_70B_LORA_CONFIG_H100_BF16_V1 = replace(
602616
_LLAMA3_70B_LORA_CONFIG_H100,
603-
recompute_num_layers=2,
617+
recompute_num_layers=1,
618+
)
619+
LLAMA3_70B_LORA_CONFIG_H100_FP8_CS_V1 = replace(
620+
_LLAMA3_70B_LORA_CONFIG_H100,
621+
tensor_model_parallel_size=2,
604622
)
605-
LLAMA3_70B_LORA_CONFIG_H100_FP8_CS_V1 = _LLAMA3_70B_LORA_CONFIG_H100
606623

607624

608625
__all__ = [

0 commit comments

Comments
 (0)