diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 86e0a529f..37a5896ab 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1200,6 +1200,18 @@ def create_standalone_class( # Fix RotaryEmbeddings being in the wrong precision source = fix_rotary_embedding_dtype(source) + # ROCm: keep router linear inputs aligned with router weight dtype. + # Some compiled GPT-OSS router paths otherwise hit Float vs BF16 matmul. + if getattr(torch.version, "hip", None): + source = source.replace( + "router_logits = F.linear(hidden_states, self.weight, self.bias)", + "router_logits = F.linear(hidden_states if hidden_states.dtype == self.weight.dtype else hidden_states.to(self.weight.dtype), self.weight, self.bias)", + ) + source = source.replace( + "router_logits = torch.nn.functional.linear(hidden_states, self.weight, self.bias)", + "router_logits = torch.nn.functional.linear(hidden_states if hidden_states.dtype == self.weight.dtype else hidden_states.to(self.weight.dtype), self.weight, self.bias)", + ) + return source @@ -2766,7 +2778,6 @@ def compile_mamba_ssm(UNSLOTH_ENABLE_LOGGING=False): "Gemma3nTextModel", "Glm4MoeLiteNaiveMoe", ] - FIX_GC_LAYER_CALLER_MODULES = [ "WhisperDecoder", ] diff --git a/unsloth_zoo/device_type.py b/unsloth_zoo/device_type.py index ccebc3e2e..898f382cd 100644 --- a/unsloth_zoo/device_type.py +++ b/unsloth_zoo/device_type.py @@ -24,6 +24,7 @@ ] import torch +import os import functools from .utils import Version import inspect @@ -77,6 +78,10 @@ def get_device_count(): # HSA_STATUS_ERROR_EXCEPTION checks - sometimes AMD fails for BnB ALLOW_BITSANDBYTES : bool = True if DEVICE_TYPE == "hip": + # Disable AITER by default on ROCm to avoid JIT build locks and runtime faults. + # Users can override by explicitly setting env vars. + os.environ.setdefault("AITER_DISABLE", "1") + os.environ.setdefault("USE_ROCM_AITER_ROPE_BACKEND", "0") try: from bitsandbytes.nn.modules import Params4bit if "blocksize = 64 if not HIP_ENVIRONMENT else 128" in inspect.getsource(Params4bit): diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 98e78f648..385aeba18 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -79,11 +79,22 @@ def chunked_hidden_states_selective_log_softmax( chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0) chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0) + hidden_dim = lm_head.shape[1] + vocab_dim = lm_head.shape[0] all_per_token_logps = [] for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index): - chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t() + # VLM GRPO paths can already pass logits here (shape [..., vocab]). + # In that case avoid projecting by lm_head again. + if chunk_hidden_states.shape[-1] == vocab_dim: + chunk_logits = chunk_hidden_states + elif chunk_hidden_states.shape[-1] == hidden_dim: + chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t() + else: + # Fallback: try projection path and let the underlying matmul raise a + # precise error if the dimensions are genuinely incompatible. + chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t() if logit_scale_multiply != 0.0: chunk_logits = chunk_logits * logit_scale_multiply diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 1388d78cc..e07d7437e 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -760,6 +760,12 @@ def forward( self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None ) -> torch.Tensor: """Forward using grouped_mm or loop fallback with LoRA support.""" + # Keep activations aligned with expert weights to avoid mixed-dtype matmul errors. + target_dtype = getattr(getattr(self.down_proj, "weight", None), "dtype", None) + if target_dtype is None: + target_dtype = self.dtype + if hidden_states is not None and hidden_states.dtype != target_dtype: + hidden_states = hidden_states.to(target_dtype) # Use optimized grouped_mm if available if _check_torch_grouped_mm_supported(): return forward_native_grouped_mm(self, hidden_states, router_indices, routing_weights) @@ -1253,6 +1259,12 @@ def moe_forward_inference_bf16(self, hidden_states): elif hasattr(down_proj, "weight"): down_proj = down_proj.weight + # Keep activations aligned with expert weights to avoid mixed-dtype + # bmm mismatches in inference kernels. + target_dtype = gate_up_proj.dtype if gate_up_proj is not None else down_proj.dtype + if hidden_states is not None and hidden_states.dtype != target_dtype: + hidden_states = hidden_states.to(target_dtype) + return _moe_forward_inference_bf16_kernel( hidden_states, routing_weights, diff --git a/unsloth_zoo/temporary_patches/misc.py b/unsloth_zoo/temporary_patches/misc.py index c4ff97629..16bcca64b 100644 --- a/unsloth_zoo/temporary_patches/misc.py +++ b/unsloth_zoo/temporary_patches/misc.py @@ -39,6 +39,216 @@ import re import os +def patch_disable_torchcodec_if_missing_ffmpeg(): + """ + Disable torchcodec in datasets if its shared libraries fail to load. + This allows falling back to soundfile-based decoding. + """ + import sys + try: + import torchcodec + try: + from torchcodec._core import ops as _ops + _ops.load_torchcodec_shared_libraries() + return + except Exception: + pass + except Exception: + pass + + try: + import datasets.config as _ds_config + _ds_config.TORCHCODEC_AVAILABLE = False + except Exception: + pass + + if "torchcodec" in sys.modules: + try: + del sys.modules["torchcodec"] + except Exception: + pass +pass +TEMPORARY_PATCHES.append(patch_disable_torchcodec_if_missing_ffmpeg) + +def patch_datasets_audio_decode_example(): + try: + import datasets.features.audio as audio_mod + except Exception as e: + return raise_error("datasets.features.audio", e) + + if getattr(audio_mod.Audio.decode_example, "_unsloth_patched", False): + return + + original_decode = audio_mod.Audio.decode_example + + def decode_example(self, value: dict, token_per_repo_id: Optional[dict] = None): + try: + return original_decode(self, value, token_per_repo_id = token_per_repo_id) + except Exception: + try: + import soundfile as sf + import numpy as np + from io import BytesIO + path, bytes_value = ( + (value["path"], value["bytes"]) + if value.get("bytes") is not None + else (value.get("path"), None) + ) + if bytes_value is not None: + data, sr = sf.read(BytesIO(bytes_value)) + else: + data, sr = sf.read(path) + if not isinstance(data, np.ndarray): + data = np.asarray(data) + + return {"array": data, "sampling_rate": sr} + except Exception: + raise + pass + + decode_example._unsloth_patched = True + patch_function( + audio_mod.Audio, + "decode_example", + decode_example, + fullgraph = False, + force = True, + match_level = "relaxed", + ) +pass +TEMPORARY_PATCHES.append(patch_datasets_audio_decode_example) + +def patch_deepseek_ocr_masked_scatter(): + def _apply_patch(module): + DeepseekOCRModel = module.DeepseekOCRModel + if getattr(DeepseekOCRModel, "_unsloth_masked_scatter_patched", False): + return + try: + import sys + if module.__name__.endswith(".modeling_deepseekocr"): + parent_name = module.__name__.rsplit(".modeling_deepseekocr", 1)[0] + sys.modules.setdefault("deepseek_ocr", sys.modules.get(parent_name, module)) + sys.modules.setdefault("deepseek_ocr.modeling_deepseekocr", module) + except Exception: + pass + orig_forward = DeepseekOCRModel.forward + + def _forward(self, *args, **kwargs): + import torch + orig_ms = torch.Tensor.masked_scatter + + def _safe_masked_scatter(tensor, mask, source): + try: + return orig_ms(tensor, mask, source) + except RuntimeError as e: + msg = str(e) + if "masked_scatter" not in msg and "size of tensor" not in msg: + raise + try: + out = tensor + if mask.shape[0] != tensor.shape[0]: + min_len = min(mask.shape[0], tensor.shape[0]) + mask = mask[:min_len] + out = tensor[:min_len] + target = int(mask.sum().item()) + src = source + if src.numel() != target: + if src.dim() > 1: + src = src.reshape(-1, src.shape[-1]) + else: + src = src.reshape(-1) + if src.shape[0] > target: + src = src[:target] + elif src.shape[0] < target: + pad_shape = (target - src.shape[0],) + src.shape[1:] + pad = src.new_zeros(pad_shape) + src = torch.cat([src, pad], dim=0) + patched = orig_ms(out, mask, src) + if patched.shape[0] != tensor.shape[0]: + full = tensor.clone() + full[:patched.shape[0]] = patched + return full + return patched + except Exception: + raise e + + torch.Tensor.masked_scatter = _safe_masked_scatter + try: + return orig_forward(self, *args, **kwargs) + finally: + torch.Tensor.masked_scatter = orig_ms + + DeepseekOCRModel.forward = _forward + DeepseekOCRModel._unsloth_masked_scatter_patched = True + + import sys + for name, mod in list(sys.modules.items()): + if name.endswith(".modeling_deepseekocr"): + try: + _apply_patch(mod) + return + except Exception: + pass + + try: + import transformers.dynamic_module_utils as _dmu + if not getattr(_dmu, "_unsloth_deepseek_ocr_hook", False): + _orig_get_class = _dmu.get_class_in_module + + def _get_class_in_module(class_name, module_path, force_reload=False): + cls = _orig_get_class(class_name, module_path, force_reload=force_reload) + try: + if module_path.endswith(".modeling_deepseekocr"): + _apply_patch(importlib.import_module(module_path)) + except Exception: + pass + return cls + + _dmu.get_class_in_module = _get_class_in_module + _dmu._unsloth_deepseek_ocr_hook = True + except Exception: + pass + + # module not available yet; patch on import via meta_path hook + try: + import importlib.machinery as _machinery + import importlib.abc as _abc + + class _DeepseekOCRHook(_abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + if not fullname.endswith(".modeling_deepseekocr"): + return None + spec = _machinery.PathFinder.find_spec(fullname, path) + if spec is None or spec.loader is None: + return spec + orig_loader = spec.loader + + class _Loader(_abc.Loader): + def create_module(self, spec): + if hasattr(orig_loader, "create_module"): + return orig_loader.create_module(spec) + return None + + def exec_module(self, module): + if hasattr(orig_loader, "exec_module"): + orig_loader.exec_module(module) + else: + module = orig_loader.load_module(fullname) + try: + _apply_patch(module) + finally: + return + + spec.loader = _Loader() + return spec + + if not any(isinstance(h, _DeepseekOCRHook) for h in sys.meta_path): + sys.meta_path.insert(0, _DeepseekOCRHook()) + except Exception as e: + return raise_error("DeepseekOCRModel", e) +pass +TEMPORARY_PATCHES.append(patch_deepseek_ocr_masked_scatter) + def patch_ministral3_config_mapping(): # Fix for Ministral-3 VL models which have text_config.model_type = "ministral3" # but transformers CONFIG_MAPPING doesn't have "ministral3" as a key @@ -259,6 +469,41 @@ def forward( TEMPORARY_PATCHES.append(patch_CsmDepthDecoderForCausalLM_forward) +def patch_rocm_disable_generate_cache(): + try: + import transformers.generation.utils as generation_utils + except Exception as e: + return raise_error("GenerationMixin.generate", e) + + if not getattr(getattr(torch, "version", None), "hip", None): + return + + if getattr(generation_utils.GenerationMixin, "_unsloth_rocm_generate_patched", False): + return + + original_generate = generation_utils.GenerationMixin.generate + + def generate(self, *args, **kwargs): + kwargs["use_cache"] = False + # HIP-safe generation: drop cache-only kwargs that can route into + # unsupported codepaths and trigger assert_async failures. + for key in ( + "past_key_values", + "cache_position", + "cache_implementation", + "cache_config", + "max_cache_length", + "cache_dtype", + ): + kwargs.pop(key, None) + return original_generate(self, *args, **kwargs) + + generation_utils.GenerationMixin.generate = generate + generation_utils.GenerationMixin._unsloth_rocm_generate_patched = True +pass +TEMPORARY_PATCHES.append(patch_rocm_disable_generate_cache) + + def patch_CsmForConditionalGeneration_forward(): try: import transformers.models.csm.modeling_csm @@ -362,7 +607,10 @@ def forward( ) depth_decoder_loss = depth_decoder_outputs.loss - loss = backbone_loss + depth_decoder_loss + if depth_decoder_loss is None: + loss = backbone_loss + else: + loss = backbone_loss + depth_decoder_loss return process_return(CsmOutputWithPast, { "loss" : loss, @@ -382,7 +630,13 @@ def forward( "depth_decoder_attentions" : depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None, }) pass - success = patch_function(transformers.models.csm.modeling_csm.CsmForConditionalGeneration, "forward", forward) + success = patch_function( + transformers.models.csm.modeling_csm.CsmForConditionalGeneration, + "forward", + forward, + force = True, + match_level = "relaxed", + ) if success: return # New transformers removes output_attentions and output_hidden_states @@ -409,10 +663,210 @@ def forward( kwargs = new_kwargs.pop('kwargs', dict()) new_kwargs.update(kwargs) return old_forward(**new_kwargs) - patch_function(transformers.models.csm.modeling_csm.CsmForConditionalGeneration, "forward", forward) + patch_function( + transformers.models.csm.modeling_csm.CsmForConditionalGeneration, + "forward", + forward, + force = True, + match_level = "relaxed", + ) pass TEMPORARY_PATCHES.append(patch_CsmForConditionalGeneration_forward) +def patch_CsmAttention_force_eager(): + if not getattr(torch.version, "hip", None): + return + try: + import transformers.models.csm.modeling_csm as csm + from transformers.models.csm.modeling_csm import apply_rotary_pos_emb, eager_attention_forward + except Exception as e: + return raise_error("CsmAttention.forward", e) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + squeeze_seq = False + if hidden_states.dim() == 2: + hidden_size = getattr(self.o_proj, "in_features", None) + if hidden_size and hidden_states.shape[-1] % hidden_size == 0 and hidden_states.shape[-1] != hidden_size: + hidden_states = hidden_states.view(hidden_states.shape[0], -1, hidden_size) + else: + hidden_states = hidden_states.unsqueeze(1) + squeeze_seq = True + elif hidden_states.dim() == 3: + pass + elif hidden_states.dim() != 3: + hidden_states = hidden_states.reshape( + hidden_states.shape[0], -1, hidden_states.shape[-1] + ) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + # Keep RoPE sequence axis aligned to query length on ROCm generation paths. + q_len = query_states.shape[-2] + if getattr(cos, "shape", None) is not None and len(cos.shape) >= 2 and cos.shape[-2] != q_len: + cos = cos[..., -q_len:, :] + if getattr(sin, "shape", None) is not None and len(sin.shape) >= 2 and sin.shape[-2] != q_len: + sin = sin[..., -q_len:, :] + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + if key_states.size(2) != value_states.size(2): + min_len = min(key_states.size(2), value_states.size(2)) + key_states = key_states[:, :, :min_len, :] + value_states = value_states[:, :, :min_len, :] + if attention_mask is not None and attention_mask.size(-1) != min_len: + attention_mask = attention_mask[..., :min_len] + if attention_mask is not None: + q_len = query_states.size(-2) + k_len = key_states.size(-2) + if attention_mask.dim() >= 4: + if attention_mask.size(-2) != q_len: + attention_mask = attention_mask[..., -q_len:, :] + if attention_mask.size(-1) != k_len: + attention_mask = attention_mask[..., -k_len:] + + attn_output, attn_weights = eager_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + target_hidden = getattr(self.o_proj, "in_features", None) + if ( + target_hidden + and attn_output.shape[-1] != target_hidden + and attn_output.shape[-1] % target_hidden == 0 + ): + attn_output = attn_output.reshape(attn_output.shape[0], -1, target_hidden) + squeeze_seq = False + if squeeze_seq and attn_output.dim() >= 3 and attn_output.shape[1] == 1: + attn_output = attn_output.squeeze(1) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + pass + + patch_function( + csm.CsmAttention, + "forward", + forward, + force = True, + match_level = "relaxed", + ) +pass +TEMPORARY_PATCHES.append(patch_CsmAttention_force_eager) + +def patch_whisper_feature_extractor_resample(): + try: + from transformers.models.whisper.feature_extraction_whisper import WhisperFeatureExtractor + except Exception as e: + return raise_error("WhisperFeatureExtractor.__call__", e) + + original_call = WhisperFeatureExtractor.__call__ + + def _resample_array(arr, orig_sr, target_sr): + try: + import torch + import torchaudio.functional as F + tensor = torch.tensor(arr) + if tensor.dim() == 1: + tensor = tensor.unsqueeze(0) + resampled = F.resample(tensor, orig_sr, target_sr) + return resampled.squeeze(0).cpu().numpy() + except Exception: + return arr + + def __call__(self, raw_speech, sampling_rate=None, *args, **kwargs): + try: + return original_call(self, raw_speech, sampling_rate = sampling_rate, *args, **kwargs) + except ValueError as e: + if sampling_rate is None or "sampling rate of 16000" not in str(e): + raise + target_sr = 16000 + if isinstance(raw_speech, (list, tuple)): + raw_speech = [ + _resample_array(x, sampling_rate, target_sr) for x in raw_speech + ] + else: + raw_speech = _resample_array(raw_speech, sampling_rate, target_sr) + return original_call(self, raw_speech, sampling_rate = target_sr, *args, **kwargs) + pass + + patch_function( + WhisperFeatureExtractor, + "__call__", + __call__, + force = True, + match_level = "relaxed", + ) +pass +TEMPORARY_PATCHES.append(patch_whisper_feature_extractor_resample) + +def patch_transformers_ffmpeg_read(): + try: + from transformers.pipelines import audio_utils as audio_utils + except Exception as e: + return raise_error("transformers.pipelines.audio_utils.ffmpeg_read", e) + + original_ffmpeg_read = audio_utils.ffmpeg_read + + def ffmpeg_read(path, sampling_rate): + try: + return original_ffmpeg_read(path, sampling_rate) + except Exception: + try: + import soundfile as sf + import numpy as np + from io import BytesIO + if isinstance(path, (bytes, bytearray)): + data, sr = sf.read(BytesIO(path)) + else: + data, sr = sf.read(path) + if data.ndim > 1: + data = np.mean(data, axis=1) + if sampling_rate is not None and sr != sampling_rate: + try: + import torch + import torchaudio.functional as F + tensor = torch.tensor(data).unsqueeze(0) + data = F.resample(tensor, sr, sampling_rate).squeeze(0).cpu().numpy() + except Exception: + pass + return data + except Exception: + raise + audio_utils.ffmpeg_read = ffmpeg_read + try: + import transformers.pipelines.automatic_speech_recognition as asr + if hasattr(asr, "ffmpeg_read"): + asr.ffmpeg_read = ffmpeg_read + except Exception: + pass +pass +TEMPORARY_PATCHES.append(patch_transformers_ffmpeg_read) + def patch_transformers_masks(): if os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1":