From dba286de95fed5e03d9e1a75cdc9a9b795e70558 Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Wed, 27 Aug 2025 11:34:37 -0700 Subject: [PATCH 1/2] [RFC] Support full bf16 training --- torchtitan/config/job_config.py | 7 ++++++ torchtitan/models/llama3/model/model.py | 2 +- torchtitan/tools/utils.py | 30 ++++++++++++++++++++++++- torchtitan/train.py | 7 ++++-- 4 files changed, 42 insertions(+), 4 deletions(-) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index a2247aa21..7c51ac2b6 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -204,6 +204,13 @@ class Training: Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP """ + dtype: Literal["bfloat16", "float32"] = "float32" + """ + torch dtype for training. In contrast to mixed precision training, setting training_dtype=bfloat16 will + put all parameters, gradients, and optimizer states in bfloat16, without an extra copy of fp32 weights. + In the case of full bf16 training, RoPE calculations and logits will still be in fp32. + """ + mixed_precision_param: Literal["bfloat16", "float32"] = "bfloat16" """ torch dtype to use for parameters when applying mixed precision via fully_shard or torch.autocast. diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index f2284920a..723a669c0 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -421,5 +421,5 @@ def forward( h = layer(h, self.freqs_cis) h = self.norm(h) if self.norm else h - output = self.output(h) if self.output else h + output = self.output(h).float() if self.output else h return output diff --git a/torchtitan/tools/utils.py b/torchtitan/tools/utils.py index 070cc2938..37273a71d 100644 --- a/torchtitan/tools/utils.py +++ b/torchtitan/tools/utils.py @@ -4,11 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib import gc import subprocess import time from dataclasses import dataclass -from typing import Optional +from typing import Generator, Optional import torch from torch._utils import _get_available_device_type, _get_device_module @@ -174,3 +175,30 @@ def check_if_feature_in_pytorch( f"{min_nightly_version}. Please upgrade a newer version to include the " f"change in ({pull_request}) for correct {feature_name}." ) + + +@contextlib.contextmanager +def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]: + """ + Context manager to set torch's default dtype. + + Args: + dtype (torch.dtype): The desired default dtype inside the context manager. + + Returns: + ContextManager: context manager for setting default dtype. + + Example: + >>> with set_default_dtype(torch.bfloat16): + >>> x = torch.tensor([1, 2, 3]) + >>> x.dtype + torch.bfloat16 + + + """ + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + try: + yield + finally: + torch.set_default_dtype(old_dtype) diff --git a/torchtitan/train.py b/torchtitan/train.py index 9b69fd679..545957acc 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -22,7 +22,7 @@ build_metrics_processor, ensure_pp_loss_visible, ) -from torchtitan.config import ConfigManager, JobConfig +from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.models.attention import init_attention_mask from torchtitan.protocols.model_converter import build_model_converters @@ -154,7 +154,10 @@ def __init__(self, job_config: JobConfig): logger.info( f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" ) - with torch.device("meta"): + with ( + torch.device("meta"), + utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), + ): model = self.train_spec.model_cls(model_args) # Build the collection of model converters. No-op if `model.converters` empty From e683589e16ebf0df558ed31adedd7f297f3eadbe Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Wed, 27 Aug 2025 18:07:17 -0700 Subject: [PATCH 2/2] remove .float() --- torchtitan/models/llama3/model/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 723a669c0..f2284920a 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -421,5 +421,5 @@ def forward( h = layer(h, self.freqs_cis) h = self.norm(h) if self.norm else h - output = self.output(h).float() if self.output else h + output = self.output(h) if self.output else h return output