Skip to content

Commit 094f23d

Browse files
authored
Patch for _compute_dynamic_ntk_parameters (#145)
* patch for _compute_dynamic_ntk_parameters * change * fix * custom patch * patch * doc
1 parent 137b16c commit 094f23d

File tree

8 files changed

+389
-58
lines changed

8 files changed

+389
-58
lines changed

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def linkcode_resolve(domain, info):
138138
("py:class", "transformers.cache_utils.SlidingWindowCache"),
139139
("py:class", "transformers.configuration_utils.PretrainedConfig"),
140140
("py:class", "transformers.modeling_outputs.BaseModelOutput"),
141+
("py:class", "transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding"),
141142
("py:func", "torch.export._draft_export.draft_export"),
142143
("py:func", "torch._export.tools.report_exportability"),
143144
("py:func", "torch.utils._pytree.register_pytree_node"),

_unittests/ut_torch_models/test_validate_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_validate_microsoft_phi4_reasoning(self):
2828
patch=True,
2929
rewrite=True,
3030
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
31-
dump_folder="dump_test_validate_model_custom",
31+
dump_folder="dump_test/validate_microsoft_phi4_reasoning",
3232
)
3333
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-5)
3434
self.assertIn("onnx_filename", data)

_unittests/ut_torch_models/test_validate_whole_models.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import unittest
33
import packaging.version as pv
4+
import onnx
45
import torch
56
from onnx_diagnostic.ext_test_case import (
67
ExtTestCase,
@@ -63,7 +64,7 @@ def test_validate_model_export(self):
6364
do_run=True,
6465
verbose=10,
6566
exporter="export-nostrict",
66-
dump_folder="dump_test_validate_model_export",
67+
dump_folder="dump_test/validate_model_export",
6768
patch=True,
6869
)
6970
self.assertIsInstance(summary, dict)
@@ -79,7 +80,7 @@ def test_validate_model_onnx_dynamo_ir(self):
7980
do_run=True,
8081
verbose=10,
8182
exporter="onnx-dynamo",
82-
dump_folder="dump_test_validate_model_onnx_dynamo",
83+
dump_folder="dump_test/validate_model_onnx_dynamo_ir",
8384
patch=True,
8485
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
8586
optimization="ir",
@@ -104,7 +105,7 @@ def test_validate_model_onnx_dynamo_os_ort(self):
104105
do_run=True,
105106
verbose=10,
106107
exporter="onnx-dynamo",
107-
dump_folder="dump_test_validate_model_onnx_dynamo",
108+
dump_folder="dump_test/validate_model_onnx_dynamo_os_ort",
108109
patch=True,
109110
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
110111
optimization="os_ort",
@@ -126,7 +127,7 @@ def test_validate_model_custom_os_ort(self):
126127
do_run=True,
127128
verbose=10,
128129
exporter="custom",
129-
dump_folder="dump_validate_model_custom_os_ort",
130+
dump_folder="dump_test/validate_model_custom_os_ort",
130131
patch=True,
131132
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
132133
optimization="default+os_ort",
@@ -148,7 +149,7 @@ def test_validate_model_custom(self):
148149
do_run=True,
149150
verbose=10,
150151
exporter="custom",
151-
dump_folder="dump_test_validate_model_custom",
152+
dump_folder="dump_test/validate_model_custom_tiny_llm",
152153
patch=True,
153154
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
154155
optimization="default",
@@ -177,7 +178,7 @@ def test_validate_model_custom_torch(self):
177178
do_run=True,
178179
verbose=10,
179180
exporter="custom-noinline",
180-
dump_folder="dump_test_validate_model_custom_torch",
181+
dump_folder="dump_test/validate_model_custom_torch",
181182
patch=True,
182183
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
183184
optimization="default",
@@ -221,7 +222,7 @@ def test_validate_model_modelbuilder(self):
221222
do_run=True,
222223
verbose=10,
223224
exporter="modelbuilder",
224-
dump_folder="dump_test_validate_model_modelbuilder",
225+
dump_folder="dump_test/validate_model_modelbuilder",
225226
)
226227
self.assertIsInstance(summary, dict)
227228
self.assertIsInstance(data, dict)
@@ -240,7 +241,7 @@ def test_validate_model_vit_model(self):
240241
do_run=True,
241242
verbose=10,
242243
exporter="onnx-dynamo",
243-
dump_folder="dump_test_validate_model_onnx_dynamo",
244+
dump_folder="dump_test/validate_model_vit_model",
244245
inputs2=True,
245246
)
246247
self.assertIsInstance(summary, dict)
@@ -254,6 +255,30 @@ def test_validate_model_vit_model(self):
254255
onnx_filename = data["onnx_filename"]
255256
self.assertExists(onnx_filename)
256257

