@@ -105,6 +105,7 @@ def get_untrained_model_with_inputs(
105105 # outputs even with the same inputs in training mode.
106106 model .eval ()
107107 res = fct (model , config , ** kwargs )
108+
108109 res ["input_kwargs" ] = kwargs
109110 res ["model_kwargs" ] = mkwargs
110111
@@ -118,19 +119,24 @@ def get_untrained_model_with_inputs(
118119 update = {}
119120 for k , v in res .items ():
120121 if k .startswith (("inputs" , "dynamic_shapes" )) and isinstance (v , dict ):
121- update [k ] = filter_out_unexpected_inputs (model , v )
122+ update [k ] = filter_out_unexpected_inputs (model , v , verbose = verbose )
122123 res .update (update )
123124 return res
124125
125126
126- def filter_out_unexpected_inputs (model : torch .nn .Module , kwargs : Dict [str , Any ]):
127+ def filter_out_unexpected_inputs (
128+ model : torch .nn .Module , kwargs : Dict [str , Any ], verbose : int = 0
129+ ):
127130 """
128131 Removes input names in kwargs if no parameter names was found in ``model.forward``.
129132 """
130133 sig = inspect .signature (model .forward )
131134 allowed = set (sig .parameters )
132- kwargs = {k : v for k , v in kwargs .items () if k in allowed }
133- return kwargs
135+ new_kwargs = {k : v for k , v in kwargs .items () if k in allowed }
136+ diff = set (kwargs ) - set (new_kwargs )
137+ if diff and verbose :
138+ print (f"[filter_out_unexpected_inputs] removed { diff } " )
139+ return new_kwargs
134140
135141
136142def compute_model_size (model : torch .nn .Module ) -> Tuple [int , int ]:
0 commit comments