1+ import os
12from typing import Any , Dict , Optional , Tuple , Union
23import time
34import torch
1112def empty (value : Any ) -> bool :
1213 """Tells if the value is empty."""
1314 if isinstance (value , (str , list , dict , tuple , set )):
14- return value
15+ return bool ( value )
1516 if value is None :
1617 return True
1718 return False
@@ -22,8 +23,8 @@ def _ds_clean(v):
2223 str (v )
2324 .replace ("<class 'onnx_diagnostic.torch_models.hghub.model_inputs." , "" )
2425 .replace ("'>" , "" )
25- .replace ("_DimHint(type=<_DimHintType.DYNAMIC: 3>" , "DYNAMIC" )
26- .replace ("_DimHint(type=<_DimHintType.AUTO: 3>" , "AUTO" )
26+ .replace ("_DimHint(type=<_DimHintType.DYNAMIC: 3>) " , "DYNAMIC" )
27+ .replace ("_DimHint(type=<_DimHintType.AUTO: 3>) " , "AUTO" )
2728 )
2829
2930
@@ -52,6 +53,7 @@ def validate_model(
5253 optimization : Optional [str ] = None ,
5354 quiet : bool = False ,
5455 patch : bool = False ,
56+ dump_folder : Optional [str ] = None ,
5557) -> Tuple [Dict [str , Union [int , float , str ]], Dict [str , Any ]]:
5658 """
5759 Validates a model.
@@ -72,11 +74,23 @@ def validate_model(
7274 depend on the the exporter
7375 :param quiet: if quiet, catches exception if any issue
7476 :param patch: applies patches before exporting
77+ :param dump_folder: dumps everything in a subfolder of this one
7578 :return: two dictionaries, one with some metrics,
7679 another one with whatever the function produces
7780 """
7881 assert not trained , f"trained={ trained } not supported yet"
7982 summary : Dict [str , Union [int , float , str ]] = {}
83+ if dump_folder :
84+ folder_name = f"{ model_id .replace ('/' ,'-' )} -{ exporter } -{ optimization or '' } "
85+ dump_folder = os .path .join (dump_folder , folder_name )
86+ if not os .path .exists (dump_folder ):
87+ os .makedirs (dump_folder )
88+ summary ["dump_folder" ] = dump_folder
89+ summary ["dump_folder_name" ] = folder_name
90+ if verbose :
91+ print (f"[validate_model] dump into { folder_name !r} " )
92+ else :
93+ folder_name = None
8094 if verbose :
8195 print (f"[validate_model] validate model id { model_id !r} " )
8296 print ("[validate_model] get dummy inputs..." )
@@ -98,15 +112,15 @@ def validate_model(
98112 dtype = getattr (torch , dtype )
99113 if verbose :
100114 print (f"[validate_model] dtype conversion to { dtype } " )
101- data ["model" ] = to_any (data ["model" ], dtype )
102- data ["inputs" ] = to_any (data ["inputs" ], dtype )
115+ data ["model" ] = to_any (data ["model" ], dtype ) # type: ignore
116+ data ["inputs" ] = to_any (data ["inputs" ], dtype ) # type: ignore
103117 summary ["model_dtype" ] = str (dtype )
104118
105119 if not empty (device ):
106120 if verbose :
107121 print (f"[validate_model] device conversion to { device } " )
108- data ["model" ] = to_any (data ["model" ], device )
109- data ["inputs" ] = to_any (data ["inputs" ], device )
122+ data ["model" ] = to_any (data ["model" ], device ) # type: ignore
123+ data ["inputs" ] = to_any (data ["inputs" ], device ) # type: ignore
110124 summary ["model_device" ] = str (device )
111125
112126 summary ["time_create" ] = time .perf_counter () - begin
@@ -156,6 +170,7 @@ def validate_model(
156170 f"before: { hash_inputs } \n "
157171 f" after: { string_type (data ["inputs" ], with_shape = True )} "
158172 )
173+
159174 if exporter :
160175 print (
161176 f"[validate_model] export the model with { exporter !r} , "
@@ -164,10 +179,10 @@ def validate_model(
164179 if patch :
165180 if verbose :
166181 print ("[validate_model] applies patches before exporting" )
167- with bypass_export_some_errors (
182+ with bypass_export_some_errors ( # type: ignore
168183 patch_transformers = True , verbose = max (0 , verbose - 1 )
169184 ) as modificator :
170- data ["inputs_export" ] = modificator (data ["inputs" ])
185+ data ["inputs_export" ] = modificator (data ["inputs" ]) # type: ignore
171186
172187 if do_run :
173188 # We run a second time the model to check the patch did not
@@ -230,6 +245,25 @@ def validate_model(
230245 )
231246 summary .update (summary_export )
232247
248+ if dump_folder :
249+ if "exported_program" in data :
250+ ep = data ["exported_program" ]
251+ if verbose :
252+ print (f"[validate_model] dumps exported program in { dump_folder !r} ..." )
253+ with open (os .path .join (dump_folder , f"{ folder_name } .ep" ), "w" ) as f :
254+ f .write (str (ep ))
255+ with open (os .path .join (dump_folder , f"{ folder_name } .graph" ), "w" ) as f :
256+ f .write (str (ep .graph ))
257+ if verbose :
258+ print ("[validate_model] done (dump ep)" )
259+ if verbose :
260+ print (f"[validate_model] dumps statistics in { dump_folder !r} ..." )
261+ with open (os .path .join (dump_folder , f"{ folder_name } .stats" ), "w" ) as f :
262+ for k , v in sorted (summary .items ()):
263+ f .write (f":{ k } :{ v } ;\n " )
264+ if verbose :
265+ print ("[validate_model] done (dump)" )
266+
233267 if verbose :
234268 print ("[validate_model] done (final)" )
235269 return summary , data
@@ -281,7 +315,7 @@ def split_args_kwargs(inputs: Any) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
281315 return (), inputs
282316 if isinstance (inputs , tuple ) and len (inputs ) == 2 and isinstance (inputs [1 ], dict ):
283317 return inputs
284- assert isinstance (inputs , tuple ), f"Unexpectd inputs { string_type (inputs )} "
318+ assert isinstance (inputs , tuple ), f"Unexpected inputs { string_type (inputs )} "
285319 return inputs , {}
286320
287321
@@ -309,7 +343,7 @@ def call_torch_export_export(
309343 """
310344 assert "model" in data , f"model is missing from data: { sorted (data )} "
311345 assert "inputs_export" in data , f"inputs_export is missing from data: { sorted (data )} "
312- summary = {}
346+ summary : Dict [ str , Union [ str , int , float ]] = {}
313347 strict = "nostrict" not in exporter
314348 args , kwargs = split_args_kwargs (data ["inputs_export" ])
315349 ds = data .get ("dynamic_shapes" , None )
@@ -323,7 +357,7 @@ def call_torch_export_export(
323357 print (f"[call_torch_export_export] dynamic_shapes={ _ds_clean (ds )} " )
324358 print ("[call_torch_export_export] export..." )
325359 summary ["export_exporter" ] = exporter
326- summary ["export_optimization" ] = optimization
360+ summary ["export_optimization" ] = optimization or ""
327361 summary ["export_strict" ] = strict
328362 summary ["export_args" ] = string_type (args , with_shape = True )
329363 summary ["export_kwargs" ] = string_type (kwargs , with_shape = True )
0 commit comments