Skip to content

Commit 2e92eda

Browse files
committed
add feature extraction
1 parent f38681a commit 2e92eda

File tree

8 files changed

+106
-0
lines changed

8 files changed

+106
-0
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ jobs:
2727
transformers: '4.51.3'
2828
- python: '3.11'
2929
torch: '2.7'
30+
- python: '3.12'
31+
torch: '2.6'
3032
steps:
3133
- uses: actions/checkout@v3
3234

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.tasks.feature_extraction
3+
========================================
4+
5+
.. automodule:: onnx_diagnostic.tasks.feature_extraction
6+
:members:
7+
:no-undoc-members:

_doc/api/tasks/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Or:
3434

3535
automatic_speech_recognition
3636
fill_mask
37+
feature_extraction
3738
image_classification
3839
image_text_to_text
3940
mixture_of_expert

_unittests/ut_tasks/test_tasks.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,18 @@ def test_fill_mask(self):
116116
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
117117
)
118118

119+
@hide_stdout()
120+
def test_feature_extraction(self):
121+
mid = "facebook/bart-base"
122+
data = get_untrained_model_with_inputs(mid, verbose=1)
123+
self.assertIn((data["size"], data["n_weights"]), [(557681664, 139420416)])
124+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
125+
model(**inputs)
126+
with bypass_export_some_errors(patch_transformers=True, verbose=10):
127+
torch.export.export(
128+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
129+
)
130+
119131
@hide_stdout()
120132
def test_text_classification(self):
121133
mid = "Intel/bert-base-uncased-mrpc"

_unittests/ut_tasks/try_tasks.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,22 @@ def test_fill_mask(self):
338338
output = model(**encoded_input)
339339
print("-- outputs", string_type(output, with_shape=True, with_min_max=True))
340340

341+
@never_test()
342+
def test_feature_extraction(self):
343+
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k feature_ex
344+
# https://huggingface.co/google-bert/bert-base-multilingual-cased
345+
346+
from transformers import BartTokenizer, BartModel
347+
348+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
349+
model = BartModel.from_pretrained("facebook/bart-base")
350+
text = "Replace me by any text you'd like."
351+
encoded_input = tokenizer(text, return_tensors="pt")
352+
print()
353+
print("-- inputs", string_type(encoded_input, with_shape=True, with_min_max=True))
354+
output = model(**encoded_input)
355+
print("-- outputs", string_type(output, with_shape=True, with_min_max=True))
356+
341357
@never_test()
342358
def test_text_classification(self):
343359
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k text_cl

onnx_diagnostic/tasks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, Callable, Dict, List, Tuple
22
from . import (
33
automatic_speech_recognition,
4+
feature_extraction,
45
fill_mask,
56
image_classification,
67
image_text_to_text,
@@ -14,6 +15,7 @@
1415

1516
__TASKS__ = [
1617
automatic_speech_recognition,
18+
feature_extraction,
1719
fill_mask,
1820
image_classification,
1921
image_text_to_text,
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from typing import Any, Callable, Dict, Optional, Tuple
2+
import torch
3+
from ..helpers.config_helper import update_config, check_hasattr
4+
5+
__TASK__ = "feature-extraction"
6+
7+
8+
def reduce_model_config(config: Any) -> Dict[str, Any]:
9+
"""Reduces a model size."""
10+
check_hasattr(config, "num_attention_heads", "num_hidden_layers")
11+
kwargs = dict(
12+
num_hidden_layers=min(config.num_hidden_layers, 2),
13+
num_attention_heads=min(config.num_attention_heads, 4),
14+
)
15+
update_config(config, kwargs)
16+
return kwargs
17+
18+
19+
def get_inputs(
20+
model: torch.nn.Module,
21+
config: Optional[Any],
22+
batch_size: int,
23+
sequence_length: int,
24+
dummy_max_token_id: int,
25+
**kwargs, # unused
26+
):
27+
"""
28+
Generates inputs for task ``feature-extraction``.
29+
Example:
30+
31+
::
32+
33+
input_ids:T7s1x13[101,72654:A16789.23076923077],
34+
token_type_ids:T7s1x13[0,0:A0.0],
35+
attention_mask:T7s1x13[1,1:A1.0])
36+
"""
37+
batch = torch.export.Dim("batch", min=1, max=1024)
38+
seq_length = "sequence_length"
39+
shapes = {
40+
"input_ids": {0: batch, 1: seq_length},
41+
"attention_mask": {0: batch, 1: seq_length},
42+
}
43+
inputs = dict(
44+
input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to(
45+
torch.int64
46+
),
47+
attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
48+
)
49+
return dict(inputs=inputs, dynamic_shapes=shapes)
50+
51+
52+
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
53+
"""
54+
Inputs kwargs.
55+
56+
If the configuration is None, the function selects typical dimensions.
57+
"""
58+
if config is not None:
59+
check_hasattr(config, "vocab_size")
60+
kwargs = dict(
61+
batch_size=2,
62+
sequence_length=30,
63+
dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
64+
)
65+
return kwargs, get_inputs

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
ASTModel,feature-extraction
1414
AlbertModel,feature-extraction
1515
BeitForImageClassification,image-classification
16+
BartModel,feature-extraction
1617
BertForMaskedLM,fill-mask
1718
BertForSequenceClassification,text-classification
1819
BertModel,sentence-similarity

0 commit comments

Comments
 (0)