diff --git a/_unittests/ut_tasks/test_tasks.py b/_unittests/ut_tasks/test_tasks.py index 57c83b55..a24dc500 100644 --- a/_unittests/ut_tasks/test_tasks.py +++ b/_unittests/ut_tasks/test_tasks.py @@ -43,7 +43,9 @@ def test_text_generation(self): model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] model(**inputs) model(**data["inputs2"]) - with torch_export_patches(patch_transformers=True, verbose=10): + with torch_export_patches( + patch_transformers=True, verbose=10 + ), torch.fx.experimental._config.patch(backed_size_oblivious=True): torch.export.export( model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False ) diff --git a/_unittests/ut_tasks/test_tasks_text_generation.py b/_unittests/ut_tasks/test_tasks_text_generation.py new file mode 100644 index 00000000..986211d8 --- /dev/null +++ b/_unittests/ut_tasks/test_tasks_text_generation.py @@ -0,0 +1,39 @@ +import unittest +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + hide_stdout, + requires_transformers, + requires_torch, +) +from onnx_diagnostic.torch_models.validate import validate_model + + +class TestTasksMaskGeneration(ExtTestCase): + @hide_stdout() + @requires_transformers("4.53") + @requires_torch("2.7.99") + def test_text_generation(self): + mid = "microsoft/phi-2" + summary, data = validate_model( + mid, + do_run=True, + verbose=10, + exporter="onnx-dynamo", + dump_folder="dump_test/microsoft_phi-2", + inputs2=True, + patch=True, + ) + self.assertIsInstance(summary, dict) + # token generation + self.assertLess(summary["disc_onnx_ort_run_abs"], 3e-2) + # prompt processing + self.assertLess(summary["disc_onnx_ort_run2_abs"], 3e-2) + # multi-turn conversation + self.assertLess(summary["disc_onnx_ort_run3_abs"], 3e-2) + self.assertIsInstance(data, dict) + onnx_filename = data["onnx_filename"] + self.assertExists(onnx_filename) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_dynamic_class.py b/_unittests/ut_torch_export_patches/test_dynamic_class.py index da4cbd91..97afde0f 100644 --- a/_unittests/ut_torch_export_patches/test_dynamic_class.py +++ b/_unittests/ut_torch_export_patches/test_dynamic_class.py @@ -307,7 +307,9 @@ def test_phi2_export_module(self): str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True) ) - with torch_export_patches(patch_transformers=True): + with torch_export_patches( + patch_transformers=True + ), torch.fx.experimental._config.patch(backed_size_oblivious=True): ep = torch.export.export( model, (), @@ -346,7 +348,9 @@ def test_phi2_export_interpreter(self): str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True) ) - with torch_export_patches(patch_transformers=True, verbose=1): + with torch_export_patches( + patch_transformers=True, verbose=1 + ), torch.fx.experimental._config.patch(backed_size_oblivious=True): if masking_utils is not None: self.assertEqual( masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"], diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index 7879a1fd..0c3fa94b 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -1061,36 +1061,6 @@ def max_diff( print(f"[max_diff] to_tuple2: {string_type(expected)} ? {string_type(got)}") return max_diff(expected, got.to_tuple(), debug_info=_debug("to_tuple2"), **_dkws) - if isinstance(got, (list, tuple)): - if len(got) != 1: - if verbose >= 6: - print( - f"[max_diff] list,tuple,2: {string_type(expected)} " - f"? {string_type(got)}" - ) - if verbose > 2: - import torch - - print( - f"[max_diff] (a) inf because len(expected)={len(expected)}!=1, " - f"len(got)={len(got)}, level={level}, _index={_index}" - ) - for i, (a, b) in enumerate(zip(expected, got)): - if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): - print( - f" i={i} expected {a.dtype}:{a.shape}, " - f"has {b.dtype}:{b.shape}, _index={_index}" - ) - else: - print( - f" i={i} a is {type(a)}, " - f"b is {type(b)}, _index={_index}" - ) - return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) - if verbose >= 6: - print(f"[max_diff] list,tuple,1: {string_type(expected)} ? {string_type(got)}") - return max_diff(expected, got[0], debug_info=_debug("lt1"), **_dkws) - if isinstance(expected, (tuple, list)): if verbose >= 6: print(f"[max_diff] list,tuple,0: {string_type(expected)} ? {string_type(got)}") diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 6e6e29ba..c7249828 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -59,8 +59,8 @@ def get_inputs( dummy_max_token_id: int, num_hidden_layers: int, batch_size: int = 2, - sequence_length: int = 30, - sequence_length2: int = 3, + past_sequence_length: int = 30, + sequence_length: int = 3, dynamic_rope: bool = False, num_key_value_heads: Optional[int] = None, head_dim: Optional[int] = None, @@ -76,17 +76,18 @@ def get_inputs( :param head_dim: last dimension of the cache :param dummy_max_token_id: dummy max token id :param batch_size: batch size - :param sequence_length: sequence length - :param sequence_length2: new sequence length + :param past_sequence_length: past sequence length + :param sequence_length: new sequence length :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`) :param cls_cache: cache class, by default it is :class:`transformers.cache_utils.DynamicCache` :return: dictionary """ batch = "batch" - seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) - cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096) + seq_length = "seq_length" + past_seq_length = "past_seq_length" + # TODO(team): Is this code block still necessary? if config is not None and config.__class__.__name__ == "FalconMambaConfig": try: from transformers.models.mamba.modeling_mamba import MambaCache @@ -98,23 +99,23 @@ def get_inputs( MambaCache, ), f"Unexpected value for cls_cache={cls_cache} and config={config}" seq_length_multiple = 8 - sequence_length = ( - (sequence_length + seq_length_multiple) + past_sequence_length = ( + (past_sequence_length + seq_length_multiple) // seq_length_multiple * seq_length_multiple ) # sequence_inc = seq_length_multiple - sequence_length2 = seq_length_multiple + sequence_length = seq_length_multiple shapes = { "input_ids": {0: batch, 1: "sequence_length"}, "attention_mask": { 0: batch, - 1: "cache+seq", # cache_length + seq_length + 1: "cache+seq", # past_seq_length + seq_length }, "cache_position": { 0: batch, - 1: "cache+seq", # cache_length + seq_length + 1: "cache+seq", # past_seq_length + seq_length }, "cache_params": [ [{0: batch} for _ in range(num_hidden_layers)], @@ -123,9 +124,9 @@ def get_inputs( } inputs = dict( input_ids=torch.randint( - 0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2) + 0, dummy_max_token_id, (batch_size, past_sequence_length + sequence_length) ).to(torch.int64), - attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to( + attention_mask=torch.ones((batch_size, past_sequence_length + sequence_length)).to( torch.int64 ), cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64), @@ -167,46 +168,51 @@ def get_inputs( make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name] is_static = cache_name == "StaticCache" + # TODO(team): Is this code block still necessary? if is_static: # static shapes = { "input_ids": {0: batch, 1: seq_length}, - "attention_mask": {0: batch, 2: "seq"}, - "cache_position": {0: "seq"}, + "attention_mask": {0: batch, 2: "past_sequence_length"}, + "cache_position": {0: "past_sequence_length"}, "past_key_values": [ - # [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], - # [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], + # past_sequence_length is now static [{0: batch} for _ in range(num_hidden_layers)], [{0: batch} for _ in range(num_hidden_layers)], ], } inputs = dict( input_ids=torch.randint( - 0, dummy_max_token_id, (batch_size, sequence_length2) + 0, dummy_max_token_id, (batch_size, past_sequence_length) ).to(torch.int64), attention_mask=torch.ones( - (batch_size, num_key_value_heads, sequence_length2, head_dim) + ( + batch_size, + num_key_value_heads, + past_sequence_length, + head_dim, + ) ).to(torch.bool), - cache_position=torch.arange(sequence_length2).to(torch.int64), + cache_position=torch.arange(past_sequence_length).to(torch.int64), past_key_values=make_static_cache( [ ( torch.randn( batch_size, num_key_value_heads, - sequence_length + sequence_length2, + sequence_length + past_sequence_length, head_dim, ), torch.randn( batch_size, num_key_value_heads, - sequence_length + sequence_length2, + sequence_length + past_sequence_length, head_dim, ), ) for i in range(num_hidden_layers) ], - max_cache_len=max(sequence_length + sequence_length2, head_dim), + max_cache_len=max(past_sequence_length, head_dim), ), ) else: @@ -215,53 +221,74 @@ def get_inputs( "input_ids": {0: batch, 1: seq_length}, "attention_mask": { 0: batch, - 1: "cache+seq", # cache_length + seq_length + 1: "past_seq_length+seq_length", # past_seq_length + seq_length }, "position_ids": { 0: batch, - 1: "cache+seq", # cache_length + seq_length + 1: seq_length, }, "past_key_values": [ - [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], - [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], + [{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)], + [{0: batch, 2: past_seq_length} for _ in range(num_hidden_layers)], ], } - inputs = dict( input_ids=torch.randint( - 0, dummy_max_token_id, (batch_size, sequence_length2) + 0, dummy_max_token_id, (batch_size, sequence_length) ).to(torch.int64), - attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to( - torch.int64 - ), - position_ids=torch.arange(sequence_length, sequence_length + sequence_length2) + attention_mask=torch.ones( + (batch_size, sequence_length + past_sequence_length) + ).to(torch.int64), + position_ids=torch.arange( + past_sequence_length, sequence_length + past_sequence_length + ) .to(torch.int64) .expand((batch_size, -1)), past_key_values=make_cache( # type: ignore[operator] [ ( torch.randn( - batch_size, num_key_value_heads, sequence_length, head_dim + batch_size, num_key_value_heads, past_sequence_length, head_dim ), torch.randn( - batch_size, num_key_value_heads, sequence_length, head_dim + batch_size, num_key_value_heads, past_sequence_length, head_dim ), ) for i in range(num_hidden_layers) ] ), ) + # NOTE: past_sequence_length can be 0 when testing prompt processing, + # which it becomes an empty tensor res = dict(inputs=inputs, dynamic_shapes=shapes) if add_second_input: + # prompt processing (prefill) testing res["inputs2"] = get_inputs( model=model, config=config, dummy_max_token_id=dummy_max_token_id, num_hidden_layers=num_hidden_layers, - batch_size=(batch_size + 1) if add_second_input > 0 else 1, - sequence_length=sequence_length + 1, - sequence_length2=sequence_length2 - + (add_second_input if add_second_input > 0 else -add_second_input), + batch_size=batch_size, + past_sequence_length=0, + sequence_length=32, + dynamic_rope=dynamic_rope, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + cls_cache=cls_cache, + add_second_input=0, + **kwargs, + )["inputs"] + # multi-turn conversation + # prompt-processing -> token-generation(loop output) -> + # prompt-processing from the loop output + res["inputs3"] = get_inputs( + model=model, + config=config, + dummy_max_token_id=dummy_max_token_id, + num_hidden_layers=num_hidden_layers, + batch_size=1, + past_sequence_length=32, + sequence_length=8, dynamic_rope=dynamic_rope, num_key_value_heads=num_key_value_heads, head_dim=head_dim, @@ -276,6 +303,23 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: """ Inputs kwargs. + NOTE: We test two scenarios: + 1. prompt processing (aka prefill): + input_ids=(batch_size, prompt_length) + attn_mask=(batch_size, 0+prompt_length) = (batch_size, prompt_length) + pos_ids=(batch_size, prompt_length) + past_key_values=(batch_size, num_key_value_heads, 0, head_dim) + present_key_values=(batch_size, num_key_value_heads, 0+prompt_length, head_dim) + 2. token generation (aka decode). + input_ids=(batch_size, 1) + attn_mask=(batch_size, past_sequence_length+1) + pos_ids=(batch_size, 1) + past_key_values=(batch_size, num_key_value_heads, past_sequence_length, + head_dim) + present_key_values=(batch_size, num_key_value_heads, + past_sequence_length+1, head_dim) + + If the configuration is None, the function selects typical dimensions. """ if config is not None: @@ -290,8 +334,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8 kwargs = dict( batch_size=2, - sequence_length=30, - sequence_length2=3, + past_sequence_length=30, + sequence_length=3, dummy_max_token_id=31999 if config is None else (config.vocab_size - 1), num_hidden_layers=4 if config is None else config.num_hidden_layers, intermediate_size=256 if config is None else config.intermediate_size, @@ -300,10 +344,12 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: conv_kernel=8 if config is None else getattr(config, "conv_kernel", None), ) else: + # Token generation (decode) testing + # NOTE: We have to export model in decode mode to preserve the cache kwargs = dict( batch_size=2, - sequence_length=30, - sequence_length2=3, + past_sequence_length=32, + sequence_length=1, head_dim=( 16 if config is None diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index f115718d..aee5149b 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -415,6 +415,11 @@ def torch_export_patches( except ImportError: masking_utils = None + try: + import transformers.modeling_utils as modeling_utils + except ImportError: + modeling_utils = None + if verbose: import transformers @@ -509,6 +514,23 @@ def torch_export_patches( patch_transformers_list.patched_sdpa_mask_recent_torch ) + if ( + modeling_utils + and patch_transformers_list.patch_modeling_utils + and "sdpa" in modeling_utils.ALL_ATTENTION_FUNCTIONS + ): + if verbose: + print( + "[torch_export_patches] patches " + "transformers.modeling_utils.sdpa_attention_forward" + ) + f_transformers_sdpa_attention_forward = modeling_utils.ALL_ATTENTION_FUNCTIONS[ + "sdpa" + ] + modeling_utils.ALL_ATTENTION_FUNCTIONS["sdpa"] = ( + patch_transformers_list.patched_sdpa_attention_forward + ) + if custom_patches: if verbose: print("[torch_export_patches] applies custom patches") @@ -688,6 +710,19 @@ def torch_export_patches( "transformers.masking_utils.sdpa_mask " "in ALL_MASK_ATTENTION_FUNCTIONS" ) + if ( + modeling_utils + and patch_transformers_list.patch_modeling_utils + and "sdpa" in modeling_utils.ALL_ATTENTION_FUNCTIONS + ): + modeling_utils.ALL_ATTENTION_FUNCTIONS["sdpa"] = ( + f_transformers_sdpa_attention_forward + ) + if verbose: + print( + "[torch_export_patches] restored " + "transformers.modeling_utils.sdpa_attention_forward" + ) ######## # caches diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index e95a0a47..4a485bce 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -986,7 +986,7 @@ def wrapper(self, x, position_ids): return wrapper -def common_eager_attention_forward( +def _common_eager_attention_forward( module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -1033,7 +1033,7 @@ def patched_model_bart_eager_attention_forward( **kwargs, ): """[patch:transformers.models.bart.modeling_bart.eager_attention_forward]""" - return common_eager_attention_forward( + return _common_eager_attention_forward( module, query, key, @@ -1058,7 +1058,7 @@ def patched_modeling_marian_eager_attention_forward( **kwargs, ): """[patch:transformers.models.marian.modeling_marian.eager_attention_forward]""" - return common_eager_attention_forward( + return _common_eager_attention_forward( module, query, key, @@ -1629,3 +1629,71 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim ) return final_hidden_states, router_logits + + +##### Attention ##### + +try: + import transformers.modeling_utils + + patch_modeling_utils = True + + from transformers.integrations.sdpa_attention import use_gqa_in_sdpa, repeat_kv + +except ImportError: + patch_modeling_utils = False + +if patch_modeling_utils: + + def patched_sdpa_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + is_causal: Optional[bool] = None, + **kwargs, + ) -> tuple[torch.Tensor, None]: + """manual patch for function ```transformers.integrations.sdpa_attention.sdpa_attention_forward```.""" # noqa: E501 + sdpa_kwargs = {} + if hasattr(module, "num_key_value_groups"): + if not use_gqa_in_sdpa(attention_mask, key): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + else: + sdpa_kwargs = {"enable_gqa": True} + + if attention_mask is not None and attention_mask.ndim == 4: + attention_mask = attention_mask[:, :, :, : key.shape[-2]] + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # noqa: E501 + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # noqa: E501 + # Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool` # noqa: E501 + if is_causal is None: + # The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag # noqa: E501 + # This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns # noqa: E501 + # is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True) # noqa: E501 + # NOTE: query.shape[2] == 1 or > 1 should have the same output for causal attention + # so we simplify the condition to: + is_causal = attention_mask is None and getattr(module, "is_causal", True) + + # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. # noqa: E501 + # We convert it to a bool for the SDPA kernel that only accepts bools. + if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor): + is_causal = is_causal.item() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=dropout, + scale=scaling, + is_causal=is_causal, + **sdpa_kwargs, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, None diff --git a/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py b/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py index f8b7fe63..cc7a3390 100644 --- a/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +++ b/onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py @@ -5,7 +5,7 @@ def get_tiny_llm( batch_size: int = 2, sequence_length: int = 30, - sequence_length2: int = 3, + past_sequence_length: int = 3, dynamic_rope: bool = False, use_static_cache: bool = False, **kwargs, @@ -15,7 +15,7 @@ def get_tiny_llm( :param batch_size: batch size :param sequence_length: sequence length - :param sequence_length2: new sequence length + :param past_sequence_length: past sequence length :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`) :param use_static_cache: use StaticCache instead of DynamicCache :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1`` @@ -62,7 +62,7 @@ def get_tiny_llm( num_hidden_layers=config["num_hidden_layers"], # type: ignore[arg-type] batch_size=batch_size, sequence_length=sequence_length, - sequence_length2=sequence_length2, + past_sequence_length=past_sequence_length, dynamic_rope=dynamic_rope, num_key_value_heads=config["num_key_value_heads"], # type: ignore[arg-type] cls_cache="StaticCache" if use_static_cache else "DynamicCache", diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index e2c5b9ec..3863899b 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -521,30 +521,6 @@ def validate_model( f.write(f"model_id: {model_id}\n------\n") f.write(pprint.pformat(dump_info)) - if exporter == "modelbuilder": - # Models used with ModelBuilder do not like batch size > 1. - # Let's change that. - for k in ["inputs", "inputs2"]: - if k not in data: - continue - if verbose: - print(f"[validate_model] set batch=1 for data[{k!r}]") - print(f"[validate_model] batch=1 === {string_type(data[k], with_shape=True)}") - cpl = CoupleInputsDynamicShapes( - tuple(), data[k], dynamic_shapes=data["dynamic_shapes"] - ) - if patch_kwargs.get("patch", False): - with torch_export_patches(**patch_kwargs): # type: ignore[arg-type] - data[k] = cpl.change_dynamic_dimensions( - desired_values=dict(batch=1), only_desired=True - ) - else: - data[k] = cpl.change_dynamic_dimensions( - desired_values=dict(batch=1), only_desired=True - ) - if verbose: - print(f"[validate_model] batch=1 --> {string_type(data[k], with_shape=True)}") - # modelbuilder needs different treatments sometimes, so # we mark it for later usage. # for example, it has different past_kv ordering than @@ -597,6 +573,7 @@ def validate_model( if verbose: print(f"[validate_model] new inputs: {string_type(data['inputs'])}") print(f"[validate_model] new dynamic_hapes: {string_type(data['dynamic_shapes'])}") + # NOTE: The dynamic_shapes is always the same across inputs sets if inputs2: assert ( "inputs2" in data @@ -607,6 +584,14 @@ def validate_model( model=data["model"], dynamic_shapes=data["dynamic_shapes"], ) + # NOTE: text-generation tests 3rd inputs for multi-turn conversation + if "inputs3" in data: + data["inputs3"], _ = filter_inputs( + data["inputs3"], + drop_names=drop_inputs, + model=data["model"], + dynamic_shapes=data["dynamic_shapes"], + ) if not empty(dtype): if isinstance(dtype, str): @@ -618,6 +603,8 @@ def validate_model( summary["model_dtype"] = str(dtype) if "inputs2" in data: data["inputs2"] = to_any(data["inputs2"], dtype) # type: ignore + if "inputs3" in data: + data["inputs3"] = to_any(data["inputs3"], dtype) # type: ignore if not empty(device): if verbose: @@ -627,6 +614,8 @@ def validate_model( summary["model_device"] = str(device) if "inputs2" in data: data["inputs2"] = to_any(data["inputs2"], device) # type: ignore + if "inputs3" in data: + data["inputs3"] = to_any(data["inputs3"], device) # type: ignore for k in ["task", "size", "n_weights"]: summary[f"model_{k.replace('_','')}"] = data[k] @@ -662,10 +651,12 @@ def validate_model( _validate_do_run_model( data, summary, "inputs", "run", "run_expected", verbose, repeat, warmup, quiet ) - if inputs2: - _validate_do_run_model( - data, summary, "inputs2", "run2", "run_expected2", verbose, 1, 0, quiet - ) + _validate_do_run_model( + data, summary, "inputs2", "run2", "run_expected2", verbose, 1, 0, quiet + ) + _validate_do_run_model( + data, summary, "inputs3", "run3", "run_expected3", verbose, 1, 0, quiet + ) if exporter: print( @@ -923,6 +914,10 @@ def _validate_do_run_model( if verbose: print(f"[validate_model] -- run the model inputs={key!r}...") print(f"[validate_model] {key}={string_type(data[key], with_shape=True)}") + if key not in data: + if verbose: + print(f"[validate_model] input; {key!r} not defined, skip.") + return # We make a copy of the input just in case the model modifies them inplace hash_inputs = string_type(data[key], with_shape=True) inputs = torch_deepcopy(data[key]) @@ -1129,16 +1124,24 @@ def call_torch_export_export( print("[call_torch_export_export] export...") model = data["model"] + + def _run_torch_export(): + with torch.fx.experimental._config.patch(backed_size_oblivious=True): + ep = torch.export.export( + model, + args, + kwargs=kwargs, + dynamic_shapes=dse, + strict=strict, + ) + return ep + ep = _quiet_or_not_quiet( quiet, "export_export", summary, data, - ( - lambda m=model, args=args, kws=kwargs, dse=dse, s=strict: ( - torch.export.export(m, args, kwargs=kws, dynamic_shapes=dse, strict=s) - ) - ), + _run_torch_export, ) if "ERR_export_export" in summary: return summary, data @@ -1345,6 +1348,9 @@ def _mk(key, flavour=flavour): keys = [("inputs", "run_expected", "")] if inputs2: keys.append(("inputs2", "run_expected2", "2")) + # text-generation tests multi-turn conversation as 3rd inputs + if "inputs3" in data: + keys.append(("inputs3", "run_expected3", "3")) for k_input, k_expected, suffix in keys: # make_feeds if verbose: @@ -1739,23 +1745,24 @@ def call_torch_export_custom( kws["target_opset"] = opset if output_names: kws["output_names"] = output_names - - epo, opt_stats = _quiet_or_not_quiet( - quiet, - "export_export_onnx_c", - summary, - data, - ( - lambda m=model, args=args, kwargs=kwargs, kws=kws: ( - to_onnx( - model, - args, - kwargs=kwargs, - **kws, + # anti-specializing 0/1 during torch.export.export + with torch.fx.experimental._config.patch(backed_size_oblivious=True): + epo, opt_stats = _quiet_or_not_quiet( + quiet, + "export_export_onnx_c", + summary, + data, + ( + lambda m=model, args=args, kwargs=kwargs, kws=kws: ( + to_onnx( + model, + args, + kwargs=kwargs, + **kws, + ) ) - ) - ), - ) + ), + ) if "ERR_export_onnx_c" in summary: return summary, data