Skip to content

Fix Gemma3 Vision + Gemma3N audio inference on transformers 5.x#492

Open
danielhanchen wants to merge 2 commits intomainfrom
fix/gemma3n-audio-mel-mask
Open

Fix Gemma3 Vision + Gemma3N audio inference on transformers 5.x#492
danielhanchen wants to merge 2 commits intomainfrom
fix/gemma3n-audio-mel-mask

Conversation

@danielhanchen
Copy link
Contributor

Summary

  • Fix Gemma3 Vision inference crash (index_copy_ dtype mismatch + flex_attention dtype validation) when using dtype=torch.float16 on transformers 5.x
  • Fix Gemma3N Conversational audio inference crash (AttributeError: 'Tensor' object has no attribute 'audio_mel_mask') on transformers 5.x (upstream transformers bug)

Changes

1. compiler.py -- RoPE mixed dtype fix

When FORCE_FLOAT32=1 (Gemma3 float16 path), the model loads in bfloat16 but inference uses float16 autocast. RoPE returns bfloat16 cos/sin (from hidden_states dtype), attention projections produce float16 q/k/v (under autocast), and apply_rotary_pos_emb promotes float16 * bfloat16 = float32. This breaks flex_attention (requires matching q/k/v dtypes) and causes cache dtype mismatches.

Adds fix_apply_rotary_pos_emb_mixed_dtype() source transformation that inserts cos/sin dtype casting to match q dtype inside apply_rotary_pos_emb during compilation. Applied in both create_standalone_class() and create_new_function().

2. temporary_patches/misc.py -- Static cache dtype mismatch

Adds patch_static_cache_dtype_mismatch() that patches StaticLayer and StaticSlidingWindowLayer to handle key/value dtype mismatches during cache initialization and updates. Uses lower precision dtype for cache allocation and casts incoming states to match cache dtype. Belt-and-suspenders defense for the FORCE_FLOAT32 path.

3. temporary_patches/gemma3n.py -- Audio mel mask fix

Upstream transformers bug in modeling_gemma3n.py:2114-2115: audio_features = audio_features.pooler_output overwrites the dataclass with a raw Tensor before .audio_mel_mask is extracted. Patches get_audio_features to attach audio_mel_mask onto the pooler_output tensor so it survives the variable reassignment.

Test plan

  • Gemma3 4B Vision inference with dtype=torch.float16 on transformers 5.1.0 -- PASS
  • Gemma3 4B Vision multi-inference (cache reuse) -- PASS
  • Gemma3 4B bfloat16 inference regression -- PASS
  • Gemma3 4B training + post-training inference on transformers 5.1.0 -- PASS (loss=1.486)
  • Gemma3N audio_mel_mask patch verified (patch registered, upstream bug detected, fix applied)

Upstream bug in modeling_gemma3n.py: Gemma3nModel.forward overwrites
audio_features with pooler_output (a tensor) before extracting
audio_mel_mask from the dataclass, causing AttributeError.

Patch get_audio_features to attach audio_mel_mask onto the
pooler_output tensor so it survives the variable reassignment.
When using float16 dtype with Gemma3 (FORCE_FLOAT32 path), the model loads
in bfloat16 but inference uses float16 autocast. This creates a dtype chain
where RoPE returns bfloat16 cos/sin, attention projections produce float16
q/k/v under autocast, and apply_rotary_pos_emb promotes float16 * bfloat16
to float32. The result: float32 keys, float16 values, and flex_attention
rejecting mismatched q/k/v dtypes.

Two fixes:

1. compiler.py: Add fix_apply_rotary_pos_emb_mixed_dtype() source
   transformation that inserts dtype casting (cos/sin to match q dtype)
   inside apply_rotary_pos_emb during compilation. Applied in both
   create_standalone_class() and create_new_function() to cover all
   compiled code paths.

2. misc.py: Add patch_static_cache_dtype_mismatch() that patches
   StaticLayer and StaticSlidingWindowLayer to handle key/value dtype
   mismatches during cache initialization and updates. Uses lower
   precision dtype for cache allocation and casts incoming states to
   match cache dtype. Only activates when UNSLOTH_FORCE_FLOAT32=1.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request delivers crucial fixes to enhance the stability and correctness of Gemma3 Vision and Gemma3N Conversational audio inference within the transformers 5.x ecosystem. It addresses complex data type inconsistencies that led to crashes, particularly when using float16 precision, and provides a targeted workaround for an upstream bug affecting audio mask propagation. These changes ensure more robust and reliable model execution across various configurations.

