Skip to content

Commit 6bca0b4

Browse files
committed
patch for _compute_dynamic_ntk_parameters
1 parent 137b16c commit 6bca0b4

File tree

6 files changed

+354
-61
lines changed

6 files changed

+354
-61
lines changed

_unittests/ut_torch_models/test_validate_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_validate_microsoft_phi4_reasoning(self):
2828
patch=True,
2929
rewrite=True,
3030
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
31-
dump_folder="dump_test_validate_model_custom",
31+
dump_folder="dump_test/validate_microsoft_phi4_reasoning",
3232
)
3333
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-5)
3434
self.assertIn("onnx_filename", data)

_unittests/ut_torch_models/test_validate_whole_models.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import unittest
33
import packaging.version as pv
4+
import onnx
45
import torch
56
from onnx_diagnostic.ext_test_case import (
67
ExtTestCase,
@@ -63,7 +64,7 @@ def test_validate_model_export(self):
6364
do_run=True,
6465
verbose=10,
6566
exporter="export-nostrict",
66-
dump_folder="dump_test_validate_model_export",
67+
dump_folder="dump_test/validate_model_export",
6768
patch=True,
6869
)
6970
self.assertIsInstance(summary, dict)
@@ -79,7 +80,7 @@ def test_validate_model_onnx_dynamo_ir(self):
7980
do_run=True,
8081
verbose=10,
8182
exporter="onnx-dynamo",
82-
dump_folder="dump_test_validate_model_onnx_dynamo",
83+
dump_folder="dump_test/validate_model_onnx_dynamo_ir",
8384
patch=True,
8485
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
8586
optimization="ir",
@@ -104,7 +105,7 @@ def test_validate_model_onnx_dynamo_os_ort(self):
104105
do_run=True,
105106
verbose=10,
106107
exporter="onnx-dynamo",
107-
dump_folder="dump_test_validate_model_onnx_dynamo",
108+
dump_folder="dump_test/validate_model_onnx_dynamo_os_ort",
108109
patch=True,
109110
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
110111
optimization="os_ort",
@@ -126,7 +127,7 @@ def test_validate_model_custom_os_ort(self):
126127
do_run=True,
127128
verbose=10,
128129
exporter="custom",
129-
dump_folder="dump_validate_model_custom_os_ort",
130+
dump_folder="dump_test/validate_model_custom_os_ort",
130131
patch=True,
131132
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
132133
optimization="default+os_ort",
@@ -148,7 +149,7 @@ def test_validate_model_custom(self):
148149
do_run=True,
149150
verbose=10,
150151
exporter="custom",
151-
dump_folder="dump_test_validate_model_custom",
152+
dump_folder="dump_test/validate_model_custom_tiny_llm",
152153
patch=True,
153154
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
154155
optimization="default",
@@ -177,7 +178,7 @@ def test_validate_model_custom_torch(self):
177178
do_run=True,
178179
verbose=10,
179180
exporter="custom-noinline",
180-
dump_folder="dump_test_validate_model_custom_torch",
181+
dump_folder="dump_test/validate_model_custom_torch",
181182
patch=True,
182183
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
183184
optimization="default",
@@ -221,7 +222,7 @@ def test_validate_model_modelbuilder(self):
221222
do_run=True,
222223
verbose=10,
223224
exporter="modelbuilder",
224-
dump_folder="dump_test_validate_model_modelbuilder",
225+
dump_folder="dump_test/validate_model_modelbuilder",
225226
)
226227
self.assertIsInstance(summary, dict)
227228
self.assertIsInstance(data, dict)
@@ -240,7 +241,7 @@ def test_validate_model_vit_model(self):
240241
do_run=True,
241242
verbose=10,
242243
exporter="onnx-dynamo",
243-
dump_folder="dump_test_validate_model_onnx_dynamo",
244+
dump_folder="dump_test/validate_model_vit_model",
244245
inputs2=True,
245246
)
246247
self.assertIsInstance(summary, dict)
@@ -254,6 +255,30 @@ def test_validate_model_vit_model(self):
254255
onnx_filename = data["onnx_filename"]
255256
self.assertExists(onnx_filename)
256257

258+
@requires_torch("2.7")
259+
@hide_stdout()
260+
@ignore_warnings(FutureWarning)
261+
@requires_transformers("4.51")
262+
def test_validate_phi3_mini_4k_instruct(self):
263+
mid = "microsoft/Phi-3-mini-4k-instruct"
264+
summary, data = validate_model(
265+
mid,
266+
do_run=True,
267+
verbose=10,
268+
exporter="custom",
269+
dump_folder="dump_test/validate_phi3_mini_4k_instruct",
270+
inputs2=True,
271+
patch=True,
272+
rewrite=True,
273+
model_options={"rope_scaling": {"rope_type": "dynamic", "factor": 10.0}},
274+
)
275+
self.assertIsInstance(summary, dict)
276+
self.assertIsInstance(data, dict)
277+
onnx_filename = data["onnx_filename"]
278+
onx = onnx.load(onnx_filename)
279+
op_types = set(n.op_type for n in onx.graph.node)
280+
self.assertIn("If", op_types)
281+
257282

258283
if __name__ == "__main__":
259284
unittest.main(verbosity=2)

