-
Notifications
You must be signed in to change notification settings - Fork 319
Description
I think many users will not expect (and be surprised by) the default SFT implementation not normalizing loss by total number of target tokens.
At least, I was, and spent some time debugging training instability that seems to be ultimately due to variable length batches combined with token sum loss.
I have a commit in my fork that adds an option to normalize the token-weighted loss (i.e., mean loss over target tokens), but I'm not sure what you'd want the implementation to look like, and since it's a simple change, it might just be easier for you to reimplement it in the style you prefer than try to coordinate with me. :)
My change also preserves the default behavior, but I think it might actually be good to switch the default for SFT. Even though a change like that would be somewhat disruptive, it might be worth the short term pain to eliminate a surprising default.