Skip to content

Commit 8d96eef

Browse files
authored
patch for Phi4 reasoning (#128)
1 parent 60d7f71 commit 8d96eef

File tree

4 files changed

+269
-1
lines changed

4 files changed

+269
-1
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import unittest
2+
import packaging.version as pv
3+
import torch
4+
from onnx_diagnostic.ext_test_case import (
5+
ExtTestCase,
6+
hide_stdout,
7+
requires_torch,
8+
requires_experimental,
9+
requires_transformers,
10+
)
11+
from onnx_diagnostic.torch_models.test_helper import validate_model
12+
13+
14+
class TestValidateModel(ExtTestCase):
15+
@requires_transformers("4.52")
16+
@requires_torch("2.7.99")
17+
@requires_experimental()
18+
@hide_stdout()
19+
def test_validate_microsoft_phi4_reasoning(self):
20+
# python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning
21+
# --run -v 1 --export custom -o dump_test --no-quiet --device cuda --patch
22+
summary, data = validate_model(
23+
"microsoft/Phi-4-mini-reasoning",
24+
do_run=True,
25+
verbose=2,
26+
exporter="custom",
27+
do_same=True,
28+
patch=True,
29+
rewrite=True,
30+
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
31+
dump_folder="dump_test_validate_model_custom",
32+
)
33+
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-5)
34+
self.assertIn("onnx_filename", data)
35+
36+
37+
if __name__ == "__main__":
38+
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def get_parser_validate() -> ArgumentParser:
356356
"--runtime",
357357
choices=["onnxruntime", "torch", "ref"],
358358
default="onnxruntime",
359-
help="onnx runtime to use, ",
359+
help="onnx runtime to use, onnxruntime by default",
360360
)
361361
parser.add_argument(
362362
"-o",

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
22
from dataclasses import dataclass
3+
from functools import wraps
34
from typing import Any, Callable, Dict, List, Optional, Tuple
45
import torch
56
import transformers
@@ -531,3 +532,90 @@ def prepare_inputs_for_generation(
531532
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
532533
model_inputs.pop("labels", None)
533534
return model_inputs
535+
536+
537+
def patched_dynamic_rope_update(rope_forward):
538+
"""
539+
patch:transformers.modeling_rope_utils.dynamic_rope_update
540+
"""
541+
542+
def longrope_frequency_update(self, position_ids, device):
543+
seq_len = torch.max(position_ids) + 1
544+
if hasattr(self.config, "original_max_position_embeddings"):
545+
original_max_position_embeddings = self.config.original_max_position_embeddings
546+
else:
547+
original_max_position_embeddings = self.config.max_position_embeddings
548+
# At export time, seq_len is unknown.
549+
long_inv_freq, _ = self.rope_init_fn(
550+
self.config, device, seq_len=original_max_position_embeddings + 1
551+
)
552+
original_inv_freq = self.original_inv_freq.to(device)
553+
554+
cond = (seq_len > original_max_position_embeddings).item()
555+
inv_freq = torch.cond(
556+
cond,
557+
(lambda x, y: x.clone()),
558+
(lambda x, y: y.clone()),
559+
[long_inv_freq, original_inv_freq],
560+
)
561+
self.inv_freq = inv_freq
562+
# if seq_len > original_max_position_embeddings:
563+
# self.inv_freq = self.long_inv_freq
564+
# else:
565+
# self.inv_freq = self.original_inv_freq
566+
567+
def dynamic_frequency_update(self, position_ids, device):
568+
seq_len = torch.max(position_ids) + 1
569+
if seq_len > self.max_seq_len_cached: # growth
570+
inv_freq, self.attention_scaling = self.rope_init_fn(
571+
self.config, device, seq_len=seq_len
572+
)
573+
self.register_buffer("inv_freq", inv_freq, persistent=False)
574+
self.max_seq_len_cached = seq_len
575+
576+
if (
577+
seq_len < self.original_max_seq_len
578+
and self.max_seq_len_cached > self.original_max_seq_len
579+
):
580+
self.original_inv_freq = self.original_inv_freq.to(device)
581+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
582+
self.max_seq_len_cached = self.original_max_seq_len
583+
584+
@wraps(rope_forward)
585+
def wrapper(self, x, position_ids):
586+
if "dynamic" in self.rope_type:
587+
dynamic_frequency_update(self, position_ids, device=x.device)
588+
elif self.rope_type == "longrope":
589+
longrope_frequency_update(self, position_ids, device=x.device)
590+
return rope_forward(self, x, position_ids)
591+
592+
return wrapper
593+
594+
595+
class patched_Phi3RotaryEmbedding(torch.nn.Module):
596+
_PATCHES_ = ["forward"]
597+
_PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
598+
599+
@torch.no_grad()
600+
@patched_dynamic_rope_update
601+
def forward(self, x, position_ids):
602+
inv_freq_expanded = (
603+
self.inv_freq[None, :, None]
604+
.float()
605+
.expand(position_ids.shape[0], -1, 1)
606+
.to(x.device)
607+
)
608+
position_ids_expanded = position_ids[:, None, :].float()
609+
610+
device_type = (
611+
x.device.type
612+
if isinstance(x.device.type, str) and x.device.type != "mps"
613+
else "cpu"
614+
)
615+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
616+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
617+
emb = torch.cat((freqs, freqs), dim=-1)
618+
cos = emb.cos() * self.attention_scaling
619+
sin = emb.sin() * self.attention_scaling
620+
621+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3951,3 +3951,145 @@ def _ccached_facebook_bart_large_cnn():
39513951
"vocab_size": 50264,
39523952
}
39533953
)
3954+
3955+
3956+
def _ccached_microsoft_phi4_reasoning():
3957+
"microsoft/Phi-4-mini-reasoning"
3958+
return transformers.Phi3Config(
3959+
**{
3960+
"architectures": ["Phi3ForCausalLM"],
3961+
"attention_bias": false,
3962+
"attention_dropout": 0.0,
3963+
"bos_token_id": 199999,
3964+
"embd_pdrop": 0.0,
3965+
"eos_token_id": 199999,
3966+
"full_attn_mod": 1,
3967+
"hidden_act": "silu",
3968+
"hidden_size": 3072,
3969+
"initializer_range": 0.02,
3970+
"intermediate_size": 8192,
3971+
"interpolate_factor": 1,
3972+
"lm_head_bias": false,
3973+
"max_position_embeddings": 131072,
3974+
"mlp_bias": false,
3975+
"model_type": "phi3",
3976+
"num_attention_heads": 24,
3977+
"num_hidden_layers": 32,
3978+
"num_key_value_heads": 8,
3979+
"original_max_position_embeddings": 4096,
3980+
"pad_token_id": 199999,
3981+
"partial_rotary_factor": 0.75,
3982+
"resid_pdrop": 0.0,
3983+
"rms_norm_eps": 1e-05,
3984+
"rope_scaling": {
3985+
"long_factor": [
3986+
1,
3987+
1.118320672,
3988+
1.250641126,
3989+
1.398617824,
3990+
1.564103225,
3991+
1.74916897,
3992+
1.956131817,
3993+
2.187582649,
3994+
2.446418898,
3995+
2.735880826,
3996+
3.059592084,
3997+
3.421605075,
3998+
3.826451687,
3999+
4.279200023,
4000+
4.785517845,
4001+
5.351743533,
4002+
5.984965424,
4003+
6.693110555,
4004+
7.485043894,
4005+
8.370679318,
4006+
9.36110372,
4007+
10.4687158,
4008+
11.70738129,
4009+
13.09260651,
4010+
14.64173252,
4011+
16.37415215,
4012+
18.31155283,
4013+
20.47818807,
4014+
22.90118105,
4015+
25.61086418,
4016+
28.64115884,
4017+
32.03,
4018+
32.1,
4019+
32.13,
4020+
32.23,
4021+
32.6,
4022+
32.61,
4023+
32.64,
4024+
32.66,
4025+
32.7,
4026+
32.71,
4027+
32.93,
4028+
32.97,
4029+
33.28,
4030+
33.49,
4031+
33.5,
4032+
44.16,
4033+
47.77,
4034+
],
4035+
"short_factor": [
4036+
1,
4037+
1.118320672,
4038+
1.250641126,
4039+
1.398617824,
4040+
1.564103225,
4041+
1.74916897,
4042+
1.956131817,
4043+
2.187582649,
4044+
2.446418898,
4045+
2.735880826,
4046+
3.059592084,
4047+
3.421605075,
4048+
3.826451687,
4049+
4.279200023,
4050+
4.785517845,
4051+
5.351743533,
4052+
5.984965424,
4053+
6.693110555,
4054+
7.485043894,
4055+
8.370679318,
4056+
9.36110372,
4057+
10.4687158,
4058+
11.70738129,
4059+
13.09260651,
4060+
14.64173252,
4061+
16.37415215,
4062+
18.31155283,
4063+
20.47818807,
4064+
22.90118105,
4065+
25.61086418,
4066+
28.64115884,
4067+
32.03,
4068+
32.1,
4069+
32.13,
4070+
32.23,
4071+
32.6,
4072+
32.61,
4073+
32.64,
4074+
32.66,
4075+
32.7,
4076+
32.71,
4077+
32.93,
4078+
32.97,
4079+
33.28,
4080+
33.49,
4081+
33.5,
4082+
44.16,
4083+
47.77,
4084+
],
4085+
"type": "longrope",
4086+
},
4087+
"rope_theta": 10000.0,
4088+
"sliding_window": 262144,
4089+
"tie_word_embeddings": true,
4090+
"torch_dtype": "bfloat16",
4091+
"transformers_version": "4.50.0",
4092+
"use_cache": true,
4093+
"vocab_size": 200064,
4094+
}
4095+
)

0 commit comments

Comments
 (0)