onnx_diagnostic/helpers/config_helper.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,14 @@ def update_config(config: Any, mkwargs: Dict[str, Any]):
3434
config._attn_implementation_autoset = False
3535
continue
3636
if isinstance(v, dict):
37-
assert hasattr(
38-
config, k
39-
), f"missing attribute {k!r} in config={config}, cannot update it with {v}"
40-
update_config(getattr(config, k), v)
37+
if not hasattr(config, k) or getattr(config, k) is None:
38+
setattr(config, k, v)
39+
continue
40+
existing = getattr(config, k)
41+
if type(existing) is dict:
42+
existing.update(v)
43+
else:
44+
update_config(getattr(config, k), v)
4145
continue
4246
setattr(config, k, v)
4347

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 82 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,29 @@
1+
import functools
2+
import importlib
13
import contextlib
2-
from typing import Any, Callable, Dict, List, Optional
4+
import re
5+
from typing import Any, Callable, Dict, List, Optional, Tuple
36
from .onnx_export_serialization import (
47
register_cache_serialization,
58
unregister_cache_serialization,
69
)
710
from .patches import patch_transformers as patch_transformers_list
811

912

10-
def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
13+
def get_function(name: str) -> Tuple["module", "function"]: # noqa: F821
1114
"""
12-
Applies all patches defined in classes prefixed by ``patched_``
13-
``cls._PATCHED_CLASS_`` defines the class to patch,
14-
``cls._PATCHES_`` defines the method to patch.
15-
The returns information needs to be sent to :func:`unpatch_module_or_classes`
16-
to revert the changes.
17-
18-
:param mod: module of list of clsses to patch
19-
:param verbose: verbosity
20-
:return: patch info
15+
Returns the module and the function based on its name.
2116
"""
17+
spl = name.split(".")
18+
module_name = ".".join(spl[:-1])
19+
fname = spl[-1]
20+
mod = importlib.import_module(module_name)
21+
return mod, getattr(mod, fname)
22+
23+
24+
@functools.lru_cache
25+
def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
26+
"""Returns the list of patches to make for a specific module."""
2227
if isinstance(mod, list):
2328
to_patch = mod
2429
name = "list"
@@ -29,10 +34,50 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
2934
v = getattr(mod, k)
3035
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
3136
to_patch.append(v)
37+
else:
38+
# a function
39+
doc = v.__doc__
40+
if doc.startswith("manual patch"):
41+
continue
42+
reg = re.compile("[[]patch:([a-z_A-Z.]+)[]]")
43+
fall = reg.findall(doc)
44+
assert (
45+
len(fall) == 1
46+
), f"Unable to find patching information for {v} in \n{doc}"
47+
fmod, f = get_function(fall[0])
48+
to_patch.append({"module": fmod, "function": f, "patch": v})
49+
3250
name = mod.__name__
51+
return name, to_patch
52+
53+
54+
def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
55+
"""
56+
Applies all patches defined in classes prefixed by ``patched_``
57+
``cls._PATCHED_CLASS_`` defines the class to patch,
58+
``cls._PATCHES_`` defines the method to patch.
59+
The returns information needs to be sent to :func:`unpatch_module_or_classes`
60+
to revert the changes.
61+
62+
:param mod: module of list of clsses to patch
63+
:param verbose: verbosity
64+
:return: patch info
65+
"""
66+
name, to_patch = get_patches(mod, verbose)
3367

3468
res = {}
3569
for cls in to_patch:
70+
if isinstance(cls, dict):
71+
# a function
72+
keep = {}
73+
original = cls["module"]
74+
f = cls["function"]
75+
res[f] = f
76+
if verbose:
77+
print(f"[patch_module_or_classes] function: {original.__name__}.{f.__name__}")
78+
setattr(original, f.__name__, cls["patch"])
79+
continue
80+
3681
original = cls._PATCHED_CLASS_
3782
methods = cls._PATCHES_
3883
if verbose:
@@ -53,30 +98,35 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo
5398
:param mod: module of list of clsses to patch
5499
:param verbose: verbosity
55100
"""
56-
if isinstance(mod, list):
57-
to_patch = mod
58-
name = "list"
59-
else:
60-
to_patch = []
61-
for k in dir(mod):
62-
if k.startswith("patched_"):
63-
v = getattr(mod, k)
64-
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
65-
to_patch.append(v)
66-
name = mod.__name__
67-
set_patch = set(to_patch)
101+
name, to_patch = get_patches(mod, verbose)
102+
set_patch_cls = {i for i in to_patch if not isinstance(i, dict)}
103+
dict_patch_fct = {i["function"]: i for i in to_patch if isinstance(i, dict)}
68104

69105
for cls, methods in info.items():
70-
assert cls in set_patch, f"No patch registered for {cls} in {mod} (found {set_patch})"
106+
if cls in set_patch_cls:
107+
if verbose:
108+
print(
109+
f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}"
110+
)
111+
original = cls._PATCHED_CLASS_
112+
for n, v in methods.items():
113+
if v is None:
114+
# The method did not exist. We remove it.
115+
delattr(original, n)
116+
else:
117+
setattr(original, n, v)
118+
continue
119+
assert cls in dict_patch_fct, (
120+
f"No patch registered for {cls} in {mod} "
121+
f"(found {set_patch_cls} and {set(dict_patch_fct)})"
122+
)
123+
patch = dict_patch_fct[cls]
71124
if verbose:
72-
print(f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}")
73-
original = cls._PATCHED_CLASS_
74-
for n, v in methods.items():
75-
if v is None:
76-
# The method did not exist. We remove it.
77-
delattr(original, n)
78-
else:
79-
setattr(original, n, v)
125+
print(
126+
f"[unpatch_module_or_classes] function "
127+
f"{patch['module'].__name__}.{cls.__name__}"
128+
)
129+
setattr(patch["module"], cls.__name__, patch["function"])
80130

81131

82132
@contextlib.contextmanager

0 commit comments

Comments
 (0)