Skip to content

Commit b4360cb

Browse files
authored
add patches for Funnel (#367)
* add patches for Funnel * a few fixes * first step for t5gemma * fixes * ifx one file
1 parent b6ac5b3 commit b4360cb

File tree

9 files changed

+270
-80
lines changed

9 files changed

+270
-80
lines changed

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
from onnx_diagnostic.torch_models.hghub.hub_api import get_cached_configuration
1919
from onnx_diagnostic.torch_export_patches import torch_export_patches
2020
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
21-
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import patch_qwen2_5
21+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
22+
patch_qwen2_5,
23+
patch_funnel,
24+
)
2225
from onnx_diagnostic.export.api import to_onnx
2326

2427

@@ -787,6 +790,42 @@ def test_plug_multi_head_attention_qwen25_loopa24_float32(self):
787790
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
788791
self.assertLess(results.diffs[0]["abs"], 1e-5)
789792

793+
@unittest.skipIf(not patch_funnel, "Funnel not part of this transformers")
794+
def test_model_funnel(self):
795+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
796+
patched_FunnelAttentionStructure,
797+
patched_FunnelRelMultiheadAttention,
798+
)
799+
800+
pos = torch.tensor([0, 4, 5, 8], dtype=torch.long)
801+
stride = 2
802+
config = transformers.models.funnel.modeling_funnel.FunnelConfig()
803+
original = transformers.models.funnel.modeling_funnel.FunnelAttentionStructure(config)
804+
patched = patched_FunnelAttentionStructure()
805+
self.assertEqualArray(
806+
original.relative_pos(pos, stride=stride), patched.relative_pos(pos, stride=stride)
807+
)
808+
809+
rmha = transformers.models.funnel.modeling_funnel.FunnelRelMultiheadAttention(
810+
config, 2
811+
)
812+
patched = patched_FunnelRelMultiheadAttention()
813+
patched.config = config
814+
for att in ["block_index", "r_r_bias", "scale", "r_kernel"]:
815+
setattr(patched, att, getattr(rmha, att))
816+
inputs = dict(
817+
position_embeds=[
818+
[torch.rand((24, 768)), None],
819+
[torch.rand((12, 768)), torch.rand((24, 768))],
820+
[torch.rand((6, 768)), torch.rand((12, 768))],
821+
],
822+
q_head=torch.rand((2, 12, 12, 64)),
823+
context_len=12,
824+
)
825+
expected = rmha.relative_positional_attention(**inputs)
826+
got = patched.relative_positional_attention(**inputs)
827+
self.assertEqualArray(expected, got)
828+
790829

791830
if __name__ == "__main__":
792831
unittest.main(verbosity=2)

onnx_diagnostic/ci_models/export_qwen25_vl.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import os
6060
import sys
6161
import time
62+
import warnings
6263
from typing import Any, Dict, List, Tuple
6364
from .ci_helpers import (
6465
check_for_discrepancies_and_log_everything_into_a_json_file,
@@ -301,7 +302,11 @@ def main(
301302
print(f"-- config._attn_implementation={model.config._attn_implementation}")
302303
print(f"-- model.dtype={model.dtype}")
303304
print(f"-- model.device={model.device}")
304-
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
305+
try:
306+
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
307+
except OSError as e:
308+
warnings.warn(f"Unable to access internet due to {e!r}", ResourceWarning, stacklevel=0)
309+
return
305310
print(f"-- processor={type(processor)}")
306311

307312
export_inputs, other_inputs = None, None

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
__TASK__ = "image-text-to-text"
1414

1515

16+
def should_have_vision_config(config):
17+
return config.architectures != ["FuyuForCausalLM"]
18+
19+
1620
def reduce_model_config(config: Any) -> Dict[str, Any]:
1721
"""Reduces a model size."""
1822
kwargs: Dict[str, Any] = {}
@@ -477,7 +481,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
477481
"hidden_size",
478482
"pad_token_id",
479483
)
480-
check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
484+
if should_have_vision_config(config):
485+
check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
481486
text_config = True
482487
else:
483488
check_hasattr(
@@ -491,7 +496,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
491496
"vision_config",
492497
)
493498
text_config = False
494-
check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
499+
if should_have_vision_config(config):
500+
check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
495501
kwargs = dict(
496502
head_dim=(
497503
16
@@ -552,17 +558,21 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
552558
),
553559
width=(
554560
224
555-
if config is None or not hasattr(config.vision_config, "image_size")
561+
if config is None
562+
or not should_have_vision_config(config)
563+
or not hasattr(config.vision_config, "image_size")
556564
else config.vision_config.image_size
557565
),
558566
height=(
559567
224
560-
if config is None or not hasattr(config.vision_config, "image_size")
568+
if config is None
569+
or not should_have_vision_config(config)
570+
or not hasattr(config.vision_config, "image_size")
561571
else config.vision_config.image_size
562572
),
563573
num_channels=(
564574
3
565-
if config is None
575+
if config is None or not should_have_vision_config(config)
566576
else _pick(config.vision_config, "num_channels", "in_chans", "in_channels")
567577
),
568578
pad_token_id=(

onnx_diagnostic/tasks/text2text_generation.py

Lines changed: 84 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,22 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
1818
config.num_decoder_layers = min(config.num_decoder_layers, 2)
1919
if hasattr(config, "num_hidden_layers"):
2020
config.num_hidden_layers = min(config.num_hidden_layers, nhl())
21+
if hasattr(config, "encoder") and hasattr(config.encoder, "layer_types"):
22+
default_layer_types = [
23+
"sliding_attention",
24+
"full_attention",
25+
"sliding_attention",
26+
"full_attention",
27+
]
28+
config.encoder.num_hidden_layers = 4
29+
config.encoder.layer_types = (
30+
default_layer_types if config is None else config.encoder.layer_types[:4]
31+
)
32+
config.decoder.num_hidden_layers = 4
33+
config.decoder.layer_types = (
34+
default_layer_types if config is None else config.decoder.layer_types[:4]
35+
)
36+
2137
update_config(config, kwargs)
2238
return kwargs
2339

@@ -177,55 +193,75 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
177193
178194
If the configuration is None, the function selects typical dimensions.
179195
"""
196+
path = 1
180197
if config is not None:
181-
check_hasattr(
182-
config,
183-
"vocab_size",
184-
"hidden_size",
185-
"num_attention_heads",
186-
("num_hidden_layers", "num_layers"),
187-
("n_positions", "d_model"),
188-
(
189-
"num_key_value_heads",
190-
"num_heads",
191-
("decoder_attention_heads", "encoder_attention_heads"),
192-
),
193-
)
194-
# exceptions = {
195-
# "PLBartForConditionalGeneration": (
196-
# lambda c: c.encoder_attention_heads + c.decoder_attention_heads
197-
# )
198-
# }
199-
kwargs = dict(
200-
batch_size=2,
201-
sequence_length=30,
202-
sequence_length2=3,
203-
head_dim_encoder=16 if config is None else _pick(config, "d_kv", "encoder_ffn_dim"),
204-
head_dim_decoder=16 if config is None else _pick(config, "d_kv", "decoder_ffn_dim"),
205-
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
206-
num_hidden_layers=(
207-
8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
208-
),
209-
num_key_value_heads_encoder=(
210-
16
211-
if config is None
212-
else _pick(
198+
if hasattr(config, "num_attention_heads"):
199+
check_hasattr(
213200
config,
214-
"encoder_attention_heads",
215-
"num_key_value_heads",
216-
"num_heads",
201+
"vocab_size",
202+
"hidden_size",
203+
"num_attention_heads",
204+
("num_hidden_layers", "num_layers"),
205+
("n_positions", "d_model"),
206+
(
207+
"num_key_value_heads",
208+
"num_heads",
209+
("decoder_attention_heads", "encoder_attention_heads"),
210+
),
217211
)
218-
),
219-
num_key_value_heads_decoder=(
220-
16
221-
if config is None
222-
else _pick(
223-
config,
224-
"decoder_attention_heads",
225-
"num_key_value_heads",
226-
"num_heads",
227-
)
228-
),
229-
encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"),
230-
)
212+
else:
213+
check_hasattr(config, "encoder", "decoder")
214+
path = 2
215+
216+
if path == 1:
217+
kwargs = dict(
218+
batch_size=2,
219+
sequence_length=30,
220+
sequence_length2=3,
221+
head_dim_encoder=(
222+
16 if config is None else _pick(config, "d_kv", "encoder_ffn_dim")
223+
),
224+
head_dim_decoder=(
225+
16 if config is None else _pick(config, "d_kv", "decoder_ffn_dim")
226+
),
227+
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
228+
num_hidden_layers=(
229+
8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
230+
),
231+
num_key_value_heads_encoder=(
232+
16
233+
if config is None
234+
else _pick(
235+
config,
236+
"encoder_attention_heads",
237+
"num_key_value_heads",
238+
"num_heads",
239+
)
240+
),
241+
num_key_value_heads_decoder=(
242+
16
243+
if config is None
244+
else _pick(
245+
config,
246+
"decoder_attention_heads",
247+
"num_key_value_heads",
248+
"num_heads",
249+
)
250+
),
251+
encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"),
252+
)
253+
else:
254+
kwargs = dict(
255+
batch_size=2,
256+
sequence_length=30,
257+
sequence_length2=3,
258+
dummy_max_token_id=config.encoder.vocab_size - 1,
259+
num_key_value_heads_encoder=config.encoder.num_key_value_heads,
260+
num_key_value_heads_decoder=config.decoder.num_key_value_heads,
261+
num_hidden_layers=len(config.encoder.layer_types),
262+
head_dim_encoder=config.encoder.head_dim,
263+
head_dim_decoder=config.decoder.head_dim,
264+
encoder_dim=256,
265+
)
266+
231267
return kwargs, get_inputs

onnx_diagnostic/tasks/text_generation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
4040
state_size=8 if config is None else getattr(config, "state_size", None),
4141
conv_kernel=4 if config is None else getattr(config, "conv_kernel", None),
4242
)
43+
elif config.__class__.__name__ == "FunnelConfig":
44+
# does not support num_hidden_layers
45+
kwargs = dict()
4346
else:
4447
kwargs = dict(
4548
head_dim=getattr(
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import torch
2+
3+
try:
4+
import transformers.models.funnel.modeling_funnel
5+
6+
patch_funnel = True
7+
except ImportError:
8+
patch_funnel = False
9+
10+
if patch_funnel:
11+
from transformers.models.funnel.modeling_funnel import _relative_shift_gather
12+
13+
class patched_FunnelAttentionStructure(torch.nn.Module):
14+
_PATCHES_ = ["relative_pos"]
15+
_PATCHED_CLASS_ = transformers.models.funnel.modeling_funnel.FunnelAttentionStructure
16+
17+
def relative_pos(
18+
self, pos: torch.Tensor, stride: int, pooled_pos=None, shift: int = 1
19+
) -> torch.Tensor:
20+
if pooled_pos is None:
21+
pooled_pos = pos
22+
ref_point = pooled_pos[0] - pos[0]
23+
# PATCHED
24+
num_remove = shift * pooled_pos.shape[0]
25+
max_dist = ref_point + num_remove * stride
26+
min_dist = pooled_pos[0] - pos[-1]
27+
return torch.arange(
28+
max_dist.to(torch.long),
29+
(min_dist - 1).to(torch.long),
30+
torch.tensor(-stride, dtype=torch.long),
31+
dtype=torch.long,
32+
device=pos.device,
33+
)
34+
35+
class patched_FunnelRelMultiheadAttention(torch.nn.Module):
36+
_PATCHES_ = ["relative_positional_attention"]
37+
_PATCHED_CLASS_ = (
38+
transformers.models.funnel.modeling_funnel.FunnelRelMultiheadAttention
39+
)
40+
41+
def relative_positional_attention(
42+
self, position_embeds, q_head, context_len, cls_mask=None
43+
):
44+
"""Relative attention score for the positional encodings"""
45+
# q_head has shape batch_size x sea_len x n_head x d_head
46+
if self.config.attention_type == "factorized":
47+
phi, pi, psi, omega = position_embeds
48+
# Shape n_head x d_head
49+
u = self.r_r_bias * self.scale
50+
# Shape d_model x n_head x d_head
51+
w_r = self.r_kernel
52+
53+
# Shape batch_size x sea_len x n_head x d_model
54+
q_r_attention = torch.einsum("binh,dnh->bind", q_head + u, w_r)
55+
q_r_attention_1 = q_r_attention * phi[:, None]
56+
q_r_attention_2 = q_r_attention * pi[:, None]
57+
58+
# Shape batch_size x n_head x seq_len x context_len
59+
positional_attn = torch.einsum(
60+
"bind,jd->bnij", q_r_attention_1, psi
61+
) + torch.einsum("bind,jd->bnij", q_r_attention_2, omega)
62+
else:
63+
shift = 2 if q_head.shape[1] != context_len else 1
64+
r = position_embeds[self.block_index][shift - 1]
65+
# Shape n_head x d_head
66+
v = self.r_r_bias * self.scale
67+
# Shape d_model x n_head x d_head
68+
w_r = self.r_kernel
69+
70+
# Shape max_rel_len x n_head x d_model
71+
r_head = torch.einsum("td,dnh->tnh", r, w_r)
72+
# Shape batch_size x n_head x seq_len x max_rel_len
73+
positional_attn = torch.einsum("binh,tnh->bnit", q_head + v, r_head)
74+
# Shape batch_size x n_head x seq_len x context_len
75+
positional_attn = _relative_shift_gather(positional_attn, context_len, shift)
76+
77+
if cls_mask is not None:
78+
# PATCHED
79+
positional_attn = positional_attn * cls_mask
80+
return positional_attn

0 commit comments

Comments
 (0)