Skip to content

[Bug] Gemma3 fine-tuning: ConstantVariable(str: 'Missing required positional argument: x') [And more!] #3996

@kabachuha

Description

@kabachuha
  1. Did you update? pip install --upgrade unsloth unsloth_zoo Yes
  2. Colab or Kaggle or local / cloud, Local
  3. Number GPUs used, use nvidia-smi, 2
  4. Which notebook? Please link! https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma3_(27B)_A100-Conversational.ipynb
  5. Which Unsloth version, TRL version, transformers version, PyTorch version? 2026.1.4, 0.24.0, 4.57.6, 2.10.0+cu128
  6. Which trainer? SFTTrainer, GRPOTrainer etc SFTTrainer

Put Minimal code to reproduce error here: Just the "Gemma3_(27B)_A100_Conversational Notebook", with the only change of MultiGPU via device_map = "balanced"

The error trace itself:

---------------------------------------------------------------------------
ObservedTypeError                         Traceback (most recent call last)
ObservedTypeError: ConstantVariable(str: 'Missing required positional argument: x')

The above exception was the direct cause of the following exception:

Unsupported                               Traceback (most recent call last)
Cell In[14], line 1
----> 1 trainer_stats = trainer.train()

File /media/kabachuha/xiangliu/Ivyel-2/unsloth_compiled_cache/UnslothSFTTrainer.py:64, in prepare_for_training_mode.<locals>.wrapper(self, *args, **kwargs)
     62 if hasattr(self, 'model') and hasattr(self.model, "for_training"):
     63     self.model.for_training(use_gradient_checkpointing=use_gc)
---> 64 output = f(self, *args, **kwargs)
     65 # Restore previous mode when possible
     66 if hasattr(self, 'model') and hasattr(self.model, "for_inference"):

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/transformers/trainer.py:2325, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2323         hf_hub_utils.enable_progress_bars()
   2324 else:
-> 2325     return inner_training_loop(
   2326         args=args,
   2327         resume_from_checkpoint=resume_from_checkpoint,
   2328         trial=trial,
   2329         ignore_keys_for_eval=ignore_keys_for_eval,
   2330     )

File <string>:328, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

File /media/kabachuha/xiangliu/Ivyel-2/unsloth_compiled_cache/UnslothSFTTrainer.py:1220, in _UnslothSFTTrainer.training_step(self, *args, **kwargs)
   1218 def training_step(self, *args, **kwargs):
   1219     with self.maybe_activation_offload_context:
-> 1220         return super().training_step(*args, **kwargs)

File <string>:40, in _unsloth_training_step(self, model, inputs, num_items_in_batch)

File /media/kabachuha/xiangliu/Ivyel-2/unsloth_compiled_cache/UnslothSFTTrainer.py:1209, in _UnslothSFTTrainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
   1206 def compute_loss(
   1207     self, model, inputs, return_outputs = False, num_items_in_batch = None
   1208 ):
-> 1209     outputs = super().compute_loss(
   1210         model,
   1211         inputs,
   1212         return_outputs = return_outputs,
   1213         num_items_in_batch = num_items_in_batch,
   1214     )
   1215     return outputs

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/unsloth/models/_utils.py:1661, in _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs)
   1654     name = inner_model.__class__.__name__
   1656     logger.warning_once(
   1657         f"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\n"
   1658         "Using gradient accumulation will be very slightly less accurate.\n"
   1659         "Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient"
   1660     )
-> 1661 outputs = self._old_compute_loss(model, inputs, *args, **kwargs)
   1662 return outputs

File <string>:36, in compute_loss(self, model, inputs, return_outputs, num_items_in_batch)

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
   1774     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1775 else:
