Skip to content

Commit 9ccc161

Browse files
authored
add utils.set_determinism for reproducibility (#576)
This PR: 1 - adds a set_determinism function to set seed for Python, PyTorch, CUDA and deterministic settings for cudnn. 2 - if seed is None, then no deterministic settings are used. This may be important as turning off cuDnn benchmarking to ensure determinism, can also negatively impact perf. 3 - note that for the None case, we revert / ensure cudnn is set back to non-deterministic and benchmarking/tuning in case people are toggling. This lack of determinism negatively impacted work with AWS where we ended up with variations in loss curves while running fp8 for our joint blog that appeared to be from fp8 but are instead likely due to not having determinism in titan. Testing - I ran multiple small runs with 7B while rotating between three seeds and saw consistent ending loss points matching to the respective seeds. This PR does not set per worker aspects for the dataloader since we do not shuffle atm...but that could become a future source of randomness that will need to be set if we shuffle in the future.
1 parent eef8bb2 commit 9ccc161

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

torchtitan/config_manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,12 @@ def __init__(self):
357357
default=50,
358358
help="Python garbage control scheduling interval, in steps",
359359
)
360-
360+
self.parser.add_argument(
361+
"--training.seed",
362+
type=int,
363+
default=None,
364+
help="Implement reproducibility by setting a Python, PyTorch and CUDA seed",
365+
)
361366
# checkpointing configs
362367
self.parser.add_argument(
363368
"--checkpoint.enable_checkpoint",

torchtitan/utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import subprocess
1010
from dataclasses import dataclass
1111
from datetime import timedelta
12-
from typing import Union
12+
from typing import Optional, Union
1313

1414
import torch
1515
import torch.distributed._functional_collectives as funcol
@@ -36,6 +36,24 @@ def _warn_overwrite_env(env, val):
3636
os.environ[env] = val
3737

3838

39+
def set_determinism(seed: Optional[int]) -> None:
40+
"""
41+
Set Python, PyTorch, CUDA seeds and cudnn settings for reproducibility
42+
"""
43+
if seed is not None:
44+
# CPU and GPU determinism
45+
torch.manual_seed(seed)
46+
# set deterministic cudnn algorithms
47+
torch.backends.cudnn.deterministic = True
48+
torch.backends.cudnn.benchmark = False
49+
# set Python seed
50+
os.environ["PYTHONHASHSEED"] = str(seed)
51+
else:
52+
# ensure we turn off deterministic cudnn algorithms
53+
torch.backends.cudnn.deterministic = False
54+
torch.backends.cudnn.benchmark = True
55+
56+
3957
def set_pg_timeouts(timeout, world_mesh):
4058
"""
4159
Sets the timeout for all PGs in the provided mesh, and the default (world) group.

train.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@ def main(job_config: JobConfig):
5656
# take control of garbage collection to avoid stragglers
5757
gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)
5858

59+
# set determinisism, use seed == None to skip deterministic training
60+
utils.set_determinism(job_config.training.seed)
61+
if job_config.training.seed is None:
62+
logger.info("Deterministic training off")
63+
else:
64+
logger.info(
65+
f"Deterministic training on. Using seed: {job_config.training.seed}"
66+
)
67+
5968
# init distributed
6069
world_size = int(os.environ["WORLD_SIZE"])
6170
parallel_dims = ParallelDims(

0 commit comments

Comments
 (0)