Skip to content

Commit 528fa41

Browse files
committed
change
1 parent 6bca0b4 commit 528fa41

File tree

3 files changed

+40
-9
lines changed

3 files changed

+40
-9
lines changed

_unittests/ut_torch_models/test_validate_whole_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def test_validate_phi3_mini_4k_instruct(self):
265265
mid,
266266
do_run=True,
267267
verbose=10,
268-
exporter="custom",
268+
exporter="onnx-dynamo",
269269
dump_folder="dump_test/validate_phi3_mini_4k_instruct",
270270
inputs2=True,
271271
patch=True,

onnx_diagnostic/_command_lines_parser.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sys
66
import textwrap
77
import onnx
8-
from typing import Any, List, Optional
8+
from typing import Any, Dict, List, Optional, Union
99
from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction
1010
from textwrap import dedent
1111

@@ -291,6 +291,14 @@ def _cmd_config(argv: List[Any]):
291291
print(f"task: {task_from_id(args.mid)}")
292292

293293

294+
def _parse_json(value: str) -> Union[str, Dict[str, Any]]:
295+
assert isinstance(value, str), f"value should be string but value={value!r}"
296+
if value and value[0] == "{" and value[-1] == "}":
297+
# a dictionary
298+
return json.loads(value.replace("'", '"'))
299+
return value
300+
301+
294302
class _ParseDict(argparse.Action):
295303
def __call__(self, parser, namespace, values, option_string=None):
296304
d = getattr(namespace, self.dest) or {}
@@ -314,7 +322,7 @@ def __call__(self, parser, namespace, values, option_string=None):
314322
continue
315323
except (TypeError, ValueError):
316324
pass
317-
d[key] = value
325+
d[key] = _parse_json(value)
318326

319327
setattr(namespace, self.dest, d)
320328

@@ -430,7 +438,8 @@ def get_parser_validate() -> ArgumentParser:
430438
metavar="KEY=VALUE",
431439
nargs="*",
432440
help="Additional model options, use to change some parameters of the model, "
433-
"example: --mop attn_implementation=eager",
441+
"example: ``--mop attn_implementation=eager`` or "
442+
"``--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"``",
434443
action=_ParseDict,
435444
)
436445
parser.add_argument(

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ def patched__compute_dynamic_ntk_parameters(
540540
seq_len: Optional[int] = None,
541541
**rope_kwargs,
542542
) -> Tuple["torch.Tensor", float]:
543-
"""
543+
"""manual patch:
544544
``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]``
545545
546546
Computes the inverse frequencies with NTK scaling.
@@ -594,8 +594,9 @@ def patched__compute_dynamic_ntk_parameters(
594594
seq_len = max_position_embeddings
595595
else:
596596
torch._check(isinstance(seq_len, torch.Tensor))
597-
seq_len = torch.max(
598-
seq_len, torch.Tensor(max_position_embeddings, dtype=seq_len.dtype)
597+
seq_len = torch.maximum(
598+
seq_len,
599+
torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
599600
)
600601

601602
# Compute the inverse frequencies
@@ -676,13 +677,23 @@ def wrapper(self, x, position_ids):
676677
"""
677678

678679
def longrope_frequency_update(self, position_ids, device):
680+
# It is no use to patch the function after the model is created
681+
# as rope_init_fn is an attribute set to one function when the model
682+
# is created and when no patch is applied yet.
683+
# So we select the patched version here.
684+
rope_init_fn = (
685+
patched__compute_dynamic_ntk_parameters
686+
if self.rope_init_fn
687+
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
688+
else self.rope_init_fn
689+
)
679690
seq_len = torch.max(position_ids) + 1
680691
if hasattr(self.config, "original_max_position_embeddings"):
681692
original_max_position_embeddings = self.config.original_max_position_embeddings
682693
else:
683694
original_max_position_embeddings = self.config.max_position_embeddings
684695
# At export time, seq_len is unknown.
685-
long_inv_freq, _ = self.rope_init_fn(
696+
long_inv_freq, _ = rope_init_fn(
686697
self.config, device, seq_len=original_max_position_embeddings + 1
687698
)
688699
original_inv_freq = self.original_inv_freq.to(device)
@@ -706,6 +717,17 @@ def dynamic_frequency_update(self, position_ids, device):
706717
# - self.original_max_seq_len = config.max_position_embeddings
707718
# - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
708719

720+
# It is no use to patch the function after the model is created
721+
# as rope_init_fn is an attribute set to one function when the model
722+
# is created and when no patch is applied yet.
723+
# So we select the patched version here.
724+
rope_init_fn = (
725+
patched__compute_dynamic_ntk_parameters
726+
if self.rope_init_fn
727+
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
728+
else self.rope_init_fn
729+
)
730+
709731
# This behaviour is difficult to translate.
710732
# The sequence always grows.
711733
# The test should always True.
@@ -729,7 +751,7 @@ def dynamic_frequency_update(self, position_ids, device):
729751
# )
730752

731753
seq_len = torch.max(position_ids) + 1
732-
long_inv_freq, self.attention_scaling = self.rope_init_fn(
754+
long_inv_freq, self.attention_scaling = rope_init_fn(
733755
self.config, device, seq_len=seq_len
734756
)
735757

0 commit comments

Comments
 (0)