Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def linkcode_resolve(domain, info):
("py:class", "transformers.cache_utils.SlidingWindowCache"),
("py:class", "transformers.configuration_utils.PretrainedConfig"),
("py:class", "transformers.modeling_outputs.BaseModelOutput"),
("py:class", "transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding"),
("py:func", "torch.export._draft_export.draft_export"),
("py:func", "torch._export.tools.report_exportability"),
("py:func", "torch.utils._pytree.register_pytree_node"),
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_torch_models/test_validate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_validate_microsoft_phi4_reasoning(self):
patch=True,
rewrite=True,
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
dump_folder="dump_test_validate_model_custom",
dump_folder="dump_test/validate_microsoft_phi4_reasoning",
)
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-5)
self.assertIn("onnx_filename", data)
Expand Down
41 changes: 33 additions & 8 deletions _unittests/ut_torch_models/test_validate_whole_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import unittest
import packaging.version as pv
import onnx
import torch
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
Expand Down Expand Up @@ -63,7 +64,7 @@ def test_validate_model_export(self):
do_run=True,
verbose=10,
exporter="export-nostrict",
dump_folder="dump_test_validate_model_export",
dump_folder="dump_test/validate_model_export",
patch=True,
)
self.assertIsInstance(summary, dict)
Expand All @@ -79,7 +80,7 @@ def test_validate_model_onnx_dynamo_ir(self):
do_run=True,
verbose=10,
exporter="onnx-dynamo",
dump_folder="dump_test_validate_model_onnx_dynamo",
dump_folder="dump_test/validate_model_onnx_dynamo_ir",
patch=True,
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
optimization="ir",
Expand All @@ -104,7 +105,7 @@ def test_validate_model_onnx_dynamo_os_ort(self):
do_run=True,
verbose=10,
exporter="onnx-dynamo",
dump_folder="dump_test_validate_model_onnx_dynamo",
dump_folder="dump_test/validate_model_onnx_dynamo_os_ort",
patch=True,
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
optimization="os_ort",
Expand All @@ -126,7 +127,7 @@ def test_validate_model_custom_os_ort(self):
do_run=True,
verbose=10,
exporter="custom",
dump_folder="dump_validate_model_custom_os_ort",
dump_folder="dump_test/validate_model_custom_os_ort",
patch=True,
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
optimization="default+os_ort",
Expand All @@ -148,7 +149,7 @@ def test_validate_model_custom(self):
do_run=True,
verbose=10,
exporter="custom",
dump_folder="dump_test_validate_model_custom",
dump_folder="dump_test/validate_model_custom_tiny_llm",
patch=True,
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
optimization="default",
Expand Down Expand Up @@ -177,7 +178,7 @@ def test_validate_model_custom_torch(self):
do_run=True,
verbose=10,
exporter="custom-noinline",
dump_folder="dump_test_validate_model_custom_torch",
dump_folder="dump_test/validate_model_custom_torch",
patch=True,
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
optimization="default",
Expand Down Expand Up @@ -221,7 +222,7 @@ def test_validate_model_modelbuilder(self):
do_run=True,
verbose=10,
exporter="modelbuilder",
dump_folder="dump_test_validate_model_modelbuilder",
dump_folder="dump_test/validate_model_modelbuilder",
)
self.assertIsInstance(summary, dict)
self.assertIsInstance(data, dict)
Expand All @@ -240,7 +241,7 @@ def test_validate_model_vit_model(self):
do_run=True,
verbose=10,
exporter="onnx-dynamo",
dump_folder="dump_test_validate_model_onnx_dynamo",
dump_folder="dump_test/validate_model_vit_model",
inputs2=True,
)
self.assertIsInstance(summary, dict)
Expand All @@ -254,6 +255,30 @@ def test_validate_model_vit_model(self):
onnx_filename = data["onnx_filename"]
self.assertExists(onnx_filename)

@requires_torch("2.7")
@hide_stdout()
@ignore_warnings(FutureWarning)
@requires_transformers("4.51")
def test_validate_phi35_mini_instruct(self):
mid = "microsoft/Phi-3.5-mini-instruct"
summary, data = validate_model(
mid,
do_run=True,
verbose=10,
exporter="custom",
dump_folder="dump_test/validate_phi35_mini_instruct",
inputs2=True,
patch=True,
rewrite=True,
# model_options={"rope_scaling": {"rope_type": "dynamic", "factor": 10.0}},
)
self.assertIsInstance(summary, dict)
self.assertIsInstance(data, dict)
onnx_filename = data["onnx_filename"]
onx = onnx.load(onnx_filename)
op_types = set(n.op_type for n in onx.graph.node)
self.assertIn("If", op_types)


if __name__ == "__main__":
unittest.main(verbosity=2)
15 changes: 12 additions & 3 deletions onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
import textwrap
import onnx
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional, Union
from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction
from textwrap import dedent

Expand Down Expand Up @@ -291,6 +291,14 @@ def _cmd_config(argv: List[Any]):
print(f"task: {task_from_id(args.mid)}")


def _parse_json(value: str) -> Union[str, Dict[str, Any]]:
assert isinstance(value, str), f"value should be string but value={value!r}"
if value and value[0] == "{" and value[-1] == "}":
# a dictionary
return json.loads(value.replace("'", '"'))
return value


class _ParseDict(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
d = getattr(namespace, self.dest) or {}
Expand All @@ -314,7 +322,7 @@ def __call__(self, parser, namespace, values, option_string=None):
continue
except (TypeError, ValueError):
pass
d[key] = value
d[key] = _parse_json(value)

setattr(namespace, self.dest, d)

Expand Down Expand Up @@ -430,7 +438,8 @@ def get_parser_validate() -> ArgumentParser:
metavar="KEY=VALUE",
nargs="*",
help="Additional model options, use to change some parameters of the model, "
"example: --mop attn_implementation=eager",
"example: ``--mop attn_implementation=eager`` or "
"``--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"``",
action=_ParseDict,
)
parser.add_argument(
Expand Down
12 changes: 8 additions & 4 deletions onnx_diagnostic/helpers/config_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@ def update_config(config: Any, mkwargs: Dict[str, Any]):
config._attn_implementation_autoset = False
continue
if isinstance(v, dict):
assert hasattr(
config, k
), f"missing attribute {k!r} in config={config}, cannot update it with {v}"
update_config(getattr(config, k), v)
if not hasattr(config, k) or getattr(config, k) is None:
setattr(config, k, v)
continue
existing = getattr(config, k)
if type(existing) is dict:
existing.update(v)
else:
update_config(getattr(config, k), v)
continue
setattr(config, k, v)

Expand Down
103 changes: 78 additions & 25 deletions onnx_diagnostic/torch_export_patches/onnx_export_errors.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,50 @@
import functools
import importlib
import contextlib
from typing import Any, Callable, Dict, List, Optional
import re
from typing import Any, Callable, Dict, List, Optional, Tuple
from .onnx_export_serialization import (
register_cache_serialization,
unregister_cache_serialization,
)
from .patches import patch_transformers as patch_transformers_list


def get_function(name: str) -> Tuple[type, Callable]:
"""Returns the module and the function based on its name."""
spl = name.split(".")
module_name = ".".join(spl[:-1])
fname = spl[-1]
mod = importlib.import_module(module_name)
return mod, getattr(mod, fname)


@functools.lru_cache
def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
"""Returns the list of patches to make for a specific module."""
to_patch = []
for k in dir(mod):
if k.startswith("patched_"):
v = getattr(mod, k)
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
to_patch.append(v)
else:
# a function
doc = v.__doc__.lstrip()
if doc.startswith("manual patch"):
continue
reg = re.compile("[[]patch:([a-z_A-Z.]+)[]]")
fall = reg.findall(doc)
assert (
len(fall) == 1
), f"Unable to find patching information for {v} in \n{doc}"
fmod, f = get_function(fall[0])
to_patch.append({"module": fmod, "function": f, "patch": v})

name = mod.__name__
return name, to_patch


def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
"""
Applies all patches defined in classes prefixed by ``patched_``
Expand All @@ -23,16 +61,21 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
to_patch = mod
name = "list"
else:
to_patch = []
for k in dir(mod):
if k.startswith("patched_"):
v = getattr(mod, k)
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
to_patch.append(v)
name = mod.__name__
name, to_patch = get_patches(mod, verbose)

res = {}
for cls in to_patch:
if isinstance(cls, dict):
# a function
keep = {}
original = cls["module"]
f = cls["function"]
res[f] = f
if verbose:
print(f"[patch_module_or_classes] function: {original.__name__}.{f.__name__}")
setattr(original, f.__name__, cls["patch"])
continue

original = cls._PATCHED_CLASS_
methods = cls._PATCHES_
if verbose:
Expand All @@ -57,26 +100,36 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo
to_patch = mod
name = "list"
else:
to_patch = []
for k in dir(mod):
if k.startswith("patched_"):
v = getattr(mod, k)
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
to_patch.append(v)
name = mod.__name__
set_patch = set(to_patch)
name, to_patch = get_patches(mod, verbose)

set_patch_cls = {i for i in to_patch if not isinstance(i, dict)}
dict_patch_fct = {i["function"]: i for i in to_patch if isinstance(i, dict)}

for cls, methods in info.items():
assert cls in set_patch, f"No patch registered for {cls} in {mod} (found {set_patch})"
if cls in set_patch_cls:
if verbose:
print(
f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}"
)
original = cls._PATCHED_CLASS_
for n, v in methods.items():
if v is None:
# The method did not exist. We remove it.
delattr(original, n)
else:
setattr(original, n, v)
continue
assert cls in dict_patch_fct, (
f"No patch registered for {cls} in {mod} "
f"(found {set_patch_cls} and {set(dict_patch_fct)})"
)
patch = dict_patch_fct[cls]
if verbose:
print(f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}")
original = cls._PATCHED_CLASS_
for n, v in methods.items():
if v is None:
# The method did not exist. We remove it.
delattr(original, n)
else:
setattr(original, n, v)
print(
f"[unpatch_module_or_classes] function "
f"{patch['module'].__name__}.{cls.__name__}"
)
setattr(patch["module"], cls.__name__, patch["function"])


@contextlib.contextmanager
Expand Down
Loading
Loading