Skip to content

Commit 330b46c

Browse files
committed
improve task text2text
1 parent c3da823 commit 330b46c

File tree

5 files changed

+97
-15
lines changed

5 files changed

+97
-15
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.5.0
55
+++++
66

7+
* :pr:`101`: first draft to rewrite loops
78
* :pr:`100`: implements a context to automatically rewrite methods or function with control flows
89
* :pr:`96`: implements ``is_stealing``, ``steal_append`` to complement ``steal_forward``
910
* :pr:`95`: fixzq Scan implementation for ``OnnxruntimeEvaluator``

_unittests/ut_tasks/test_tasks.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def test_fill_mask(self):
123123
)
124124

125125
@hide_stdout()
126-
def test_feature_extraction(self):
126+
def test_feature_extraction_bart_base(self):
127127
mid = "facebook/bart-base"
128128
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
129129
self.assertEqual(data["task"], "feature-extraction")
@@ -136,6 +136,20 @@ def test_feature_extraction(self):
136136
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
137137
)
138138

139+
@hide_stdout()
140+
def test_feature_extraction_tiny_bart(self):
141+
mid = "hf-tiny-model-private/tiny-random-PLBartForConditionalGeneration"
142+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
143+
self.assertEqual(data["task"], "feature-extraction")
144+
self.assertIn((data["size"], data["n_weights"]), [(557681664, 139420416)])
145+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
146+
model(**inputs)
147+
model(**data["inputs2"])
148+
with torch_export_patches(patch_transformers=True, verbose=10):
149+
torch.export.export(
150+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
151+
)
152+
139153
@hide_stdout()
140154
def test_text_classification(self):
141155
mid = "Intel/bert-base-uncased-mrpc"

onnx_diagnostic/tasks/text2text_generation.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@ def get_inputs(
2121
model: torch.nn.Module,
2222
config: Optional[Any],
2323
dummy_max_token_id: int,
24-
num_key_value_heads: int,
24+
num_key_value_heads_encoder: int,
25+
num_key_value_heads_decoder: int,
2526
num_hidden_layers: int,
26-
head_dim: int,
27+
head_dim_encoder: int,
28+
head_dim_decoder: int,
2729
encoder_dim: int,
2830
batch_size: int = 2,
2931
sequence_length: int = 30,
@@ -36,7 +38,10 @@ def get_inputs(
3638
3739
:param model: model to get the missing information
3840
:param config: configuration used to generate the model
39-
:param head_dim: last dimension of the cache
41+
:param head_dim_encoder: last dimension of the cache for the encoder
42+
:param head_dim_decoder: last dimension of the cache for the decoder
43+
:param num_key_value_heads_encoder: number of heads for the encoder
44+
:param num_key_value_heads_decoder: number of heads for the decoder
4045
:param dummy_max_token_id: dummy max token id
4146
:param batch_size: batch size
4247
:param encoder_dim: last dimension of encoder_last_hidden_state
@@ -83,6 +88,7 @@ def get_inputs(
8388
# "encoder_last_hidden_state": {0: batch, 1: torch.export.Dim.DYNAMIC},
8489
# "encoder_outputs": {0: batch, 1: torch.export.Dim.DYNAMIC},
8590
}
91+
8692
inputs = dict(
8793
input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to(
8894
torch.int64
@@ -99,10 +105,16 @@ def get_inputs(
99105
[
100106
(
101107
torch.randn(
102-
batch_size, num_key_value_heads, sequence_length, head_dim
108+
batch_size,
109+
num_key_value_heads_encoder,
110+
sequence_length,
111+
head_dim_encoder,
103112
),
104113
torch.randn(
105-
batch_size, num_key_value_heads, sequence_length, head_dim
114+
batch_size,
115+
num_key_value_heads_encoder,
116+
sequence_length,
117+
head_dim_encoder,
106118
),
107119
)
108120
for i in range(num_hidden_layers)
@@ -112,10 +124,16 @@ def get_inputs(
112124
[
113125
(
114126
torch.randn(
115-
batch_size, num_key_value_heads, sequence_length2, head_dim
127+
batch_size,
128+
num_key_value_heads_decoder,
129+
sequence_length2,
130+
head_dim_decoder,
116131
),
117132
torch.randn(
118-
batch_size, num_key_value_heads, sequence_length2, head_dim
133+
batch_size,
134+
num_key_value_heads_decoder,
135+
sequence_length2,
136+
head_dim_decoder,
119137
),
120138
)
121139
for i in range(num_hidden_layers)
@@ -132,9 +150,11 @@ def get_inputs(
132150
model=model,
133151
config=config,
134152
dummy_max_token_id=dummy_max_token_id,
135-
num_key_value_heads=num_key_value_heads,
153+
num_key_value_heads_encoder=num_key_value_heads_encoder,
154+
num_key_value_heads_decoder=num_key_value_heads_decoder,
136155
num_hidden_layers=num_hidden_layers,
137-
head_dim=head_dim,
156+
head_dim_encoder=head_dim_encoder,
157+
head_dim_decoder=head_dim_decoder,
138158
encoder_dim=encoder_dim,
139159
batch_size=batch_size + 1,
140160
sequence_length=sequence_length + 1,
@@ -173,20 +193,30 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
173193
batch_size=2,
174194
sequence_length=30,
175195
sequence_length2=3,
176-
head_dim=16 if config is None else (config.d_kv if hasattr(config, "d_kv") else 1),
196+
head_dim_encoder=16 if config is None else _pick(config, "d_kv", "encoder_ffn_dim"),
197+
head_dim_decoder=16 if config is None else _pick(config, "d_kv", "decoder_ffn_dim"),
177198
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
178199
num_hidden_layers=(
179200
8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
180201
),
181-
num_key_value_heads=(
202+
num_key_value_heads_encoder=(
203+
16
204+
if config is None
205+
else _pick(
206+
config,
207+
"encoder_attention_heads",
208+
"num_key_value_heads",
209+
"num_heads",
210+
)
211+
),
212+
num_key_value_heads_decoder=(
182213
16
183214
if config is None
184215
else _pick(
185216
config,
217+
"decoder_attention_heads",
186218
"num_key_value_heads",
187219
"num_heads",
188-
(sum, "encoder_attention_heads", "decoder_attention_heads"),
189-
# exceptions=exceptions,
190220
)
191221
),
192222
encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"),

onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3850,3 +3850,40 @@ def _ccached_hustvl_yolos_tiny():
38503850
"use_mid_position_embeddings": false,
38513851
}
38523852
)
3853+
3854+
3855+
def _ccached_facebook_bart_base():
3856+
"hf-tiny-model-private/tiny-random-PLBartForConditionalGeneration"
3857+
return transformers.BartConfig(
3858+
**{
3859+
"activation_dropout": 0.0,
3860+
"activation_function": "gelu",
3861+
"architectures": ["PLBartForConditionalGeneration"],
3862+
"attention_dropout": 0.1,
3863+
"bos_token_id": 0,
3864+
"classifier_dropout": 0.0,
3865+
"d_model": 16,
3866+
"decoder_attention_heads": 4,
3867+
"decoder_ffn_dim": 4,
3868+
"decoder_layerdrop": 0.0,
3869+
"decoder_layers": 2,
3870+
"dropout": 0.1,
3871+
"encoder_attention_heads": 4,
3872+
"encoder_ffn_dim": 4,
3873+
"encoder_layerdrop": 0.0,
3874+
"encoder_layers": 2,
3875+
"eos_token_id": 2,
3876+
"forced_eos_token_id": 2,
3877+
"init_std": 0.02,
3878+
"is_encoder_decoder": true,
3879+
"max_position_embeddings": 100,
3880+
"model_type": "plbart",
3881+
"num_hidden_layers": 2,
3882+
"pad_token_id": 1,
3883+
"scale_embedding": true,
3884+
"torch_dtype": "float32",
3885+
"transformers_version": "4.52.0.dev0",
3886+
"use_cache": true,
3887+
"vocab_size": 50005,
3888+
}
3889+
)

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def validate_model(
374374

375375
for k in ["task", "size", "n_weights"]:
376376
summary[f"model_{k.replace('_','')}"] = data[k]
377-
summary["model_inputs_opionts"] = str(input_options or "")
377+
summary["model_inputs_options"] = str(input_options or "")
378378
summary["model_inputs"] = string_type(data["inputs"], with_shape=True)
379379
summary["model_shapes"] = string_type(data["dynamic_shapes"])
380380
summary["model_class"] = data["model"].__class__.__name__

0 commit comments

Comments
 (0)