Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions _unittests/ut_torch_models/test_validate_whole_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def test_k_filter_inputs(self):
@ignore_warnings(FutureWarning)
@requires_transformers("4.51")
def test_l_validate_model_modelbuilder(self):
mid = "arnir0/Tiny-LLM"
mid = "meta-llama/Llama-2-7b-hf"
summary, data = validate_model(
mid,
do_run=True,
Expand All @@ -205,7 +205,7 @@ def test_l_validate_model_modelbuilder(self):
)
self.assertIsInstance(summary, dict)
self.assertIsInstance(data, dict)
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4)
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-2)
onnx_filename = data["onnx_filename"]
self.assertExists(onnx_filename)

Expand Down
85 changes: 49 additions & 36 deletions onnx_diagnostic/helpers/rt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import onnx
import torch
from .helper import string_type, flatten_object
from .onnx_helper import dtype_to_tensor_dtype
from .cache_helper import is_cache_dynamic_registered


Expand All @@ -23,6 +22,7 @@ def make_feeds(
use_numpy: bool = False,
copy: bool = False,
check_flatten: bool = True,
is_modelbuilder: bool = False,
) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
"""
Serializes the inputs to produce feeds expected
Expand All @@ -35,10 +35,15 @@ def make_feeds(
by ``OrtValue``
:param check_flatten: if True, checks the ``torch.utils._pytree.tree_flatten``
returns the same number of outputs
:param is_modelbuilder: if True, the exporter is ModelBuilder, and we need to reorder
the past_key_values inputs to match the expected order, and get rid of position_ids.
:return: feeds dictionary
"""
# position_ids is a special case because ModelBuilder does not usually use it.
# We use types to detect the best inputs.
# NOTE: position_ids is a special case because ModelBuilder does not usually use it,
# because it's fued into rotary embedding in GQA.
if isinstance(inputs, dict):
inputs.pop("position_ids", None) # Ensure 'position_ids' absent before removing.

flat = flatten_object(inputs, drop_keys=True)
assert (
not check_flatten
Expand Down Expand Up @@ -76,39 +81,6 @@ def make_feeds(
f"\n-- inputs={string_type(inputs, with_shape=True)}"
f"\n-- names={names}"
)
if len(names) < len(flat) and (
isinstance(proto, onnx.ModelProto) or hasattr(proto, "get_inputs")
):

typed_names = (
[(i.name, i.type.tensor_type.elem_type) for i in proto.graph.input]
if isinstance(proto, onnx.ModelProto)
else [(i.name, name_type_to_onnx_dtype(i.type)) for i in proto.get_inputs()]
)

new_flat = []
pos = 0
for _name, dtype in typed_names:
assert isinstance(
dtype, int
), f"Unexpected value for dtype={dtype!r}, type(proto)={type(proto)}"
itype = dtype_to_tensor_dtype(flat[pos].dtype)
while dtype != itype:
pos += 1
if pos >= len(flat):
break
itype = dtype_to_tensor_dtype(flat[pos].dtype)
if pos >= len(flat):
break
new_flat.append(flat[pos])
pos += 1
assert len(new_flat) == len(names), (
f"Unable to align expected input {names} with the given input, "
f"type(proto)={type(proto)}"
f"\n-- inputs: {string_type(inputs, with_shape=True)}"
f"\n-- typed_names: {typed_names}"
)
flat = new_flat

if copy:
flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat]
Expand All @@ -122,4 +94,45 @@ def make_feeds(
elif isinstance(i, float):
i = np.array(i, dtype=np.float32)
new_flat.append(i)

# NOTE: model builder has a different order for past_key_values
# we need to reorder them to match the expected order
if is_modelbuilder:
# We assume that if "past_key_values" is in the names when it's
# modelbuilder
non_past_kv_input_names = [n for n in names if "past_key_values" not in n]
past_kv_names = [n for n in names if "past_key_values" in n]
reorder_past_kv_names = reorder_modelbuilder_cache_to_torch(past_kv_names)
names = non_past_kv_input_names + reorder_past_kv_names
return dict(zip(names, new_flat))


def reorder_modelbuilder_cache_to_torch(past_kv: List[Any]) -> List[Any]:
"""
Reorders the past_kvs for ModelBuilder to match the expected order
by PyTorch exported models.

NOTE: This function can take either the names or the actual tensors
as long as they are in a list.

Conceptually,

From:
[past_key_values.0.key, past_key_values.0.value,
past_key_values.1.key, past_key_values.1.value, ...]
To:
[past_key_values.0.key, past_key_values.1.key,
..., past_key_values.0.value, past_key_values.1.value, ...]

:param flat: list of flattened inputs
:return: reordered list of flattened inputs
"""
total_len = len(past_kv)
if total_len % 2 != 0:
raise ValueError("The length of past_key_values should be even.")
keys = []
values = []
for i in range(0, total_len, 2):
keys.append(past_kv[i])
values.append(past_kv[i + 1])
return keys + values
22 changes: 20 additions & 2 deletions onnx_diagnostic/torch_models/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..export import CoupleInputsDynamicShapes
from ..helpers import max_diff, string_type, string_diff
from ..helpers.helper import flatten_object
from ..helpers.rt_helper import make_feeds
from ..helpers.rt_helper import make_feeds, reorder_modelbuilder_cache_to_torch
from ..helpers.torch_helper import to_any, torch_deepcopy
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
from ..tasks import random_input_kwargs
Expand Down Expand Up @@ -536,6 +536,11 @@ def validate_model(
if verbose:
print(f"[validate_model] batch=1 --> {string_type(data[k], with_shape=True)}")

# modelbuilder needs different treatments sometimes, so
# we mark it for later usage.
# for example, it has different past_kv ordering than
# flattened CacheObject
data["exporter"] = exporter
data["input_options"] = iop
data["model_options"] = mop
data["model_dump_folder"] = dump_folder
Expand Down Expand Up @@ -1322,7 +1327,13 @@ def _mk(key, flavour=flavour):
print(
f"[validate_onnx_model] inputs={string_type(data[k_input], with_shape=True)}"
)
feeds = make_feeds(sess, data[k_input], use_numpy=True, check_flatten=False)
feeds = make_feeds(
sess,
data[k_input],
use_numpy=True,
check_flatten=False,
is_modelbuilder=data["exporter"] == "modelbuilder",
)
if verbose:
print(f"[validate_onnx_model] ort inputs={string_type(feeds, with_shape=True)}")
summary[_mk(f"onnx_ort_inputs{suffix}")] = string_type(feeds, with_shape=True)
Expand All @@ -1342,6 +1353,13 @@ def _mk(key, flavour=flavour):
repeat=repeat,
warmup=warmup,
)
# NOTE: modelbuilder has different order on past_kv outputs
if data["exporter"] == "modelbuilder":
logits = got[:1]
past_key_values = got[1:]
reorder_past_key_values = reorder_modelbuilder_cache_to_torch(past_key_values)
got = logits + reorder_past_key_values

if f"ERR_{_mk(f'time_onnx_ort_run{suffix}')}" in summary:
return summary, data

Expand Down
Loading