Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions unsloth_zoo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,31 @@ def higher_precision_sqrt_mean(source):
pass


def fix_apply_rotary_pos_emb_mixed_dtype(source):
# When FORCE_FLOAT32 is set, the model is bfloat16 but inference uses float16 autocast.
# RoPE forward returns cos/sin in bfloat16 (hidden_states dtype), but q/k are float16
# (from autocast). apply_rotary_pos_emb does float16 * bfloat16 = float32, breaking
# flex_attention's strict dtype validation and causing cache dtype mismatches.
# Fix: cast cos/sin to match q dtype inside apply_rotary_pos_emb.
if "apply_rotary_pos_emb" not in source:
return source
# Use regex to handle any indentation level
source = re.sub(
r"([ \t]*)(cos = cos\.unsqueeze\(unsqueeze_dim\)\n)"
r"(\1sin = sin\.unsqueeze\(unsqueeze_dim\)\n)"
r"(\1)(q_embed)",
r"\1cos = cos.unsqueeze(unsqueeze_dim)\n"
r"\1sin = sin.unsqueeze(unsqueeze_dim)\n"
r"\1if cos.dtype != q.dtype:\n"
r"\1 cos = cos.to(q.dtype)\n"
r"\1 sin = sin.to(q.dtype)\n"
r"\4\5",
source,
)
return source
pass


