Skip to content

Commit 035ccf8

Browse files
authored
add support for text-classification (#55)
* add support for text-classification * fix zero shot * fix examples
1 parent 89a50b0 commit 035ccf8

File tree

12 files changed

+166
-8
lines changed

12 files changed

+166
-8
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.4.0
55
+++++
66

7+
* :pr:`55`: add support for text-classification
8+
* :pr:`54`: add support for fill-mask, refactoring
79
* :pr:`52`: add support for zero-shot-image-classification
810
* :pr:`50`: add support for onnxruntime fusion
911
* :pr:`48`: add support for EncoderDecoderCache, test with openai/whisper-tiny

_doc/api/tasks/index.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ onnx_diagnostic.tasks
88
automatic_speech_recognition
99
fill_mask
1010
image_classification
11-
image_text_to_text
11+
image_text_to_text
12+
text_classification
1213
text_generation
1314
text2text_generation
1415
zero_shot_image_classification
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.export.text_classification
3+
==========================================
4+
5+
.. automodule:: onnx_diagnostic.tasks.text_classification
6+
:members:
7+
:no-undoc-members:

_unittests/ut_tasks/test_tasks.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,15 @@ def test_fill_mask(self):
100100
model, inputs = data["model"], data["inputs"]
101101
model(**inputs)
102102

103+
@hide_stdout()
104+
def test_text_classification(self):
105+
mid = "Intel/bert-base-uncased-mrpc"
106+
# mid = "Salesforce/codet5-small"
107+
data = get_untrained_model_with_inputs(mid, verbose=1)
108+
self.assertIn((data["size"], data["n_weights"]), [(154420232, 38605058)])
109+
model, inputs = data["model"], data["inputs"]
110+
model(**inputs)
111+
103112

104113
if __name__ == "__main__":
105114
unittest.main(verbosity=2)

_unittests/ut_tasks/try_tasks.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,31 @@ def test_fill_mask(self):
211211
output = model(**encoded_input)
212212
print("-- outputs", string_type(output, with_shape=True, with_min_max=True))
213213

214+
@never_test()
215+
def test_text_classification(self):
216+
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k text_cl
217+
# https://huggingface.co/Intel/bert-base-uncased-mrpc
218+
219+
from transformers import BertTokenizer, BertModel
220+
221+
tokenizer = BertTokenizer.from_pretrained("Intel/bert-base-uncased-mrpc")
222+
model = BertModel.from_pretrained("Intel/bert-base-uncased-mrpc")
223+
text = "The inspector analyzed the soundness in the building."
224+
encoded_input = tokenizer(text, return_tensors="pt")
225+
print()
226+
print("-- inputs", string_type(encoded_input, with_shape=True, with_min_max=True))
227+
output = model(**encoded_input)
228+
print("-- outputs", string_type(output, with_shape=True, with_min_max=True))
229+
# print BaseModelOutputWithPoolingAndCrossAttentions and pooler_output
230+
231+
# Print tokens * ids in of inmput string below
232+
print("Tokenized Text: ", tokenizer.tokenize(text), "\n")
233+
print("Token IDs: ", tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)))
234+
235+
# Print tokens in text
236+
encoded_input["input_ids"][0]
237+
tokenizer.convert_ids_to_tokens(encoded_input["input_ids"][0])
238+
214239

215240
if __name__ == "__main__":
216241
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_documentation_recipes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ def run_test(self, fold: str, name: str, verbose=0) -> int:
5353
# dot not installed, this part
5454
# is tested in onnx framework
5555
raise unittest.SkipTest(f"failed: {name!r} due to missing dot.")
56-
if "We couldn't connect to 'https://huggingface.co'" in st:
56+
if (
57+
"We couldn't connect to 'https://huggingface.co'" in st
58+
or "Cannot access content at: https://huggingface.co/" in st
59+
):
5760
raise unittest.SkipTest(f"Connectivity issues due to\n{err}")
5861
raise AssertionError( # noqa: B904
5962
"Example '{}' (cmd: {} - exec_prefix='{}') "

onnx_diagnostic/tasks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
fill_mask,
55
image_classification,
66
image_text_to_text,
7+
text_classification,
78
text_generation,
89
text2text_generation,
910
zero_shot_image_classification,
@@ -14,6 +15,7 @@
1415
fill_mask,
1516
image_classification,
1617
image_text_to_text,
18+
text_classification,
1719
text_generation,
1820
text2text_generation,
1921
zero_shot_image_classification,
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from typing import Any, Callable, Dict, Optional, Tuple
2+
import torch
3+
from ..helpers.config_helper import update_config, check_hasattr
4+
5+
__TASK__ = "text-classification"
6+
7+
8+
def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
9+
"""Reduces a model size."""
10+
check_hasattr(config, "num_attention_heads", "num_hidden_layers")
11+
kwargs = dict(
12+
num_hidden_layers=min(config.num_hidden_layers, 2),
13+
num_attention_heads=min(config.num_attention_heads, 4),
14+
)
15+
update_config(config, kwargs)
16+
return kwargs
17+
18+
19+
def get_inputs(
20+
model: torch.nn.Module,
21+
config: Optional[Any],
22+
batch_size: int,
23+
sequence_length: int,
24+
dummy_max_token_id: int,
25+
**kwargs, # unused
26+
):
27+
"""
28+
Generates inputs for task ``fill-mask``.
29+
Example:
30+
31+
::
32+
33+
input_ids:T7s1x13[101,72654:A16789.23076923077],
34+
token_type_ids:T7s1x13[0,0:A0.0],
35+
attention_mask:T7s1x13[1,1:A1.0])
36+
"""
37+
batch = torch.export.Dim("batch", min=1, max=1024)
38+
seq_length = torch.export.Dim("sequence_length", min=1, max=1024)
39+
shapes = {
40+
"input_ids": {0: batch, 1: seq_length},
41+
"token_type_ids": {0: batch, 1: seq_length},
42+
"attention_mask": {0: batch, 1: seq_length},
43+
}
44+
inputs = dict(
45+
input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to(
46+
torch.int64
47+
),
48+
token_type_ids=torch.zeros((batch_size, sequence_length)).to(torch.int64),
49+
attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
50+
)
51+
return dict(inputs=inputs, dynamic_shapes=shapes)
52+
53+
54+
def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
55+
"""
56+
Inputs kwargs.
57+
58+
If the configuration is None, the function selects typical dimensions.
59+
"""
60+
if config is not None:
61+
check_hasattr(config, "vocab_size")
62+
kwargs = dict(
63+
batch_size=2,
64+
sequence_length=30,
65+
dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
66+
)
67+
return kwargs, get_inputs

onnx_diagnostic/tasks/zero_shot_image_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ def get_inputs(
6262
), f"Unexpected type for input_height {type(input_height)}{config}"
6363

6464
batch = torch.export.Dim("batch", min=1, max=1024)
65-
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
65+
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
6666
shapes = {
67-
"inputs_ids": {0: batch, 1: seq_length},
67+
"input_ids": {0: batch, 1: seq_length},
6868
"attention_mask": {0: batch, 1: seq_length},
6969
"pixel_values": {
7070
0: torch.export.Dim("batch_img", min=1, max=1024),

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
AlbertModel,feature-extraction
1515
BeitForImageClassification,image-classification
1616
BertForMaskedLM,fill-mask
17+
BertForSequenceClassification,text-classification
1718
BigBirdModel,feature-extraction
1819
BlenderbotModel,feature-extraction
1920
BloomModel,feature-extraction
@@ -145,6 +146,7 @@
145146
"no-pipeline-tag",
146147
"object-detection",
147148
"reinforcement-learning",
149+
"text-classification",
148150
"text-generation",
149151
"text-to-audio",
150152
"text2text-generation",

0 commit comments

Comments
 (0)