-> 1776     return self._call_impl(*args, **kwargs)

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
   1782 # If we don't have any hooks, we want to skip the rest of the logic in
   1783 # this function, and just call forward.
   1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1785         or _global_backward_pre_hooks or _global_backward_hooks
   1786         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787     return forward_call(*args, **kwargs)
   1789 result = None
   1790 called_always_called_hooks = set()

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/accelerate/utils/operations.py:819, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    818 def forward(*args, **kwargs):
--> 819     return model_forward(*args, **kwargs)

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/accelerate/utils/operations.py:807, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    806 def __call__(self, *args, **kwargs):
--> 807     return convert_to_fp32(self.model_forward(*args, **kwargs))

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/amp/autocast_mode.py:44, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     41 @functools.wraps(func)
     42 def decorate_autocast(*args, **kwargs):
     43     with autocast_instance:
---> 44         return func(*args, **kwargs)

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/peft/peft_model.py:1923, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1921     with self._enable_peft_forward_hooks(**kwargs):
   1922         kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1923         return self.base_model(
   1924             input_ids=input_ids,
   1925             attention_mask=attention_mask,
   1926             inputs_embeds=inputs_embeds,
   1927             labels=labels,
   1928             output_attentions=output_attentions,
   1929             output_hidden_states=output_hidden_states,
   1930             return_dict=return_dict,
   1931             **kwargs,
   1932         )
   1934 batch_size = _get_batch_size(input_ids, inputs_embeds)
   1935 if attention_mask is not None:
   1936     # concat prompt attention mask

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
   1774     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1775 else:
-> 1776     return self._call_impl(*args, **kwargs)

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
   1782 # If we don't have any hooks, we want to skip the rest of the logic in
   1783 # this function, and just call forward.
   1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1785         or _global_backward_pre_hooks or _global_backward_hooks
   1786         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787     return forward_call(*args, **kwargs)
   1789 result = None
   1790 called_always_called_hooks = set()

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/peft/tuners/tuners_utils.py:311, in BaseTuner.forward(self, *args, **kwargs)
    310 def forward(self, *args: Any, **kwargs: Any):
--> 311     return self.model.forward(*args, **kwargs)

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/accelerate/hooks.py:175, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    173         output = module._old_forward(*args, **kwargs)
    174 else:
--> 175     output = module._old_forward(*args, **kwargs)
    176 return module._hf_hook.post_forward(module, output)

File /media/kabachuha/xiangliu/Ivyel-2/unsloth_compiled_cache/unsloth_compiled_module_gemma3.py:902, in Gemma3ForConditionalGeneration.forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, logits_to_keep, **lm_kwargs)
    884 def forward(
    885     self,
    886     input_ids: Optional[torch.LongTensor] = None,
   (...)    900     **lm_kwargs,
    901 ) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
--> 902     return Gemma3ForConditionalGeneration_forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, logits_to_keep, **lm_kwargs)

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/_dynamo/external_utils.py:203, in get_nonrecursive_disable_wrapper.<locals>.nonrecursive_disable_wrapper(*args, **kwargs)
    199 if torch.compiler.is_exporting():
    200     raise RuntimeError(
    201         "Non-recursive torch.compiler.disable is not supported with torch.export."
    202     )
--> 203 return fn(*args, **kwargs)

File /media/kabachuha/xiangliu/Ivyel-2/unsloth_compiled_cache/unsloth_compiled_module_gemma3.py:712, in Gemma3ForConditionalGeneration_forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, logits_to_keep, **lm_kwargs)
    707 output_hidden_states = (
    708     output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    709 )
    710 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
--> 712 outputs = self.model(
    713     input_ids=input_ids,
    714     pixel_values=pixel_values,
    715     token_type_ids=token_type_ids,
    716     attention_mask=attention_mask,
    717     position_ids=position_ids,
    718     past_key_values=past_key_values,
    719     inputs_embeds=inputs_embeds,
    720     use_cache=use_cache,
    721     labels=mask_attention_mask_out(labels = labels, attention_mask = attention_mask),
    722     output_attentions=output_attentions,
    723     output_hidden_states=output_hidden_states,
    724     return_dict=return_dict,
    725     cache_position=cache_position,
    726     **lm_kwargs,
    727 )
    729 hidden_states = outputs[0]
    730 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
   1774     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1775 else:
