Skip to content

Commit 05a2e93

Browse files
authored
Allows to convert a model with other task inputs (#264)
* allow to convert a model with other task inputs * doc * pym * fix * none * fix rope * fix optional * fix rotary patch
1 parent 2b5cb7c commit 05a2e93

File tree

8 files changed

+117
-41
lines changed

8 files changed

+117
-41
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.7.15
55
++++++
66

7+
* :pr:`264`: allows to validate a model with inputs defined from another task
78
* :pr:`261`: updates to support ``transformers>=5.0``
89

910
0.7.14

_doc/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def linkcode_resolve(domain, info):
229229
"Linux": "https://www.linux.org/",
230230
"ml_dtypes": "https://github.com/jax-ml/ml_dtypes",
231231
"ModelBuilder": "https://onnxruntime.ai/docs/genai/howto/build-model.html",
232-
"monai": "https://monai.io/",
232+
"monai": "https://github.com/Project-MONAI/MONAI",
233233
"numpy": "https://numpy.org/",
234234
"onnx": "https://onnx.ai/onnx/",
235235
"onnx-ir": "https://github.com/onnx/ir-py",

_unittests/ut_torch_models/test_validate_whole_models1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_f_validate_model_onnx_dynamo_ir(self):
9898
@requires_torch("2.7")
9999
@requires_onnxscript("0.7")
100100
@hide_stdout()
101-
@ignore_warnings(FutureWarning)
101+
@ignore_warnings((FutureWarning, RuntimeWarning))
102102
def test_g_validate_model_onnx_dynamo_os_ort(self):
103103
mid = "arnir0/Tiny-LLM"
104104
summary, data = validate_model(

onnx_diagnostic/helpers/config_helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[typ
9595
mod_name = cls.__module__
9696
mod = importlib.import_module(mod_name)
9797
source = inspect.getsource(mod)
98-
reg = re.compile("config: ([A-Za-z0-9]+)")
98+
# [^O] avoids capturing Optional[Something]
99+
reg = re.compile("config: ([^O][A-Za-z0-9]+)")
99100
fall = reg.findall(source)
100101
if len(fall) == 0:
101102
assert not exc, (

onnx_diagnostic/tasks/text_generation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
def reduce_model_config(config: Any) -> Dict[str, Any]:
2020
"""Reduces a model size."""
2121
# FalconMambaConfig: use_mambapy
22+
if hasattr(config, "text_config"):
23+
# The model is probably of mixture of models used only for text.
24+
config = config.text_config
2225
check_hasattr(
2326
config,
2427
("head_dim", ("hidden_size", "num_attention_heads"), "use_mambapy"),
@@ -308,6 +311,9 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
308311
309312
If the configuration is None, the function selects typical dimensions.
310313
"""
314+
if hasattr(config, "text_config"):
315+
# The model is probably of mixture of models used only for text.
316+
config = config.text_config
311317
if config is not None:
312318
check_hasattr(
313319
config,

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 77 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,26 @@ def patched__compute_dynamic_ntk_parameters(
10191019
return inv_freq, attention_factor
10201020

10211021

1022+
def _get_rope_init_fn(self, layer_type=None) -> Callable:
1023+
if hasattr(self, "rope_init_fn"):
1024+
# transformers<=5.0
1025+
rope_init_fn = (
1026+
patched__compute_dynamic_ntk_parameters
1027+
if self.rope_init_fn
1028+
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
1029+
else self.rope_init_fn
1030+
)
1031+
return rope_init_fn
1032+
1033+
rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
1034+
rope_init_fn = self.compute_default_rope_parameters
1035+
if rope_type != "default":
1036+
rope_init_fn = transformers.modeling_rope_utils.ROPE_INIT_FUNCTIONS[self.rope_type]
1037+
if rope_init_fn is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters:
1038+
return patched__compute_dynamic_ntk_parameters
1039+
return rope_init_fn
1040+
1041+
10221042
def patched_dynamic_rope_update(rope_forward):
10231043
"""manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
10241044
@@ -1082,22 +1102,27 @@ def wrapper(self, x, position_ids):
10821102
10831103
"""
10841104

1085-
def longrope_frequency_update(self, position_ids, device):
1105+
def longrope_frequency_update(self, position_ids, device, layer_type=None):
10861106
# It is no use to patch the function after the model is created
10871107
# as rope_init_fn is an attribute set to one function when the model
10881108
# is created and when no patch is applied yet.
10891109
# So we select the patched version here.
1090-
rope_init_fn = (
1091-
patched__compute_dynamic_ntk_parameters
1092-
if self.rope_init_fn
1093-
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
1094-
else self.rope_init_fn
1095-
)
1110+
rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
10961111
seq_len = torch.max(position_ids) + 1
10971112
if hasattr(self.config, "original_max_position_embeddings"):
10981113
original_max_position_embeddings = self.config.original_max_position_embeddings
10991114
else:
11001115
original_max_position_embeddings = self.config.max_position_embeddings
1116+
1117+
if layer_type is None:
1118+
# rope_type = self.rope_type
1119+
original_inv_freq = self.original_inv_freq
1120+
prefix = ""
1121+
else:
1122+
# rope_type = self.rope_type[layer_type]
1123+
original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
1124+
prefix = f"{layer_type}_"
1125+
11011126
# At export time, seq_len is unknown.
11021127
long_inv_freq, _ = rope_init_fn(
11031128
self.config, device, seq_len=original_max_position_embeddings + 1
@@ -1112,13 +1137,13 @@ def longrope_frequency_update(self, position_ids, device):
11121137
(lambda x, y: y.clone()),
11131138
[long_inv_freq, original_inv_freq],
11141139
)
1115-
self.inv_freq = inv_freq
1140+
setattr(self, f"{prefix}inv_freq", inv_freq)
11161141
# if seq_len > original_max_position_embeddings:
11171142
# self.inv_freq = self.long_inv_freq
11181143
# else:
11191144
# self.inv_freq = self.original_inv_freq
11201145

1121-
def dynamic_frequency_update(self, position_ids, device):
1146+
def dynamic_frequency_update(self, position_ids, device, layer_type=None):
11221147
# constructor:
11231148
# - self.max_seq_len_cached = config.max_position_embeddings
11241149
# - self.original_max_seq_len = config.max_position_embeddings
@@ -1128,12 +1153,7 @@ def dynamic_frequency_update(self, position_ids, device):
11281153
# as rope_init_fn is an attribute set to one function when the model
11291154
# is created and when no patch is applied yet.
11301155
# So we select the patched version here.
1131-
rope_init_fn = (
1132-
patched__compute_dynamic_ntk_parameters
1133-
if self.rope_init_fn
1134-
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
1135-
else self.rope_init_fn
1136-
)
1156+
rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
11371157

11381158
# This behaviour is difficult to translate.
11391159
# The sequence always grows.
@@ -1162,6 +1182,19 @@ def dynamic_frequency_update(self, position_ids, device):
11621182
self.config, device, seq_len=seq_len
11631183
)
11641184

1185+
if layer_type is None:
1186+
# rope_type = self.rope_type
1187+
# max_seq_len_cached = self.max_seq_len_cached
1188+
original_inv_freq = self.original_inv_freq
1189+
prefix = ""
1190+
else:
1191+
# rope_type = self.rope_type[layer_type]
1192+
# max_seq_len_cached = getattr(
1193+
# self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
1194+
# )
1195+
original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
1196+
prefix = f"{layer_type}_"
1197+
11651198
# Second test to translate.
11661199
# Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
11671200
# But in that case the following condition is a way to restore the original cache.
@@ -1183,15 +1216,26 @@ def dynamic_frequency_update(self, position_ids, device):
11831216
(lambda x, y: y.clone()),
11841217
[long_inv_freq, original_inv_freq],
11851218
)
1186-
self.inv_freq = inv_freq
1219+
setattr(self, f"{prefix}inv_freq", inv_freq)
11871220

11881221
@wraps(rope_forward)
1189-
def wrapper(self, x, position_ids):
1222+
def wrapper(self, x, position_ids, layer_type=None):
1223+
if layer_type is None:
1224+
if "dynamic" in self.rope_type:
1225+
dynamic_frequency_update(self, position_ids, device=x.device)
1226+
elif self.rope_type == "longrope":
1227+
longrope_frequency_update(self, position_ids, device=x.device)
1228+
return rope_forward(self, x, position_ids)
1229+
11901230
if "dynamic" in self.rope_type:
1191-
dynamic_frequency_update(self, position_ids, device=x.device)
1231+
dynamic_frequency_update(
1232+
self, position_ids, device=x.device, layer_type=layer_type
1233+
)
11921234
elif self.rope_type == "longrope":
1193-
longrope_frequency_update(self, position_ids, device=x.device)
1194-
return rope_forward(self, x, position_ids)
1235+
longrope_frequency_update(
1236+
self, position_ids, device=x.device, layer_type=layer_type
1237+
)
1238+
return rope_forward(self, x, position_ids, layer_type=layer_type)
11951239

11961240
return wrapper
11971241

@@ -1287,12 +1331,18 @@ class common_RotaryEmbedding(torch.nn.Module):
12871331
# @torch.no_grad()
12881332
# PATCHED: the decorator
12891333
@patched_dynamic_rope_update
1290-
def forward(self, x, position_ids):
1334+
def forward(self, x, position_ids, layer_type=None):
1335+
if layer_type is not None:
1336+
# transformers>=5.0
1337+
inv_freq = getattr(self, f"{layer_type}_inv_freq")
1338+
attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
1339+
else:
1340+
# transformers<5.0
1341+
inv_freq = self.inv_freq
1342+
attention_scaling = self.attention_scaling
1343+
12911344
inv_freq_expanded = (
1292-
self.inv_freq[None, :, None]
1293-
.float()
1294-
.expand(position_ids.shape[0], -1, 1)
1295-
.to(x.device)
1345+
inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
12961346
)
12971347
position_ids_expanded = position_ids[:, None, :].float()
12981348

@@ -1304,8 +1354,8 @@ def forward(self, x, position_ids):
13041354
with torch.autocast(device_type=device_type, enabled=False): # Force float32
13051355
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
13061356
emb = torch.cat((freqs, freqs), dim=-1)
1307-
cos = emb.cos() * self.attention_scaling
1308-
sin = emb.sin() * self.attention_scaling
1357+
cos = emb.cos() * attention_scaling
1358+
sin = emb.sin() * attention_scaling
13091359

13101360
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
13111361

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def get_untrained_model_with_inputs(
9595
print("-- dynamic shapes:", pprint.pformat(data['dynamic_shapes']))
9696
print("-- configuration:", pprint.pformat(data['configuration']))
9797
"""
98+
if task == "":
99+
task = None
98100
assert not use_preinstalled or not use_only_preinstalled, (
99101
f"model_id={model_id!r}, preinstalled model is only available "
100102
f"if use_only_preinstalled is False."
@@ -120,14 +122,16 @@ def get_untrained_model_with_inputs(
120122
**(model_kwargs or {}),
121123
)
122124

123-
model, task, mkwargs, diff_config = None, None, {}, None
125+
model, task_, mkwargs, diff_config = None, None, {}, None
124126
if use_pretrained and same_as_pretrained:
125127
if model_id in HANDLED_MODELS:
126-
model, task, config = load_specific_model(model_id, verbose=verbose)
128+
model, task_, config = load_specific_model(model_id, verbose=verbose)
127129

130+
if task is None:
131+
task = task_
128132
if model is None:
129133
arch = architecture_from_config(config)
130-
if arch is None:
134+
if task is None and arch is None:
131135
task = task_from_id(model_id, subfolder=subfolder)
132136
assert task is not None or arch is not None, (
133137
f"Unable to determine the architecture for model {model_id!r}, "

onnx_diagnostic/torch_models/validate.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,21 @@ def _make_folder_name(
117117
drop_inputs: Optional[List[str]] = None,
118118
same_as_pretrained: bool = False,
119119
use_pretrained: bool = False,
120+
task: Optional[str] = None,
120121
) -> str:
121122
"Creates a filename unique based on the given options."
122123
els = [model_id.replace("/", "_")]
123124
if subfolder:
124125
els.append(subfolder.replace("/", "_"))
126+
if not task:
127+
els.append(task) # type: ignore[arg-type]
128+
if drop_inputs:
129+
ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs)
130+
els.append(f"I-{ii.upper()}")
131+
if use_pretrained:
132+
els.append("TRAINED")
133+
elif same_as_pretrained:
134+
els.append("SAMESIZE")
125135
if exporter:
126136
els.append(exporter)
127137
if optimization:
@@ -142,14 +152,7 @@ def _make_folder_name(
142152
els.append(sdev)
143153
if opset is not None:
144154
els.append(f"op{opset}")
145-
if drop_inputs:
146-
ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs)
147-
els.append(f"I-{ii.upper()}")
148-
if use_pretrained:
149-
els.append("TRAINED")
150-
elif same_as_pretrained:
151-
els.append("SAMESIZE")
152-
return "-".join(els)
155+
return "/".join([e for e in els if e])
153156

154157

155158
def version_summary() -> Dict[str, Union[int, float, str]]:
@@ -476,6 +479,7 @@ def validate_model(
476479
drop_inputs=drop_inputs,
477480
use_pretrained=use_pretrained,
478481
same_as_pretrained=same_as_pretrained,
482+
task=task,
479483
)
480484
dump_folder = os.path.join(dump_folder, folder_name)
481485
if not os.path.exists(dump_folder):
@@ -490,6 +494,8 @@ def validate_model(
490494
print(f"[validate_model] validate model id {model_id!r}, subfolder={subfolder!r}")
491495
else:
492496
print(f"[validate_model] validate model id {model_id!r}")
497+
if task:
498+
print(f"[validate_model] with task {task!r}")
493499
print(f"[validate_model] patch={patch!r}")
494500
if model_options:
495501
print(f"[validate_model] model_options={model_options!r}")
@@ -765,6 +771,10 @@ def validate_model(
765771
ep = data["exported_program"]
766772
if verbose:
767773
print(f"[validate_model] -- dumps exported program in {dump_folder!r}...")
774+
assert isinstance(
775+
folder_name, str
776+
), f"folder_name={folder_name!r} should be a string"
777+
folder_name = folder_name.replace("/", "-")
768778
with open(os.path.join(dump_folder, f"{folder_name}.ep"), "w") as f:
769779
f.write(str(ep))
770780
torch.export.save(ep, os.path.join(dump_folder, f"{folder_name}.pt2"))
@@ -773,6 +783,10 @@ def validate_model(
773783
if verbose:
774784
print("[validate_model] done (dump ep)")
775785
if "onnx_program" in data:
786+
assert isinstance(
787+
folder_name, str
788+
), f"folder_name={folder_name!r} should be a string"
789+
folder_name = folder_name.replace("/", "-")
776790
epo = data["onnx_program"]
777791
if verbose:
778792
print(f"[validate_model] dumps onnx program in {dump_folder!r}...")

0 commit comments

Comments
 (0)