|
1 | 1 | import datetime |
2 | 2 | import inspect |
3 | 3 | import os |
4 | | -from typing import Any, Dict, List, Optional, Tuple, Union |
| 4 | +from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
5 | 5 | import time |
6 | 6 | import onnx |
7 | 7 | import torch |
@@ -180,6 +180,23 @@ def version_summary() -> Dict[str, Union[int, float, str]]: |
180 | 180 | return summary |
181 | 181 |
|
182 | 182 |
|
| 183 | +def _quiet_or_not_quiet( |
| 184 | + quiet: bool, suffix: str, summary: Dict[str, Any], data: Dict[str, Any], fct: Callable |
| 185 | +) -> Any: |
| 186 | + begin = time.perf_counter() |
| 187 | + if quiet: |
| 188 | + try: |
| 189 | + return fct() |
| 190 | + except Exception as e: |
| 191 | + summary[f"ERR_{suffix}"] = str(e) |
| 192 | + data[f"ERR_{suffix}"] = e |
| 193 | + summary[f"time_{suffix}"] = time.perf_counter() - begin |
| 194 | + return summary, {} |
| 195 | + res = fct() |
| 196 | + summary[f"time_{suffix}"] = time.perf_counter() - begin |
| 197 | + return res |
| 198 | + |
| 199 | + |
183 | 200 | def validate_model( |
184 | 201 | model_id: str, |
185 | 202 | task: Optional[str] = None, |
@@ -266,21 +283,19 @@ def validate_model( |
266 | 283 | print("[validate_model] get dummy inputs...") |
267 | 284 | summary["model_id"] = model_id |
268 | 285 |
|
269 | | - begin = time.perf_counter() |
270 | | - if quiet: |
271 | | - try: |
272 | | - data = get_untrained_model_with_inputs( |
273 | | - model_id, verbose=verbose, task=task, same_as_pretrained=trained |
| 286 | + data = _quiet_or_not_quiet( |
| 287 | + quiet, |
| 288 | + "create", |
| 289 | + summary, |
| 290 | + None, |
| 291 | + ( |
| 292 | + lambda mid=model_id, v=verbose, task=task, tr=trained: ( |
| 293 | + get_untrained_model_with_inputs( |
| 294 | + mid, verbose=v, task=task, same_as_pretrained=tr |
| 295 | + ) |
274 | 296 | ) |
275 | | - except Exception as e: |
276 | | - summary["ERR_create"] = str(e) |
277 | | - data["ERR_create"] = e |
278 | | - summary["time_create"] = time.perf_counter() - begin |
279 | | - return summary, {} |
280 | | - else: |
281 | | - data = get_untrained_model_with_inputs( |
282 | | - model_id, verbose=verbose, task=task, same_as_pretrained=trained |
283 | | - ) |
| 297 | + ), |
| 298 | + ) |
284 | 299 |
|
285 | 300 | if drop_inputs: |
286 | 301 | if verbose: |
@@ -316,7 +331,6 @@ def validate_model( |
316 | 331 | data["inputs"] = to_any(data["inputs"], device) # type: ignore |
317 | 332 | summary["model_device"] = str(device) |
318 | 333 |
|
319 | | - summary["time_create"] = time.perf_counter() - begin |
320 | 334 | for k in ["task", "size", "n_weights"]: |
321 | 335 | summary[f"model_{k.replace('_','')}"] = data[k] |
322 | 336 | summary["model_inputs"] = string_type(data["inputs"], with_shape=True) |
|
0 commit comments