1- from typing import Any , Dict , Optional , Union
1+ import torch
2+ from typing import Any , Dict , Optional , Tuple , Union
3+ import time
4+ from ..helpers import string_type
5+ from .hghub import get_untrained_model_with_inputs
26from .hghub .model_inputs import random_input_kwargs
37
48
9+ def _ds_clean (v ):
10+ return (
11+ str (v )
12+ .replace ("<class 'onnx_diagnostic.torch_models.hghub.model_inputs." , "" )
13+ .replace ("'>" , "" )
14+ .replace ("_DimHint(type=<_DimHintType.DYNAMIC: 3>" , "DYNAMIC" )
15+ .replace ("_DimHint(type=<_DimHintType.AUTO: 3>" , "AUTO" )
16+ )
17+
18+
519def get_inputs_for_task (task : str , config : Optional [Any ] = None ) -> Dict [str , Any ]:
620 """
721 Returns dummy inputs for a specific task.
@@ -18,12 +32,90 @@ def validate_model(
1832 model_id : str ,
1933 task : Optional [str ] = None ,
2034 do_run : bool = False ,
21- do_export : bool = False ,
35+ exporter : Optional [ str ] = None ,
2236 do_same : bool = False ,
2337 verbose : int = 0 ,
24- ) -> Dict [str , Union [int , float , str ]]:
38+ dtype : Optional [Union [str , torch .dtype ]] = None ,
39+ device : Optional [Union [str , torch .device ]] = None ,
40+ trained : bool = False ,
41+ optimization : Optional [str ] = None ,
42+ quiet : bool = False ,
43+ ) -> Tuple [Dict [str , Union [int , float , str ]], Dict [str , Any ]]:
2544 """
2645 Validates a model.
2746
28-
47+ :param model_id: model id to validate
48+ :param task: task used to generate the necessary inputs,
49+ can be left empty to use the default task for this model
50+ if it can be determined
51+ :param do_run: checks the model works with the defined inputs
52+ :param exporter: exporter the model using this exporter,
53+ available list: ``export-strict``, ``export-nostrict``, ``onnx``
54+ :param do_same: checks the discrepancies of the exported model
55+ :param verbose: verbosity level
56+ :param dtype: uses this dtype to check the model
57+ :param device: do the verification on this device
58+ :param trained: use the trained model, not the untrained one
59+ :param optimization: optimization to apply to the exported model,
60+ depend on the the exporter
61+ :param quiet: if quiet, catches exception if any issue
62+ :return: two dictionaries, one with some metrics,
63+ another one with whatever the function produces
2964 """
65+ assert not trained , f"trained={ trained } not supported yet"
66+ assert not dtype , f"dtype={ dtype } not supported yet"
67+ assert not device , f"device={ device } not supported yet"
68+ summary = {}
69+ if verbose :
70+ print (f"[validate_model] validate model id { model_id !r} " )
71+ print ("[validate_model] get dummy inputs..." )
72+ summary ["model_id" ] = model_id
73+ begin = time .perf_counter ()
74+ if quiet :
75+ try :
76+ data = get_untrained_model_with_inputs (model_id , verbose = verbose , task = task )
77+ except Exception as e :
78+ summary ["ERR_create" ] = e
79+ summary ["time_create" ] = time .perf_counter () - begin
80+ return summary , {}
81+ else :
82+ data = get_untrained_model_with_inputs (model_id , verbose = verbose , task = task )
83+ summary ["time_create" ] = time .perf_counter () - begin
84+ for k in ["task" , "size" , "n_weights" ]:
85+ summary [f"model_{ k .replace ('_' ,'' )} " ] = data [k ]
86+ summary ["model_inputs" ] = string_type (data ["inputs" ], with_shape = True )
87+ summary ["model_shapes" ] = _ds_clean (str (data ["dynamic_shapes" ]))
88+ summary ["model_class" ] = data ["model" ].__class__ .__name__
89+ summary ["model_config_class" ] = data ["configuration" ].__class__ .__name__
90+ summary ["model_config" ] = str (data ["configuration" ].to_dict ()).replace (" " , "" )
91+ summary ["model_id" ] = model_id
92+ if verbose :
93+ print (f"[validate_model] task={ data ["task" ]} " )
94+ print (f"[validate_model] size={ data ["size" ]} " )
95+ print (f"[validate_model] n_weights={ data ["n_weights" ]} " )
96+ print (f"[validate_model] n_weights={ data ["n_weights" ]} " )
97+ for k , v in data ["inputs" ].items ():
98+ print (f"[validate_model] +INPUT { k } ={ string_type (v , with_shape = True )} " )
99+ for k , v in data ["dynamic_shapes" ].items ():
100+ print (f"[validate_model] +SHAPE { k } ={ _ds_clean (v )} " )
101+ if do_run :
102+ if verbose :
103+ print ("[validate_model] run the model..." )
104+ begin = time .perf_counter ()
105+ if quiet :
106+ try :
107+ expected = data ["model" ](** data ["inputs" ])
108+ except Exception as e :
109+ summary ["ERR_run" ] = e
110+ summary ["time_run" ] = time .perf_counter () - begin
111+ return summary , data
112+ else :
113+ expected = data ["model" ](** data ["inputs" ])
114+ summary ["time_run" ] = time .perf_counter () - begin
115+ summary ["model_expected" ] = string_type (expected , with_shape = True )
116+ if verbose :
117+ print ("[validate_model] run the model" )
118+ data ["expected" ] = expected
119+ if verbose :
120+ print ("[validate_model] done." )
121+ return summary , data
0 commit comments