-
Notifications
You must be signed in to change notification settings - Fork 495
[RFC] Support full bf16 training #1646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@@ -421,5 +421,5 @@ def forward( | |||
h = layer(h, self.freqs_cis) | |||
|
|||
h = self.norm(h) if self.norm else h | |||
output = self.output(h) if self.output else h | |||
output = self.output(h).float() if self.output else h |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we set the training dtype during the training initialization, why not also do the output conversion in the trainer (train loop)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is already in the loss function https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/loss.py#L21
Also see #642
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, just removed
put all parameters, gradients, and optimizer states in bfloat16, without an extra copy of fp32 weights. | ||
In the case of full bf16 training, RoPE calculations and logits will still be in fp32. | ||
""" | ||
|
||
mixed_precision_param: Literal["bfloat16", "float32"] = "bfloat16" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if mixed_precision_param
is float32
but dtype is bfloat16
? There should be a check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah agreed. Do we want to do this somewhere in train.py? Lmk if you think there's a better place
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mixed_precision_param
is coming from FSDP2. I think if FSDP2 can work with that, it's users responsibility to config them properly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also make it work with DDP/single device: #1303. I think a warning is at least required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. In that case I will leave this as is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR adds support for full bfloat16 training. In SFT it is pretty common to store everything in bfloat16 to save memory, with select tensors (logits, RoPE buffers and activations) maintained in a higher precision to preserve numerical accuracy. Separately I think having this supported more generally would be useful for faster iteration -- e.g. it allows me to run Llama3 70B on a single node of H100s, which otherwise is not possible with the default config.
Assuming this is generally useful, would like feedback on:
set_default_dtype
API