-> 1776     return self._call_impl(*args, **kwargs)

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
   1782 # If we don't have any hooks, we want to skip the rest of the logic in
   1783 # this function, and just call forward.
   1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1785         or _global_backward_pre_hooks or _global_backward_hooks
   1786         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787     return forward_call(*args, **kwargs)
   1789 result = None
   1790 called_always_called_hooks = set()

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/transformers/utils/generic.py:918, in can_return_tuple.<locals>.wrapper(self, *args, **kwargs)
    916 if return_dict_passed is not None:
    917     return_dict = return_dict_passed
--> 918 output = func(self, *args, **kwargs)
    919 if not return_dict and not isinstance(output, tuple):
    920     output = output.to_tuple()

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py:957, in Gemma3Model.forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, **lm_kwargs)
    951     # Create the masks
    952     causal_mask_mapping = {
    953         "full_attention": create_causal_mask(**mask_kwargs),
    954         "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
    955     }
--> 957 outputs = self.language_model(
    958     attention_mask=causal_mask_mapping,
    959     position_ids=position_ids,
    960     past_key_values=past_key_values,
    961     inputs_embeds=inputs_embeds,
    962     use_cache=use_cache,
    963     output_attentions=output_attentions,
    964     output_hidden_states=output_hidden_states,
    965     return_dict=True,
    966     cache_position=cache_position,
    967     **lm_kwargs,
    968 )
    970 return Gemma3ModelOutputWithPast(
    971     last_hidden_state=outputs.last_hidden_state,
    972     past_key_values=outputs.past_key_values if use_cache else None,
   (...)    975     image_hidden_states=image_features if pixel_values is not None else None,
    976 )

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
   1774     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1775 else:
-> 1776     return self._call_impl(*args, **kwargs)

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
   1782 # If we don't have any hooks, we want to skip the rest of the logic in
   1783 # this function, and just call forward.
   1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1785         or _global_backward_pre_hooks or _global_backward_hooks
   1786         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787     return forward_call(*args, **kwargs)
   1789 result = None
   1790 called_always_called_hooks = set()

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/transformers/utils/generic.py:1072, in check_model_inputs.<locals>.wrapped_fn.<locals>.wrapper(self, *args, **kwargs)
   1069                 monkey_patched_layers.append((module, original_forward))
   1071 try:
-> 1072     outputs = func(self, *args, **kwargs)
   1073 except TypeError as original_exception:
   1074     # If we get a TypeError, it's possible that the model is not receiving the recordable kwargs correctly.
   1075     # Get a TypeError even after removing the recordable kwargs -> re-raise the original exception
   1076     # Otherwise -> we're probably missing `**kwargs` in the decorated function
   1077     kwargs_without_recordable = {k: v for k, v in kwargs.items() if k not in recordable_keys}

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py:570, in Gemma3TextModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, cache_position, **kwargs)
    567 if output_hidden_states:
    568     all_hidden_states += (hidden_states,)
--> 570 layer_outputs = decoder_layer(
    571     hidden_states,
    572     position_embeddings_global=position_embeddings_global,
    573     position_embeddings_local=position_embeddings_local,
    574     attention_mask=causal_mask_mapping[decoder_layer.attention_type],
    575     position_ids=position_ids,
    576     past_key_values=past_key_values,
    577     output_attentions=output_attentions,
    578     use_cache=use_cache,
    579     cache_position=cache_position,
    580     **kwargs,
    581 )
    583 hidden_states = layer_outputs[0]
    585 if output_attentions:

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/transformers/modeling_layers.py:93, in GradientCheckpointingLayer.__call__(self, *args, **kwargs)
     90         message = message.rstrip(",") + "."
     91         logger.warning_once(message)
