Skip to content

Comments

Fix NaN loss in SFT by using token-based label masking#339

Draft
sfc-gh-aponnusamy wants to merge 6 commits intomainfrom
ac-sft-fix-token-based-label-masking
Draft

Fix NaN loss in SFT by using token-based label masking#339
sfc-gh-aponnusamy wants to merge 6 commits intomainfrom
ac-sft-fix-token-based-label-masking

Conversation

@sfc-gh-aponnusamy
Copy link
Collaborator

Problem:

  • Character-based text matching for assistant content could find wrong positions when short responses appeared earlier in conversation
  • End-of-turn tokens (e.g., <|im_end|>) were not included in labels
  • These issues caused all labels to be masked, resulting in NaN loss

Solution:

  • Replace character-position matching with token pattern matching
  • Find assistant_start marker tokens and mark everything until turn_end
  • Add pre-defined configs for popular models (Qwen/ChatML, Llama3, etc.)
  • Auto-detect model family from tokenizer name and special tokens
  • Heuristic fallback for unknown models

New files:

  • arctic_training/data/chat_markers.py: Chat marker detection and token-based labels

Config options:

  • chat_template_family: Explicitly specify model family (optional, auto-detected) Available: chatml, llama3, llama2, mistral_v3, phi3, gemma, deepseek, deepseek_v2, vicuna, zephyr, command_r

sfc-gh-aponnusamy and others added 6 commits January 8, 2026 13:07
Problem:
- Character-based text matching for assistant content could find wrong
  positions when short responses appeared earlier in conversation
- End-of-turn tokens (e.g., <|im_end|>) were not included in labels
- These issues caused all labels to be masked, resulting in NaN loss

Solution:
- Replace character-position matching with token pattern matching
- Find assistant_start marker tokens and mark everything until turn_end
- Add pre-defined configs for popular models (Qwen/ChatML, Llama3, etc.)
- Auto-detect model family from tokenizer name and special tokens
- Heuristic fallback for unknown models

New files:
- arctic_training/data/chat_markers.py: Chat marker detection and token-based labels

Config options:
- chat_template_family: Explicitly specify model family (optional, auto-detected)
  Available: chatml, llama3, llama2, mistral_v3, phi3, gemma, deepseek,
  deepseek_v2, vicuna, zephyr, command_r
- Add protection against division by zero when total_good_tokens is 0
- Handle NaN loss from Liger kernel on all-masked batches
- Add warnings to help diagnose data issues with empty/masked outputs
- This fixes NaN losses that occur when packed samples have mostly
  non-assistant content distributed across sequence parallel ranks
…ency

## NaN Loss Handling

- Add NaN/Inf checking to SP==1 path that was missing protection
  (lines 49-72 in sft_trainer.py)
- Ensure zero loss is connected to computation graph to maintain
  proper gradient flow for DeepSpeed optimizer state
- Use outputs.logits.sum() * 0.0 pattern when available, fallback
  to torch.zeros with requires_grad=True

## Marker Tokenization Fix

- Fix inconsistency where get_token_based_labels_with_ignore_empty_think
  used _tokenize_marker() while get_token_based_labels() used
  _tokenize_marker_without_trailing_whitespace()
- This caused different content_start positions, potentially masking
  wrong token ranges when ignore_empty_think=True

## Other

- Add torch.cuda.empty_cache() after evaluation to free memory
… turn's end marker

When detect_markers_heuristic falls back to extracting markers by finding
'X' in the conversation, the asst_turn variable was including the previous
turn's end marker (e.g., '<|im_end|><|im_start|>assistant' instead of just
'<|im_start|>assistant').

Fix: In the fallback case, look for known end markers in the segment
between X and Y, and extract only the portion after the end marker as
the assistant_start. This prevents pattern matching failures that would
result in all labels being masked to -100 and NaN loss during training.
- chat_markers.py: Remove over-engineered assistant turn detection fallback
  logic; use simple conditional instead of complex marker heuristics
- chat_markers.py: Use consistent _tokenize_marker() for assistant/user
  start markers in get_token_based_labels_with_ignore_empty_think()
- sft_trainer.py: Remove redundant NaN loss handling in non-SP forward path
- sft_trainer.py: Fix zero loss creation by using fresh tensor instead of
  loss * 0.0 (since NaN * 0 = NaN, multiplying NaN loss by zero still fails)
…nk handling

Use _tokenize_marker_without_trailing_whitespace() for assistant_start and user_start markers to match the tokenization method used in get_token_based_labels(), preventing inconsistencies in label masking.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant