Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions _doc/cmds/validate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,19 @@ Let's export with ONNX this time and checks for discrepancies.

main("validate -m arnir0/Tiny-LLM --run -v 1 --export onnx-dynamo -o dump_models --patch --opt ir".split())

Run onnxruntime fusions
+++++++++++++++++++++++

This option runs `transformers optimizations <https://onnxruntime.ai/docs/performance/transformers-optimization.html>`_
implemented in :epkg:`onnxruntime`. The list of supported ``model_type`` can be found in the documentation
of function :func:`onnx_diagnostic.torch_models.test_helper.run_ort_fusion`.

.. code-block::

python -m onnx_diagnostic validate -m arnir0/Tiny-LLM --run -v 1 --export onnx-dynamo -o dump_models --patch --opt ir --ortfusiontype ALL

.. runpython::

from onnx_diagnostic._command_lines_parser import main

main("validate -m arnir0/Tiny-LLM --run -v 1 --export onnx-dynamo -o dump_models --patch --opt ir --ortfusiontype ALL".split())
70 changes: 69 additions & 1 deletion _unittests/ut_export/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.export import ModelInputs, CoupleInputsDynamicShapes
Expand Down Expand Up @@ -679,6 +679,42 @@ def test_couple_input_ds_replace_string(self):
).replace_string_by(value="DYN"),
)

def test_couple_input_ds_replace_by_string(self):
T3x1 = torch.rand((3, 1))
T3x4 = torch.rand((3, 4))
T5x1 = torch.rand((5, 1))
args = (T5x1,)
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
ds_batch = {0: "batch"}
ds_batch_seq = {0: "batch", 1: "seq"}
ds = {"X": ds_batch, "A": ds_batch, "B": (ds_batch, ds_batch)}
Cls = CoupleInputsDynamicShapes
res = Cls(
args,
kwargs,
ds,
args_names=["X"],
).replace_by_string()
self.assertEqual(ds, res)

ds_batch = {0: torch.export.Dim("batch")}
ds_batch_seq = {0: torch.export.Dim("batch"), 1: torch.export.Dim.DYNAMIC}
ds = {"X": ds_batch, "A": ds_batch_seq, "B": (ds_batch_seq, ds_batch_seq)}
res = Cls(
args,
kwargs,
ds,
args_names=["X"],
).replace_by_string()
self.assertEqual(
{
"X": {0: "batch"},
"A": {0: "batch", 1: "Dim1"},
"B": ({0: "batch", 1: "Dim2"}, {0: "batch", 1: "Dim3"}),
},
res,
)

def test_couple_input_ds_change_dynamic_dimensions(self):
T257 = torch.arange(2 * 5 * 7).reshape((2, 5, 7))
T29 = torch.arange(2 * 9).reshape((2, 9))
Expand All @@ -703,6 +739,38 @@ def test_couple_input_ds_change_dynamic_dimensions_fixed(self):
self.assertEqual((1, 5, 8), new_input["A"].shape)
self.assertEqual((1, 50), new_input["B"].shape)

@requires_transformers("4.51")
def test_dynamic_cache_replace_by_string(self):
n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
cache = make_dynamic_cache(
[
(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
for i in range(n_layers)
]
)

DYN = torch.export.Dim.DYNAMIC
ds = {
"cache": [
[{0: DYN, 1: DYN}, {0: DYN, 1: DYN}],
[{0: DYN, 1: DYN}, {0: DYN, 1: DYN}],
]
}
inst = CoupleInputsDynamicShapes((), dict(cache=cache), ds)
as_string = inst.replace_by_string()
self.assertEqual(
{
"cache": [
{0: "Dim0", 1: "Dim1"},
{0: "Dim2", 1: "Dim3"},
{0: "Dim4", 1: "Dim5"},
{0: "Dim6", 1: "Dim7"},
]
},
as_string,
)


if __name__ == "__main__":
unittest.main(verbosity=2)
3 changes: 2 additions & 1 deletion onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,8 @@ def get_parser_validate() -> ArgumentParser:
"--ortfusiontype",
required=False,
help="applies onnxruntime fusion, this parameter should contain the "
"model type or multiple values separated by |",
"model type or multiple values separated by `|`. `ALL` can be used "
"to run them all",
)
parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
parser.add_argument("--dtype", help="changes dtype if necessary")
Expand Down
81 changes: 79 additions & 2 deletions onnx_diagnostic/export/dynamic_shapes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import numpy as np
import torch
from ..helpers import string_type
Expand All @@ -8,6 +8,30 @@
DYNAMIC_SHAPES = Tuple[Tuple[Any, ...], Dict[str, Any]]


def flatten_dynamic_shapes(ds: Any) -> Any:
"""Flattens the dynamic shapes."""
if isinstance(ds, list):
return _flat_list([flatten_dynamic_shapes(t) for t in ds])
if isinstance(ds, tuple):
return tuple(_flat_list([flatten_dynamic_shapes(t) for t in ds]))
if isinstance(ds, dict):
if all(isinstance(i, int) for i in ds):
# That's a dynamic shape
return ds
return _flat_list([flatten_dynamic_shapes(t) for t in ds.values()])
raise AssertionError(f"Not implemented for {type(ds)}: {ds}")


def _flat_list(li: List[Any]) -> List[Dict[int, str]]:
res = []
for t in li:
if isinstance(t, dict):
res.append(t)
else:
res.extend(t)
return res


class CoupleInputsDynamicShapes:
"""
Pair inputs / dynamic shapes.
Expand Down Expand Up @@ -76,7 +100,7 @@ def _replace_string_dim_tensor(cls, inputs, ds, value=None):
assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
f"a dictionary is expected to specify a dimension dimension"
f"a dictionary is expected to specify a dimension"
)
if value is None:
value = torch.export.Dim.DYNAMIC
Expand All @@ -86,6 +110,56 @@ def _replace_string_dim_tensor(cls, inputs, ds, value=None):
new_ds[i] = value
return new_ds

def replace_by_string(self):
"""
Replaces dimensions by strings.

Example:

.. runpython::
:showcode:

import torch
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes

Dim = torch.export.Dim
T3x1 = torch.rand((3, 1))
T3x4 = torch.rand((3, 4))
ds_batch = {0: Dim("batch")}
ds_batch_seq = {0: Dim("batch"), 1: Dim("seq")}
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
print(CoupleInputsDynamicShapes((), kwargs, ds).replace_by_string())
"""
unique = set()
return self._generic_walker(
lambda inputs, ds, unique=unique: self._replace_dim_tensor_by_string(
inputs, ds, unique=unique
)
)

@classmethod
def _replace_dim_tensor_by_string(cls, inputs, ds, unique: Set[str]):
assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
f"a dictionary is expected to specify a dimension"
)
new_ds = ds.copy()
for i, v in ds.items():
if isinstance(v, str):
unique.add(v)
new_ds[i] = v
elif v in (torch.export.Dim.DYNAMIC, torch.export.Dim.AUTO):
name = f"Dim{len(unique)}"
new_ds[i] = name
unique.add(name)
else:
name = v.__name__
unique.add(name)
new_ds[i] = name
return new_ds

def invalid_dimensions_for_export(self):
"""
Tells if the inputs are valid based on the dynamic shapes definition.
Expand Down Expand Up @@ -252,6 +326,9 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds):
f"map this class with the given dynamic shapes."
)
flat, _spec = torch.utils._pytree.tree_flatten(inputs)
if all(isinstance(t, torch.Tensor) for t in flat):
# We need to flatten dynamic shapes as well
ds = flatten_dynamic_shapes(ds)
return cls._generic_walker_step(processor, flat, ds)

class ChangeDimensionProcessor:
Expand Down
64 changes: 47 additions & 17 deletions onnx_diagnostic/torch_models/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import onnx
import torch
from ..export import CoupleInputsDynamicShapes
from ..helpers import max_diff, string_type, string_diff
from ..helpers.helper import flatten_object
from ..helpers.rt_helper import make_feeds
Expand Down Expand Up @@ -506,7 +507,12 @@ def validate_model(
), f"Missing attribute num_attention_heads in configuration {config}"
num_attention_heads = config.num_attention_heads

model_types = ortfusiontype.split("|")
if ortfusiontype == "ALL":
from onnxruntime.transformers.optimizer import MODEL_TYPES

model_types = sorted(MODEL_TYPES)
else:
model_types = ortfusiontype.split("|")
for model_type in model_types:
flavour = f"ort{model_type}"
summary[f"version_{flavour}_hidden_size"] = hidden_size
Expand All @@ -517,13 +523,15 @@ def validate_model(
print(f"[validate_model] run onnxruntime fusion for {model_type!r}")
input_filename = data["onnx_filename"]
output_path = f"{os.path.splitext(input_filename)[0]}.ort.{model_type}.onnx"
run_ort_fusion(
ort_sum, ort_data = run_ort_fusion(
input_filename,
output_path,
model_type=model_type,
num_attention_heads=num_attention_heads,
hidden_size=hidden_size,
)
summary.update(ort_sum)
data.update(ort_data)
data[f"onnx_filename_{flavour}"] = output_path
duration = time.perf_counter() - begin
summary[f"time_ortfusion_{flavour}"] = duration
Expand Down Expand Up @@ -590,7 +598,7 @@ def call_exporter(
optimization=optimization,
)
return summary, data
if exporter.startswith("custom-"):
if exporter == "custom" or exporter.startswith("custom"):
# torch export
summary, data = call_torch_export_custom(
exporter=exporter,
Expand Down Expand Up @@ -746,7 +754,7 @@ def validate_onnx_model(
def _mk(key):
return f"{key}_{flavour}" if flavour else key

summary = {}
summary: Dict[str, Any] = {}
flat_inputs = flatten_object(data["inputs"], drop_keys=True)
d = flat_inputs[0].get_device()
providers = (
Expand All @@ -758,6 +766,9 @@ def _mk(key):

if input_data_key in data:
source = data[input_data_key]
if not os.path.exists(source):
summary[_mk("ERR_onnx_missing")] = f"FileNotFoundError({source!r})"
return summary, data
summary[input_data_key] = source
summary[_mk("onnx_size")] = os.stat(source).st_size
else:
Expand Down Expand Up @@ -866,7 +877,7 @@ def call_torch_export_onnx(
assert "model" in data, f"model is missing from data: {sorted(data)}"
assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
summary: Dict[str, Union[str, int, float]] = {}
dynamo = "nostrict" not in exporter
dynamo = "dynamo" in exporter
args, kwargs = split_args_kwargs(data["inputs_export"])
ds = data.get("dynamic_shapes", None)
if verbose:
Expand All @@ -884,15 +895,23 @@ def call_torch_export_onnx(
summary["export_args"] = string_type(args, with_shape=True)
summary["export_kwargs"] = string_type(kwargs, with_shape=True)

export_export_kwargs = (
dict(dynamo=True, dynamic_shapes=ds)
if dynamo
else dict(
dynamo=False,
dynamic_axes=CoupleInputsDynamicShapes(args, kwargs, ds).replace_by_string(),
)
)

begin = time.perf_counter()
if quiet:
try:
epo = torch.onnx.export(
data["model"],
args,
kwargs=kwargs,
dynamic_shapes=ds,
dynamo=dynamo,
**export_export_kwargs,
)
except Exception as e:
summary["ERR_export_export"] = str(e)
Expand All @@ -904,8 +923,7 @@ def call_torch_export_onnx(
data["model"],
args,
kwargs=kwargs,
dynamic_shapes=ds,
dynamo=dynamo,
**export_export_kwargs,
)

summary["time_export_export"] = time.perf_counter() - begin
Expand Down Expand Up @@ -966,6 +984,7 @@ def call_torch_export_custom(
None,
}, f"unexpected value for optimization={optimization}"
assert exporter in {
"custom",
"custom-strict",
"custom-strict-dec",
"custom-strict-all",
Expand Down Expand Up @@ -1155,14 +1174,24 @@ def run_ort_fusion(
f"[run_ort_fusion] starts optimization for "
f"model_type={model_type!r} with {n_nodes} nodes"
)
new_onx = optimize_by_fusion(
onx,
model_type=model_type,
num_heads=num_attention_heads,
hidden_size=hidden_size,
optimization_options=opts,
)
duration = {time.perf_counter() - begin}
try:
new_onx = optimize_by_fusion(
onx,
model_type=model_type,
num_heads=num_attention_heads,
hidden_size=hidden_size,
optimization_options=opts,
)
except Exception as e:
duration = time.perf_counter() - begin
if verbose:
print(f"[run_ort_fusion] failed in {duration} for model_type={model_type!r}")
return {
f"ERR_opt_ort_{model_type}": str(e),
f"opt_ort_{model_type}_duration": duration,
}, {}

duration = time.perf_counter() - begin
delta = len(new_onx.model.graph.node)
if verbose:
print(f"[run_ort_fusion] done in {duration} with {delta} nodes")
Expand All @@ -1175,6 +1204,7 @@ def run_ort_fusion(
return {
f"opt_ort_{model_type}_n_nodes1": n_nodes,
f"opt_ort_{model_type}_n_nodes2": delta,
f"opt_ort_{model_type}_delta_node": delta - n_nodes,
f"opt_ort_{model_type}_duration": duration,
f"opt_ort_{model_type}_duration_save": d,
}, {f"opt_ort_{model_type}": output_path}
Loading