Skip to content

Commit 92a750c

Browse files
committed
update feature extraction
1 parent 3cbc1e6 commit 92a750c

File tree

5 files changed

+140
-10
lines changed

5 files changed

+140
-10
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.7.5
55
+++++
66

7+
* :pr:`185`: remove the use of _seen_tokens in DynamicCache (removed in transformers>4.53),
8+
updates dummpy inputs for feature-extraction
79
* :pr:`184`: implements side-by-side
810

911
0.7.4

_unittests/ut_tasks/try_tasks.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import unittest
2+
import torch
23
from onnx_diagnostic.ext_test_case import ExtTestCase, never_test
34
from onnx_diagnostic.helpers import string_type
5+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
46
from onnx_diagnostic.helpers.torch_helper import steal_forward
57

68

@@ -378,6 +380,51 @@ def test_feature_extraction(self):
378380
model = BartModel.from_pretrained("facebook/bart-base")
379381
text = "Replace me by any text you'd like."
380382
encoded_input = tokenizer(text, return_tensors="pt")
383+
sequence_length, sequence_length2 = 30, 4
384+
sequence_length = 3
385+
batch_size, encoder_attention_heads, encoder_ffn_dim = 1, 12, 64
386+
batch_size, decoder_attention_heads, decoder_ffn_dim = 1, 12, 64
387+
num_hidden_layers = 6
388+
encoded_input["past_key_values"] = make_encoder_decoder_cache(
389+
make_dynamic_cache(
390+
[
391+
(
392+
torch.randn(
393+
batch_size,
394+
encoder_attention_heads,
395+
sequence_length,
396+
encoder_ffn_dim,
397+
),
398+
torch.randn(
399+
batch_size,
400+
encoder_attention_heads,
401+
sequence_length,
402+
encoder_ffn_dim,
403+
),
404+
)
405+
for i in range(num_hidden_layers)
406+
]
407+
),
408+
make_dynamic_cache(
409+
[
410+
(
411+
torch.randn(
412+
batch_size,
413+
decoder_attention_heads,
414+
sequence_length2,
415+
decoder_ffn_dim,
416+
),
417+
torch.randn(
418+
batch_size,
419+
decoder_attention_heads,
420+
sequence_length2,
421+
decoder_ffn_dim,
422+
),
423+
)
424+
for i in range(num_hidden_layers)
425+
]
426+
),
427+
)
381428
print()
382429
print("-- inputs", string_type(encoded_input, with_shape=True, with_min_max=True))
383430
output = model(**encoded_input)

onnx_diagnostic/tasks/feature_extraction.py

Lines changed: 86 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
from typing import Any, Callable, Dict, Optional, Tuple
22
import torch
33
from ..helpers.config_helper import update_config, check_hasattr
4+
from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
45

56
__TASK__ = "feature-extraction"
67

78

89
def reduce_model_config(config: Any) -> Dict[str, Any]:
910
"""Reduces a model size."""
10-
check_hasattr(config, "num_attention_heads", "num_hidden_layers")
11-
kwargs = dict(
12-
num_hidden_layers=min(config.num_hidden_layers, 2),
13-
num_attention_heads=min(config.num_attention_heads, 4),
14-
)
11+
check_hasattr(config, "num_hidden_layers")
12+
kwargs = dict(num_hidden_layers=min(config.num_hidden_layers, 2))
1513
update_config(config, kwargs)
1614
return kwargs
1715

@@ -22,6 +20,12 @@ def get_inputs(
2220
batch_size: int,
2321
sequence_length: int,
2422
dummy_max_token_id: int,
23+
sequence_length2: int = 3,
24+
decoder_attention_heads: Optional[int] = None,
25+
encoder_attention_heads: Optional[int] = None,
26+
encoder_ffn_dim: Optional[int] = None,
27+
decoder_ffn_dim: Optional[int] = None,
28+
num_hidden_layers: Optional[int] = None,
2529
add_second_input: int = 1,
2630
**kwargs, # unused
2731
):
@@ -50,6 +54,66 @@ def get_inputs(
5054
),
5155
attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
5256
)
57+
if (
58+
encoder_attention_heads
59+
and decoder_attention_heads
60+
and encoder_ffn_dim
61+
and decoder_ffn_dim
62+
and num_hidden_layers
63+
):
64+
inputs["past_key_values"] = make_encoder_decoder_cache(
65+
make_dynamic_cache(
66+
[
67+
(
68+
torch.randn(
69+
batch_size,
70+
encoder_attention_heads,
71+
sequence_length,
72+
encoder_ffn_dim,
73+
),
74+
torch.randn(
75+
batch_size,
76+
encoder_attention_heads,
77+
sequence_length,
78+
encoder_ffn_dim,
79+
),
80+
)
81+
for i in range(num_hidden_layers)
82+
]
83+
),
84+
make_dynamic_cache(
85+
[
86+
(
87+
torch.randn(
88+
batch_size,
89+
decoder_attention_heads,
90+
sequence_length2,
91+
decoder_ffn_dim,
92+
),
93+
torch.randn(
94+
batch_size,
95+
decoder_attention_heads,
96+
sequence_length2,
97+
decoder_ffn_dim,
98+
),
99+
)
100+
for i in range(num_hidden_layers)
101+
]
102+
),
103+
)
104+
cache_length = "cache_length_key"
105+
cache_length2 = "cache_length_val"
106+
shapes["past_key_values"] = [
107+
[
108+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
109+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
110+
],
111+
[
112+
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
113+
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
114+
],
115+
]
116+
53117
res = dict(inputs=inputs, dynamic_shapes=shapes)
54118
if add_second_input:
55119
assert (
@@ -61,6 +125,12 @@ def get_inputs(
61125
batch_size=batch_size + 1,
62126
sequence_length=sequence_length + add_second_input,
63127
dummy_max_token_id=dummy_max_token_id,
128+
sequence_length2=sequence_length2,
129+
decoder_attention_heads=decoder_attention_heads,
130+
encoder_attention_heads=encoder_attention_heads,
131+
encoder_ffn_dim=encoder_ffn_dim,
132+
decoder_ffn_dim=decoder_ffn_dim,
133+
num_hidden_layers=num_hidden_layers,
64134
add_second_input=0,
65135
**kwargs,
66136
)["inputs"]
@@ -80,4 +150,15 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
80150
sequence_length=30,
81151
dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
82152
)
153+
for att in [
154+
"decoder_attention_heads",
155+
"encoder_attention_heads",
156+
"encoder_ffn_dim",
157+
"decoder_ffn_dim",
158+
"num_hidden_layers",
159+
]:
160+
if hasattr(config, att):
161+
kwargs[att] = getattr(config, att)
162+
kwargs["decoder_ffn_dim"] = kwargs["encoder_ffn_dim"] = 64
163+
print(kwargs)
83164
return kwargs, get_inputs

onnx_diagnostic/tasks/text2text_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def get_inputs(
6969
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
7070
batch = torch.export.Dim("batch", min=1, max=1024)
7171
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
72-
cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096)
73-
cache_length2 = "cache_length_val" # torch.export.Dim("cache_length2", min=1, max=4096)
72+
cache_length = "cache_length_key"
73+
cache_length2 = "cache_length_val"
7474

7575
shapes = {
7676
"input_ids": {0: batch, 1: seq_length},

onnx_diagnostic/torch_models/validate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,7 @@ def validate_onnx_model(
10901090
"""
10911091
import onnxruntime
10921092

1093-
def _mk(key):
1093+
def _mk(key, flavour=flavour):
10941094
return f"{key}_{flavour}" if flavour else key
10951095

10961096
summary: Dict[str, Any] = {}
@@ -1145,7 +1145,7 @@ def _mk(key):
11451145
)
11461146
sess = _quiet_or_not_quiet(
11471147
quiet,
1148-
_mk("onnx_ort_create"),
1148+
_mk("create_onnx_ort"),
11491149
summary,
11501150
data,
11511151
(lambda source=source, providers=providers: cls_runtime(source, providers)),
@@ -1180,7 +1180,7 @@ def _mk(key):
11801180

11811181
got = _quiet_or_not_quiet(
11821182
quiet,
1183-
_mk(f"time_onnx_ort_run{suffix}"),
1183+
_mk(f"run_onnx_ort{suffix}"),
11841184
summary,
11851185
data,
11861186
(lambda sess=sess, feeds=feeds: sess.run(None, feeds)),

0 commit comments

Comments
 (0)