---> 93     return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
     94 return super().__call__(*args, **kwargs)

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/_compile.py:54, in _disable_dynamo.<locals>.inner(*args, **kwargs)
     51     disable_fn = torch._dynamo.disable(fn, recursive, wrapping=False)
     52     fn.__dynamo_disable = disable_fn  # type: ignore[attr-defined]
---> 54 return disable_fn(*args, **kwargs)

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:1181, in DisableContext.__call__.<locals>._fn(*args, **kwargs)
   1171         with fx_traceback.annotate(
   1172             {
   1173                 "_torchdynamo_disable": True,
   (...)   1178             }
   1179         ):
   1180             return fn(*args, **kwargs)
-> 1181     return fn(*args, **kwargs)
   1182 finally:
   1183     set_eval_frame(None)

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/utils/checkpoint.py:505, in checkpoint(function, use_reentrant, context_fn, determinism_check, debug, early_stop, *args, **kwargs)
    500     if context_fn is not noop_context_fn or debug is not False:
    501         raise ValueError(
    502             "Passing `context_fn` or `debug` is only supported when "
    503             "use_reentrant=False."
    504         )
--> 505     return CheckpointFunction.apply(function, preserve, *args)
    506 else:
    507     gen = _checkpoint_without_reentrant_generator(
    508         function, preserve, context_fn, determinism_check, debug, early_stop, *args, **kwargs
    509     )

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/autograd/function.py:583, in Function.apply(cls, *args, **kwargs)
    580 if not torch._C._are_functorch_transforms_active():
    581     # See NOTE: [functorch vjp and autograd interaction]
    582     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 583     return super().apply(*args, **kwargs)  # type: ignore[misc]
    585 if not is_setup_ctx_defined:
    586     raise RuntimeError(
    587         "In order to use an autograd.Function with functorch transforms "
    588         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    589         "staticmethod. For more details, please see "
    590         "https://pytorch.org/docs/main/notes/extending.func.html"
    591     )

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/unsloth_zoo/gradient_checkpointing.py:498, in UnslothCheckpointFunction.forward(ctx, run_function, preserve_rng_state, *args)
    495 if ctx._requires_gradient: ctx.save_for_backward(*tensor_inputs)
    497 with torch.no_grad():
--> 498     outputs = run_function(*args)
    500 if use_gpu_buffer: MAIN_STREAM.wait_stream(EXTRA_STREAM)
    501 return outputs

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
   1774     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1775 else:
-> 1776     return self._call_impl(*args, **kwargs)

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
   1782 # If we don't have any hooks, we want to skip the rest of the logic in
   1783 # this function, and just call forward.
   1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1785         or _global_backward_pre_hooks or _global_backward_hooks
   1786         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787     return forward_call(*args, **kwargs)
   1789 result = None
   1790 called_always_called_hooks = set()

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/accelerate/hooks.py:175, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    173         output = module._old_forward(*args, **kwargs)
    174 else:
--> 175     output = module._old_forward(*args, **kwargs)
    176 return module._hf_hook.post_forward(module, output)

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/transformers/utils/deprecation.py:172, in deprecate_kwarg.<locals>.wrapper.<locals>.wrapped_func(*args, **kwargs)
    168 elif minimum_action in (Action.NOTIFY, Action.NOTIFY_ALWAYS) and not is_torchdynamo_compiling():
    169     # DeprecationWarning is ignored by default, so we use FutureWarning instead
    170     warnings.warn(message, FutureWarning, stacklevel=2)
--> 172 return func(*args, **kwargs)

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py:382, in Gemma3DecoderLayer.forward(self, hidden_states, position_embeddings_global, position_embeddings_local, attention_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, **kwargs)
    379 else:
    380     position_embeddings = position_embeddings_global
--> 382 hidden_states, self_attn_weights = self.self_attn(
    383     hidden_states=hidden_states,
    384     position_embeddings=position_embeddings,
    385     attention_mask=attention_mask,
    386     position_ids=position_ids,
    387     past_key_values=past_key_values,
    388     output_attentions=output_attentions,
    389     use_cache=use_cache,
    390     cache_position=cache_position,
    391     **kwargs,
    392 )
    393 hidden_states = self.post_attention_layernorm(hidden_states)
    394 hidden_states = residual + hidden_states

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
   1774     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1775 else:
