fix: normalize SFT weights per-example for consistent gradient magnitudes#334
Open
bledden wants to merge 2 commits intothinking-machines-lab:mainfrom
Open
fix: normalize SFT weights per-example for consistent gradient magnitudes#334bledden wants to merge 2 commits intothinking-machines-lab:mainfrom
bledden wants to merge 2 commits intothinking-machines-lab:mainfrom
Conversation
…udes Normalize the per-token weights in datum_from_model_input_weights to sum to 1.0 per example. This produces a token-mean loss instead of a token-sum loss, making gradient magnitudes consistent across batches with variable-length sequences. The old token-sum behavior caused batches with more target tokens to produce proportionally larger gradients, effectively making the learning rate fluctuate with sequence length. This was especially problematic for datasets with high variance in per-batch token counts. The normalize_weights parameter defaults to True but can be set to False to recover the old behavior. Fixes thinking-machines-lab#271
…ghts Tests cover normalize_weights=True/False, all-zero weights edge case, single non-zero weight, target token left-shifting, and max_length truncation interactions.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
datum_from_model_input_weightsto sum to 1.0 per examplenormalize_weightsparameter (defaultTrue) to opt out if needed vianormalize_weights=FalseMotivation
The current token-sum loss causes batches with more target tokens to produce proportionally larger gradients. With variable-length sequences (common in chat/instruction datasets), this effectively makes the learning rate fluctuate with per-batch token count. As @joschu noted, this is mostly masked by Adam but can cause instability with high-variance sequence length datasets.
Breaking change note
This PR defaults
normalize_weights=True, which silently changes loss behavior for all existing callers ofdatum_from_model_input_weights(SFT, preference, DPO, VLM classifier). The old token-sum behavior is recoverable by passingnormalize_weights=False. This aligns with the direction discussed in #271, but if a more gradual rollout is preferred — e.g., threading the option throughChatDatasetBuilderCommonConfigso users can toggle it per-dataset, or defaulting toFalsefor backward compatibility — happy to adjust.Test plan
normalize_weights=Falserecovers old behaviorFixes #271