Skip to content

Conversation

ebsmothers
Copy link
Contributor

This PR adds support for full bfloat16 training. In SFT it is pretty common to store everything in bfloat16 to save memory, with select tensors (logits, RoPE buffers and activations) maintained in a higher precision to preserve numerical accuracy. Separately I think having this supported more generally would be useful for faster iteration -- e.g. it allows me to run Llama3 70B on a single node of H100s, which otherwise is not possible with the default config.

Assuming this is generally useful, would like feedback on:

  1. Acceptable loss convergence: in the first 100 steps on Llama3 8B full bf16 training goes from 12.25 -> 8, as opposed to 12.25 -> 7 with fp32 training. Is this a concern? (As mentioned, for SFT this is less of an issue; happy to validate that statement if that's helpful.)
  2. Interaction with mixed precision training -- where is the right place to validate that these are not both set at once?
  3. Where to put the set_default_dtype API

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 27, 2025
@@ -421,5 +421,5 @@ def forward(
h = layer(h, self.freqs_cis)

h = self.norm(h) if self.norm else h
output = self.output(h) if self.output else h
output = self.output(h).float() if self.output else h
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we set the training dtype during the training initialization, why not also do the output conversion in the trainer (train loop)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just removed

put all parameters, gradients, and optimizer states in bfloat16, without an extra copy of fp32 weights.
In the case of full bf16 training, RoPE calculations and logits will still be in fp32.
"""

mixed_precision_param: Literal["bfloat16", "float32"] = "bfloat16"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if mixed_precision_param is float32 but dtype is bfloat16? There should be a check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed. Do we want to do this somewhere in train.py? Lmk if you think there's a better place

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mixed_precision_param is coming from FSDP2. I think if FSDP2 can work with that, it's users responsibility to config them properly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also make it work with DDP/single device: #1303. I think a warning is at least required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. In that case I will leave this as is

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin
autocast is not well supported in torchtitan anyways. I'm not sure if it is still maintained. See other issue like #1525

But sure, having a warning sounds good.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants