Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
65f1ca0
draft improve llm random inputs
titaiwangms Sep 19, 2025
94a9b10
Merge branch 'main' into titaiwang/fix_modelbuilder_discrepancy
titaiwangms Sep 19, 2025
5c55755
resolve conflicts
titaiwangms Sep 22, 2025
aa8b0f8
revert unintentional changes
titaiwangms Sep 22, 2025
cd1a19f
add comments
titaiwangms Sep 22, 2025
f413ea7
draft-patched_sdpa
titaiwangms Sep 23, 2025
8bd2fa1
set is_causal
titaiwangms Sep 23, 2025
d65493b
support prompt processing and token generation
titaiwangms Sep 23, 2025
df3ad9b
fix mypy
titaiwangms Sep 23, 2025
d817f19
fix draft
titaiwangms Sep 24, 2025
f15360e
fix static cahce
titaiwangms Sep 24, 2025
6fea147
fix torch.export 0/1 specializing
titaiwangms Sep 25, 2025
9568e18
add a test
titaiwangms Sep 26, 2025
393d391
fix CIs - 4.48.3
titaiwangms Sep 26, 2025
d527851
fail fast
titaiwangms Sep 26, 2025
31dfd97
disable ort tests
titaiwangms Sep 26, 2025
77939dd
fix dynamic shape
titaiwangms Sep 26, 2025
5749221
modelbuilder test is duplicated
titaiwangms Sep 27, 2025
21355b5
broken api from tr main
titaiwangms Sep 27, 2025
358159a
Merge branch 'main' into titaiwang/fix_modelbuilder_discrepancy
titaiwangms Sep 29, 2025
dc11cfa
fix a test
titaiwangms Sep 29, 2025
3dd887a
fix patch
titaiwangms Sep 30, 2025
fd455e1
use multi-turn batch=1 to export
titaiwangms Oct 6, 2025
1f4ca3a
disable modeling_utils rewrite
titaiwangms Oct 6, 2025
dc02405
bring back inputs2
titaiwangms Oct 6, 2025
28cd455
fix CI
titaiwangms Oct 6, 2025
2badb72
enable sdpa rewritten patch
titaiwangms Oct 6, 2025
3430eb5
only examine attention_mask shape when it's available
titaiwangms Oct 6, 2025
ddbbdb3
fix summary naming
titaiwangms Oct 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion _doc/examples/plot_export_tiny_phi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@
# Shapes may not match on the second call with the modified inputs.


with torch_export_patches(patch_transformers=True):
with (
torch_export_patches(patch_transformers=True),
torch.fx.experimental._config.patch(backed_size_oblivious=True),
):

# Two unnecessary steps but useful in case of an error
# We check the cache is registered.
Expand Down
18 changes: 0 additions & 18 deletions _unittests/ut_export/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, CacheKeyValue
from onnx_diagnostic.export import ModelInputs, CoupleInputsDynamicShapes
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs


class TestDynamicShapes(ExtTestCase):
Expand Down Expand Up @@ -848,23 +847,6 @@ def test_dynamic_cache_replace_by_string(self):
as_string,
)

@requires_transformers("4.51")
def test_unbatch_inputs(self):
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
cpl = CoupleInputsDynamicShapes(
None, data["inputs"], dynamic_shapes=data["dynamic_shapes"]
)
new_dims = cpl.change_dynamic_dimensions(
desired_values=dict(batch=1), only_desired=True
)
s = self.string_type(new_dims, with_shape=True)
self.assertEqual(
"dict(input_ids:T7s1x3,attention_mask:T7s1x33,position_ids:T7s1x3,"
"past_key_values:DynamicCache("
"key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]))",
s,
)


