Skip to content

Commit 64999f9

Browse files
authored
Add cached configurations for unit tests (#27)
* add cached configuration * ci * fix dict * fix onfigs
1 parent 2914c64 commit 64999f9

File tree

8 files changed

+3375
-126
lines changed

8 files changed

+3375
-126
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,19 @@ jobs:
7575
run: |
7676
export PYTHONPATH=.
7777
python _unittests/ut_torch_models/test_tiny_llms_onnx.py
78-
continue-on-error: true
78+
continue-on-error: true # connectivity issues
7979

8080
- name: tiny-llm example
8181
run: |
8282
export PYTHONPATH=.
8383
python _doc/examples/plot_export_tiny_llm.py
84-
continue-on-error: true
84+
continue-on-error: true # connectivity issues
8585

8686
- name: tiny-llm bypass
8787
run: |
8888
export PYTHONPATH=.
8989
python _doc/examples/plot_export_tiny_llm_patched.py
90+
continue-on-error: true # connectivity issues
9091

9192
- name: run tests
9293
run: |

_doc/examples/plot_export_tiny_phi2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
data["n_weights"],
4545
)
4646

47-
print(f"model {size / 2**10:1.3f} Kb with {n_weights} parameters.")
47+
print(f"model {size / 2**20:1.3f} Mb with {n_weights // 1000} mille parameters.")
4848
# %%
4949
# The original model has 2.7 billion parameters. It was divided by more than 10.
5050
# Let's see the configuration.
@@ -156,4 +156,4 @@
156156
# It looks good.
157157

158158
# %%
159-
doc.plot_legend("untrained smaller\nmicrosoft/phi-2", "torch.onnx.export", "green")
159+
doc.plot_legend("untrained smaller\nmicrosoft/phi-2", "torch.onnx.export", "orange")

_unittests/ut_torch_models/test_hghub_api.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,14 @@ def test_model_testings_and_architectures(self):
127127
task = task_from_id(mid)
128128
self.assertNotEmpty(task)
129129

130+
def test__ccached_config_64(self):
131+
from onnx_diagnostic.torch_models.hghub.hub_data_cached_configs import (
132+
_cached_hf_internal_testing_tiny_random_beitforimageclassification,
133+
)
134+
135+
conf = _cached_hf_internal_testing_tiny_random_beitforimageclassification()
136+
self.assertEqual(conf.auxiliary_channels, 256)
137+
130138

131139
if __name__ == "__main__":
132140
unittest.main(verbosity=2)

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,58 @@
11
import functools
2-
from typing import List, Optional, Union
2+
from typing import Dict, List, Optional, Union
33
import transformers
44
from huggingface_hub import HfApi, model_info
5+
from . import hub_data_cached_configs
56
from .hub_data import __date__, __data_tasks__, load_architecture_task
67

78

8-
def get_pretrained_config(model_id: str, trust_remote_code: bool = True) -> str:
9-
"""Returns the config for a model_id."""
9+
@functools.cache
10+
def _retrieve_cached_configurations() -> Dict[str, transformers.PretrainedConfig]:
11+
res = {}
12+
for k, v in hub_data_cached_configs.__dict__.items():
13+
if k.startswith("_ccached_"):
14+
doc = v.__doc__
15+
res[doc] = v
16+
return res
17+
18+
19+
def get_cached_configuration(name: str) -> Optional[transformers.PretrainedConfig]:
20+
"""
21+
Returns cached configuration to avoid having to many accesses to internet.
22+
It returns None if not Cache. The list of cached models follows.
23+
24+
.. runpython::
25+
26+
from onnx_diagnostic.torch_models.hghub.hug_api import _retrieve_cached_configurations
27+
28+
configs = _retrieve_cached_configurations()
29+
pprint.pprint(sorted(configs))
30+
"""
31+
cached = _retrieve_cached_configurations()
32+
assert cached, "no cached configuration, which is weird"
33+
if name in cached:
34+
return cached[name]()
35+
return None
36+
37+
38+
def get_pretrained_config(
39+
model_id: str, trust_remote_code: bool = True, use_cached: bool = True
40+
) -> str:
41+
"""
42+
Returns the config for a model_id.
43+
44+
:param model_id: model id
45+
:param trust_remote_code: trust_remote_code,
46+
see :meth:`transformers.AutoConfig.from_pretrained`
47+
:param used_cached: if cached, uses this version to avoid
48+
accessing the network, if available, it is returned by
49+
:func:`get_cached_configuration`, the cached list is mostly for
50+
unit tests
51+
"""
52+
if use_cached:
53+
conf = get_cached_configuration(model_id)
54+
if conf is not None:
55+
return conf
1056
return transformers.AutoConfig.from_pretrained(
1157
model_id, trust_remote_code=trust_remote_code
1258
)

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 120 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,127 +1,129 @@
11
import io
22
import functools
3+
import textwrap
34
from typing import Dict, List
45

