diff --git a/_doc/examples/plot_export_tiny_phi2.py b/_doc/examples/plot_export_tiny_phi2.py index b4979334..aa09d2c8 100644 --- a/_doc/examples/plot_export_tiny_phi2.py +++ b/_doc/examples/plot_export_tiny_phi2.py @@ -88,7 +88,10 @@ # Shapes may not match on the second call with the modified inputs. -with torch_export_patches(patch_transformers=True): +with ( + torch_export_patches(patch_transformers=True), + torch.fx.experimental._config.patch(backed_size_oblivious=True), +): # Two unnecessary steps but useful in case of an error # We check the cache is registered. diff --git a/_unittests/ut_export/test_dynamic_shapes.py b/_unittests/ut_export/test_dynamic_shapes.py index 52b758f1..f7dfd65e 100644 --- a/_unittests/ut_export/test_dynamic_shapes.py +++ b/_unittests/ut_export/test_dynamic_shapes.py @@ -6,7 +6,6 @@ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, CacheKeyValue from onnx_diagnostic.export import ModelInputs, CoupleInputsDynamicShapes from onnx_diagnostic.torch_export_patches import torch_export_patches -from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs class TestDynamicShapes(ExtTestCase): @@ -848,23 +847,6 @@ def test_dynamic_cache_replace_by_string(self): as_string, ) - @requires_transformers("4.51") - def test_unbatch_inputs(self): - data = get_untrained_model_with_inputs("arnir0/Tiny-LLM") - cpl = CoupleInputsDynamicShapes( - None, data["inputs"], dynamic_shapes=data["dynamic_shapes"] - ) - new_dims = cpl.change_dynamic_dimensions( - desired_values=dict(batch=1), only_desired=True - ) - s = self.string_type(new_dims, with_shape=True) - self.assertEqual( - "dict(input_ids:T7s1x3,attention_mask:T7s1x33,position_ids:T7s1x3," - "past_key_values:DynamicCache(" - "key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]))", - s, - ) - if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_export/test_shape_helper.py b/_unittests/ut_export/test_shape_helper.py index 2917fa7b..ef18248f 100644 --- a/_unittests/ut_export/test_shape_helper.py +++ b/_unittests/ut_export/test_shape_helper.py @@ -168,17 +168,19 @@ def test_guess_dynamic_shapes_from_inputs(self): guessed = guess_dynamic_shapes_from_inputs( [data["inputs"], data["inputs2"]], auto="dd" ) + # TODO(xadupre): guess_dynamic_shapes_from_inputs does not support well when + # there are dim==1 self.assertEqual( ( (), { - "attention_mask": {0: "dd_0I0", 1: "dd_0I1"}, - "input_ids": {0: "dd_1I0", 1: "dd_1I1"}, + "attention_mask": {1: "dd_0I1"}, + "input_ids": {1: "dd_1I1"}, "past_key_values": [ - [{0: "dd_2I_0o_0l0", 2: "dd_2I_0o_0l2"}], - [{0: "dd_2I_1o_0l0", 2: "dd_2I_1o_0l2"}], + [{2: "dd_2I_0o_0l2"}], + [{2: "dd_2I_1o_0l2"}], ], - "position_ids": {0: "dd_3I0", 1: "dd_3I1"}, + "position_ids": {1: "dd_3I1"}, }, ), guessed, diff --git a/_unittests/ut_helpers/test_model_builder_helper.py b/_unittests/ut_helpers/test_model_builder_helper.py deleted file mode 100644 index 94fe28f2..00000000 --- a/_unittests/ut_helpers/test_model_builder_helper.py +++ /dev/null @@ -1,70 +0,0 @@ -import unittest -from onnx_diagnostic.ext_test_case import ( - ExtTestCase, - ignore_errors, - requires_torch, - requires_transformers, - hide_stdout, -) -from onnx_diagnostic.helpers.model_builder_helper import ( - download_model_builder_to_cache, - import_model_builder, - create_model_builder, - save_model_builder, -) -from onnx_diagnostic.torch_models.hghub import ( - get_untrained_model_with_inputs, -) -from onnx_diagnostic.helpers.rt_helper import make_feeds - - -class TestModelBuilderHelper(ExtTestCase): - # This is to limit impact on CI. - @requires_transformers("4.52") - @requires_torch("2.7.99") - @ignore_errors(OSError) # connectivity issues - def test_download_model_builder(self): - path = download_model_builder_to_cache() - self.assertExists(path) - builder = import_model_builder() - self.assertHasAttr(builder, "create_model") - - # This is to limit impact on CI. - @requires_transformers("4.52") - @requires_torch("2.7.99") - @hide_stdout() - @ignore_errors(OSError) # connectivity issues - def test_model_builder_id(self): - # clear&&python ~/.cache/onnx-diagnostic/builder.py - # --model arnir0/Tiny-LLM -p fp16 -c dump_cache -e cpu -o dump_model - folder = self.get_dump_folder("test_model_builder_id") - data = get_untrained_model_with_inputs("arnir0/Tiny-LLM") - onnx_model = create_model_builder( - data["configuration"], - data["model"], - precision="fp32", - execution_provider="cpu", - cache_dir=folder, - verbose=1, - ) - self.assertGreater(onnx_model.model.graph.num_nodes(), 5) - model_name = save_model_builder(onnx_model, folder, verbose=1) - self.assertExists(model_name) - - import onnxruntime - - sess = onnxruntime.InferenceSession(model_name, providers=["CPUExecutionProvider"]) - del data["inputs"]["position_ids"] - feeds = make_feeds([i.name for i in sess.get_inputs()], data["inputs"], use_numpy=True) - expected = data["model"](**data["inputs"]) - - try: - got = sess.run(None, feeds) - except onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument as e: - if "batch_size must be 1 when sequence_length > 1" in str(e): - raise unittest.SkipTest("batch_size must be 1 when sequence_length > 1") - self.assertEqualAny(expected, got) - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/_unittests/ut_tasks/test_tasks.py b/_unittests/ut_tasks/test_tasks.py index 6815e8dc..5d5d2683 100644 --- a/_unittests/ut_tasks/test_tasks.py +++ b/_unittests/ut_tasks/test_tasks.py @@ -35,6 +35,7 @@ def test_text2text_generation(self): ) @hide_stdout() + @requires_transformers("4.55.4") # modeling_units def test_text_generation(self): mid = "arnir0/Tiny-LLM" data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) @@ -43,17 +44,19 @@ def test_text_generation(self): model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] model(**inputs) model(**data["inputs2"]) - with torch_export_patches(patch_transformers=True, verbose=10): + with ( + torch_export_patches(patch_transformers=True, verbose=10), + torch.fx.experimental._config.patch(backed_size_oblivious=True), + ): torch.export.export( model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False ) - def test_text_generation_empty_cache(self): + def test_text_generation_prompt_processing(self): mid = "arnir0/Tiny-LLM" data = get_untrained_model_with_inputs(mid, add_second_input=True) model, inputs = data["model"], data["inputs"] - self.assertIn("inputs_empty_cache", data) - empty_inputs = torch_deepcopy(data["inputs_empty_cache"]) + empty_inputs = torch_deepcopy(data["inputs2"]) model(**torch_deepcopy(empty_inputs)) expected = model(**torch_deepcopy(inputs)) self.assertEqual( diff --git a/_unittests/ut_tasks/test_tasks_text_generation.py b/_unittests/ut_tasks/test_tasks_text_generation.py new file mode 100644 index 00000000..10379f37 --- /dev/null +++ b/_unittests/ut_tasks/test_tasks_text_generation.py @@ -0,0 +1,39 @@ +import unittest +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + hide_stdout, + requires_transformers, + requires_torch, +) +from onnx_diagnostic.torch_models.validate import validate_model + + +class TestTasksMaskGeneration(ExtTestCase): + @hide_stdout() + @requires_transformers("4.53") + @requires_torch("2.7.99") + def test_text_generation(self): + mid = "microsoft/phi-2" + summary, data = validate_model( + mid, + do_run=True, + verbose=10, + exporter="onnx-dynamo", + dump_folder="dump_test/microsoft_phi-2", + inputs2=True, + patch=True, + ) + self.assertIsInstance(summary, dict) + # multi-turn conversation + self.assertLess(summary["disc_onnx_ort_run_abs"], 3e-2) + # prompt processing + self.assertLess(summary["disc_onnx_ort_run2_abs"], 3e-2) + # token generation + self.assertLess(summary["disc_onnx_ort_run3_abs"], 3e-2) + self.assertIsInstance(data, dict) + onnx_filename = data["onnx_filename"] + self.assertExists(onnx_filename) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_dynamic_class.py b/_unittests/ut_torch_export_patches/test_dynamic_class.py index da4cbd91..7dca25eb 100644 --- a/_unittests/ut_torch_export_patches/test_dynamic_class.py +++ b/_unittests/ut_torch_export_patches/test_dynamic_class.py @@ -307,7 +307,10 @@ def test_phi2_export_module(self): str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True) ) - with torch_export_patches(patch_transformers=True): + with ( + torch_export_patches(patch_transformers=True), + torch.fx.experimental._config.patch(backed_size_oblivious=True), + ): ep = torch.export.export( model, (), @@ -346,7 +349,10 @@ def test_phi2_export_interpreter(self): str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True) ) - with torch_export_patches(patch_transformers=True, verbose=1): + with ( + torch_export_patches(patch_transformers=True, verbose=1), + torch.fx.experimental._config.patch(backed_size_oblivious=True), + ): if masking_utils is not None: self.assertEqual( masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"], diff --git a/_unittests/ut_torch_export_patches/test_patch_torch.py b/_unittests/ut_torch_export_patches/test_patch_torch.py index 98bb720a..58d833b0 100644 --- a/_unittests/ut_torch_export_patches/test_patch_torch.py +++ b/_unittests/ut_torch_export_patches/test_patch_torch.py @@ -414,10 +414,10 @@ def _batch1(t): if got is not None: self.assertEqualArrayAny(expected, got) - if "inputs_empty_cache" not in data: + # inputs2 is prompt_processing (no cache) + if "inputs2" not in data: return - - export_inputs = data["inputs_empty_cache"] + export_inputs = data["inputs2"] # with self.subTest(input="cache0", backed_size_oblivious=False): # with torch_export_patches(patch_transformers=True): diff --git a/_unittests/ut_torch_models/test_validate_models.py b/_unittests/ut_torch_models/test_validate_models.py index 6bbf60ee..741966b3 100644 --- a/_unittests/ut_torch_models/test_validate_models.py +++ b/_unittests/ut_torch_models/test_validate_models.py @@ -41,7 +41,7 @@ def test_validate_tiny_llms_bfloat16(self): @requires_transformers("4.53") @requires_torch("2.7.99") @requires_experimental() - @hide_stdout() + # @hide_stdout() def test_validate_microsoft_phi4_reasoning(self): # python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning # --run -v 1 --export custom -o dump_test --no-quiet --device cuda --patch diff --git a/_unittests/ut_torch_models/test_validate_whole_models.py b/_unittests/ut_torch_models/test_validate_whole_models.py index 830ee0b2..df17277f 100644 --- a/_unittests/ut_torch_models/test_validate_whole_models.py +++ b/_unittests/ut_torch_models/test_validate_whole_models.py @@ -118,6 +118,7 @@ def test_g_validate_model_onnx_dynamo_os_ort(self): self.assertExists(onnx_filename) @requires_torch("2.7") + @requires_transformers("4.55.4") # modeling_units @hide_stdout() @ignore_warnings(FutureWarning) @requires_experimental() @@ -147,6 +148,7 @@ def test_i_validate_model_custom(self): ) @requires_torch("2.7") + @requires_transformers("4.55.4") # modeling_units @hide_stdout() @ignore_warnings(FutureWarning) @requires_experimental() @@ -227,7 +229,7 @@ def test_m_validate_model_vit_model(self): self.assertIsInstance(summary, dict) self.assertIsInstance(data, dict) self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-3) - self.assertLess(summary["disc_onnx_ort_run22_abs"], 1e-3) + self.assertLess(summary["disc_onnx_ort_run2_abs"], 1e-3) self.assertEqual("dict(pixel_values:A1s2x3x30x30)", summary["run_feeds_inputs"]) self.assertEqual("dict(pixel_values:A1s3x3x31x31)", summary["run_feeds_inputs2"]) self.assertEqual("#1[A1s2x2]", summary["run_output_inputs"]) diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index e80567d7..a6751721 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -1063,36 +1063,6 @@ def max_diff( print(f"[max_diff] to_tuple2: {string_type(expected)} ? {string_type(got)}") return max_diff(expected, got.to_tuple(), debug_info=_debug("to_tuple2"), **_dkws) - if isinstance(got, (list, tuple)): - if len(got) != 1: - if verbose >= 6: - print( - f"[max_diff] list,tuple,2: {string_type(expected)} " - f"? {string_type(got)}" - ) - if verbose > 2: - import torch - - print( - f"[max_diff] (a) inf because len(expected)={len(expected)}!=1, " - f"len(got)={len(got)}, level={level}, _index={_index}" - ) - for i, (a, b) in enumerate(zip(expected, got)): - if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): - print( - f" i={i} expected {a.dtype}:{a.shape}, " - f"has {b.dtype}:{b.shape}, _index={_index}" - ) - else: - print( - f" i={i} a is {type(a)}, " - f"b is {type(b)}, _index={_index}" - ) - return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) - if verbose >= 6: - print(f"[max_diff] list,tuple,1: {string_type(expected)} ? {string_type(got)}") - return max_diff(expected, got[0], debug_info=_debug("lt1"), **_dkws) - if isinstance(expected, (tuple, list)): if verbose >= 6: print(f"[max_diff] list,tuple,0: {string_type(expected)} ? {string_type(got)}") diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 964e0462..b8e07045 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -59,8 +59,8 @@ def get_inputs( dummy_max_token_id: int, num_hidden_layers: int, batch_size: int = 2, - sequence_length: int = 30, - sequence_length2: int = 3, + past_sequence_length: int = 30, + sequence_length: int = 3, dynamic_rope: bool = False, num_key_value_heads: Optional[int] = None, head_dim: Optional[int] = None, @@ -76,17 +76,18 @@ def get_inputs( :param head_dim: last dimension of the cache :param dummy_max_token_id: dummy max token id :param batch_size: batch size - :param sequence_length: sequence length - :param sequence_length2: new sequence length + :param past_sequence_length: past sequence length + :param sequence_length: new sequence length :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`) :param cls_cache: cache class, by default it is :class:`transformers.cache_utils.DynamicCache` :return: dictionary """ batch = "batch" - seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) - cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096) + seq_length = "seq_length" + past_seq_length = "past_seq_length" + # TODO(team): Is this code block still necessary? if config is not None and config.__class__.__name__ == "FalconMambaConfig": try: from transformers.models.mamba.modeling_mamba import MambaCache @@ -98,23 +99,23 @@ def get_inputs( MambaCache, ), f"Unexpected value for cls_cache={cls_cache} and config={config}" seq_length_multiple = 8 - sequence_length = ( - (sequence_length + seq_length_multiple) + past_sequence_length = ( + (past_sequence_length + seq_length_multiple) // seq_length_multiple * seq_length_multiple ) # sequence_inc = seq_length_multiple - sequence_length2 = seq_length_multiple + sequence_length = seq_length_multiple shapes = { "input_ids": {0: batch, 1: "sequence_length"}, "attention_mask": { 0: batch, - 1: "cache+seq", # cache_length + seq_length + 1: "cache+seq", # past_seq_length + seq_length }, "cache_position": { 0: batch, - 1: "cache+seq", # cache_length + seq_length + 1: "cache+seq", # past_seq_length + seq_length }, "cache_params": [ [{0: batch} for _ in range(num_hidden_layers)], @@ -123,9 +124,9 @@ def get_inputs( } inputs = dict( input_ids=torch.randint( - 0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2) + 0, dummy_max_token_id, (batch_size, past_sequence_length + sequence_length) ).to(torch.int64), - attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to( + attention_mask=torch.ones((batch_size, past_sequence_length + sequence_length)).to( torch.int64 ), cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64), @@ -167,46 +168,51 @@ def get_inputs( make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name] is_static = cache_name == "StaticCache" + # TODO(team): Is this code block still necessary? if is_static: # static shapes = { "input_ids": {0: batch, 1: seq_length}, - "attention_mask": {0: batch, 2: "seq"}, - "cache_position": {0: "seq"}, + "attention_mask": {0: batch, 2: "past_sequence_length"}, + "cache_position": {0: "past_sequence_length"}, "past_key_values": [ - # [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], - # [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], + # past_sequence_length is now static [{0: batch} for _ in range(num_hidden_layers)], [{0: batch} for _ in range(num_hidden_layers)], ], } inputs = dict( input_ids=torch.randint( - 0, dummy_max_token_id, (batch_size, sequence_length2) + 0, dummy_max_token_id, (batch_size, past_sequence_length) ).to(torch.int64), attention_mask=torch.ones( - (batch_size, num_key_value_heads, sequence_length2, head_dim) + ( + batch_size, + num_key_value_heads, + past_sequence_length, + head_dim, + ) ).to(torch.bool), - cache_position=torch.arange(sequence_length2).to(torch.int64), + cache_position=torch.arange(past_sequence_length).to(torch.int64), past_key_values=make_static_cache( [ ( torch.randn( batch_size, num_key_value_heads, - sequence_length + sequence_length2, + sequence_length + past_sequence_length, head_dim, ), torch.randn( batch_size, num_key_value_heads, - sequence_length + sequence_length2, + sequence_length + past_sequence_length, head_dim, ), ) for i in range(num_hidden_layers) ], - max_cache_len=max(sequence_length + sequence_length2, head_dim), + max_cache_len=max(past_sequence_length, head_dim), ), ) else: @@ -215,53 +221,57 @@ def get_inputs( "input_ids": {0: batch, 1: seq_length}, "attention_mask": { 0: batch, - 1: "cache+seq", # cache_length + seq_length + 1: "past_seq_length+seq_length", # past_seq_length + seq_length }, "position_ids": { 0: batch, - 1: "cache+seq", # cache_length + seq_length + 1: seq_length, }, "past_key_values": [ - [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], - [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], + [{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)], + [{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)], ], } - inputs = dict( input_ids=torch.randint( - 0, dummy_max_token_id, (batch_size, sequence_length2) + 0, dummy_max_token_id, (batch_size, sequence_length) ).to(torch.int64), - attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to( - torch.int64 - ), - position_ids=torch.arange(sequence_length, sequence_length + sequence_length2) + attention_mask=torch.ones( + (batch_size, sequence_length + past_sequence_length) + ).to(torch.int64), + position_ids=torch.arange( + past_sequence_length, sequence_length + past_sequence_length + ) .to(torch.int64) .expand((batch_size, -1)), past_key_values=make_cache( # type: ignore[operator] [ ( torch.randn( - batch_size, num_key_value_heads, sequence_length, head_dim + batch_size, num_key_value_heads, past_sequence_length, head_dim ), torch.randn( - batch_size, num_key_value_heads, sequence_length, head_dim + batch_size, num_key_value_heads, past_sequence_length, head_dim ), ) for i in range(num_hidden_layers) ] ), ) + # NOTE: past_sequence_length can be 0 when testing prompt processing, + # which it becomes an empty tensor res = dict(inputs=inputs, dynamic_shapes=shapes) if add_second_input: + # TODO(titaiwang): Make input key more informative + # prompt processing (prefill) testing res["inputs2"] = get_inputs( model=model, config=config, dummy_max_token_id=dummy_max_token_id, num_hidden_layers=num_hidden_layers, - batch_size=(batch_size + 1) if add_second_input > 0 else 1, - sequence_length=sequence_length + 1, - sequence_length2=sequence_length2 - + (add_second_input if add_second_input > 0 else -add_second_input), + batch_size=batch_size, + past_sequence_length=0, + sequence_length=32, dynamic_rope=dynamic_rope, num_key_value_heads=num_key_value_heads, head_dim=head_dim, @@ -269,14 +279,16 @@ def get_inputs( add_second_input=0, **kwargs, )["inputs"] - res["inputs_empty_cache"] = get_inputs( + # Token generation (decode) testing + # NOTE: We have to export model in decode mode to preserve the cache + res["inputs3"] = get_inputs( model=model, config=config, dummy_max_token_id=dummy_max_token_id, num_hidden_layers=num_hidden_layers, - batch_size=batch_size, - sequence_length=0, - sequence_length2=sequence_length2, + batch_size=2, + past_sequence_length=32, + sequence_length=1, dynamic_rope=dynamic_rope, num_key_value_heads=num_key_value_heads, head_dim=head_dim, @@ -291,6 +303,23 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: """ Inputs kwargs. + NOTE: We test two scenarios: + 1. prompt processing (aka prefill): + input_ids=(batch_size, prompt_length) + attn_mask=(batch_size, 0+prompt_length) = (batch_size, prompt_length) + pos_ids=(batch_size, prompt_length) + past_key_values=(batch_size, num_key_value_heads, 0, head_dim) + present_key_values=(batch_size, num_key_value_heads, 0+prompt_length, head_dim) + 2. token generation (aka decode). + input_ids=(batch_size, 1) + attn_mask=(batch_size, past_sequence_length+1) + pos_ids=(batch_size, 1) + past_key_values=(batch_size, num_key_value_heads, past_sequence_length, + head_dim) + present_key_values=(batch_size, num_key_value_heads, + past_sequence_length+1, head_dim) + + If the configuration is None, the function selects typical dimensions. """ if config is not None: @@ -305,8 +334,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8 kwargs = dict( batch_size=2, - sequence_length=30, - sequence_length2=3, + past_sequence_length=30, + sequence_length=3, dummy_max_token_id=31999 if config is None else (config.vocab_size - 1), num_hidden_layers=4 if config is None else config.num_hidden_layers, intermediate_size=256 if config is None else config.intermediate_size, @@ -315,10 +344,16 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: conv_kernel=8 if config is None else getattr(config, "conv_kernel", None), ) else: + # multi-turn conversation + # prompt-processing -> token-generation(loop output) -> + # prompt-processing from the loop output + # Token generation (decode) testing + # NOTE: We have to export model in decode mode to preserve the cache + # NOTE: batch_size=1 for ORT GQA to run kwargs = dict( - batch_size=2, - sequence_length=30, - sequence_length2=3, + batch_size=1, + past_sequence_length=32, + sequence_length=16, head_dim=( 16 if config is None diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index 972c728c..f177ce65 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -450,6 +450,11 @@ def torch_export_patches( except ImportError: masking_utils = None + try: + import transformers.modeling_utils as modeling_utils + except ImportError: + modeling_utils = None + if verbose: import transformers @@ -544,6 +549,23 @@ def torch_export_patches( patch_transformers_list.patched_sdpa_mask_recent_torch ) + if ( + modeling_utils + and patch_transformers_list.patch_modeling_utils + and "sdpa" in modeling_utils.ALL_ATTENTION_FUNCTIONS + ): + if verbose: + print( + "[torch_export_patches] patches " + "transformers.modeling_utils.sdpa_attention_forward" + ) + f_transformers_sdpa_attention_forward = modeling_utils.ALL_ATTENTION_FUNCTIONS[ + "sdpa" + ] + modeling_utils.ALL_ATTENTION_FUNCTIONS["sdpa"] = ( + patch_transformers_list.patched_sdpa_attention_forward + ) + if custom_patches: if verbose: print("[torch_export_patches] applies custom patches") @@ -729,6 +751,19 @@ def torch_export_patches( "transformers.masking_utils.sdpa_mask " "in ALL_MASK_ATTENTION_FUNCTIONS" ) + if ( + modeling_utils + and patch_transformers_list.patch_modeling_utils + and "sdpa" in modeling_utils.ALL_ATTENTION_FUNCTIONS + ): + modeling_utils.ALL_ATTENTION_FUNCTIONS["sdpa"] = ( + f_transformers_sdpa_attention_forward + ) + if verbose: + print( + "[torch_export_patches] restored " + "transformers.modeling_utils.sdpa_attention_forward" + ) ######## # caches diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 3fe0ba83..48b00fde 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -1196,7 +1196,7 @@ def wrapper(self, x, position_ids): return wrapper -def common_eager_attention_forward( +def _common_eager_attention_forward( module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -1244,7 +1244,7 @@ def patched_model_bart_eager_attention_forward( **kwargs, ): """[patch:transformers.models.bart.modeling_bart.eager_attention_forward]""" - return common_eager_attention_forward( + return _common_eager_attention_forward( module, query, key, @@ -1269,7 +1269,7 @@ def patched_modeling_marian_eager_attention_forward( **kwargs, ): """[patch:transformers.models.marian.modeling_marian.eager_attention_forward]""" - return common_eager_attention_forward( + return _common_eager_attention_forward( module, query, key, @@ -1894,3 +1894,72 @@ def get_placeholder_mask( ), ) return special_image_mask + + +##### Attention ##### + +try: + import transformers.modeling_utils + + patch_modeling_utils = True + + from transformers.integrations.sdpa_attention import use_gqa_in_sdpa, repeat_kv + +except ImportError: + patch_modeling_utils = False + +if patch_modeling_utils: + + def patched_sdpa_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + is_causal: Optional[bool] = None, + **kwargs, + ) -> tuple[torch.Tensor, None]: + """manual patch for function ```transformers.integrations.sdpa_attention.sdpa_attention_forward```.""" # noqa: E501 + sdpa_kwargs = {} + if hasattr(module, "num_key_value_groups"): + if not use_gqa_in_sdpa(attention_mask, key): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + else: + sdpa_kwargs = {"enable_gqa": True} + + if attention_mask is not None and attention_mask.ndim == 4: + attention_mask = attention_mask[:, :, :, : key.shape[-2]] + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # noqa: E501 + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # noqa: E501 + # Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool` # noqa: E501 + if is_causal is None: + # NOTE: attention_mask should always be not None + # https://github.com/huggingface/transformers/blob/def4a37e19601b597f170e81684c8b0b5f84db39/src/transformers/masking_utils.py#L240-L243 + is_causal = False + # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. # noqa: E501 + # We convert it to a bool for the SDPA kernel that only accepts bools. + if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor): + is_causal = is_causal.item() + + # From causal_mask generation, attention_mask is 4D, and the last dim + # should be the same as key's seq_len + torch._check( + attention_mask.shape[3] == key.shape[2] if attention_mask is not None else True + ) + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=dropout, + scale=scaling, + is_causal=is_causal, + **sdpa_kwargs, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, None diff --git a/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py b/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py index f8b7fe63..cc7a3390 100644 --- a/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +++ b/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py @@ -5,7 +5,7 @@ def get_tiny_llm( batch_size: int = 2, sequence_length: int = 30, - sequence_length2: int = 3, + past_sequence_length: int = 3, dynamic_rope: bool = False, use_static_cache: bool = False, **kwargs, @@ -15,7 +15,7 @@ def get_tiny_llm( :param batch_size: batch size :param sequence_length: sequence length - :param sequence_length2: new sequence length + :param past_sequence_length: past sequence length :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`) :param use_static_cache: use StaticCache instead of DynamicCache :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1`` @@ -62,7 +62,7 @@ def get_tiny_llm( num_hidden_layers=config["num_hidden_layers"], # type: ignore[arg-type] batch_size=batch_size, sequence_length=sequence_length, - sequence_length2=sequence_length2, + past_sequence_length=past_sequence_length, dynamic_rope=dynamic_rope, num_key_value_heads=config["num_key_value_heads"], # type: ignore[arg-type] cls_cache="StaticCache" if use_static_cache else "DynamicCache", diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index 9bccf9f1..8ff741d6 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -556,30 +556,6 @@ def validate_model( f.write(f"model_id: {model_id}\n------\n") f.write(pprint.pformat(dump_info)) - if exporter == "modelbuilder": - # Models used with ModelBuilder do not like batch size > 1. - # Let's change that. - for k in ["inputs", "inputs2"]: - if k not in data: - continue - if verbose: - print(f"[validate_model] set batch=1 for data[{k!r}]") - print(f"[validate_model] batch=1 === {string_type(data[k], with_shape=True)}") - cpl = CoupleInputsDynamicShapes( - tuple(), data[k], dynamic_shapes=data["dynamic_shapes"] - ) - if patch_kwargs.get("patch", False): - with torch_export_patches(**patch_kwargs): # type: ignore[arg-type] - data[k] = cpl.change_dynamic_dimensions( - desired_values=dict(batch=1), only_desired=True - ) - else: - data[k] = cpl.change_dynamic_dimensions( - desired_values=dict(batch=1), only_desired=True - ) - if verbose: - print(f"[validate_model] batch=1 --> {string_type(data[k], with_shape=True)}") - # modelbuilder needs different treatments sometimes, so # we mark it for later usage. # for example, it has different past_kv ordering than @@ -707,8 +683,8 @@ def validate_model( data, summary, k, - f"run2{k[6:]}", - f"run_expected2{k[6:]}", + f"run{k[6:]}", + f"run_expected{k[6:]}", verbose, 1, 0, @@ -992,6 +968,10 @@ def node_iter(proto): def _validate_do_run_model( data, summary, key, tag, expected_tag, verbose, repeat, warmup, quiet ): + if key not in data: + if verbose: + print(f"[validate_model] input; {key!r} not defined, skip.") + return if verbose: print(f"[validate_model] -- run the model inputs={key!r}...") print(f"[validate_model] {key}={string_type(data[key], with_shape=True)}") @@ -1228,16 +1208,24 @@ def call_torch_export_export( print("[call_torch_export_export] export...") model = data["model"] + + def _run_torch_export(): + with torch.fx.experimental._config.patch(backed_size_oblivious=True): + ep = torch.export.export( + model, + args, + kwargs=kwargs, + dynamic_shapes=dse, + strict=strict, + ) + return ep + ep = _quiet_or_not_quiet( quiet, "export_export", summary, data, - ( - lambda m=model, args=args, kws=kwargs, dse=dse, s=strict: ( - torch.export.export(m, args, kwargs=kws, dynamic_shapes=dse, strict=s) - ) - ), + _run_torch_export, ) if "ERR_export_export" in summary: return summary, data @@ -1443,7 +1431,7 @@ def _mk(key, flavour=flavour): keys = [("inputs", "run_expected", "")] if second_input_keys: - keys.extend([(k, f"run_expected2{k[6:]}", f"2{k[6:]}") for k in second_input_keys]) + keys.extend([(k, f"run_expected{k[6:]}", f"{k[6:]}") for k in second_input_keys]) for k_input, k_expected, suffix in keys: # make_feeds assert k_input in data, f"Unable to find {k_input!r} in {sorted(data)}" @@ -1937,23 +1925,24 @@ def call_torch_export_custom( kws["target_opset"] = opset if output_names: kws["output_names"] = output_names - - epo, opt_stats = _quiet_or_not_quiet( - quiet, - "export_export_onnx_c", - summary, - data, - ( - lambda m=model, args=args, kwargs=kwargs, kws=kws: ( - to_onnx( - model, - args, - kwargs=kwargs, - **kws, + # anti-specializing 0/1 during torch.export.export + with torch.fx.experimental._config.patch(backed_size_oblivious=True): + epo, opt_stats = _quiet_or_not_quiet( + quiet, + "export_export_onnx_c", + summary, + data, + ( + lambda m=model, args=args, kwargs=kwargs, kws=kws: ( + to_onnx( + model, + args, + kwargs=kwargs, + **kws, + ) ) - ) - ), - ) + ), + ) if "ERR_export_onnx_c" in summary: return summary, data