if __name__ == "__main__":
unittest.main(verbosity=2)
12 changes: 7 additions & 5 deletions _unittests/ut_export/test_shape_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,17 +168,19 @@ def test_guess_dynamic_shapes_from_inputs(self):
guessed = guess_dynamic_shapes_from_inputs(
[data["inputs"], data["inputs2"]], auto="dd"
)
# TODO(xadupre): guess_dynamic_shapes_from_inputs does not support well when
# there are dim==1
self.assertEqual(
(
(),
{
"attention_mask": {0: "dd_0I0", 1: "dd_0I1"},
"input_ids": {0: "dd_1I0", 1: "dd_1I1"},
"attention_mask": {1: "dd_0I1"},
"input_ids": {1: "dd_1I1"},
"past_key_values": [
[{0: "dd_2I_0o_0l0", 2: "dd_2I_0o_0l2"}],
[{0: "dd_2I_1o_0l0", 2: "dd_2I_1o_0l2"}],
[{2: "dd_2I_0o_0l2"}],
[{2: "dd_2I_1o_0l2"}],
],
"position_ids": {0: "dd_3I0", 1: "dd_3I1"},
"position_ids": {1: "dd_3I1"},
},
),
guessed,
Expand Down
70 changes: 0 additions & 70 deletions _unittests/ut_helpers/test_model_builder_helper.py

This file was deleted.

11 changes: 7 additions & 4 deletions _unittests/ut_tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_text2text_generation(self):
)

@hide_stdout()
@requires_transformers("4.55.4") # modeling_units
def test_text_generation(self):
mid = "arnir0/Tiny-LLM"
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
Expand All @@ -43,17 +44,19 @@ def test_text_generation(self):
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**inputs)
model(**data["inputs2"])
with torch_export_patches(patch_transformers=True, verbose=10):
with (
torch_export_patches(patch_transformers=True, verbose=10),
torch.fx.experimental._config.patch(backed_size_oblivious=True),
):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)

def test_text_generation_empty_cache(self):
def test_text_generation_prompt_processing(self):
mid = "arnir0/Tiny-LLM"
data = get_untrained_model_with_inputs(mid, add_second_input=True)
model, inputs = data["model"], data["inputs"]
self.assertIn("inputs_empty_cache", data)
empty_inputs = torch_deepcopy(data["inputs_empty_cache"])
empty_inputs = torch_deepcopy(data["inputs2"])
model(**torch_deepcopy(empty_inputs))
expected = model(**torch_deepcopy(inputs))
self.assertEqual(
Expand Down
39 changes: 39 additions & 0 deletions _unittests/ut_tasks/test_tasks_text_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import unittest
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
hide_stdout,
requires_transformers,
requires_torch,
)
from onnx_diagnostic.torch_models.validate import validate_model


class TestTasksMaskGeneration(ExtTestCase):
@hide_stdout()
@requires_transformers("4.53")
@requires_torch("2.7.99")
def test_text_generation(self):
mid = "microsoft/phi-2"
summary, data = validate_model(
mid,
do_run=True,
verbose=10,
exporter="onnx-dynamo",
dump_folder="dump_test/microsoft_phi-2",
inputs2=True,
patch=True,
)
self.assertIsInstance(summary, dict)
# multi-turn conversation
self.assertLess(summary["disc_onnx_ort_run_abs"], 3e-2)
# prompt processing
self.assertLess(summary["disc_onnx_ort_run2_abs"], 3e-2)
# token generation
self.assertLess(summary["disc_onnx_ort_run3_abs"], 3e-2)
self.assertIsInstance(data, dict)
onnx_filename = data["onnx_filename"]
self.assertExists(onnx_filename)


if __name__ == "__main__":
unittest.main(verbosity=2)
10 changes: 8 additions & 2 deletions _unittests/ut_torch_export_patches/test_dynamic_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,10 @@ def test_phi2_export_module(self):
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
)

with torch_export_patches(patch_transformers=True):
with (
torch_export_patches(patch_transformers=True),
torch.fx.experimental._config.patch(backed_size_oblivious=True),
):
ep = torch.export.export(
model,
(),
Expand Down Expand Up @@ -346,7 +349,10 @@ def test_phi2_export_interpreter(self):
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
)

