Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions _unittests/ut_torch_models/test_validate_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import unittest
import packaging.version as pv
import torch
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
hide_stdout,
requires_torch,
requires_experimental,
requires_transformers,
)
from onnx_diagnostic.torch_models.test_helper import validate_model


class TestValidateModel(ExtTestCase):
@requires_transformers("4.52")
@requires_torch("2.7.99")
@requires_experimental()
@hide_stdout()
def test_validate_microsoft_phi4_reasoning(self):
# python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning
# --run -v 1 --export custom -o dump_test --no-quiet --device cuda --patch
summary, data = validate_model(
"microsoft/Phi-4-mini-reasoning",
do_run=True,
verbose=2,
exporter="custom",
do_same=True,
patch=True,
rewrite=True,
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
dump_folder="dump_test_validate_model_custom",
)
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-5)
self.assertIn("onnx_filename", data)


if __name__ == "__main__":
unittest.main(verbosity=2)
2 changes: 1 addition & 1 deletion onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def get_parser_validate() -> ArgumentParser:
"--runtime",
choices=["onnxruntime", "torch", "ref"],
default="onnxruntime",
help="onnx runtime to use, ",
help="onnx runtime to use, onnxruntime by default",
)
parser.add_argument(
"-o",
Expand Down
88 changes: 88 additions & 0 deletions onnx_diagnostic/torch_export_patches/patches/patch_transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
from dataclasses import dataclass
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import transformers
Expand Down Expand Up @@ -531,3 +532,90 @@ def prepare_inputs_for_generation(
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
model_inputs.pop("labels", None)
return model_inputs


def patched_dynamic_rope_update(rope_forward):
"""
patch:transformers.modeling_rope_utils.dynamic_rope_update
"""

def longrope_frequency_update(self, position_ids, device):
seq_len = torch.max(position_ids) + 1
if hasattr(self.config, "original_max_position_embeddings"):
original_max_position_embeddings = self.config.original_max_position_embeddings
else:
original_max_position_embeddings = self.config.max_position_embeddings
# At export time, seq_len is unknown.
long_inv_freq, _ = self.rope_init_fn(
self.config, device, seq_len=original_max_position_embeddings + 1
)
original_inv_freq = self.original_inv_freq.to(device)

cond = (seq_len > original_max_position_embeddings).item()
inv_freq = torch.cond(
cond,
(lambda x, y: x.clone()),
(lambda x, y: y.clone()),
[long_inv_freq, original_inv_freq],
)
self.inv_freq = inv_freq
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to register the buffer?

# if seq_len > original_max_position_embeddings:
# self.inv_freq = self.long_inv_freq
# else:
# self.inv_freq = self.original_inv_freq

def dynamic_frequency_update(self, position_ids, device):
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.max_seq_len_cached = seq_len

if (
seq_len < self.original_max_seq_len
and self.max_seq_len_cached > self.original_max_seq_len
):
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

@wraps(rope_forward)
def wrapper(self, x, position_ids):
if "dynamic" in self.rope_type:
dynamic_frequency_update(self, position_ids, device=x.device)
elif self.rope_type == "longrope":
longrope_frequency_update(self, position_ids, device=x.device)
return rope_forward(self, x, position_ids)

return wrapper


class patched_Phi3RotaryEmbedding(torch.nn.Module):
_PATCHES_ = ["forward"]
_PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding

@torch.no_grad()
@patched_dynamic_rope_update
def forward(self, x, position_ids):
inv_freq_expanded = (
self.inv_freq[None, :, None]
.float()
.expand(position_ids.shape[0], -1, 1)
.to(x.device)
)
position_ids_expanded = position_ids[:, None, :].float()

device_type = (
x.device.type
if isinstance(x.device.type, str) and x.device.type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
142 changes: 142 additions & 0 deletions onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3951,3 +3951,145 @@ def _ccached_facebook_bart_large_cnn():
"vocab_size": 50264,
}
)


def _ccached_microsoft_phi4_reasoning():
"microsoft/Phi-4-mini-reasoning"
return transformers.Phi3Config(
**{
"architectures": ["Phi3ForCausalLM"],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 199999,
"embd_pdrop": 0.0,
"eos_token_id": 199999,
"full_attn_mod": 1,
"hidden_act": "silu",
"hidden_size": 3072,
"initializer_range": 0.02,
"intermediate_size": 8192,
"interpolate_factor": 1,
"lm_head_bias": false,
"max_position_embeddings": 131072,
"mlp_bias": false,
"model_type": "phi3",
"num_attention_heads": 24,
"num_hidden_layers": 32,
"num_key_value_heads": 8,
"original_max_position_embeddings": 4096,
"pad_token_id": 199999,
"partial_rotary_factor": 0.75,
"resid_pdrop": 0.0,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"long_factor": [
1,
1.118320672,
1.250641126,
1.398617824,
1.564103225,
1.74916897,
1.956131817,
2.187582649,
2.446418898,
2.735880826,
3.059592084,
3.421605075,
3.826451687,
4.279200023,
4.785517845,
5.351743533,
5.984965424,
6.693110555,
7.485043894,
8.370679318,
9.36110372,
10.4687158,
11.70738129,
13.09260651,
14.64173252,
16.37415215,
18.31155283,
20.47818807,
22.90118105,
25.61086418,
28.64115884,
32.03,
32.1,
32.13,
32.23,
32.6,
32.61,
32.64,
32.66,
32.7,
32.71,
32.93,
32.97,
33.28,
33.49,
33.5,
44.16,
47.77,
],
"short_factor": [
1,
1.118320672,
1.250641126,
1.398617824,
1.564103225,
1.74916897,
1.956131817,
2.187582649,
2.446418898,
2.735880826,
3.059592084,
3.421605075,
3.826451687,
4.279200023,
4.785517845,
5.351743533,
5.984965424,
6.693110555,
7.485043894,
8.370679318,
9.36110372,
10.4687158,
11.70738129,
13.09260651,
14.64173252,
16.37415215,
18.31155283,
20.47818807,
22.90118105,
25.61086418,
28.64115884,
32.03,
32.1,
32.13,
32.23,
32.6,
32.61,
32.64,
32.66,
32.7,
32.71,
32.93,
32.97,
33.28,
33.49,
33.5,
44.16,
47.77,
],
"type": "longrope",
},
"rope_theta": 10000.0,
"sliding_window": 262144,
"tie_word_embeddings": true,
"torch_dtype": "bfloat16",
"transformers_version": "4.50.0",
"use_cache": true,
"vocab_size": 200064,
}
)
Loading