1+ import inspect
12import os
2- from typing import Any , Dict , Optional , Tuple , Union
3+ from typing import Any , Dict , List , Optional , Tuple , Union
34import time
45import torch
56from ..helpers import max_diff , string_type , string_diff
@@ -45,6 +46,69 @@ def get_inputs_for_task(task: str, config: Optional[Any] = None) -> Dict[str, An
4546 return f (model = None , config = config , ** kwargs )
4647
4748
49+ def split_args_kwargs (inputs : Any ) -> Tuple [Tuple [Any , ...], Dict [str , Any ]]:
50+ """Splits into args, kwargs."""
51+ if isinstance (inputs , dict ):
52+ return (), inputs
53+ if isinstance (inputs , tuple ) and len (inputs ) == 2 and isinstance (inputs [1 ], dict ):
54+ return inputs
55+ assert isinstance (inputs , tuple ), f"Unexpected inputs { string_type (inputs )} "
56+ return inputs , {}
57+
58+
59+ def make_inputs (
60+ args : Optional [Tuple [Any , ...]], kwargs : Optional [Dict [str , Any ]] = None
61+ ) -> Any :
62+ """Returns either args, kwargs or both depending on which ones are empty."""
63+ assert args or kwargs , "No input was given."
64+ if not args :
65+ return kwargs
66+ if not kwargs :
67+ return args
68+ return args , kwargs
69+
70+
71+ def filter_inputs (
72+ inputs : Any ,
73+ drop_names : List [str ],
74+ model : Optional [Union [torch .nn .Module , List [str ]]] = None ,
75+ dynamic_shapes : Optional [Any ] = None ,
76+ ):
77+ """
78+ Drops some inputs from the given inputs.
79+ It updates the dynamic shapes as well.
80+ """
81+ args , kwargs = split_args_kwargs (inputs )
82+ set_drop_names = set (drop_names )
83+ kwargs = {k : v for k , v in kwargs .items () if k not in set_drop_names }
84+ dyn = (
85+ {k : v for k , v in dynamic_shapes .items () if k not in set_drop_names }
86+ if dynamic_shapes and isinstance (dynamic_shapes , dict )
87+ else dynamic_shapes
88+ )
89+ if not args or all (i in kwargs for i in set_drop_names ):
90+ return make_inputs (args , kwargs ), dyn
91+ assert model , (
92+ f"we need the model to get the parameter name but model is None, "
93+ f"input_names={ drop_names } and args={ string_type (args )} "
94+ )
95+ pnames = (
96+ list (inspect .signature (model .forward ).parameters )
97+ if isinstance (model , torch .nn .Module )
98+ else model
99+ )
100+ new_args = []
101+ new_ds = []
102+ for i , a in enumerate (args ):
103+ if isinstance (dynamic_shapes , tuple ):
104+ new_ds .append (None if pnames [i ] in set_drop_names else dynamic_shapes [i ])
105+ new_args .append (None if pnames [i ] in set_drop_names else a )
106+ new_inputs = make_inputs (tuple (new_args ), kwargs )
107+ if new_ds :
108+ return new_inputs , tuple (new_ds )
109+ return new_inputs , dyn
110+
111+
48112def validate_model (
49113 model_id : str ,
50114 task : Optional [str ] = None ,
@@ -59,6 +123,7 @@ def validate_model(
59123 quiet : bool = False ,
60124 patch : bool = False ,
61125 dump_folder : Optional [str ] = None ,
126+ drop_inputs : Optional [List [str ]] = None ,
62127) -> Tuple [Dict [str , Union [int , float , str ]], Dict [str , Any ]]:
63128 """
64129 Validates a model.
@@ -80,6 +145,7 @@ def validate_model(
80145 :param quiet: if quiet, catches exception if any issue
81146 :param patch: applies patches before exporting
82147 :param dump_folder: dumps everything in a subfolder of this one
148+ :param drop_inputs: drops this list of inputs (given their names)
83149 :return: two dictionaries, one with some metrics,
84150 another one with whatever the function produces
85151 """
@@ -112,6 +178,27 @@ def validate_model(
112178 else :
113179 data = get_untrained_model_with_inputs (model_id , verbose = verbose , task = task )
114180
181+ if drop_inputs :
182+ if verbose :
183+ print (f"[validate_model] drop inputs { drop_inputs !r} " )
184+ print (f"[validate_model] current inputs: { string_type (data ["inputs" ])} " )
185+ print (
186+ f"[validate_model] current dynnamic_shapes: "
187+ f"{ _ds_clean (data ["dynamic_shapes" ])} "
188+ )
189+ data ["inputs" ], data ["dynamic_shapes" ] = filter_inputs (
190+ data ["inputs" ],
191+ drop_names = drop_inputs ,
192+ model = data ["model" ],
193+ dynamic_shapes = data ["dynamic_shapes" ],
194+ )
195+ if verbose :
196+ print (f"[validate_model] new inputs: { string_type (data ["inputs" ])} " )
197+ print (
198+ f"[validate_model] new dynnamic_shapes: "
199+ f"{ _ds_clean (data ["dynamic_shapes" ])} "
200+ )
201+
115202 if not empty (dtype ):
116203 if isinstance (dtype , str ):
117204 dtype = getattr (torch , dtype )
@@ -338,18 +425,6 @@ def call_exporter(
338425 )
339426
340427
341- def split_args_kwargs (inputs : Any ) -> Tuple [Tuple [Any , ...], Dict [str , Any ]]:
342- """
343- Splits into args, kwargs.
344- """
345- if isinstance (inputs , dict ):
346- return (), inputs
347- if isinstance (inputs , tuple ) and len (inputs ) == 2 and isinstance (inputs [1 ], dict ):
348- return inputs
349- assert isinstance (inputs , tuple ), f"Unexpected inputs { string_type (inputs )} "
350- return inputs , {}
351-
352-
353428def call_torch_export_export (
354429 data : Dict [str , Any ],
355430 exporter : str ,
0 commit comments