11import inspect
22import os
33import textwrap
4- from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union
4+ from typing import Any , Callable , Dict , List , Optional , Sequence , Set , Tuple , Union
55import torch
66from .dynamic_shapes import ModelInputs
77from .onnx_plug import EagerDirectReplacementWithOnnx
@@ -340,6 +340,7 @@ def __init__(
340340 inline : bool = True ,
341341 convert_after_n_calls : int = 2 ,
342342 patch_kwargs : Optional [Dict [str , Any ]] = None ,
343+ skip_kwargs_names : Optional [Set [str ]] = None ,
343344 ):
344345 super ().__init__ ()
345346 self ._model_to_call = mod
@@ -354,6 +355,7 @@ def __init__(
354355 self ._patch_kwargs = patch_kwargs
355356 self ._method_src = None
356357 self .verbose = verbose
358+ self .skip_kwargs_names = skip_kwargs_names
357359 self ._to_onnx_kwargs = dict (
358360 input_names = input_names ,
359361 target_opset = target_opset ,
@@ -370,6 +372,7 @@ def __init__(
370372 onnx_plugs = onnx_plugs ,
371373 inline = inline ,
372374 )
375+ self ._export_done = False
373376
374377 def __str__ (self ) -> str :
375378 return self .__repr__ ()
@@ -381,14 +384,28 @@ def __repr__(self) -> str:
381384 )
382385
383386 def forward (self , * args , ** kwargs ):
384- self ._inputs .append ((args , kwargs ))
385- if self .verbose :
386- print (
387- f"[method_to_onnx] input[{ len (self ._inputs )- 1 } ]: "
388- f"{ string_type ((args , kwargs ), with_shape = True )} "
387+ if not self ._export_done :
388+ self ._inputs .append (
389+ (
390+ args ,
391+ (
392+ kwargs
393+ if not kwargs or not self .skip_kwargs_names
394+ else {
395+ k : v for k , v in kwargs .items () if k not in self .skip_kwargs_names
396+ }
397+ ),
398+ )
389399 )
390- if len (self ._inputs ) >= self ._convert_after_n_calls :
391- self ._convert_method_to_onnx ()
400+ if self .verbose :
401+ print (
402+ f"[method_to_onnx] input[{ len (self ._inputs )- 1 } ]: "
403+ f"{ string_type (self ._inputs [- 1 ], with_shape = True )} "
404+ )
405+ if len (self ._inputs ) >= self ._convert_after_n_calls :
406+ self ._convert_method_to_onnx ()
407+ del self ._inputs [:]
408+ self ._export_done = True
392409 return self ._method_call (* args , ** kwargs )
393410
394411 def _convert_method_to_onnx (self ):
@@ -473,6 +490,7 @@ def method_to_onnx(
473490 inline : bool = True ,
474491 convert_after_n_calls : int = 2 ,
475492 patch_kwargs : Optional [Dict [str , Any ]] = None ,
493+ skip_kwargs_names : Optional [Set [str ]] = None ,
476494) -> Callable :
477495 """
478496 Exports one method into ONNX for a module into ONNX.
@@ -499,8 +517,12 @@ def method_to_onnx(
499517 :param inline: inline local functions
500518 :param convert_after_n_calls: converts the model after this number of calls.
501519 :param patch_kwargs: patch arguments
520+ :param skip_kwargs_names: use default values for these parameters part of
521+ the signature of the method to export
502522 :return: the output of the selected exporter, usually a structure including
503523 an onnx model
524+
525+ See :ref:`l-plot-tiny-llm-export-method-generate` for an example.
504526 """
505527 wrapped_model = _WrapperToExportMethodToOnnx (
506528 mod = mod ,
@@ -521,5 +543,6 @@ def method_to_onnx(
521543 inline = inline ,
522544 convert_after_n_calls = convert_after_n_calls ,
523545 patch_kwargs = patch_kwargs ,
546+ skip_kwargs_names = skip_kwargs_names ,
524547 )
525548 return wrapped_model
0 commit comments