diff --git a/_unittests/ut_export/test_simple_cases.py b/_unittests/ut_export/test_simple_cases.py new file mode 100644 index 00000000..22edb588 --- /dev/null +++ b/_unittests/ut_export/test_simple_cases.py @@ -0,0 +1,33 @@ +import unittest +import torch +from onnx_diagnostic.ext_test_case import ExtTestCase +from onnx_diagnostic.reference import ExtendedReferenceEvaluator + + +class TestDynamicShapes(ExtTestCase): + def test_getitem_index_put1(self): + class Model(torch.nn.Module): + def forward(self, x, value): + x = x.clone() + x[:, :, :, : value.shape[-1]] = value + return x + + inputs = (torch.randn(2, 2, 3, 4), torch.randn(2, 2, 3, 3)) + model = Model() + expected = model(*inputs) + + onx = self.to_onnx(model, inputs, dynamic_shapes=({3: "M"}, {3: "N"})) + self.dump_onnx("test_getitem_index_put1.onnx", onx) + feeds = dict(zip(["x", "value"], [x.detach().cpu().numpy() for x in inputs])) + ref = ExtendedReferenceEvaluator(onx, verbose=0) + got = ref.run(None, feeds)[0] + self.assertEqualArray(expected, got, atol=1e-5) + sess = self.ort().InferenceSession( + onx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + got = sess.run(None, feeds)[0] + self.assertEqualArray(expected, got, atol=1e-5) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_models/test_tiny_llms.py b/_unittests/ut_torch_models/test_tiny_llms.py index 2ac4dfb3..cadbdb5b 100644 --- a/_unittests/ut_torch_models/test_tiny_llms.py +++ b/_unittests/ut_torch_models/test_tiny_llms.py @@ -56,7 +56,7 @@ def test_tiny_llm_export_static(self): self.assertEqual( {"attention_mask", "past_key_values", "input_ids", "cache_position"}, set(inputs) ) - with torch_export_patches(patch_transformers=True, stop_if_static=1): + with torch_export_patches(patch_transformers=True, stop_if_static=0): ep = torch.export.export( model, (), diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py index 3179969a..4cc3c4d3 100644 --- a/_unittests/ut_xrun_doc/test_documentation_examples.py +++ b/_unittests/ut_xrun_doc/test_documentation_examples.py @@ -109,6 +109,13 @@ def add_test_methods(cls): ): reason = "torch<2.8" + if ( + not reason + and name in {"plot_dump_intermediate_results.py"} + and not has_torch("2.9.1") + ): + reason = "unstable, let's wait for the next version" + if reason: @unittest.skip(reason) diff --git a/onnx_diagnostic/ext_test_case.py b/onnx_diagnostic/ext_test_case.py index 242b0ed4..b9129559 100644 --- a/onnx_diagnostic/ext_test_case.py +++ b/onnx_diagnostic/ext_test_case.py @@ -756,6 +756,18 @@ def todo(cls, f: Callable, msg: str): "Adds a todo printed when all test are run." cls._todos.append((f, msg)) + @classmethod + def ort(cls): + import onnxruntime + + return onnxruntime + + @classmethod + def to_onnx(self, *args, **kwargs): + from experimental_experiment.torch_interpreter import to_onnx + + return to_onnx(*args, **kwargs) + def print_model(self, model: "ModelProto"): # noqa: F821 "Prints a ModelProto" from onnx_diagnostic.helpers.onnx_helper import pretty_onnx diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 759fce17..99413691 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -181,7 +181,8 @@ def make_static_cache( torch.randn(bsize, nheads, slen, dim), ) for i in range(n_layers) - ] + ], + max_cache_len=10, ) print(string_type(past_key_values, with_shape=True)) """ diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index a960505b..873fa4fc 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -176,8 +176,10 @@ def get_inputs( "attention_mask": {0: batch, 2: "seq"}, "cache_position": {0: "seq"}, "past_key_values": [ - [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], - [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], + # [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], + # [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], + [{0: batch} for _ in range(num_hidden_layers)], + [{0: batch} for _ in range(num_hidden_layers)], ], } inputs = dict(