-> 1776     return self._call_impl(*args, **kwargs)

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
   1782 # If we don't have any hooks, we want to skip the rest of the logic in
   1783 # this function, and just call forward.
   1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1785         or _global_backward_pre_hooks or _global_backward_hooks
   1786         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787     return forward_call(*args, **kwargs)
   1789 result = None
   1790 called_always_called_hooks = set()

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/accelerate/hooks.py:175, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    173         output = module._old_forward(*args, **kwargs)
    174 else:
--> 175     output = module._old_forward(*args, **kwargs)
    176 return module._hf_hook.post_forward(module, output)

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/unsloth_zoo/temporary_patches/gemma.py:765, in patch_Gemma3Attention_generic.<locals>.forward(self, hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs)
    756 def forward(
    757     self,
    758     hidden_states: torch.Tensor,
   (...)    763     **kwargs: KWARGS_TYPE,
    764 ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
--> 765     return forward_function(self, hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs)

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/unsloth_zoo/temporary_patches/gemma.py:680, in patch_Gemma3Attention_generic.<locals>.forward_function(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)
    650 # 2. Upcast Q, K, V for norm and RoPE, and then transpose for attention
    651 # (bsz, num_specific_heads, q_len, head_dim)
    652 """ ####### REPLACED WITH TORCH_COMPILED_MODULE
    653 query_states_fp32 = query_states_fp16.view(query_hidden_shape).to(torch.float32).transpose(1, 2)
    654 key_states_fp32   = key_states_fp16.view(kv_hidden_shape).to(torch.float32).transpose(1, 2)
   (...)    671 query_states_fp32, key_states_fp32 = apply_rotary_pos_emb(query_states_fp32, key_states_fp32, cos = cos_fp32, sin = sin_fp32)
    672 """
    673 (
    674     query_states_fp32,
    675     key_states_fp32,
    676     value_states_fp32,
    677     cos_fp32,
    678     sin_fp32,
    679     attn_mask_for_sdpa,
--> 680 ) = prepare(
    681     hidden_states,
    682     query_states_fp16,
    683     key_states_fp16,
    684     value_states_fp16,
    685     query_hidden_shape,
    686     kv_hidden_shape,
    687     position_embeddings,
    688     attention_mask,
    689     self.q_norm,
    690     self.k_norm,
    691 )
    693 # 5. KV Cache update (using fp32 K, V)
    694 if past_key_value is not None:

File /media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:963, in _TorchDynamoContext.__call__.<locals>.compile_wrapper(*args, **kwargs)
    961         cur_exn = cur_exn.__cause__
    962     # pyrefly: ignore [invalid-inheritance]
--> 963     raise e.with_traceback(None) from e.__cause__  # User compiler error
    964 except ShortenTraceback as e:
    965     # Failures in the backend likely don't have useful
    966     # data in the TorchDynamo frames, so we strip them out.
    967     raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1

Unsupported: Observed exception
  Explanation: Dynamo found no exception handler at the top-level compiled function when encountering an exception. Exception will propagate outside the compiled region.
  Hint: Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled.
  Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.

  Developer debug context: raised exception TypeError([ConstantVariable(str: 'Missing required positional argument: x')])

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0088.html

from user code:
   File "/media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/unsloth_zoo/temporary_patches/gemma.py", line 589, in prepare
    query_norm_out_fp16 = q_norm(query_states_fp32) # self.q_norm doesn't use auto compiler
  File "/media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
  File "/media/kabachuha/holodok01/miniconda3/envs/simpletuner/lib/python3.12/site-packages/accelerate/hooks.py", line 175, in new_forward
    output = module._old_forward(*args, **kwargs)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions