Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,13 @@ class Training:
Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP
"""

dtype: Literal["bfloat16", "float32"] = "float32"
"""
torch dtype for training. In contrast to mixed precision training, setting training_dtype=bfloat16 will
put all parameters, gradients, and optimizer states in bfloat16, without an extra copy of fp32 weights.
In the case of full bf16 training, RoPE calculations and logits will still be in fp32.
"""

mixed_precision_param: Literal["bfloat16", "float32"] = "bfloat16"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if mixed_precision_param is float32 but dtype is bfloat16? There should be a check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed. Do we want to do this somewhere in train.py? Lmk if you think there's a better place

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mixed_precision_param is coming from FSDP2. I think if FSDP2 can work with that, it's users responsibility to config them properly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also make it work with DDP/single device: #1303. I think a warning is at least required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. In that case I will leave this as is

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin
autocast is not well supported in torchtitan anyways. I'm not sure if it is still maintained. See other issue like #1525

But sure, having a warning sounds good.

"""
torch dtype to use for parameters when applying mixed precision via fully_shard or torch.autocast.
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/llama3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,5 +421,5 @@ def forward(
h = layer(h, self.freqs_cis)

h = self.norm(h) if self.norm else h
output = self.output(h) if self.output else h
output = self.output(h).float() if self.output else h
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we set the training dtype during the training initialization, why not also do the output conversion in the trainer (train loop)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just removed

return output
30 changes: 29 additions & 1 deletion torchtitan/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import gc
import subprocess
import time
from dataclasses import dataclass
from typing import Optional
from typing import Generator, Optional

import torch
from torch._utils import _get_available_device_type, _get_device_module
Expand Down Expand Up @@ -174,3 +175,30 @@ def check_if_feature_in_pytorch(
f"{min_nightly_version}. Please upgrade a newer version to include the "
f"change in ({pull_request}) for correct {feature_name}."
)


@contextlib.contextmanager
def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]:
"""
Context manager to set torch's default dtype.

Args:
dtype (torch.dtype): The desired default dtype inside the context manager.

Returns:
ContextManager: context manager for setting default dtype.

Example:
>>> with set_default_dtype(torch.bfloat16):
>>> x = torch.tensor([1, 2, 3])
>>> x.dtype
torch.bfloat16


"""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
try:
yield
finally:
torch.set_default_dtype(old_dtype)
7 changes: 5 additions & 2 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
build_metrics_processor,
ensure_pp_loss_visible,
)
from torchtitan.config import ConfigManager, JobConfig
from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims, utils as dist_utils
from torchtitan.models.attention import init_attention_mask
from torchtitan.protocols.model_converter import build_model_converters
Expand Down Expand Up @@ -154,7 +154,10 @@ def __init__(self, job_config: JobConfig):
logger.info(
f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}"
)
with torch.device("meta"):
with (
torch.device("meta"),
utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]),
):
model = self.train_spec.model_cls(model_args)

# Build the collection of model converters. No-op if `model.converters` empty
Expand Down