Skip to content

Commit 9f9bf0b

Browse files
committed
better stats
1 parent de59acd commit 9f9bf0b

File tree

2 files changed

+55
-6
lines changed

2 files changed

+55
-6
lines changed

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,9 @@ def _get_inputs_gemma3(
183183
)
184184
dummies = {k: v for k, v in dummies.items() if k in shapes}
185185
expected = {"input_ids", "token_type_ids", "position_ids", "cache_position"}
186-
assert expected & set(
187-
dummies
188-
), f"Unable to find expected inputs {expected} in loaded inputs {set(dummines)}"
186+
assert expected & set(dummies), (
187+
f"Unable to find expected inputs {expected} in loaded inputs {set(dummies)}"
188+
)
189189

190190
inputs = dict(
191191
input_ids=input_ids,

onnx_diagnostic/torch_models/validate.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import gc
12
import datetime
23
import inspect
34
import os
@@ -113,6 +114,8 @@ def _make_folder_name(
113114
subfolder: Optional[str] = None,
114115
opset: Optional[int] = None,
115116
drop_inputs: Optional[List[str]] = None,
117+
same_as_pretrained: bool = False,
118+
use_pretrained: bool = False,
116119
) -> str:
117120
"Creates a filename unique based on the given options."
118121
els = [model_id.replace("/", "_")]
@@ -141,6 +144,10 @@ def _make_folder_name(
141144
if drop_inputs:
142145
ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs)
143146
els.append(f"I-{ii.upper()}")
147+
if use_pretrained:
148+
els.append("TRAINED")
149+
elif same_as_pretrained:
150+
els.append("SAMESIZE")
144151
return "-".join(els)
145152

146153

@@ -246,12 +253,21 @@ def _quiet_or_not_quiet(
246253
begin = time.perf_counter()
247254
t = fct()
248255
times.append(time.perf_counter() - begin)
249-
a = np.array(times)
256+
a = np.array(times, dtype=np.float64)
257+
a.sort()
258+
i5 = max(1, a.shape[0] * 5 // 100)
259+
i2 = max(1, a.shape[0] * 2 // 100)
250260
summary[f"time_{suffix}_latency"] = a.mean()
251261
summary[f"time_{suffix}_latency_std"] = a.std()
252262
summary[f"time_{suffix}_latency_min"] = a.min()
253-
summary[f"time_{suffix}_latency_min"] = a.max()
263+
summary[f"time_{suffix}_latency_max"] = a.max()
264+
summary[f"time_{suffix}_latency_098"] = a[-i2]
265+
summary[f"time_{suffix}_latency_095"] = a[-i5]
266+
summary[f"time_{suffix}_latency_005"] = a[i5]
267+
summary[f"time_{suffix}_latency_002"] = a[i2]
254268
summary[f"time_{suffix}_n"] = len(a)
269+
summary[f"time_{suffix}_latency_098"] = a[i2:-i2].mean()
270+
255271
return res
256272

257273

@@ -392,12 +408,14 @@ def validate_model(
392408
if ``runtime == 'ref'``,
393409
``orteval10`` increases the verbosity.
394410
"""
411+
validation_begin = time.perf_counter()
395412
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
396413
model_id,
397414
subfolder,
398415
same_as_pretrained=same_as_pretrained,
399416
use_pretrained=use_pretrained,
400417
)
418+
time_preprocess_model_id = time.perf_counter() - validation_begin
401419
default_patch = dict(patch_transformers=True, patch_diffusers=True, patch=True)
402420
if isinstance(patch, bool):
403421
patch_kwargs = default_patch if patch else dict(patch=False)
@@ -438,6 +456,7 @@ def validate_model(
438456
version_exporter=exporter or "",
439457
version_runtime=runtime,
440458
version_inputs2=inputs2,
459+
time_preprocess_model_id=time_preprocess_model_id,
441460
)
442461
)
443462
if opset:
@@ -454,6 +473,8 @@ def validate_model(
454473
subfolder=subfolder,
455474
opset=opset,
456475
drop_inputs=drop_inputs,
476+
use_pretrained=use_pretrained,
477+
same_as_pretrained=same_as_pretrained,
457478
)
458479
dump_folder = os.path.join(dump_folder, folder_name)
459480
if not os.path.exists(dump_folder):
@@ -486,7 +507,7 @@ def validate_model(
486507
mop = model_options or {}
487508
data = _quiet_or_not_quiet(
488509
quiet,
489-
"create",
510+
"create_torch_model",
490511
summary,
491512
None,
492513
(
@@ -663,19 +684,23 @@ def validate_model(
663684
print("[validate_model] --")
664685

665686
if do_run:
687+
validation_begin = time.perf_counter()
688+
666689
_validate_do_run_model(
667690
data, summary, "inputs", "run", "run_expected", verbose, repeat, warmup, quiet
668691
)
669692
if inputs2:
670693
_validate_do_run_model(
671694
data, summary, "inputs2", "run2", "run_expected2", verbose, 1, 0, quiet
672695
)
696+
summary["time_total_validation_torch"] = time.perf_counter() - validation_begin
673697

674698
if exporter:
675699
print(
676700
f"[validate_model] -- export the model with {exporter!r}, "
677701
f"optimization={optimization!r}"
678702
)
703+
exporter_begin = time.perf_counter()
679704
if patch_kwargs:
680705
if verbose:
681706
print(
@@ -718,7 +743,9 @@ def validate_model(
718743
dump_folder=dump_folder,
719744
output_names=output_names,
720745
)
746+
721747
summary.update(summary_export)
748+
summary["time_total_exporter"] = time.perf_counter() - exporter_begin
722749

723750
dump_stats = None
724751
if dump_folder:
@@ -759,6 +786,8 @@ def validate_model(
759786
data["onnx_filename"] = onnx_filename
760787
summary["time_onnx_save"] = duration
761788
summary.update(compute_statistics(onnx_filename))
789+
del epo
790+
762791
if verbose:
763792
print(f"[validate_model] dumps statistics in {dump_folder!r}...")
764793
dump_stats = os.path.join(dump_folder, f"{folder_name}.stats")
@@ -781,6 +810,20 @@ def validate_model(
781810
return summary, data
782811

783812
if do_run:
813+
# Let's move the model to CPU to make sure it frees GPU memory.
814+
if verbose:
815+
# It does not really work for the time being and the model
816+
# gets loaded twice, one by torch, one by onnxruntime
817+
print("[validation_model] -- delete the model")
818+
for key in ["model", "onnx_program", "config"]:
819+
if key in data:
820+
del data[key]
821+
if "cuda" in device.lower():
822+
torch.cuda.empty_cache()
823+
gc.collect()
824+
print("[validation_model] -- done")
825+
826+
validation_begin = time.perf_counter()
784827
summary_valid, data = validate_onnx_model(
785828
data=data,
786829
quiet=quiet,
@@ -792,6 +835,7 @@ def validate_model(
792835
ort_logs=ort_logs,
793836
)
794837
summary.update(summary_valid)
838+
summary["time_total_validation_onnx"] = time.perf_counter() - validation_begin
795839

796840
if ortfusiontype and "onnx_filename" in data:
797841
assert (
@@ -855,10 +899,12 @@ def validate_model(
855899
summary.update(summary_valid)
856900

857901
_compute_final_statistics(summary)
902+
summary["time_total"] = time.perf_counter() - validation_begin
858903

859904
if verbose:
860905
print("[validate_model] -- done (final)")
861906
if dump_stats:
907+
# Dumps again the statistics.
862908
with open(dump_stats, "w") as f:
863909
for k, v in sorted(summary.items()):
864910
f.write(f":{k}:{v};\n")
@@ -2020,4 +2066,7 @@ def _compute_final_statistics(summary: Dict[str, Any]):
20202066
stats["stat_estimated_speedup_ort"] = (
20212067
summary["time_run_latency"] / summary["time_run_onnx_ort_latency"]
20222068
)
2069+
stats["stat_estimated_speedup_ort_098"] = (
2070+
summary["time_run_latency_098"] / summary["time_run_onnx_ort_latency_098"]
2071+
)
20232072
summary.update(stats)

0 commit comments

Comments
 (0)