Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.7.15
++++++

* :pr:`264`: allows to validate a model with inputs defined from another task
* :pr:`261`: updates to support ``transformers>=5.0``

0.7.14
Expand Down
2 changes: 1 addition & 1 deletion _doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def linkcode_resolve(domain, info):
"Linux": "https://www.linux.org/",
"ml_dtypes": "https://github.com/jax-ml/ml_dtypes",
"ModelBuilder": "https://onnxruntime.ai/docs/genai/howto/build-model.html",
"monai": "https://monai.io/",
"monai": "https://github.com/Project-MONAI/MONAI",
"numpy": "https://numpy.org/",
"onnx": "https://onnx.ai/onnx/",
"onnx-ir": "https://github.com/onnx/ir-py",
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_torch_models/test_validate_whole_models1.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_f_validate_model_onnx_dynamo_ir(self):
@requires_torch("2.7")
@requires_onnxscript("0.7")
@hide_stdout()
@ignore_warnings(FutureWarning)
@ignore_warnings((FutureWarning, RuntimeWarning))
def test_g_validate_model_onnx_dynamo_os_ort(self):
mid = "arnir0/Tiny-LLM"
summary, data = validate_model(
Expand Down
3 changes: 2 additions & 1 deletion onnx_diagnostic/helpers/config_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[typ
mod_name = cls.__module__
mod = importlib.import_module(mod_name)
source = inspect.getsource(mod)
reg = re.compile("config: ([A-Za-z0-9]+)")
# [^O] avoids capturing Optional[Something]
reg = re.compile("config: ([^O][A-Za-z0-9]+)")
fall = reg.findall(source)
if len(fall) == 0:
assert not exc, (
Expand Down
6 changes: 6 additions & 0 deletions onnx_diagnostic/tasks/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
def reduce_model_config(config: Any) -> Dict[str, Any]:
"""Reduces a model size."""
# FalconMambaConfig: use_mambapy
if hasattr(config, "text_config"):
# The model is probably of mixture of models used only for text.
config = config.text_config
check_hasattr(
config,
("head_dim", ("hidden_size", "num_attention_heads"), "use_mambapy"),
Expand Down Expand Up @@ -308,6 +311,9 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:

If the configuration is None, the function selects typical dimensions.
"""
if hasattr(config, "text_config"):
# The model is probably of mixture of models used only for text.
config = config.text_config
if config is not None:
check_hasattr(
config,
Expand Down
104 changes: 77 additions & 27 deletions onnx_diagnostic/torch_export_patches/patches/patch_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,26 @@ def patched__compute_dynamic_ntk_parameters(
return inv_freq, attention_factor


def _get_rope_init_fn(self, layer_type=None) -> Callable:
if hasattr(self, "rope_init_fn"):
# transformers<=5.0
rope_init_fn = (
patched__compute_dynamic_ntk_parameters
if self.rope_init_fn
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
else self.rope_init_fn
)
return rope_init_fn

rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
rope_init_fn = self.compute_default_rope_parameters
if rope_type != "default":
rope_init_fn = transformers.modeling_rope_utils.ROPE_INIT_FUNCTIONS[self.rope_type]
if rope_init_fn is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters:
return patched__compute_dynamic_ntk_parameters
return rope_init_fn


def patched_dynamic_rope_update(rope_forward):
"""manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``

Expand Down Expand Up @@ -1082,22 +1102,27 @@ def wrapper(self, x, position_ids):

"""

def longrope_frequency_update(self, position_ids, device):
def longrope_frequency_update(self, position_ids, device, layer_type=None):
# It is no use to patch the function after the model is created
# as rope_init_fn is an attribute set to one function when the model
# is created and when no patch is applied yet.
# So we select the patched version here.
rope_init_fn = (
patched__compute_dynamic_ntk_parameters
if self.rope_init_fn
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
else self.rope_init_fn
)
rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
seq_len = torch.max(position_ids) + 1
if hasattr(self.config, "original_max_position_embeddings"):
original_max_position_embeddings = self.config.original_max_position_embeddings
else:
original_max_position_embeddings = self.config.max_position_embeddings

if layer_type is None:
# rope_type = self.rope_type
original_inv_freq = self.original_inv_freq
prefix = ""
else:
# rope_type = self.rope_type[layer_type]
original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
prefix = f"{layer_type}_"

# At export time, seq_len is unknown.
long_inv_freq, _ = rope_init_fn(
self.config, device, seq_len=original_max_position_embeddings + 1
Expand All @@ -1112,13 +1137,13 @@ def longrope_frequency_update(self, position_ids, device):
(lambda x, y: y.clone()),
[long_inv_freq, original_inv_freq],
)
self.inv_freq = inv_freq
setattr(self, f"{prefix}inv_freq", inv_freq)
# if seq_len > original_max_position_embeddings:
# self.inv_freq = self.long_inv_freq
# else:
# self.inv_freq = self.original_inv_freq

def dynamic_frequency_update(self, position_ids, device):
def dynamic_frequency_update(self, position_ids, device, layer_type=None):
# constructor:
# - self.max_seq_len_cached = config.max_position_embeddings
# - self.original_max_seq_len = config.max_position_embeddings
Expand All @@ -1128,12 +1153,7 @@ def dynamic_frequency_update(self, position_ids, device):
# as rope_init_fn is an attribute set to one function when the model
# is created and when no patch is applied yet.
# So we select the patched version here.
rope_init_fn = (
patched__compute_dynamic_ntk_parameters
if self.rope_init_fn
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
else self.rope_init_fn
)
rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)

# This behaviour is difficult to translate.
# The sequence always grows.
Expand Down Expand Up @@ -1162,6 +1182,19 @@ def dynamic_frequency_update(self, position_ids, device):
self.config, device, seq_len=seq_len
)

if layer_type is None:
# rope_type = self.rope_type
# max_seq_len_cached = self.max_seq_len_cached
original_inv_freq = self.original_inv_freq
prefix = ""
else:
# rope_type = self.rope_type[layer_type]
# max_seq_len_cached = getattr(
# self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
# )
original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
prefix = f"{layer_type}_"

# Second test to translate.
# Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
# But in that case the following condition is a way to restore the original cache.
Expand All @@ -1183,15 +1216,26 @@ def dynamic_frequency_update(self, position_ids, device):
(lambda x, y: y.clone()),
[long_inv_freq, original_inv_freq],
)
self.inv_freq = inv_freq
setattr(self, f"{prefix}inv_freq", inv_freq)

@wraps(rope_forward)
def wrapper(self, x, position_ids):
def wrapper(self, x, position_ids, layer_type=None):
if layer_type is None:
if "dynamic" in self.rope_type:
dynamic_frequency_update(self, position_ids, device=x.device)
elif self.rope_type == "longrope":
longrope_frequency_update(self, position_ids, device=x.device)
return rope_forward(self, x, position_ids)

if "dynamic" in self.rope_type:
dynamic_frequency_update(self, position_ids, device=x.device)
dynamic_frequency_update(
self, position_ids, device=x.device, layer_type=layer_type
)
elif self.rope_type == "longrope":
longrope_frequency_update(self, position_ids, device=x.device)
return rope_forward(self, x, position_ids)
longrope_frequency_update(
self, position_ids, device=x.device, layer_type=layer_type
)
return rope_forward(self, x, position_ids, layer_type=layer_type)

return wrapper

Expand Down Expand Up @@ -1287,12 +1331,18 @@ class common_RotaryEmbedding(torch.nn.Module):
# @torch.no_grad()
# PATCHED: the decorator
@patched_dynamic_rope_update
def forward(self, x, position_ids):
def forward(self, x, position_ids, layer_type=None):
if layer_type is not None:
# transformers>=5.0
inv_freq = getattr(self, f"{layer_type}_inv_freq")
attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
else:
# transformers<5.0
inv_freq = self.inv_freq
attention_scaling = self.attention_scaling

inv_freq_expanded = (
self.inv_freq[None, :, None]
.float()
.expand(position_ids.shape[0], -1, 1)
.to(x.device)
inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
)
position_ids_expanded = position_ids[:, None, :].float()

Expand All @@ -1304,8 +1354,8 @@ def forward(self, x, position_ids):
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
cos = emb.cos() * attention_scaling
sin = emb.sin() * attention_scaling

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

Expand Down
10 changes: 7 additions & 3 deletions onnx_diagnostic/torch_models/hghub/model_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def get_untrained_model_with_inputs(
print("-- dynamic shapes:", pprint.pformat(data['dynamic_shapes']))
print("-- configuration:", pprint.pformat(data['configuration']))
"""
if task == "":
task = None
assert not use_preinstalled or not use_only_preinstalled, (
f"model_id={model_id!r}, preinstalled model is only available "
f"if use_only_preinstalled is False."
Expand All @@ -120,14 +122,16 @@ def get_untrained_model_with_inputs(
**(model_kwargs or {}),
)

model, task, mkwargs, diff_config = None, None, {}, None
model, task_, mkwargs, diff_config = None, None, {}, None
if use_pretrained and same_as_pretrained:
if model_id in HANDLED_MODELS:
model, task, config = load_specific_model(model_id, verbose=verbose)
model, task_, config = load_specific_model(model_id, verbose=verbose)

if task is None:
task = task_
if model is None:
arch = architecture_from_config(config)
if arch is None:
if task is None and arch is None:
task = task_from_id(model_id, subfolder=subfolder)
assert task is not None or arch is not None, (
f"Unable to determine the architecture for model {model_id!r}, "
Expand Down
30 changes: 22 additions & 8 deletions onnx_diagnostic/torch_models/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,21 @@ def _make_folder_name(
drop_inputs: Optional[List[str]] = None,
same_as_pretrained: bool = False,
use_pretrained: bool = False,
task: Optional[str] = None,
) -> str:
"Creates a filename unique based on the given options."
els = [model_id.replace("/", "_")]
if subfolder:
els.append(subfolder.replace("/", "_"))
if not task:
els.append(task) # type: ignore[arg-type]
if drop_inputs:
ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs)
els.append(f"I-{ii.upper()}")
if use_pretrained:
els.append("TRAINED")
elif same_as_pretrained:
els.append("SAMESIZE")
if exporter:
els.append(exporter)
if optimization:
Expand All @@ -142,14 +152,7 @@ def _make_folder_name(
els.append(sdev)
if opset is not None:
els.append(f"op{opset}")
if drop_inputs:
ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs)
els.append(f"I-{ii.upper()}")
if use_pretrained:
els.append("TRAINED")
elif same_as_pretrained:
els.append("SAMESIZE")
return "-".join(els)
return "/".join([e for e in els if e])


def version_summary() -> Dict[str, Union[int, float, str]]:
Expand Down Expand Up @@ -476,6 +479,7 @@ def validate_model(
drop_inputs=drop_inputs,
use_pretrained=use_pretrained,
same_as_pretrained=same_as_pretrained,
task=task,
)
dump_folder = os.path.join(dump_folder, folder_name)
if not os.path.exists(dump_folder):
Expand All @@ -490,6 +494,8 @@ def validate_model(
print(f"[validate_model] validate model id {model_id!r}, subfolder={subfolder!r}")
else:
print(f"[validate_model] validate model id {model_id!r}")
if task:
print(f"[validate_model] with task {task!r}")
print(f"[validate_model] patch={patch!r}")
if model_options:
print(f"[validate_model] model_options={model_options!r}")
Expand Down Expand Up @@ -765,6 +771,10 @@ def validate_model(
ep = data["exported_program"]
if verbose:
print(f"[validate_model] -- dumps exported program in {dump_folder!r}...")
assert isinstance(
folder_name, str
), f"folder_name={folder_name!r} should be a string"
folder_name = folder_name.replace("/", "-")
with open(os.path.join(dump_folder, f"{folder_name}.ep"), "w") as f:
f.write(str(ep))
torch.export.save(ep, os.path.join(dump_folder, f"{folder_name}.pt2"))
Expand All @@ -773,6 +783,10 @@ def validate_model(
if verbose:
print("[validate_model] done (dump ep)")
if "onnx_program" in data:
assert isinstance(
folder_name, str
), f"folder_name={folder_name!r} should be a string"
folder_name = folder_name.replace("/", "-")
epo = data["onnx_program"]
if verbose:
print(f"[validate_model] dumps onnx program in {dump_folder!r}...")
Expand Down
Loading