diff --git a/_unittests/ut_torch_models/test_validate_whole_models.py b/_unittests/ut_torch_models/test_validate_whole_models.py index 4fd63f02..9f53ef73 100644 --- a/_unittests/ut_torch_models/test_validate_whole_models.py +++ b/_unittests/ut_torch_models/test_validate_whole_models.py @@ -195,7 +195,7 @@ def test_k_filter_inputs(self): @ignore_warnings(FutureWarning) @requires_transformers("4.51") def test_l_validate_model_modelbuilder(self): - mid = "arnir0/Tiny-LLM" + mid = "meta-llama/Llama-2-7b-hf" summary, data = validate_model( mid, do_run=True, @@ -205,7 +205,7 @@ def test_l_validate_model_modelbuilder(self): ) self.assertIsInstance(summary, dict) self.assertIsInstance(data, dict) - self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4) + self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-2) onnx_filename = data["onnx_filename"] self.assertExists(onnx_filename) diff --git a/onnx_diagnostic/helpers/rt_helper.py b/onnx_diagnostic/helpers/rt_helper.py index ebd6e157..a3b06898 100644 --- a/onnx_diagnostic/helpers/rt_helper.py +++ b/onnx_diagnostic/helpers/rt_helper.py @@ -3,7 +3,6 @@ import onnx import torch from .helper import string_type, flatten_object -from .onnx_helper import dtype_to_tensor_dtype from .cache_helper import is_cache_dynamic_registered @@ -23,6 +22,7 @@ def make_feeds( use_numpy: bool = False, copy: bool = False, check_flatten: bool = True, + is_modelbuilder: bool = False, ) -> Dict[str, Union[torch.Tensor, np.ndarray]]: """ Serializes the inputs to produce feeds expected @@ -35,10 +35,15 @@ def make_feeds( by ``OrtValue`` :param check_flatten: if True, checks the ``torch.utils._pytree.tree_flatten`` returns the same number of outputs + :param is_modelbuilder: if True, the exporter is ModelBuilder, and we need to reorder + the past_key_values inputs to match the expected order, and get rid of position_ids. :return: feeds dictionary """ - # position_ids is a special case because ModelBuilder does not usually use it. - # We use types to detect the best inputs. + # NOTE: position_ids is a special case because ModelBuilder does not usually use it, + # because it's fued into rotary embedding in GQA. + if isinstance(inputs, dict): + inputs.pop("position_ids", None) # Ensure 'position_ids' absent before removing. + flat = flatten_object(inputs, drop_keys=True) assert ( not check_flatten @@ -76,39 +81,6 @@ def make_feeds( f"\n-- inputs={string_type(inputs, with_shape=True)}" f"\n-- names={names}" ) - if len(names) < len(flat) and ( - isinstance(proto, onnx.ModelProto) or hasattr(proto, "get_inputs") - ): - - typed_names = ( - [(i.name, i.type.tensor_type.elem_type) for i in proto.graph.input] - if isinstance(proto, onnx.ModelProto) - else [(i.name, name_type_to_onnx_dtype(i.type)) for i in proto.get_inputs()] - ) - - new_flat = [] - pos = 0 - for _name, dtype in typed_names: - assert isinstance( - dtype, int - ), f"Unexpected value for dtype={dtype!r}, type(proto)={type(proto)}" - itype = dtype_to_tensor_dtype(flat[pos].dtype) - while dtype != itype: - pos += 1 - if pos >= len(flat): - break - itype = dtype_to_tensor_dtype(flat[pos].dtype) - if pos >= len(flat): - break - new_flat.append(flat[pos]) - pos += 1 - assert len(new_flat) == len(names), ( - f"Unable to align expected input {names} with the given input, " - f"type(proto)={type(proto)}" - f"\n-- inputs: {string_type(inputs, with_shape=True)}" - f"\n-- typed_names: {typed_names}" - ) - flat = new_flat if copy: flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat] @@ -122,4 +94,45 @@ def make_feeds( elif isinstance(i, float): i = np.array(i, dtype=np.float32) new_flat.append(i) + + # NOTE: model builder has a different order for past_key_values + # we need to reorder them to match the expected order + if is_modelbuilder: + # We assume that if "past_key_values" is in the names when it's + # modelbuilder + non_past_kv_input_names = [n for n in names if "past_key_values" not in n] + past_kv_names = [n for n in names if "past_key_values" in n] + reorder_past_kv_names = reorder_modelbuilder_cache_to_torch(past_kv_names) + names = non_past_kv_input_names + reorder_past_kv_names return dict(zip(names, new_flat)) + + +def reorder_modelbuilder_cache_to_torch(past_kv: List[Any]) -> List[Any]: + """ + Reorders the past_kvs for ModelBuilder to match the expected order + by PyTorch exported models. + + NOTE: This function can take either the names or the actual tensors + as long as they are in a list. + + Conceptually, + + From: + [past_key_values.0.key, past_key_values.0.value, + past_key_values.1.key, past_key_values.1.value, ...] + To: + [past_key_values.0.key, past_key_values.1.key, + ..., past_key_values.0.value, past_key_values.1.value, ...] + + :param flat: list of flattened inputs + :return: reordered list of flattened inputs + """ + total_len = len(past_kv) + if total_len % 2 != 0: + raise ValueError("The length of past_key_values should be even.") + keys = [] + values = [] + for i in range(0, total_len, 2): + keys.append(past_kv[i]) + values.append(past_kv[i + 1]) + return keys + values diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index 084194fe..0e2ff083 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -11,7 +11,7 @@ from ..export import CoupleInputsDynamicShapes from ..helpers import max_diff, string_type, string_diff from ..helpers.helper import flatten_object -from ..helpers.rt_helper import make_feeds +from ..helpers.rt_helper import make_feeds, reorder_modelbuilder_cache_to_torch from ..helpers.torch_helper import to_any, torch_deepcopy from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes from ..tasks import random_input_kwargs @@ -536,6 +536,11 @@ def validate_model( if verbose: print(f"[validate_model] batch=1 --> {string_type(data[k], with_shape=True)}") + # modelbuilder needs different treatments sometimes, so + # we mark it for later usage. + # for example, it has different past_kv ordering than + # flattened CacheObject + data["exporter"] = exporter data["input_options"] = iop data["model_options"] = mop data["model_dump_folder"] = dump_folder @@ -1322,7 +1327,13 @@ def _mk(key, flavour=flavour): print( f"[validate_onnx_model] inputs={string_type(data[k_input], with_shape=True)}" ) - feeds = make_feeds(sess, data[k_input], use_numpy=True, check_flatten=False) + feeds = make_feeds( + sess, + data[k_input], + use_numpy=True, + check_flatten=False, + is_modelbuilder=data["exporter"] == "modelbuilder", + ) if verbose: print(f"[validate_onnx_model] ort inputs={string_type(feeds, with_shape=True)}") summary[_mk(f"onnx_ort_inputs{suffix}")] = string_type(feeds, with_shape=True) @@ -1342,6 +1353,13 @@ def _mk(key, flavour=flavour): repeat=repeat, warmup=warmup, ) + # NOTE: modelbuilder has different order on past_kv outputs + if data["exporter"] == "modelbuilder": + logits = got[:1] + past_key_values = got[1:] + reorder_past_key_values = reorder_modelbuilder_cache_to_torch(past_key_values) + got = logits + reorder_past_key_values + if f"ERR_{_mk(f'time_onnx_ort_run{suffix}')}" in summary: return summary, data