Skip to content

fix: normalize SFT weights per-example for consistent gradient magnitudes#334

Open
bledden wants to merge 2 commits intothinking-machines-lab:mainfrom
bledden:fix/normalize-sft-weights
Open

fix: normalize SFT weights per-example for consistent gradient magnitudes#334
bledden wants to merge 2 commits intothinking-machines-lab:mainfrom
bledden:fix/normalize-sft-weights

Conversation

@bledden
Copy link
Contributor

@bledden bledden commented Jan 31, 2026

Summary

  • Normalize per-token weights in datum_from_model_input_weights to sum to 1.0 per example
  • Produces a token-mean loss instead of a token-sum loss, making gradient magnitudes consistent across batches with variable-length sequences
  • Adds normalize_weights parameter (default True) to opt out if needed via normalize_weights=False

Motivation

The current token-sum loss causes batches with more target tokens to produce proportionally larger gradients. With variable-length sequences (common in chat/instruction datasets), this effectively makes the learning rate fluctuate with per-batch token count. As @joschu noted, this is mostly masked by Adam but can cause instability with high-variance sequence length datasets.

Breaking change note

This PR defaults normalize_weights=True, which silently changes loss behavior for all existing callers of datum_from_model_input_weights (SFT, preference, DPO, VLM classifier). The old token-sum behavior is recoverable by passing normalize_weights=False. This aligns with the direction discussed in #271, but if a more gradual rollout is preferred — e.g., threading the option through ChatDatasetBuilderCommonConfig so users can toggle it per-dataset, or defaulting to False for backward compatibility — happy to adjust.

Test plan

  • Verify weights sum to 1.0 per example after normalization
  • Verify normalize_weights=False recovers old behavior
  • Verify all-zero weights are handled safely (no division by zero)
  • Verify single non-zero weight edge case
  • Verify target tokens are correctly left-shifted
  • Verify max_length truncation interacts correctly with normalization
  • Verify existing recipe smoke tests pass

Fixes #271

…udes

Normalize the per-token weights in datum_from_model_input_weights to
sum to 1.0 per example. This produces a token-mean loss instead of a
token-sum loss, making gradient magnitudes consistent across batches
with variable-length sequences.

The old token-sum behavior caused batches with more target tokens to
produce proportionally larger gradients, effectively making the learning
rate fluctuate with sequence length. This was especially problematic for
datasets with high variance in per-batch token counts.

The normalize_weights parameter defaults to True but can be set to False
to recover the old behavior.

Fixes thinking-machines-lab#271
…ghts

Tests cover normalize_weights=True/False, all-zero weights edge case,
single non-zero weight, target token left-shifting, and max_length
truncation interactions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Unnormalized token-sum loss is a surprising default for SFT

1 participant