Skip to content

Fix Gemma3N audio notebooks to use reentrant checkpointing#215

Open
danielhanchen wants to merge 1 commit intomainfrom
dh/fix-gemma3n-audio-reentrant-checkpointing
Open

Fix Gemma3N audio notebooks to use reentrant checkpointing#215
danielhanchen wants to merge 1 commit intomainfrom
dh/fix-gemma3n-audio-reentrant-checkpointing

Conversation

@danielhanchen
Copy link
Copy Markdown
Contributor

Summary

Switch use_reentrant=False to use_reentrant=True in Gemma3N Audio notebooks (both Colab and Kaggle versions).

Non-reentrant gradient checkpointing triggers AOT autograd compilation of the backward pass, which fails on Gemma3N audio conformer because variable-length audio tensors cause stride mismatches:

AssertionError: expected size 2==2, stride 1928==1936 at dim=0
This error most often comes from a incorrect fake (aka meta) kernel for a custom op.

The audio conformer's conv/norm layers produce tensors whose strides vary with audio clip duration (e.g. 241 vs 242 mel frames), and the compiled backward graph expects fixed strides.

Changes

  • nb/Gemma3N_(4B)-Audio.ipynb: use_reentrant=False -> use_reentrant=True
  • nb/Kaggle-Gemma3N_(4B)-Audio.ipynb: use_reentrant=False -> use_reentrant=True

Companion to unslothai/unsloth#4629 which also adds a server-side guard in vision.py to force use_reentrant=True for Gemma3N regardless of the notebook setting.

Test plan

  • Gemma3N audio training with use_reentrant=True: trains successfully (loss 12.9 -> 0.56 over 100 steps on librispeech)
  • No changes to any other notebooks

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Note

Gemini is unable to generate a review for this pull request due to the file types involved not being currently supported.

Non-reentrant gradient checkpointing (use_reentrant=False) causes
AOT autograd stride assertion failures during backward pass with
Gemma3N audio conformer due to variable-length audio tensors:

    AssertionError: expected size 2==2, stride 1928==1936 at dim=0

Switch to use_reentrant=True which avoids AOT autograd compilation
of the backward pass entirely.

Companion to unslothai/unsloth#4629 which adds a server-side guard.
@danielhanchen danielhanchen force-pushed the dh/fix-gemma3n-audio-reentrant-checkpointing branch from c72a24c to 9774858 Compare March 27, 2026 06:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant