Skip to content

Commit 63abbbb

Browse files
committed
add summarization
1 parent b065f65 commit 63abbbb

File tree

9 files changed

+328
-1
lines changed

9 files changed

+328
-1
lines changed

_doc/api/tasks/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ Or:
4343
mixture_of_expert
4444
object_detection
4545
sentence_similarity
46+
summarization
4647
text_classification
4748
text_generation
4849
text2text_generation

_doc/api/tasks/summarization.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.tasks.summarization
3+
===================================
4+
5+
.. automodule:: onnx_diagnostic.tasks.summarization
6+
:members:
7+
:no-undoc-members:

_unittests/ut_tasks/test_tasks.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,20 @@ def test_feature_extraction_tiny_bart(self):
150150
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
151151
)
152152

153+
@hide_stdout()
154+
def test_summarization(self):
155+
mid = "facebook/bart-large-cnn"
156+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
157+
self.assertEqual(data["task"], "summarization")
158+
self.assertIn((data["size"], data["n_weights"]), [(1625161728, 406290432)])
159+
model, inputs, _ds = data["model"], data["inputs"], data["dynamic_shapes"]
160+
model(**inputs)
161+
model(**data["inputs2"])
162+
# with torch_export_patches(patch_transformers=True, verbose=10):
163+
# torch.export.export(
164+
# model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
165+
# )
166+
153167
@hide_stdout()
154168
def test_text_classification(self):
155169
mid = "Intel/bert-base-uncased-mrpc"

onnx_diagnostic/tasks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
mixture_of_expert,
99
object_detection,
1010
sentence_similarity,
11+
summarization,
1112
text_classification,
1213
text_generation,
1314
text2text_generation,
@@ -23,6 +24,7 @@
2324
mixture_of_expert,
2425
object_detection,
2526
sentence_similarity,
27+
summarization,
2628
text_classification,
2729
text_generation,
2830
text2text_generation,
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
from typing import Any, Callable, Dict, Optional, Tuple
2+
import torch
3+
from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
4+
from ..helpers.config_helper import update_config, check_hasattr, _pick
5+
6+
__TASK__ = "summarization"
7+
8+
9+
def reduce_model_config(config: Any) -> Dict[str, Any]:
10+
"""Reduces a model size."""
11+
kwargs: Dict[str, Any] = {}
12+
if hasattr(config, "num_decoder_layers"):
13+
config.num_decoder_layers = min(config.num_decoder_layers, 2)
14+
if hasattr(config, "num_hidden_layers"):
15+
config.num_hidden_layers = min(config.num_hidden_layers, 2)
16+
update_config(config, kwargs)
17+
return kwargs
18+
19+
20+
def get_inputs(
21+
model: torch.nn.Module,
22+
config: Optional[Any],
23+
dummy_max_token_id: int,
24+
num_key_value_heads_encoder: int,
25+
num_key_value_heads_decoder: int,
26+
num_hidden_layers: int,
27+
head_dim_encoder: int,
28+
head_dim_decoder: int,
29+
batch_size: int = 2,
30+
sequence_length: int = 30,
31+
sequence_length2: int = 3,
32+
add_second_input: bool = False,
33+
**kwargs, # unused
34+
):
35+
"""
36+
Generates input for task ``summarization``.
37+
38+
:param model: model to get the missing information
39+
:param config: configuration used to generate the model
40+
:param head_dim_encoder: last dimension of the cache for the encoder
41+
:param head_dim_decoder: last dimension of the cache for the decoder
42+
:param num_key_value_heads_encoder: number of heads for the encoder
43+
:param num_key_value_heads_decoder: number of heads for the decoder
44+
:param dummy_max_token_id: dummy max token id
45+
:param batch_size: batch size
46+
:param sequence_length: sequence length
47+
:param sequence_length2: new sequence length
48+
:return: dictionary
49+
50+
Stolen inputs for one model.
51+
52+
::
53+
54+
cache_position:T7s1
55+
past_key_values:EncoderDecoderCache(
56+
self_attention_cache=DynamicCache(
57+
key_cache=#6[T1s1x8x1x64,...],
58+
value_cache=#6[T1s1x8x1x64,...]),
59+
cross_attention_cache=DynamicCache(
60+
key_cache=#6[T1s1x8x16x64,...],
61+
value_cache=#6[T1s1x8x16x64,...])),
62+
decoder_input_ids:T7s1x1,
63+
encoder_outputs:dict(last_hidden_state:T1s1x16x512)
64+
"""
65+
batch = torch.export.Dim("batch", min=1, max=1024)
66+
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
67+
cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096)
68+
cache_length2 = "cache_length_val" # torch.export.Dim("cache_length2", min=1, max=4096)
69+
70+
shapes = {
71+
"input_ids": {0: batch, 1: seq_length},
72+
"decoder_input_ids": {0: batch, 1: "seq_ids"},
73+
"attention_mask": {0: batch, 1: "seq_mask"},
74+
# "cache_position": {0: batch, 1: torch.export.Dim.DYNAMIC},
75+
"past_key_values": [
76+
[
77+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
78+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
79+
],
80+
[
81+
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
82+
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
83+
],
84+
],
85+
# one these is selected based on the forward method signature
86+
# "encoder_last_hidden_state": {0: batch, 1: torch.export.Dim.DYNAMIC},
87+
# "encoder_outputs": {0: batch, 1: torch.export.Dim.DYNAMIC},
88+
}
89+
90+
inputs = dict(
91+
input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to(
92+
torch.int64
93+
),
94+
decoder_input_ids=torch.randint(
95+
0, dummy_max_token_id, (batch_size, sequence_length2)
96+
).to(torch.int64),
97+
attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
98+
# cache_position=torch.arange(sequence_length, sequence_length + sequence_length2)
99+
# .to(torch.int64)
100+
# .expand((batch_size, -1)),
101+
past_key_values=make_encoder_decoder_cache(
102+
make_dynamic_cache(
103+
[
104+
(
105+
torch.randn(
106+
batch_size,
107+
num_key_value_heads_encoder,
108+
sequence_length,
109+
head_dim_encoder,
110+
),
111+
torch.randn(
112+
batch_size,
113+
num_key_value_heads_encoder,
114+
sequence_length,
115+
head_dim_encoder,
116+
),
117+
)
118+
for i in range(num_hidden_layers)
119+
]
120+
),
121+
make_dynamic_cache(
122+
[
123+
(
124+
torch.randn(
125+
batch_size,
126+
num_key_value_heads_decoder,
127+
sequence_length2,
128+
head_dim_decoder,
129+
),
130+
torch.randn(
131+
batch_size,
132+
num_key_value_heads_decoder,
133+
sequence_length2,
134+
head_dim_decoder,
135+
),
136+
)
137+
for i in range(num_hidden_layers)
138+
]
139+
),
140+
),
141+
)
142+
res = dict(inputs=inputs, dynamic_shapes=shapes)
143+
if add_second_input:
144+
res["inputs2"] = get_inputs(
145+
model=model,
146+
config=config,
147+
dummy_max_token_id=dummy_max_token_id,
148+
num_key_value_heads_encoder=num_key_value_heads_encoder,
149+
num_key_value_heads_decoder=num_key_value_heads_decoder,
150+
num_hidden_layers=num_hidden_layers,
151+
head_dim_encoder=head_dim_encoder,
152+
head_dim_decoder=head_dim_decoder,
153+
batch_size=batch_size + 1,
154+
sequence_length=sequence_length + 1,
155+
sequence_length2=sequence_length2 + 1,
156+
**kwargs,
157+
)["inputs"]
158+
return res
159+
160+
161+
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
162+
"""
163+
Inputs kwargs.
164+
165+
If the configuration is None, the function selects typical dimensions.
166+
"""
167+
if config is not None:
168+
check_hasattr(
169+
config,
170+
"vocab_size",
171+
"hidden_size",
172+
"num_attention_heads",
173+
("num_hidden_layers", "num_layers"),
174+
("n_positions", "d_model"),
175+
(
176+
"num_key_value_heads",
177+
"num_heads",
178+
("decoder_attention_heads", "encoder_attention_heads"),
179+
),
180+
)
181+
# exceptions = {
182+
# "PLBartForConditionalGeneration": (
183+
# lambda c: c.encoder_attention_heads + c.decoder_attention_heads
184+
# )
185+
# }
186+
kwargs = dict(
187+
batch_size=2,
188+
sequence_length=30,
189+
sequence_length2=3,
190+
head_dim_encoder=(
191+
16 if config is None else int(_pick(config, "encoder_ffn_dim") ** 0.5)
192+
),
193+
head_dim_decoder=(
194+
16 if config is None else int(_pick(config, "decoder_ffn_dim") ** 0.5)
195+
),
196+
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
197+
num_hidden_layers=(
198+
8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
199+
),
200+
num_key_value_heads_encoder=(
201+
16
202+
if config is None
203+
else _pick(
204+
config,
205+
"encoder_attention_heads",
206+
"num_key_value_heads",
207+
"num_heads",
208+
)
209+
),
210+
num_key_value_heads_decoder=(
211+
16
212+
if config is None
213+
else _pick(
214+
config,
215+
"decoder_attention_heads",
216+
"num_key_value_heads",
217+
"num_heads",
218+
)
219+
),
220+
)
221+
return kwargs, get_inputs

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
ASTModel,feature-extraction
1414
AlbertModel,feature-extraction
1515
BeitForImageClassification,image-classification
16+
BartForConditionalGeneration,summarization
1617
BartModel,feature-extraction
1718
BertForMaskedLM,fill-mask
1819
BertForSequenceClassification,text-classification
@@ -163,6 +164,7 @@
163164
"object-detection",
164165
"reinforcement-learning",
165166
"sentence-similarity",
167+
"summarization",
166168
"text-classification",
167169
"text-generation",
168170
"text-to-image",

onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3852,7 +3852,7 @@ def _ccached_hustvl_yolos_tiny():
38523852
)
38533853

38543854

3855-
def _ccached_facebook_bart_base():
3855+
def _ccached_tiny_random_plbart_for_conditioan_generation():
38563856
"hf-tiny-model-private/tiny-random-PLBartForConditionalGeneration"
38573857
return transformers.BartConfig(
38583858
**{
@@ -3887,3 +3887,67 @@ def _ccached_facebook_bart_base():
38873887
"vocab_size": 50005,
38883888
}
38893889
)
3890+
3891+
3892+
def _ccached_facebook_bart_large_cnn():
3893+
"facebook/bart-large-cnn"
3894+
return transformers.BartConfig(
3895+
**{
3896+
"_num_labels": 3,
3897+
"activation_dropout": 0.0,
3898+
"activation_function": "gelu",
3899+
"add_final_layer_norm": false,
3900+
"architectures": ["BartForConditionalGeneration"],
3901+
"attention_dropout": 0.0,
3902+
"bos_token_id": 0,
3903+
"classif_dropout": 0.0,
3904+
"classifier_dropout": 0.0,
3905+
"d_model": 1024,
3906+
"decoder_attention_heads": 16,
3907+
"decoder_ffn_dim": 4096,
3908+
"decoder_layerdrop": 0.0,
3909+
"decoder_layers": 12,
3910+
"decoder_start_token_id": 2,
3911+
"dropout": 0.1,
3912+
"early_stopping": true,
3913+
"encoder_attention_heads": 16,
3914+
"encoder_ffn_dim": 4096,
3915+
"encoder_layerdrop": 0.0,
3916+
"encoder_layers": 12,
3917+
"eos_token_id": 2,
3918+
"force_bos_token_to_be_generated": true,
3919+
"forced_bos_token_id": 0,
3920+
"forced_eos_token_id": 2,
3921+
"gradient_checkpointing": false,
3922+
"id2label": {"0": "LABEL_0", "1": "LABEL_1", "2": "LABEL_2"},
3923+
"init_std": 0.02,
3924+
"is_encoder_decoder": true,
3925+
"label2id": {"LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2},
3926+
"length_penalty": 2.0,
3927+
"max_length": 142,
3928+
"max_position_embeddings": 1024,
3929+
"min_length": 56,
3930+
"model_type": "bart",
3931+
"no_repeat_ngram_size": 3,
3932+
"normalize_before": false,
3933+
"num_beams": 4,
3934+
"num_hidden_layers": 12,
3935+
"output_past": true,
3936+
"pad_token_id": 1,
3937+
"prefix": " ",
3938+
"scale_embedding": false,
3939+
"task_specific_params": {
3940+
"summarization": {
3941+
"early_stopping": true,
3942+
"length_penalty": 2.0,
3943+
"max_length": 142,
3944+
"min_length": 56,
3945+
"no_repeat_ngram_size": 3,
3946+
"num_beams": 4,
3947+
}
3948+
},
3949+
"transformers_version": "4.7.0.dev0",
3950+
"use_cache": true,
3951+
"vocab_size": 50264,
3952+
}
3953+
)

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
import os
23
from typing import Any, Dict, Optional, Tuple
34
import torch
45
import transformers
@@ -132,6 +133,11 @@ def get_untrained_model_with_inputs(
132133
kwargs, fct = random_input_kwargs(config, task)
133134
if verbose:
134135
print(f"[get_untrained_model_with_inputs] use fct={fct}")
136+
if os.environ.get("PRINT_CONFIG") in (1, "1"):
137+
import pprint
138+
139+
print(f"-- input kwargs for task {task!r}")
140+
pprint.pprint(kwargs)
135141
if inputs_kwargs:
136142
kwargs.update(inputs_kwargs)
137143

0 commit comments

Comments
 (0)