Highlights

  • Gemma3 Vision Inference Fix: Fixed a crash in Gemma3 Vision inference caused by index_copy_ dtype mismatch and flex_attention dtype validation when using dtype=torch.float16 on transformers 5.x.
  • Gemma3N Audio Inference Fix: Resolved a crash in Gemma3N Conversational audio inference (AttributeError: 'Tensor' object has no attribute 'audio_mel_mask') on transformers 5.x, addressing an upstream bug.
  • RoPE Mixed Dtype Handling: Implemented a source transformation, fix_apply_rotary_pos_emb_mixed_dtype, to cast cos/sin dtypes to match q dtype within apply_rotary_pos_emb during compilation, preventing float32 intermediate results.
  • Static Cache Dtype Mismatch Patch: Added patch_static_cache_dtype_mismatch to temporary_patches/misc.py to correctly handle key/value dtype mismatches during cache initialization and updates in StaticLayer and StaticSlidingWindowLayer for FORCE_FLOAT32 paths.
  • Audio Mel Mask Preservation: Introduced a patch in temporary_patches/gemma3n.py to attach audio_mel_mask onto the pooler_output tensor within get_audio_features, ensuring it survives variable reassignment and is accessible later.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • unsloth_zoo/compiler.py
    • Added a new function fix_apply_rotary_pos_emb_mixed_dtype to handle mixed dtypes in Rotary Positional Embeddings.
    • Applied fix_apply_rotary_pos_emb_mixed_dtype within the create_new_function compilation flow.
    • Applied fix_apply_rotary_pos_emb_mixed_dtype within the create_standalone_class compilation flow.
  • unsloth_zoo/temporary_patches/gemma3n.py
    • Added patch_Gemma3nModel_get_audio_features to address an upstream bug by attaching audio_mel_mask to the pooler_output tensor.
    • Appended patch_Gemma3nModel_get_audio_features to the list of TEMPORARY_PATCHES.
  • unsloth_zoo/temporary_patches/misc.py
    • Added patch_static_cache_dtype_mismatch to resolve key/value dtype inconsistencies in StaticLayer and StaticSlidingWindowLayer.
    • Appended patch_static_cache_dtype_mismatch to the list of TEMPORARY_PATCHES.
Activity
  • No human activity (comments, reviews, approvals) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

source,
)
return source
pass
if hasattr(audio_outputs, "audio_mel_mask") and audio_outputs.audio_mel_mask is not None:
audio_outputs.pooler_output.audio_mel_mask = audio_outputs.audio_mel_mask
return audio_outputs
pass
return audio_outputs
pass
patch_function(transformers.models.gemma3n.modeling_gemma3n.Gemma3nModel, "get_audio_features", get_audio_features, match_level="relaxed")
pass
return _orig_sw_update(self, key_states, value_states, cache_kwargs)
patched_sw_update._unsloth_patched = True
StaticSlidingWindowLayer.update = patched_sw_update
pass
pass
TEMPORARY_PATCHES.append(patch_Gemma3nModel_get_placeholder_mask)

def patch_Gemma3nModel_get_audio_features():
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 9da41f099a

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

# sees is_initialized=True and uses our casted states
if key_states.dtype != value_states.dtype:
target = _resolve_dtype(key_states, value_states)
self.lazy_initialization(key_states.to(target), value_states.to(target))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep lazy_initialization call compatible across transformers

This wrapper now calls lazy_initialization with both key and value tensors, but in transformers 4.x (StaticLayer/StaticSlidingWindowLayer) the method only accepts key_states; with UNSLOTH_FORCE_FLOAT32=1, the first cache update raises TypeError before inference can proceed on static-cache paths. Because this patch has no version/arity guard, it introduces a runtime regression for 4.x users (the same issue is repeated in patched_sw_update).