56
__date__ = "2025-03-26"
67

7-
__data_arch__ = """
8-
architecture,task
9-
ASTModel,feature-extraction
10-
AlbertModel,feature-extraction
11-
BeitForImageClassification,image-classification
12-
BigBirdModel,feature-extraction
13-
BlenderbotModel,feature-extraction
14-
BloomModel,feature-extraction
15-
CLIPModel,zero-shot-image-classification
16-
CLIPVisionModel,feature-extraction
17-
CamembertModel,feature-extraction
18-
CodeGenModel,feature-extraction
19-
ConvBertModel,feature-extraction
20-
ConvNextForImageClassification,image-classification
21-
ConvNextV2Model,image-feature-extraction
22-
CvtModel,feature-extraction
23-
DPTModel,image-feature-extraction
24-
Data2VecAudioModel,feature-extraction
25-
Data2VecTextModel,feature-extraction
26-
Data2VecVisionModel,image-feature-extraction
27-
DebertaModel,feature-extraction
28-
DebertaV2Model,feature-extraction
29-
DecisionTransformerModel,reinforcement-learning
30-
DeiTModel,image-feature-extraction
31-
DetrModel,image-feature-extraction
32-
Dinov2Model,image-feature-extraction
33-
DistilBertModel,feature-extraction
34-
DonutSwinModel,feature-extraction
35-
ElectraModel,feature-extraction
36-
EsmModel,feature-extraction
37-
GLPNModel,image-feature-extraction
38-
GPTBigCodeModel,feature-extraction
39-
GPTJModel,feature-extraction
40-
GPTNeoModel,feature-extraction
41-
GPTNeoXForCausalLM,text-generation
42-
GemmaForCausalLM,text-generation
43-
GraniteForCausalLM,text-generation
44-
GroupViTModel,feature-extraction
45-
HieraForImageClassification,image-classification
46-
HubertModel,feature-extraction
47-
IBertModel,feature-extraction
48-
ImageGPTModel,image-feature-extraction
49-
LayoutLMModel,feature-extraction
50-
LayoutLMv3Model,feature-extraction
51-
LevitModel,image-feature-extraction
52-
LiltModel,feature-extraction
53-
LlamaForCausalLM,text-generation
54-
LongT5Model,feature-extraction
55-
LongformerModel,feature-extraction
56-
MCTCTModel,feature-extraction
57-
MPNetModel,feature-extraction
58-
MT5Model,feature-extraction
59-
MarianMTModel,text2text-generation
60-
MarkupLMModel,feature-extraction
61-
MaskFormerForInstanceSegmentation,image-segmentation
62-
MegatronBertModel,feature-extraction
63-
MgpstrForSceneTextRecognition,feature-extraction
64-
MistralForCausalLM,text-generation
65-
MobileBertModel,feature-extraction
66-
MobileNetV1Model,image-feature-extraction
67-
MobileNetV2Model,image-feature-extraction
68-
MobileViTForImageClassification,image-classification
69-
ModernBertForMaskedLM,fill-mask
70-
MoonshineForConditionalGeneration,automatic-speech-recognition
71-
MptForCausalLM,text-generation
72-
MusicgenForConditionalGeneration,text-to-audio
73-
NystromformerModel,feature-extraction
74-
OPTModel,feature-extraction
75-
Olmo2ForCausalLM,text-generation
76-
OlmoForCausalLM,text-generation
77-
OwlViTModel,feature-extraction
78-
Owlv2Model,feature-extraction
79-
PatchTSMixerForPrediction,no-pipeline-tag
80-
PatchTSTForPrediction,no-pipeline-tag
81-
PegasusModel,feature-extraction
82-
Phi3ForCausalLM,text-generation
83-
PhiForCausalLM,text-generation
84-
Pix2StructForConditionalGeneration,image-to-text
85-
PoolFormerModel,image-feature-extraction
86-
PvtForImageClassification,image-classification
87-
Qwen2ForCausalLM,text-generation
88-
RTDetrForObjectDetection,object-detection
89-
RegNetModel,image-feature-extraction
90-
RemBertModel,feature-extraction
91-
ResNetForImageClassification,image-classification
92-
RoFormerModel,feature-extraction
93-
RobertaModel,feature-extraction
94-
RtDetrV2ForObjectDetection,object-detection
95-
SEWDModel,feature-extraction
96-
SEWModel,feature-extraction
97-
SamModel,mask-generation
98-
SegformerModel,image-feature-extraction
99-
SiglipModel,zero-shot-image-classification
100-
SiglipVisionModel,image-feature-extraction
101-
Speech2TextModel,feature-extraction
102-
SpeechT5ForTextToSpeech,text-to-audio
103-
SplinterModel,feature-extraction
104-
SqueezeBertModel,feature-extraction
105-
Swin2SRModel,image-feature-extraction
106-
SwinModel,image-feature-extraction
107-
Swinv2Model,image-feature-extraction
108-
T5ForConditionalGeneration,text2text-generation
109-
TableTransformerModel,image-feature-extraction
110-
UniSpeechForSequenceClassification,audio-classification
111-
ViTForImageClassification,image-classification
112-
ViTMAEModel,image-feature-extraction
113-
ViTMSNForImageClassification,image-classification
114-
VisionEncoderDecoderModel,document-question-answering
115-
VitPoseForPoseEstimation,keypoint-detection
116-
VitsModel,text-to-audio
117-
Wav2Vec2ConformerForCTC,automatic-speech-recognition
118-
Wav2Vec2Model,feature-extraction
119-
WhisperForConditionalGeneration,no-pipeline-tag
120-
XLMModel,feature-extraction
121-
XLMRobertaForCausalLM,text-generation
122-
YolosForObjectDetection,object-detection
123-
YolosModel,image-feature-extraction
124-
"""
8+
__data_arch__ = textwrap.dedent(
9+
"""
10+
architecture,task
11+
ASTModel,feature-extraction
12+
AlbertModel,feature-extraction
13+
BeitForImageClassification,image-classification
14+
BigBirdModel,feature-extraction
15+
BlenderbotModel,feature-extraction
16+
BloomModel,feature-extraction
17+
CLIPModel,zero-shot-image-classification
18+
CLIPVisionModel,feature-extraction
19+
CamembertModel,feature-extraction
20+
CodeGenModel,feature-extraction
21+
ConvBertModel,feature-extraction
22+
ConvNextForImageClassification,image-classification
23+
ConvNextV2Model,image-feature-extraction
24+
CvtModel,feature-extraction
25+
DPTModel,image-feature-extraction
26+
Data2VecAudioModel,feature-extraction
27+
Data2VecTextModel,feature-extraction
28+
Data2VecVisionModel,image-feature-extraction
29+
DebertaModel,feature-extraction
30+
DebertaV2Model,feature-extraction
31+
DecisionTransformerModel,reinforcement-learning
32+
DeiTModel,image-feature-extraction
33+
DetrModel,image-feature-extraction
34+
Dinov2Model,image-feature-extraction
35+
DistilBertModel,feature-extraction
36+
DonutSwinModel,feature-extraction
37+
ElectraModel,feature-extraction
38+
EsmModel,feature-extraction
39+
GLPNModel,image-feature-extraction
40+
GPTBigCodeModel,feature-extraction
41+
GPTJModel,feature-extraction
42+
GPTNeoModel,feature-extraction
43+
GPTNeoXForCausalLM,text-generation
44+
GemmaForCausalLM,text-generation
45+
GraniteForCausalLM,text-generation
46+
GroupViTModel,feature-extraction
47+
HieraForImageClassification,image-classification
48+
HubertModel,feature-extraction
49+
IBertModel,feature-extraction
50+
ImageGPTModel,image-feature-extraction
51+
LayoutLMModel,feature-extraction
52+
LayoutLMv3Model,feature-extraction
53+
LevitModel,image-feature-extraction
54+
LiltModel,feature-extraction
55+
LlamaForCausalLM,text-generation
56+
LongT5Model,feature-extraction
57+
LongformerModel,feature-extraction
58+
MCTCTModel,feature-extraction
59+
MPNetModel,feature-extraction
60+
MT5Model,feature-extraction
61+
MarianMTModel,text2text-generation
62+
MarkupLMModel,feature-extraction
63+
MaskFormerForInstanceSegmentation,image-segmentation
64+
MegatronBertModel,feature-extraction
65+
MgpstrForSceneTextRecognition,feature-extraction
66+
MistralForCausalLM,text-generation
67+
MobileBertModel,feature-extraction
68+
MobileNetV1Model,image-feature-extraction
69+
MobileNetV2Model,image-feature-extraction
70+
MobileViTForImageClassification,image-classification
71+
ModernBertForMaskedLM,fill-mask
72+
MoonshineForConditionalGeneration,automatic-speech-recognition
73+
MptForCausalLM,text-generation
74+
MusicgenForConditionalGeneration,text-to-audio
75+
NystromformerModel,feature-extraction
76+
OPTModel,feature-extraction
77+
Olmo2ForCausalLM,text-generation
78+
OlmoForCausalLM,text-generation
79+
OwlViTModel,feature-extraction
80+
Owlv2Model,feature-extraction
81+
PatchTSMixerForPrediction,no-pipeline-tag
82+
PatchTSTForPrediction,no-pipeline-tag
83+
PegasusModel,feature-extraction
84+
Phi3ForCausalLM,text-generation
85+
PhiForCausalLM,text-generation
86+
Pix2StructForConditionalGeneration,image-to-text
87+
PoolFormerModel,image-feature-extraction
88+
PvtForImageClassification,image-classification
89+
Qwen2ForCausalLM,text-generation
90+
RTDetrForObjectDetection,object-detection
91+
RegNetModel,image-feature-extraction
92+
RemBertModel,feature-extraction
93+
ResNetForImageClassification,image-classification
94+
RoFormerModel,feature-extraction
95+
RobertaModel,feature-extraction
96+
RtDetrV2ForObjectDetection,object-detection
97+
SEWDModel,feature-extraction
98+
SEWModel,feature-extraction
99+
SamModel,mask-generation
100+
SegformerModel,image-feature-extraction
101+
SiglipModel,zero-shot-image-classification
102+
SiglipVisionModel,image-feature-extraction
103+
Speech2TextModel,feature-extraction
104+
SpeechT5ForTextToSpeech,text-to-audio
105+
SplinterModel,feature-extraction
106+
SqueezeBertModel,feature-extraction
107+
Swin2SRModel,image-feature-extraction
108+
SwinModel,image-feature-extraction
109+
Swinv2Model,image-feature-extraction
110+
T5ForConditionalGeneration,text2text-generation
111+
TableTransformerModel,image-feature-extraction
112+
UniSpeechForSequenceClassification,audio-classification
113+
ViTForImageClassification,image-classification
114+
ViTMAEModel,image-feature-extraction
115+
ViTMSNForImageClassification,image-classification
116+
VisionEncoderDecoderModel,document-question-answering
117+
VitPoseForPoseEstimation,keypoint-detection
118+
VitsModel,text-to-audio
119+
Wav2Vec2ConformerForCTC,automatic-speech-recognition
120+
Wav2Vec2Model,feature-extraction
121+
WhisperForConditionalGeneration,no-pipeline-tag
122+
XLMModel,feature-extraction
123+
XLMRobertaForCausalLM,text-generation
124+
YolosForObjectDetection,object-detection
125+
YolosModel,image-feature-extraction"""
126+
)
125127

126128
__data_tasks__ = [
127129
"automatic-speech-recognition",

0 commit comments

Comments
 (0)