Skip to content

Commit ead5d01

Browse files
committed
fix total time metrics
1 parent 1006017 commit ead5d01

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

onnx_diagnostic/_command_lines_parser.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,12 +400,17 @@ def get_parser_validate() -> ArgumentParser:
400400
401401
position_ids is usually not needed, they can be removed by adding:
402402
403-
--drop position_ids
403+
--drop position_ids
404404
405405
The behaviour may be modified compare the original configuration,
406406
the following argument can be rope_scaling to dynamic:
407407
408-
--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\""
408+
--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\""
409+
410+
You can profile the command line by running:
411+
412+
pyinstrument -m onnx_diagnostic validate ...
413+
pyinstrument -r html -o profile.html -m onnx_diagnostic validate ...
409414
"""
410415
),
411416
formatter_class=RawTextHelpFormatter,

onnx_diagnostic/torch_models/validate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -419,14 +419,14 @@ def validate_model(
419419
such as ``input_empty_cache``
420420
which refers to a set of inputs using an empty cache.
421421
"""
422-
validation_begin = time.perf_counter()
422+
main_validation_begin = time.perf_counter()
423423
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
424424
model_id,
425425
subfolder,
426426
same_as_pretrained=same_as_pretrained,
427427
use_pretrained=use_pretrained,
428428
)
429-
time_preprocess_model_id = time.perf_counter() - validation_begin
429+
time_preprocess_model_id = time.perf_counter() - main_validation_begin
430430
default_patch = dict(patch_transformers=True, patch_diffusers=True, patch=True)
431431
if isinstance(patch, bool):
432432
patch_kwargs = default_patch if patch else dict(patch=False)
@@ -921,7 +921,7 @@ def validate_model(
921921
summary.update(summary_valid)
922922

923923
_compute_final_statistics(summary)
924-
summary["time_total"] = time.perf_counter() - validation_begin
924+
summary["time_total"] = time.perf_counter() - main_validation_begin
925925

926926
if verbose:
927927
print("[validate_model] -- done (final)")

0 commit comments

Comments
 (0)