Skip to content

Commit ddbbdb3

Browse files
committed
fix summary naming
1 parent 3430eb5 commit ddbbdb3

File tree

6 files changed

+8
-9
lines changed

6 files changed

+8
-9
lines changed

_unittests/ut_tasks/test_tasks_text_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ def test_text_generation(self):
2424
patch=True,
2525
)
2626
self.assertIsInstance(summary, dict)
27-
# token generation
27+
# multi-turn conversation
2828
self.assertLess(summary["disc_onnx_ort_run_abs"], 3e-2)
2929
# prompt processing
3030
self.assertLess(summary["disc_onnx_ort_run2_abs"], 3e-2)
31-
# multi-turn conversation
31+
# token generation
3232
self.assertLess(summary["disc_onnx_ort_run3_abs"], 3e-2)
3333
self.assertIsInstance(data, dict)
3434
onnx_filename = data["onnx_filename"]

_unittests/ut_torch_models/test_validate_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_validate_tiny_llms_bfloat16(self):
4141
@requires_transformers("4.53")
4242
@requires_torch("2.7.99")
4343
@requires_experimental()
44-
@hide_stdout()
44+
# @hide_stdout()
4545
def test_validate_microsoft_phi4_reasoning(self):
4646
# python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning
4747
# --run -v 1 --export custom -o dump_test --no-quiet --device cuda --patch

_unittests/ut_torch_models/test_validate_whole_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def test_m_validate_model_vit_model(self):
229229
self.assertIsInstance(summary, dict)
230230
self.assertIsInstance(data, dict)
231231
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-3)
232-
self.assertLess(summary["disc_onnx_ort_run22_abs"], 1e-3)
232+
self.assertLess(summary["disc_onnx_ort_run2_abs"], 1e-3)
233233
self.assertEqual("dict(pixel_values:A1s2x3x30x30)", summary["run_feeds_inputs"])
234234
self.assertEqual("dict(pixel_values:A1s3x3x31x31)", summary["run_feeds_inputs2"])
235235
self.assertEqual("#1[A1s2x2]", summary["run_output_inputs"])

onnx_diagnostic/tasks/text_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def get_inputs(
281281
)["inputs"]
282282
# Token generation (decode) testing
283283
# NOTE: We have to export model in decode mode to preserve the cache
284-
res["token_generation"] = get_inputs(
284+
res["inputs3"] = get_inputs(
285285
model=model,
286286
config=config,
287287
dummy_max_token_id=dummy_max_token_id,

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1950,7 +1950,6 @@ def patched_sdpa_attention_forward(
19501950
torch._check(
19511951
attention_mask.shape[3] == key.shape[2] if attention_mask is not None else True
19521952
)
1953-
19541953
attn_output = torch.nn.functional.scaled_dot_product_attention(
19551954
query,
19561955
key,

onnx_diagnostic/torch_models/validate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -683,8 +683,8 @@ def validate_model(
683683
data,
684684
summary,
685685
k,
686-
f"run2{k[6:]}",
687-
f"run_expected2{k[6:]}",
686+
f"run{k[6:]}",
687+
f"run_expected{k[6:]}",
688688
verbose,
689689
1,
690690
0,
@@ -1431,7 +1431,7 @@ def _mk(key, flavour=flavour):
14311431

14321432
keys = [("inputs", "run_expected", "")]
14331433
if second_input_keys:
1434-
keys.extend([(k, f"run_expected2{k[6:]}", f"2{k[6:]}") for k in second_input_keys])
1434+
keys.extend([(k, f"run_expected{k[6:]}", f"{k[6:]}") for k in second_input_keys])
14351435
for k_input, k_expected, suffix in keys:
14361436
# make_feeds
14371437
assert k_input in data, f"Unable to find {k_input!r} in {sorted(data)}"

0 commit comments

Comments
 (0)