Useful? React with 👍 / 👎.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several important fixes for Gemma3 Vision and Gemma3N audio inference on transformers 5.x, addressing crashes related to data type mismatches and an upstream bug. The changes are well-targeted and include:

  1. A fix for mixed-precision RoPE calculations in compiler.py.
  2. A patch for static cache data type mismatches in temporary_patches/misc.py.
  3. A workaround for an audio feature bug in temporary_patches/gemma3n.py.

The code is generally of high quality. I've included one suggestion to improve maintainability by refactoring duplicated code.

Comment on lines +1305 to +1338
def patched_update(self, key_states, value_states, cache_kwargs=None):
if not self.is_initialized:
# Eagerly initialize with consistent dtype so the original update
# sees is_initialized=True and uses our casted states
if key_states.dtype != value_states.dtype:
target = _resolve_dtype(key_states, value_states)
self.lazy_initialization(key_states.to(target), value_states.to(target))
else:
self.lazy_initialization(key_states, value_states)
# Cast incoming states to match the cache dtype
if key_states.dtype != self.keys.dtype:
key_states = key_states.to(self.keys.dtype)
if value_states.dtype != self.values.dtype:
value_states = value_states.to(self.values.dtype)
return _orig_update(self, key_states, value_states, cache_kwargs)
patched_update._unsloth_patched = True
StaticLayer.update = patched_update

# StaticSlidingWindowLayer has its own update method (not inherited)
_orig_sw_update = StaticSlidingWindowLayer.update
def patched_sw_update(self, key_states, value_states, cache_kwargs=None):
if not self.is_initialized:
if key_states.dtype != value_states.dtype:
target = _resolve_dtype(key_states, value_states)
self.lazy_initialization(key_states.to(target), value_states.to(target))
else:
self.lazy_initialization(key_states, value_states)
if key_states.dtype != self.keys.dtype:
key_states = key_states.to(self.keys.dtype)
if value_states.dtype != self.values.dtype:
value_states = value_states.to(self.values.dtype)
return _orig_sw_update(self, key_states, value_states, cache_kwargs)
patched_sw_update._unsloth_patched = True
StaticSlidingWindowLayer.update = patched_sw_update
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic within patched_update and patched_sw_update is identical. To improve maintainability and reduce code duplication, you can extract the common logic into a helper function. This will make the code cleaner and easier to manage in the future.

    def _patched_update_logic(self, key_states, value_states):
        if not self.is_initialized:
            # Eagerly initialize with consistent dtype so the original update
            # sees is_initialized=True and uses our casted states
            if key_states.dtype != value_states.dtype:
                target = _resolve_dtype(key_states, value_states)
                self.lazy_initialization(key_states.to(target), value_states.to(target))
            else:
                self.lazy_initialization(key_states, value_states)
        # Cast incoming states to match the cache dtype
        if key_states.dtype != self.keys.dtype:
            key_states = key_states.to(self.keys.dtype)
        if value_states.dtype != self.values.dtype:
            value_states = value_states.to(self.values.dtype)
        return key_states, value_states

    # Patch StaticLayer.update
    _orig_update = StaticLayer.update
    def patched_update(self, key_states, value_states, cache_kwargs=None):
        key_states, value_states = _patched_update_logic(self, key_states, value_states)
        return _orig_update(self, key_states, value_states, cache_kwargs)
    patched_update._unsloth_patched = True
    StaticLayer.update = patched_update

    # StaticSlidingWindowLayer has its own update method (not inherited)
    _orig_sw_update = StaticSlidingWindowLayer.update
    def patched_sw_update(self, key_states, value_states, cache_kwargs=None):
        key_states, value_states = _patched_update_logic(self, key_states, value_states)
        return _orig_sw_update(self, key_states, value_states, cache_kwargs)
    patched_sw_update._unsloth_patched = True
    StaticSlidingWindowLayer.update = patched_sw_update

def patch_static_cache_dtype_mismatch():
"""Fix StaticLayer/StaticSlidingWindowLayer index_copy_ dtype mismatch.

When using float16 autocast on a bfloat16 model (FORCE_FLOAT32 path for Gemma3),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was effecting gpt_oss
That is why I had to do this conversion
Once this lands, can we explore removing that ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments