Skip to content

Commit ce20a57

Browse files
committed
fix dummy inputs
1 parent bef2af2 commit ce20a57

File tree

5 files changed

+48
-11
lines changed

5 files changed

+48
-11
lines changed

_unittests/ut_torch_models/test_hghub_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_task_from_tags(self):
9898
]
9999
for tags, etask in _tags:
100100
with self.subTest(tags=tags, task=etask):
101-
task = task_from_tags(tags, True)
101+
task = task_from_tags(tags)
102102
self.assertEqual(etask, task)
103103

104104

_unittests/ut_torch_models/test_hghub_model.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,27 @@ def test_config_class_from_architecture(self):
2020
self.assertEqual(config, transformers.LlamaConfig)
2121

2222
@hide_stdout()
23-
def test_get_untrained_model_with_inputs(self):
23+
def test_get_untrained_model_with_inputs_tiny_llm(self):
2424
mid = "arnir0/Tiny-LLM"
2525
data = get_untrained_model_with_inputs(mid, verbose=1)
2626
model, inputs = data["model"], data["inputs"]
2727
model(**inputs)
28-
self.assertEqual(data["size"], 1858125824)
29-
self.assertEqual(data["n_weights"], 464531456)
28+
self.assertEqual((data["size"], data["n_weights"]), (1858125824, 464531456))
29+
30+
@hide_stdout()
31+
def test_get_untrained_model_with_inputs_tiny_xlm_roberta(self):
32+
mid = "hf-internal-testing/tiny-xlm-roberta" # XLMRobertaConfig
33+
data = get_untrained_model_with_inputs(mid, verbose=1)
34+
model, inputs = data["model"], data["inputs"]
35+
model(**inputs)
36+
self.assertEqual((data["size"], data["n_weights"]), (126190824, 31547706))
37+
38+
def test_get_untrained_model_with_inputs_tiny_gpt_neo(self):
39+
mid = "hf-internal-testing/tiny-random-GPTNeoXForCausalLM"
40+
data = get_untrained_model_with_inputs(mid, verbose=1)
41+
model, inputs = data["model"], data["inputs"]
42+
model(**inputs)
43+
self.assertEqual((data["size"], data["n_weights"]), (4291141632, 1072785408))
3044

3145

3246
if __name__ == "__main__":

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .hub_data import __date__, __data_tasks__, load_architecture_task
66

77

8-
def get_pretrained_config(model_id) -> str:
8+
def get_pretrained_config(model_id: str) -> str:
99
"""Returns the config for a model_id."""
1010
return transformers.AutoConfig.from_pretrained(model_id)
1111

@@ -61,7 +61,7 @@ def task_from_id(model_id: str, pretrained: bool = False) -> str:
6161
def task_from_tags(tags: Union[str, List[str]]) -> str:
6262
"""
6363
Guesses the task from the list of tags.
64-
If given by a string, ``|`` should be the separater.
64+
If given by a string, ``|`` should be the separator.
6565
"""
6666
if isinstance(tags, str):
6767
tags = tags.split("|")

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
__data_tasks__ = [
77
"automatic-speech-recognition",
8+
"image-text-to-text",
89
"image-to-text",
910
"text-generation",
1011
"object-detection",

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,31 @@
1010

1111

1212
@functools.cache
13-
def config_class_from_architecture(arch: str) -> type:
13+
def config_class_from_architecture(arch: str, exc: bool = False) -> type:
1414
"""
1515
Retrieves the configuration class for a given architecture.
16+
17+
:param arch: architecture (clas name)
18+
:param exc: raise an exception if not found
19+
:return: type
1620
"""
1721
cls = getattr(transformers, arch)
1822
mod_name = cls.__module__
1923
mod = importlib.import_module(mod_name)
2024
source = inspect.getsource(mod)
2125
reg = re.compile("config: ([A-Za-z0-9]+)")
2226
fall = reg.findall(source)
27+
if len(fall) == 0:
28+
assert not exc, (
29+
f"Unable to guess Configuration class name for arch={arch!r}, "
30+
f"module={mod_name!r}, no candidate, source is\n{source}"
31+
)
32+
return None
2333
unique = set(fall)
2434
assert len(unique) == 1, (
2535
f"Unable to guess Configuration class name for arch={arch!r}, "
26-
f"module={mod_name!r}, source is\n{source}"
36+
f"module={mod_name!r}, found={unique} (#{len(unique)}), "
37+
f"source is\n{source}"
2738
)
2839
cls_name = unique.pop()
2940
return getattr(transformers, cls_name)
@@ -81,7 +92,14 @@ def get_untrained_model_with_inputs(
8192
arch = archs[0]
8293
if verbose:
8394
print(f"[get_untrained_model_with_inputs] architecture={arch!r}")
84-
cls = config_class_from_architecture(arch)
95+
cls = config_class_from_architecture(arch, exc=False)
96+
if cls is None:
97+
if verbose:
98+
print(
99+
"[get_untrained_model_with_inputs] no found config name in the code, loads it"
100+
)
101+
config = get_pretrained_config(model_id)
102+
cls = config.__class__
85103
if verbose:
86104
print(f"[get_untrained_model_with_inputs] cls={cls.__name__!r}")
87105

@@ -107,12 +125,16 @@ def get_untrained_model_with_inputs(
107125
batch_size=2,
108126
sequence_length=30,
109127
sequence_length2=3,
110-
num_hidden_layers=config.num_hidden_layers,
111-
num_key_value_heads=config.num_key_value_heads,
112128
head_dim=getattr(
113129
config, "head_dim", config.hidden_size // config.num_attention_heads
114130
),
115131
max_token_id=config.vocab_size - 1,
132+
num_hidden_layers=config.num_hidden_layers,
133+
num_key_value_heads=(
134+
config.num_key_value_heads
135+
if hasattr(config, "num_key_value_heads")
136+
else config.num_attention_heads
137+
),
116138
)
117139
if inputs_kwargs:
118140
kwargs.update(inputs_kwargs)

0 commit comments

Comments
 (0)