with torch_export_patches(patch_transformers=True, verbose=1):
with (
torch_export_patches(patch_transformers=True, verbose=1),
torch.fx.experimental._config.patch(backed_size_oblivious=True),
):
if masking_utils is not None:
self.assertEqual(
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"],
Expand Down
6 changes: 3 additions & 3 deletions _unittests/ut_torch_export_patches/test_patch_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,10 @@ def _batch1(t):
if got is not None:
self.assertEqualArrayAny(expected, got)

if "inputs_empty_cache" not in data:
# inputs2 is prompt_processing (no cache)
if "inputs2" not in data:
return

export_inputs = data["inputs_empty_cache"]
export_inputs = data["inputs2"]

# with self.subTest(input="cache0", backed_size_oblivious=False):
# with torch_export_patches(patch_transformers=True):
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_torch_models/test_validate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_validate_tiny_llms_bfloat16(self):
@requires_transformers("4.53")
@requires_torch("2.7.99")
@requires_experimental()
@hide_stdout()
# @hide_stdout()
def test_validate_microsoft_phi4_reasoning(self):
# python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning
# --run -v 1 --export custom -o dump_test --no-quiet --device cuda --patch
Expand Down
4 changes: 3 additions & 1 deletion _unittests/ut_torch_models/test_validate_whole_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def test_g_validate_model_onnx_dynamo_os_ort(self):
self.assertExists(onnx_filename)

@requires_torch("2.7")
@requires_transformers("4.55.4") # modeling_units
@hide_stdout()
@ignore_warnings(FutureWarning)
@requires_experimental()
Expand Down Expand Up @@ -147,6 +148,7 @@ def test_i_validate_model_custom(self):
)

@requires_torch("2.7")
@requires_transformers("4.55.4") # modeling_units
@hide_stdout()
@ignore_warnings(FutureWarning)
@requires_experimental()
Expand Down Expand Up @@ -227,7 +229,7 @@ def test_m_validate_model_vit_model(self):
self.assertIsInstance(summary, dict)
self.assertIsInstance(data, dict)
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-3)
self.assertLess(summary["disc_onnx_ort_run22_abs"], 1e-3)
self.assertLess(summary["disc_onnx_ort_run2_abs"], 1e-3)
self.assertEqual("dict(pixel_values:A1s2x3x30x30)", summary["run_feeds_inputs"])
self.assertEqual("dict(pixel_values:A1s3x3x31x31)", summary["run_feeds_inputs2"])
self.assertEqual("#1[A1s2x2]", summary["run_output_inputs"])
Expand Down
30 changes: 0 additions & 30 deletions onnx_diagnostic/helpers/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,36 +1063,6 @@ def max_diff(
print(f"[max_diff] to_tuple2: {string_type(expected)} ? {string_type(got)}")
return max_diff(expected, got.to_tuple(), debug_info=_debug("to_tuple2"), **_dkws)

if isinstance(got, (list, tuple)):
if len(got) != 1:
if verbose >= 6:
print(
f"[max_diff] list,tuple,2: {string_type(expected)} "
f"? {string_type(got)}"
)
if verbose > 2:
import torch

print(
f"[max_diff] (a) inf because len(expected)={len(expected)}!=1, "
f"len(got)={len(got)}, level={level}, _index={_index}"
)
for i, (a, b) in enumerate(zip(expected, got)):
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
print(
f" i={i} expected {a.dtype}:{a.shape}, "
f"has {b.dtype}:{b.shape}, _index={_index}"
)
else:
print(
f" i={i} a is {type(a)}, "
f"b is {type(b)}, _index={_index}"
)
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
if verbose >= 6:
print(f"[max_diff] list,tuple,1: {string_type(expected)} ? {string_type(got)}")
return max_diff(expected, got[0], debug_info=_debug("lt1"), **_dkws)

if isinstance(expected, (tuple, list)):
if verbose >= 6:
print(f"[max_diff] list,tuple,0: {string_type(expected)} ? {string_type(got)}")
Expand Down
Loading
Loading