-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Open
Labels
Description
- Did you update?
pip install --upgrade unsloth unsloth_zooYes ColaborKaggleor local / cloud, Local- Number GPUs used, use
nvidia-smi, 2 - Which notebook? Please link! https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma3_(27B)_A100-Conversational.ipynb
- Which Unsloth version, TRL version, transformers version, PyTorch version? 2026.1.4, 0.24.0, 4.57.6, 2.10.0+cu128
- Which trainer?
SFTTrainer,GRPOTraineretc SFTTrainer
Put Minimal code to reproduce error here: Just the "Gemma3_(27B)_A100_Conversational Notebook", with the only change of MultiGPU via device_map = "balanced"
The error trace itself:
---------------------------------------------------------------------------
ObservedTypeError Traceback (most recent call last)
ObservedTypeError: ConstantVariable(str: 'Missing required positional argument: x')
The above exception was the direct cause of the following exception:
Unsupported Traceback (most recent call last)
Cell In[14], line 1
----> 1 trainer_stats = trainer.train()
File /media/kabachuha/xiangliu/Ivyel-2/unsloth_compiled_cache/UnslothSFTTrainer.py:64, in prepare_for_training_mode.<locals>.wrapper(self, *args, **kwargs)
62 if hasattr(self, 'model') and hasattr(self.model, "for_training"):
63 self.model.for_training(use_gradient_checkpointing=use_gc)
---> 64 output = f(self, *args, **kwargs)
65 # Restore previous mode when possible
66 if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/transformers/trainer.py:2325, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
2323 hf_hub_utils.enable_progress_bars()
2324 else:
-> 2325 return inner_training_loop(
2326 args=args,
2327 resume_from_checkpoint=resume_from_checkpoint,
2328 trial=trial,
2329 ignore_keys_for_eval=ignore_keys_for_eval,
2330 )
File <string>:328, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
File /media/kabachuha/xiangliu/Ivyel-2/unsloth_compiled_cache/UnslothSFTTrainer.py:1220, in _UnslothSFTTrainer.training_step(self, *args, **kwargs)
1218 def training_step(self, *args, **kwargs):
1219 with self.maybe_activation_offload_context:
-> 1220 return super().training_step(*args, **kwargs)
File <string>:40, in _unsloth_training_step(self, model, inputs, num_items_in_batch)
File /media/kabachuha/xiangliu/Ivyel-2/unsloth_compiled_cache/UnslothSFTTrainer.py:1209, in _UnslothSFTTrainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
1206 def compute_loss(
1207 self, model, inputs, return_outputs = False, num_items_in_batch = None
1208 ):
-> 1209 outputs = super().compute_loss(
1210 model,
1211 inputs,
1212 return_outputs = return_outputs,
1213 num_items_in_batch = num_items_in_batch,
1214 )
1215 return outputs
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/unsloth/models/_utils.py:1661, in _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs)
1654 name = inner_model.__class__.__name__
1656 logger.warning_once(
1657 f"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\n"
1658 "Using gradient accumulation will be very slightly less accurate.\n"
1659 "Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient"
1660 )
-> 1661 outputs = self._old_compute_loss(model, inputs, *args, **kwargs)
1662 return outputs
File <string>:36, in compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
1774 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1775 else:
-> 1776 return self._call_impl(*args, **kwargs)
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
1782 # If we don't have any hooks, we want to skip the rest of the logic in
1783 # this function, and just call forward.
1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1785 or _global_backward_pre_hooks or _global_backward_hooks
1786 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787 return forward_call(*args, **kwargs)
1789 result = None
1790 called_always_called_hooks = set()
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/accelerate/utils/operations.py:819, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
818 def forward(*args, **kwargs):
--> 819 return model_forward(*args, **kwargs)
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/accelerate/utils/operations.py:807, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
806 def __call__(self, *args, **kwargs):
--> 807 return convert_to_fp32(self.model_forward(*args, **kwargs))
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/amp/autocast_mode.py:44, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
41 @functools.wraps(func)
42 def decorate_autocast(*args, **kwargs):
43 with autocast_instance:
---> 44 return func(*args, **kwargs)
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/peft/peft_model.py:1923, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
1921 with self._enable_peft_forward_hooks(**kwargs):
1922 kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1923 return self.base_model(
1924 input_ids=input_ids,
1925 attention_mask=attention_mask,
1926 inputs_embeds=inputs_embeds,
1927 labels=labels,
1928 output_attentions=output_attentions,
1929 output_hidden_states=output_hidden_states,
1930 return_dict=return_dict,
1931 **kwargs,
1932 )
1934 batch_size = _get_batch_size(input_ids, inputs_embeds)
1935 if attention_mask is not None:
1936 # concat prompt attention mask
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
1774 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1775 else:
-> 1776 return self._call_impl(*args, **kwargs)
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
1782 # If we don't have any hooks, we want to skip the rest of the logic in
1783 # this function, and just call forward.
1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1785 or _global_backward_pre_hooks or _global_backward_hooks
1786 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787 return forward_call(*args, **kwargs)
1789 result = None
1790 called_always_called_hooks = set()
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/peft/tuners/tuners_utils.py:311, in BaseTuner.forward(self, *args, **kwargs)
310 def forward(self, *args: Any, **kwargs: Any):
--> 311 return self.model.forward(*args, **kwargs)
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/accelerate/hooks.py:175, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
173 output = module._old_forward(*args, **kwargs)
174 else:
--> 175 output = module._old_forward(*args, **kwargs)
176 return module._hf_hook.post_forward(module, output)
File /media/kabachuha/xiangliu/Ivyel-2/unsloth_compiled_cache/unsloth_compiled_module_gemma3.py:902, in Gemma3ForConditionalGeneration.forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, logits_to_keep, **lm_kwargs)
884 def forward(
885 self,
886 input_ids: Optional[torch.LongTensor] = None,
(...) 900 **lm_kwargs,
901 ) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
--> 902 return Gemma3ForConditionalGeneration_forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, logits_to_keep, **lm_kwargs)
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/_dynamo/external_utils.py:203, in get_nonrecursive_disable_wrapper.<locals>.nonrecursive_disable_wrapper(*args, **kwargs)
199 if torch.compiler.is_exporting():
200 raise RuntimeError(
201 "Non-recursive torch.compiler.disable is not supported with torch.export."
202 )
--> 203 return fn(*args, **kwargs)
File /media/kabachuha/xiangliu/Ivyel-2/unsloth_compiled_cache/unsloth_compiled_module_gemma3.py:712, in Gemma3ForConditionalGeneration_forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, logits_to_keep, **lm_kwargs)
707 output_hidden_states = (
708 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
709 )
710 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
--> 712 outputs = self.model(
713 input_ids=input_ids,
714 pixel_values=pixel_values,
715 token_type_ids=token_type_ids,
716 attention_mask=attention_mask,
717 position_ids=position_ids,
718 past_key_values=past_key_values,
719 inputs_embeds=inputs_embeds,
720 use_cache=use_cache,
721 labels=mask_attention_mask_out(labels = labels, attention_mask = attention_mask),
722 output_attentions=output_attentions,
723 output_hidden_states=output_hidden_states,
724 return_dict=return_dict,
725 cache_position=cache_position,
726 **lm_kwargs,
727 )
729 hidden_states = outputs[0]
730 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
1774 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1775 else:
-> 1776 return self._call_impl(*args, **kwargs)
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
1782 # If we don't have any hooks, we want to skip the rest of the logic in
1783 # this function, and just call forward.
1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1785 or _global_backward_pre_hooks or _global_backward_hooks
1786 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787 return forward_call(*args, **kwargs)
1789 result = None
1790 called_always_called_hooks = set()
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/transformers/utils/generic.py:918, in can_return_tuple.<locals>.wrapper(self, *args, **kwargs)
916 if return_dict_passed is not None:
917 return_dict = return_dict_passed
--> 918 output = func(self, *args, **kwargs)
919 if not return_dict and not isinstance(output, tuple):
920 output = output.to_tuple()
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py:957, in Gemma3Model.forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, **lm_kwargs)
951 # Create the masks
952 causal_mask_mapping = {
953 "full_attention": create_causal_mask(**mask_kwargs),
954 "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
955 }
--> 957 outputs = self.language_model(
958 attention_mask=causal_mask_mapping,
959 position_ids=position_ids,
960 past_key_values=past_key_values,
961 inputs_embeds=inputs_embeds,
962 use_cache=use_cache,
963 output_attentions=output_attentions,
964 output_hidden_states=output_hidden_states,
965 return_dict=True,
966 cache_position=cache_position,
967 **lm_kwargs,
968 )
970 return Gemma3ModelOutputWithPast(
971 last_hidden_state=outputs.last_hidden_state,
972 past_key_values=outputs.past_key_values if use_cache else None,
(...) 975 image_hidden_states=image_features if pixel_values is not None else None,
976 )
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
1774 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1775 else:
-> 1776 return self._call_impl(*args, **kwargs)
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
1782 # If we don't have any hooks, we want to skip the rest of the logic in
1783 # this function, and just call forward.
1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1785 or _global_backward_pre_hooks or _global_backward_hooks
1786 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787 return forward_call(*args, **kwargs)
1789 result = None
1790 called_always_called_hooks = set()
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/transformers/utils/generic.py:1072, in check_model_inputs.<locals>.wrapped_fn.<locals>.wrapper(self, *args, **kwargs)
1069 monkey_patched_layers.append((module, original_forward))
1071 try:
-> 1072 outputs = func(self, *args, **kwargs)
1073 except TypeError as original_exception:
1074 # If we get a TypeError, it's possible that the model is not receiving the recordable kwargs correctly.
1075 # Get a TypeError even after removing the recordable kwargs -> re-raise the original exception
1076 # Otherwise -> we're probably missing `**kwargs` in the decorated function
1077 kwargs_without_recordable = {k: v for k, v in kwargs.items() if k not in recordable_keys}
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py:570, in Gemma3TextModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, cache_position, **kwargs)
567 if output_hidden_states:
568 all_hidden_states += (hidden_states,)
--> 570 layer_outputs = decoder_layer(
571 hidden_states,
572 position_embeddings_global=position_embeddings_global,
573 position_embeddings_local=position_embeddings_local,
574 attention_mask=causal_mask_mapping[decoder_layer.attention_type],
575 position_ids=position_ids,
576 past_key_values=past_key_values,
577 output_attentions=output_attentions,
578 use_cache=use_cache,
579 cache_position=cache_position,
580 **kwargs,
581 )
583 hidden_states = layer_outputs[0]
585 if output_attentions:
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/transformers/modeling_layers.py:93, in GradientCheckpointingLayer.__call__(self, *args, **kwargs)
90 message = message.rstrip(",") + "."
91 logger.warning_once(message)
---> 93 return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
94 return super().__call__(*args, **kwargs)
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/_compile.py:54, in _disable_dynamo.<locals>.inner(*args, **kwargs)
51 disable_fn = torch._dynamo.disable(fn, recursive, wrapping=False)
52 fn.__dynamo_disable = disable_fn # type: ignore[attr-defined]
---> 54 return disable_fn(*args, **kwargs)
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:1181, in DisableContext.__call__.<locals>._fn(*args, **kwargs)
1171 with fx_traceback.annotate(
1172 {
1173 "_torchdynamo_disable": True,
(...) 1178 }
1179 ):
1180 return fn(*args, **kwargs)
-> 1181 return fn(*args, **kwargs)
1182 finally:
1183 set_eval_frame(None)
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/utils/checkpoint.py:505, in checkpoint(function, use_reentrant, context_fn, determinism_check, debug, early_stop, *args, **kwargs)
500 if context_fn is not noop_context_fn or debug is not False:
501 raise ValueError(
502 "Passing `context_fn` or `debug` is only supported when "
503 "use_reentrant=False."
504 )
--> 505 return CheckpointFunction.apply(function, preserve, *args)
506 else:
507 gen = _checkpoint_without_reentrant_generator(
508 function, preserve, context_fn, determinism_check, debug, early_stop, *args, **kwargs
509 )
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/autograd/function.py:583, in Function.apply(cls, *args, **kwargs)
580 if not torch._C._are_functorch_transforms_active():
581 # See NOTE: [functorch vjp and autograd interaction]
582 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 583 return super().apply(*args, **kwargs) # type: ignore[misc]
585 if not is_setup_ctx_defined:
586 raise RuntimeError(
587 "In order to use an autograd.Function with functorch transforms "
588 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
589 "staticmethod. For more details, please see "
590 "https://pytorch.org/docs/main/notes/extending.func.html"
591 )
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/unsloth_zoo/gradient_checkpointing.py:498, in UnslothCheckpointFunction.forward(ctx, run_function, preserve_rng_state, *args)
495 if ctx._requires_gradient: ctx.save_for_backward(*tensor_inputs)
497 with torch.no_grad():
--> 498 outputs = run_function(*args)
500 if use_gpu_buffer: MAIN_STREAM.wait_stream(EXTRA_STREAM)
501 return outputs
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
1774 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1775 else:
-> 1776 return self._call_impl(*args, **kwargs)
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
1782 # If we don't have any hooks, we want to skip the rest of the logic in
1783 # this function, and just call forward.
1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1785 or _global_backward_pre_hooks or _global_backward_hooks
1786 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787 return forward_call(*args, **kwargs)
1789 result = None
1790 called_always_called_hooks = set()
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/accelerate/hooks.py:175, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
173 output = module._old_forward(*args, **kwargs)
174 else:
--> 175 output = module._old_forward(*args, **kwargs)
176 return module._hf_hook.post_forward(module, output)
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/transformers/utils/deprecation.py:172, in deprecate_kwarg.<locals>.wrapper.<locals>.wrapped_func(*args, **kwargs)
168 elif minimum_action in (Action.NOTIFY, Action.NOTIFY_ALWAYS) and not is_torchdynamo_compiling():
169 # DeprecationWarning is ignored by default, so we use FutureWarning instead
170 warnings.warn(message, FutureWarning, stacklevel=2)
--> 172 return func(*args, **kwargs)
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py:382, in Gemma3DecoderLayer.forward(self, hidden_states, position_embeddings_global, position_embeddings_local, attention_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, **kwargs)
379 else:
380 position_embeddings = position_embeddings_global
--> 382 hidden_states, self_attn_weights = self.self_attn(
383 hidden_states=hidden_states,
384 position_embeddings=position_embeddings,
385 attention_mask=attention_mask,
386 position_ids=position_ids,
387 past_key_values=past_key_values,
388 output_attentions=output_attentions,
389 use_cache=use_cache,
390 cache_position=cache_position,
391 **kwargs,
392 )
393 hidden_states = self.post_attention_layernorm(hidden_states)
394 hidden_states = residual + hidden_states
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
1774 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1775 else:
-> 1776 return self._call_impl(*args, **kwargs)
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
1782 # If we don't have any hooks, we want to skip the rest of the logic in
1783 # this function, and just call forward.
1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1785 or _global_backward_pre_hooks or _global_backward_hooks
1786 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787 return forward_call(*args, **kwargs)
1789 result = None
1790 called_always_called_hooks = set()
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/accelerate/hooks.py:175, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
173 output = module._old_forward(*args, **kwargs)
174 else:
--> 175 output = module._old_forward(*args, **kwargs)
176 return module._hf_hook.post_forward(module, output)
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/unsloth_zoo/temporary_patches/gemma.py:765, in patch_Gemma3Attention_generic.<locals>.forward(self, hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs)
756 def forward(
757 self,
758 hidden_states: torch.Tensor,
(...) 763 **kwargs: KWARGS_TYPE,
764 ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
--> 765 return forward_function(self, hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs)
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/unsloth_zoo/temporary_patches/gemma.py:680, in patch_Gemma3Attention_generic.<locals>.forward_function(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)
650 # 2. Upcast Q, K, V for norm and RoPE, and then transpose for attention
651 # (bsz, num_specific_heads, q_len, head_dim)
652 """ ####### REPLACED WITH TORCH_COMPILED_MODULE
653 query_states_fp32 = query_states_fp16.view(query_hidden_shape).to(torch.float32).transpose(1, 2)
654 key_states_fp32 = key_states_fp16.view(kv_hidden_shape).to(torch.float32).transpose(1, 2)
(...) 671 query_states_fp32, key_states_fp32 = apply_rotary_pos_emb(query_states_fp32, key_states_fp32, cos = cos_fp32, sin = sin_fp32)
672 """
673 (
674 query_states_fp32,
675 key_states_fp32,
676 value_states_fp32,
677 cos_fp32,
678 sin_fp32,
679 attn_mask_for_sdpa,
--> 680 ) = prepare(
681 hidden_states,
682 query_states_fp16,
683 key_states_fp16,
684 value_states_fp16,
685 query_hidden_shape,
686 kv_hidden_shape,
687 position_embeddings,
688 attention_mask,
689 self.q_norm,
690 self.k_norm,
691 )
693 # 5. KV Cache update (using fp32 K, V)
694 if past_key_value is not None:
File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:963, in _TorchDynamoContext.__call__.<locals>.compile_wrapper(*args, **kwargs)
961 cur_exn = cur_exn.__cause__
962 # pyrefly: ignore [invalid-inheritance]
--> 963 raise e.with_traceback(None) from e.__cause__ # User compiler error
964 except ShortenTraceback as e:
965 # Failures in the backend likely don't have useful
966 # data in the TorchDynamo frames, so we strip them out.
967 raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
Unsupported: Observed exception
Explanation: Dynamo found no exception handler at the top-level compiled function when encountering an exception. Exception will propagate outside the compiled region.
Hint: Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled.
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
Developer debug context: raised exception TypeError([ConstantVariable(str: 'Missing required positional argument: x')])
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0088.html
from user code:
File "/media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/unsloth_zoo/temporary_patches/gemma.py", line 589, in prepare
query_norm_out_fp16 = q_norm(query_states_fp32) # self.q_norm doesn't use auto compiler
File "/media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
return forward_call(*args, **kwargs)
File "/media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/accelerate/hooks.py", line 175, in new_forward
output = module._old_forward(*args, **kwargs)
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
Reactions are currently unavailable