258+
@requires_torch("2.7")
259+
@hide_stdout()
260+
@ignore_warnings(FutureWarning)
261+
@requires_transformers("4.51")
262+
def test_validate_phi35_mini_instruct(self):
263+
mid = "microsoft/Phi-3.5-mini-instruct"
264+
summary, data = validate_model(
265+
mid,
266+
do_run=True,
267+
verbose=10,
268+
exporter="custom",
269+
dump_folder="dump_test/validate_phi35_mini_instruct",
270+
inputs2=True,
271+
patch=True,
272+
rewrite=True,
273+
# model_options={"rope_scaling": {"rope_type": "dynamic", "factor": 10.0}},
274+
)
275+
self.assertIsInstance(summary, dict)
276+
self.assertIsInstance(data, dict)
277+
onnx_filename = data["onnx_filename"]
278+
onx = onnx.load(onnx_filename)
279+
op_types = set(n.op_type for n in onx.graph.node)
280+
self.assertIn("If", op_types)
281+
257282

258283
if __name__ == "__main__":
259284
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sys
66
import textwrap
77
import onnx
8-
from typing import Any, List, Optional
8+
from typing import Any, Dict, List, Optional, Union
99
from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction
1010
from textwrap import dedent
1111

@@ -291,6 +291,14 @@ def _cmd_config(argv: List[Any]):
291291
print(f"task: {task_from_id(args.mid)}")
292292

293293

294+
def _parse_json(value: str) -> Union[str, Dict[str, Any]]:
295+
assert isinstance(value, str), f"value should be string but value={value!r}"
296+
if value and value[0] == "{" and value[-1] == "}":
297+
# a dictionary
298+
return json.loads(value.replace("'", '"'))
299+
return value
300+
301+
294302
class _ParseDict(argparse.Action):
295303
def __call__(self, parser, namespace, values, option_string=None):
296304
d = getattr(namespace, self.dest) or {}
@@ -314,7 +322,7 @@ def __call__(self, parser, namespace, values, option_string=None):
314322
continue
315323
except (TypeError, ValueError):
316324
pass
317-
d[key] = value
325+
d[key] = _parse_json(value)
318326

319327
setattr(namespace, self.dest, d)
320328