def fix_rotary_embedding_dtype(source):
# Rotary Embeddings might be left in float32 since we upcast it
# We downcast it to float16 if we see float32 for X's dtype
Expand Down Expand Up @@ -766,6 +791,9 @@ def create_new_function(
# Fix all softmax low precisions to float32
new_source = higher_precision_softmax(new_source)

# Fix apply_rotary_pos_emb mixed dtype (bfloat16 cos/sin * float16 q/k = float32)
new_source = fix_apply_rotary_pos_emb_mixed_dtype(new_source)

if new_source[0] == " ":
spaces = new_source.find("def")
new_source = new_source.split("\n")
Expand Down Expand Up @@ -1200,6 +1228,9 @@ def create_standalone_class(
# Fix RotaryEmbeddings being in the wrong precision
source = fix_rotary_embedding_dtype(source)

# Fix apply_rotary_pos_emb mixed dtype (bfloat16 cos/sin * float16 q/k = float32)
source = fix_apply_rotary_pos_emb_mixed_dtype(source)

return source


Expand Down
35 changes: 35 additions & 0 deletions unsloth_zoo/temporary_patches/gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,38 @@ def get_placeholder_mask(
patch_function(transformers.models.gemma3n.modeling_gemma3n.Gemma3nModel, "get_placeholder_mask", get_placeholder_mask, match_level="relaxed")
pass
TEMPORARY_PATCHES.append(patch_Gemma3nModel_get_placeholder_mask)

def patch_Gemma3nModel_get_audio_features():
# Fix upstream transformers bug: Gemma3nModel.forward does
# audio_features = audio_features.pooler_output (overwrites dataclass with tensor)
# audio_mask = audio_features.audio_mel_mask (crashes - tensor has no audio_mel_mask)
# We patch get_audio_features to attach audio_mel_mask onto the pooler_output tensor.
try:
import transformers.models.gemma3n.modeling_gemma3n
from transformers.models.gemma3n.modeling_gemma3n import Gemma3nModel
except Exception as e:
return raise_error("Gemma3nModel.get_audio_features", e)

if not hasattr(Gemma3nModel, "get_audio_features"):
return

def get_audio_features(
self,
input_features: torch.Tensor,
input_features_mask: torch.Tensor,
**kwargs,
):
audio_outputs = self.audio_tower(
input_features, input_features_mask, return_dict=True, **kwargs
)
audio_embeds = self.embed_audio(inputs_embeds=audio_outputs.last_hidden_state)
audio_outputs.pooler_output = audio_embeds
# Attach audio_mel_mask to pooler_output tensor so it survives
# the variable reassignment in Gemma3nModel.forward
if hasattr(audio_outputs, "audio_mel_mask") and audio_outputs.audio_mel_mask is not None:
audio_outputs.pooler_output.audio_mel_mask = audio_outputs.audio_mel_mask
return audio_outputs
pass
patch_function(transformers.models.gemma3n.modeling_gemma3n.Gemma3nModel, "get_audio_features", get_audio_features, match_level="relaxed")
pass
TEMPORARY_PATCHES.append(patch_Gemma3nModel_get_audio_features)
69 changes: 69 additions & 0 deletions unsloth_zoo/temporary_patches/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,3 +1269,72 @@ def _patched_safe_apply(model_config, tokenizer, conversation, *,
pass
pass
TEMPORARY_PATCHES.append(patch_vllm_safe_apply_chat_template)


def patch_static_cache_dtype_mismatch():
"""Fix StaticLayer/StaticSlidingWindowLayer index_copy_ dtype mismatch.

When using float16 autocast on a bfloat16 model (FORCE_FLOAT32 path for Gemma3),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was effecting gpt_oss
That is why I had to do this conversion
Once this lands, can we explore removing that ?

RoPE produces float32 keys (float16 * bfloat16) while values stay float16.
The cache is allocated with key_states.dtype (float32), causing index_copy_ to
fail when storing float16 values. Fix: eagerly initialize with consistent dtype
in update wrappers, then cast all incoming states before delegating.
"""
if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0":
return
try:
from transformers.cache_utils import StaticLayer, StaticSlidingWindowLayer
except Exception:
return

if hasattr(StaticLayer.update, "_unsloth_patched"):
return

def _resolve_dtype(key_states, value_states):
"""Pick lower-precision dtype when key/value dtypes differ."""
if key_states.dtype == value_states.dtype:
return key_states.dtype
return (
value_states.dtype
if value_states.element_size() <= key_states.element_size()
else key_states.dtype
)

# Patch StaticLayer.update
_orig_update = StaticLayer.update
def patched_update(self, key_states, value_states, cache_kwargs=None):
if not self.is_initialized:
# Eagerly initialize with consistent dtype so the original update
# sees is_initialized=True and uses our casted states
if key_states.dtype != value_states.dtype:
target = _resolve_dtype(key_states, value_states)
self.lazy_initialization(key_states.to(target), value_states.to(target))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep lazy_initialization call compatible across transformers

This wrapper now calls lazy_initialization with both key and value tensors, but in transformers 4.x (StaticLayer/StaticSlidingWindowLayer) the method only accepts key_states; with UNSLOTH_FORCE_FLOAT32=1, the first cache update raises TypeError before inference can proceed on static-cache paths. Because this patch has no version/arity guard, it introduces a runtime regression for 4.x users (the same issue is repeated in patched_sw_update).

Useful? React with 👍 / 👎.

else:
self.lazy_initialization(key_states, value_states)
# Cast incoming states to match the cache dtype
if key_states.dtype != self.keys.dtype:
key_states = key_states.to(self.keys.dtype)
if value_states.dtype != self.values.dtype:
value_states = value_states.to(self.values.dtype)
return _orig_update(self, key_states, value_states, cache_kwargs)
patched_update._unsloth_patched = True
StaticLayer.update = patched_update

# StaticSlidingWindowLayer has its own update method (not inherited)
_orig_sw_update = StaticSlidingWindowLayer.update
def patched_sw_update(self, key_states, value_states, cache_kwargs=None):
if not self.is_initialized:
if key_states.dtype != value_states.dtype:
target = _resolve_dtype(key_states, value_states)
self.lazy_initialization(key_states.to(target), value_states.to(target))
else:
self.lazy_initialization(key_states, value_states)
if key_states.dtype != self.keys.dtype:
key_states = key_states.to(self.keys.dtype)
if value_states.dtype != self.values.dtype:
value_states = value_states.to(self.values.dtype)
return _orig_sw_update(self, key_states, value_states, cache_kwargs)
patched_sw_update._unsloth_patched = True
StaticSlidingWindowLayer.update = patched_sw_update
Comment on lines +1305 to +1338
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic within patched_update and patched_sw_update is identical. To improve maintainability and reduce code duplication, you can extract the common logic into a helper function. This will make the code cleaner and easier to manage in the future.

    def _patched_update_logic(self, key_states, value_states):
        if not self.is_initialized:
            # Eagerly initialize with consistent dtype so the original update
            # sees is_initialized=True and uses our casted states
            if key_states.dtype != value_states.dtype:
                target = _resolve_dtype(key_states, value_states)
                self.lazy_initialization(key_states.to(target), value_states.to(target))
            else:
                self.lazy_initialization(key_states, value_states)
        # Cast incoming states to match the cache dtype
        if key_states.dtype != self.keys.dtype:
            key_states = key_states.to(self.keys.dtype)
        if value_states.dtype != self.values.dtype:
            value_states = value_states.to(self.values.dtype)
        return key_states, value_states

    # Patch StaticLayer.update
    _orig_update = StaticLayer.update
    def patched_update(self, key_states, value_states, cache_kwargs=None):
        key_states, value_states = _patched_update_logic(self, key_states, value_states)
        return _orig_update(self, key_states, value_states, cache_kwargs)
    patched_update._unsloth_patched = True
    StaticLayer.update = patched_update

    # StaticSlidingWindowLayer has its own update method (not inherited)
    _orig_sw_update = StaticSlidingWindowLayer.update
    def patched_sw_update(self, key_states, value_states, cache_kwargs=None):
        key_states, value_states = _patched_update_logic(self, key_states, value_states)
        return _orig_sw_update(self, key_states, value_states, cache_kwargs)
    patched_sw_update._unsloth_patched = True
    StaticSlidingWindowLayer.update = patched_sw_update

pass
TEMPORARY_PATCHES.append(patch_static_cache_dtype_mismatch)