Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion unsloth_zoo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,18 @@ def create_standalone_class(
# Fix RotaryEmbeddings being in the wrong precision
source = fix_rotary_embedding_dtype(source)

# ROCm: keep router linear inputs aligned with router weight dtype.
# Some compiled GPT-OSS router paths otherwise hit Float vs BF16 matmul.
if getattr(torch.version, "hip", None):
source = source.replace(
"router_logits = F.linear(hidden_states, self.weight, self.bias)",
"router_logits = F.linear(hidden_states if hidden_states.dtype == self.weight.dtype else hidden_states.to(self.weight.dtype), self.weight, self.bias)",
)
source = source.replace(
"router_logits = torch.nn.functional.linear(hidden_states, self.weight, self.bias)",
"router_logits = torch.nn.functional.linear(hidden_states if hidden_states.dtype == self.weight.dtype else hidden_states.to(self.weight.dtype), self.weight, self.bias)",
)

return source


Expand Down Expand Up @@ -2766,7 +2778,6 @@ def compile_mamba_ssm(UNSLOTH_ENABLE_LOGGING=False):
"Gemma3nTextModel",
"Glm4MoeLiteNaiveMoe",
]

FIX_GC_LAYER_CALLER_MODULES = [
"WhisperDecoder",
]
Expand Down
5 changes: 5 additions & 0 deletions unsloth_zoo/device_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
]

import torch
import os
import functools
from .utils import Version
import inspect
Expand Down Expand Up @@ -77,6 +78,10 @@ def get_device_count():
# HSA_STATUS_ERROR_EXCEPTION checks - sometimes AMD fails for BnB
ALLOW_BITSANDBYTES : bool = True
if DEVICE_TYPE == "hip":
# Disable AITER by default on ROCm to avoid JIT build locks and runtime faults.
# Users can override by explicitly setting env vars.
os.environ.setdefault("AITER_DISABLE", "1")
os.environ.setdefault("USE_ROCM_AITER_ROPE_BACKEND", "0")
try:
from bitsandbytes.nn.modules import Params4bit
if "blocksize = 64 if not HIP_ENVIRONMENT else 128" in inspect.getsource(Params4bit):
Expand Down
13 changes: 12 additions & 1 deletion unsloth_zoo/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,22 @@ def chunked_hidden_states_selective_log_softmax(

chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
hidden_dim = lm_head.shape[1]
vocab_dim = lm_head.shape[0]

all_per_token_logps = []

for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
# VLM GRPO paths can already pass logits here (shape [..., vocab]).
# In that case avoid projecting by lm_head again.
if chunk_hidden_states.shape[-1] == vocab_dim:
chunk_logits = chunk_hidden_states
elif chunk_hidden_states.shape[-1] == hidden_dim:
chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
else:
# Fallback: try projection path and let the underlying matmul raise a
# precise error if the dimensions are genuinely incompatible.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The elif and else can be combined into single call?

chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()

if logit_scale_multiply != 0.0:
chunk_logits = chunk_logits * logit_scale_multiply
Expand Down
12 changes: 12 additions & 0 deletions unsloth_zoo/temporary_patches/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,12 @@ def forward(
self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None
) -> torch.Tensor:
"""Forward using grouped_mm or loop fallback with LoRA support."""
# Keep activations aligned with expert weights to avoid mixed-dtype matmul errors.
target_dtype = getattr(getattr(self.down_proj, "weight", None), "dtype", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we still seeing any errors? I remember checking both fp16 and bf16

if target_dtype is None:
target_dtype = self.dtype
if hidden_states is not None and hidden_states.dtype != target_dtype:
hidden_states = hidden_states.to(target_dtype)
# Use optimized grouped_mm if available
if _check_torch_grouped_mm_supported():
return forward_native_grouped_mm(self, hidden_states, router_indices, routing_weights)
Expand Down Expand Up @@ -1253,6 +1259,12 @@ def moe_forward_inference_bf16(self, hidden_states):
elif hasattr(down_proj, "weight"):
down_proj = down_proj.weight

# Keep activations aligned with expert weights to avoid mixed-dtype
# bmm mismatches in inference kernels.
target_dtype = gate_up_proj.dtype if gate_up_proj is not None else down_proj.dtype
if hidden_states is not None and hidden_states.dtype != target_dtype:
hidden_states = hidden_states.to(target_dtype)

return _moe_forward_inference_bf16_kernel(
hidden_states,
routing_weights,
Expand Down
Loading