@@ -430,7 +438,8 @@ def get_parser_validate() -> ArgumentParser:
430438
metavar="KEY=VALUE",
431439
nargs="*",
432440
help="Additional model options, use to change some parameters of the model, "
433-
"example: --mop attn_implementation=eager",
441+
"example: ``--mop attn_implementation=eager`` or "
442+
"``--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"``",
434443
action=_ParseDict,
435444
)
436445
parser.add_argument(

onnx_diagnostic/helpers/config_helper.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,14 @@ def update_config(config: Any, mkwargs: Dict[str, Any]):
3434
config._attn_implementation_autoset = False
3535
continue
3636
if isinstance(v, dict):
37-
assert hasattr(
38-
config, k
39-
), f"missing attribute {k!r} in config={config}, cannot update it with {v}"
40-
update_config(getattr(config, k), v)
37+
if not hasattr(config, k) or getattr(config, k) is None:
38+
setattr(config, k, v)
39+
continue
40+
existing = getattr(config, k)
41+
if type(existing) is dict:
42+
existing.update(v)
43+
else:
44+
update_config(getattr(config, k), v)
4145
continue
4246
setattr(config, k, v)
4347

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 78 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,50 @@
1+
import functools
2+
import importlib
13
import contextlib
2-
from typing import Any, Callable, Dict, List, Optional
4+
import re
5+
from typing import Any, Callable, Dict, List, Optional, Tuple
36
from .onnx_export_serialization import (
47
register_cache_serialization,
58
unregister_cache_serialization,
69
)
710
from .patches import patch_transformers as patch_transformers_list
811

912

13+
def get_function(name: str) -> Tuple[type, Callable]:
14+
"""Returns the module and the function based on its name."""
15+
spl = name.split(".")
16+
module_name = ".".join(spl[:-1])
17+
fname = spl[-1]
18+
mod = importlib.import_module(module_name)
19+
return mod, getattr(mod, fname)
20+
21+
22+
@functools.lru_cache
23+
def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
24+
"""Returns the list of patches to make for a specific module."""
25+
to_patch = []
26+
for k in dir(mod):
27+
if k.startswith("patched_"):
28+
v = getattr(mod, k)
29+
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
30+
to_patch.append(v)
31+
else:
32+
# a function
33+
doc = v.__doc__.lstrip()
34+
if doc.startswith("manual patch"):
35+
continue
36+
reg = re.compile("[[]patch:([a-z_A-Z.]+)[]]")
37+
fall = reg.findall(doc)
38+
assert (
39+
len(fall) == 1
40+
), f"Unable to find patching information for {v} in \n{doc}"
41+
fmod, f = get_function(fall[0])
42+
to_patch.append({"module": fmod, "function": f, "patch": v})
43+
44+
name = mod.__name__
45+
return name, to_patch
46+
47+
1048
def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
1149
"""
1250
Applies all patches defined in classes prefixed by ``patched_``
@@ -23,16 +61,21 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
2361
to_patch = mod
2462
name = "list"
2563
else:
26-
to_patch = []
27-
for k in dir(mod):
28-
if k.startswith("patched_"):
29-
v = getattr(mod, k)
30-
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
31-
to_patch.append(v)
32-
name = mod.__name__
64+
name, to_patch = get_patches(mod, verbose)
3365

3466
res = {}
3567
for cls in to_patch:
68+
if isinstance(cls, dict):
69+
# a function
70+
keep = {}
71+
original = cls["module"]
72+
f = cls["function"]
73+
res[f] = f
74+
if verbose:
75+
print(f"[patch_module_or_classes] function: {original.__name__}.{f.__name__}")
76+
setattr(original, f.__name__, cls["patch"])
77+
continue
78+
3679
original = cls._PATCHED_CLASS_
3780
methods = cls._PATCHES_
3881
if verbose:
@@ -57,26 +100,36 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo
57100
to_patch = mod
58101
name = "list"
59102
else:
60-
to_patch = []
61-
for k in dir(mod):
62-
if k.startswith("patched_"):
63-
v = getattr(mod, k)
64-
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
65-
to_patch.append(v)
66-
name = mod.__name__
67-
set_patch = set(to_patch)
103+
name, to_patch = get_patches(mod, verbose)
104+
105+
set_patch_cls = {i for i in to_patch if not isinstance(i, dict)}
106+
dict_patch_fct = {i["function"]: i for i in to_patch if isinstance(i, dict)}
68107

69108
for cls, methods in info.items():
70-
assert cls in set_patch, f"No patch registered for {cls} in {mod} (found {set_patch})"
109+
if cls in set_patch_cls:
110+
if verbose:
111+
print(
112+
f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}"
113+
)
114+
original = cls._PATCHED_CLASS_
115+
for n, v in methods.items():
116+
if v is None:
117+
# The method did not exist. We remove it.
118+
delattr(original, n)
119+
else:
120+
setattr(original, n, v)
121+
continue
122+
assert cls in dict_patch_fct, (
123+
f"No patch registered for {cls} in {mod} "
124+
f"(found {set_patch_cls} and {set(dict_patch_fct)})"
125+
)
126+
patch = dict_patch_fct[cls]
71127
if verbose:
72-
print(f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}")
73-
original = cls._PATCHED_CLASS_
74-
for n, v in methods.items():
75-
if v is None:
76-
# The method did not exist. We remove it.
77-
delattr(original, n)
78-
else:
79-
setattr(original, n, v)
128+
print(
129+
f"[unpatch_module_or_classes] function "
130+
f"{patch['module'].__name__}.{cls.__name__}"
131+
)
132+
setattr(patch["module"], cls.__name__, patch["function"])
80133

81134

82135
@contextlib.contextmanager

0 commit comments

Comments
 (0)