Skip to content

Fix NaN validation loss in SFT training by correcting label masking logic#335

Open
sfc-gh-aponnusamy wants to merge 8 commits intomainfrom
ac-sft-fix
Open

Fix NaN validation loss in SFT training by correcting label masking logic#335
sfc-gh-aponnusamy wants to merge 8 commits intomainfrom
ac-sft-fix

Conversation

@sfc-gh-aponnusamy
Copy link
Collaborator

@sfc-gh-aponnusamy sfc-gh-aponnusamy commented Jan 5, 2026

Summary

This PR fixes a bug where NaN validation loss was occurring during SQL autocompletion SFT training. The root cause was incorrect label masking that caused all labels to be set to -100 (ignored), resulting in NaN loss during training.

Changes

1. Fixed get_assistant_start_end_indices() in sft_factory.py

Problem: The previous implementation searched for assistant content from the beginning of the conversation text each time. This could incorrectly match content that appeared earlier in the conversation (e.g., in user context/history).

Solution: Now tracks search_start position and processes ALL messages in order, ensuring assistant content is found AFTER the preceding user message rather than at its first occurrence anywhere in the text.

2. Fixed get_masked_labels() in sft_factory.py

Problem: The token inclusion condition required tokens to be fully contained within assistant ranges (id_s >= s and id_e <= e). This was too strict for short assistant content where tokenizer offsets can span wider than the actual content.

Solution: Changed to an overlap condition (id_s < e and id_e > s) that includes tokens if they overlap with any assistant range. Also added handling for invalid ranges (s == -1 means content was not found).

3. Added Debug Logging

  • Label masking debug: Set DEBUG_LABEL_MASKING=1 environment variable to enable detailed logging when labels are unexpectedly all masked or very few are non-masked
  • NaN loss detection: Added logging in the evaluation loop to detect and report when NaN/Inf losses occur, including batch and label statistics

Files Changed

  • arctic_training/data/sft_factory.py - Label masking fixes and debug logging
  • arctic_training/trainer/trainer.py - NaN/Inf evaluation loss detection

Root Cause

The NaN loss was caused by all labels being masked (set to -100), which happens when:

  1. The assistant content search found the wrong occurrence of text (earlier in conversation)
  2. The token overlap logic was too strict (requiring full containment vs overlap)

When all labels are -100, the cross-entropy loss computation has no valid targets, resulting in NaN.

@sfc-gh-aponnusamy sfc-gh-aponnusamy changed the title Ac sft fix Fix NaN validation loss in SFT training by correcting label masking logic Jan 5, 2026
@sfc-gh-aponnusamy sfc-gh-aponnusamy marked this pull request as ready for review January 5, 2026 19:28
Copy link
Collaborator

@sfc-gh-sbekman sfc-gh-sbekman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trainer part is perfect, the data part I'd ask for someone who works a lot with instruct data to validate.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants