Skip to content

Commit b1dc330

Browse files
authored
Refactor script to use 'overwrites' variable for command-line arguments in training scripts (#1473)
The goal of this PR is to add support for command line arguments to the bash training scripts. The `run_train.sh` had support for `overrides`, however, the `multinode_trainer.slurm` script did not. This `overrides` flag add supports for commands like: `sbatch multinode_trainer.slurm --job.description="TEST_RUN"` However, there is a problem with the current `overrides` implementation, when passing arguments with space such as `"TEST RUN"` instead of `"TEST_RUN"` then the variable `job.description` would only get `TEST` as input and the training script throws an error for unrecognizing the argument `RUN` which is passed in a different line. To address this I simplify the code and directly pass the additional overrides through `$@`. This solves the issue for commands such as: `sbatch multinode_trainer.slurm --job.description="TEST RUN"`
1 parent ad9849c commit b1dc330

File tree

2 files changed

+4
-9
lines changed

2 files changed

+4
-9
lines changed

multinode_trainer.slurm

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ export LOGLEVEL=INFO
3434
export FI_PROVIDER="efa"
3535
# Ensure that P2P is available
3636
# export NCCL_P2P_DISABLE=1
37-
export NCCL_IB_DISABLE=1
37+
# export NCCL_IB_DISABLE=1
3838

3939
# debugging flags (optional)
4040
export NCCL_DEBUG=WARN
@@ -59,5 +59,5 @@ CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/llama3_8b.t
5959
dcgmi profile --pause
6060
# adjust sbatch --ntasks and sbatch --nodes above and --nnodes below
6161
# to your specific node count, and update target launch file.
62-
srun torchrun --nnodes 4 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" ./torchtitan/train.py --job.config_file ${CONFIG_FILE}
62+
srun torchrun --nnodes 4 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" ./torchtitan/train.py --job.config_file ${CONFIG_FILE} "$@"
6363
dcgmi profile --resume

run_train.sh

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,17 @@
77

88
set -ex
99

10-
# use envs as local overrides for convenience
10+
# use envs as local overwrites for convenience
1111
# e.g.
1212
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
1313
NGPU=${NGPU:-"8"}
1414
export LOG_RANK=${LOG_RANK:-0}
1515
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
1616

17-
overrides=""
18-
if [ $# -ne 0 ]; then
19-
overrides="$*"
20-
fi
21-
2217
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}
2318

2419
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
2520
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \
2621
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
2722
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
28-
-m torchtitan.train --job.config_file ${CONFIG_FILE} $overrides
23+
-m torchtitan.train --job.config_file ${CONFIG_FILE} "$@"

0 commit comments

Comments
 (0)