Skip to content

Commit a90e771

Browse files
committed
refacto
1 parent f1c6529 commit a90e771

File tree

4 files changed

+148
-67
lines changed

4 files changed

+148
-67
lines changed

_unittests/ut_torch_models/test_hghub_api.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,22 @@ class TestHuggingFaceHubApi(ExtTestCase):
2626

2727
@requires_transformers("4.50") # we limit to some versions of the CI
2828
@requires_torch("2.7")
29+
@hide_stdout()
2930
def test_enumerate_model_list(self):
3031
models = list(
3132
enumerate_model_list(
3233
2,
3334
verbose=1,
3435
dump="test_enumerate_model_list.csv",
35-
filter="text-generation",
36+
filter="image-classification",
3637
library="transformers",
3738
)
3839
)
3940
self.assertEqual(len(models), 2)
4041
df = pandas.read_csv("test_enumerate_model_list.csv")
4142
self.assertEqual(df.shape, (2, 12))
42-
tasks = [task_from_id(c) for c in df.id]
43-
self.assertEqual(["text-generation", "text-generation"], tasks)
43+
tasks = [task_from_id(c, "missing") for c in df.id]
44+
self.assertEqual(len(tasks), 2)
4445

4546
@requires_transformers("4.50")
4647
@requires_torch("2.7")

_unittests/ut_torch_models/test_hghub_model.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,27 +29,36 @@ def test_get_untrained_model_with_inputs_tiny_llm(self):
2929
data = get_untrained_model_with_inputs(mid, verbose=1)
3030
self.assertEqual(
3131
set(data),
32-
{"model", "inputs", "dynamic_shapes", "configuration", "size", "n_weights"},
32+
{
33+
"model",
34+
"inputs",
35+
"dynamic_shapes",
36+
"configuration",
37+
"size",
38+
"n_weights",
39+
"input_kwargs",
40+
"model_kwargs",
41+
},
3342
)
3443
model, inputs = data["model"], data["inputs"]
3544
model(**inputs)
36-
self.assertEqual((1858125824, 464531456), (data["size"], data["n_weights"]))
45+
self.assertEqual((51955968, 12988992), (data["size"], data["n_weights"]))
3746

3847
@hide_stdout()
3948
def test_get_untrained_model_with_inputs_tiny_xlm_roberta(self):
4049
mid = "hf-internal-testing/tiny-xlm-roberta" # XLMRobertaConfig
4150
data = get_untrained_model_with_inputs(mid, verbose=1)
4251
model, inputs = data["model"], data["inputs"]
4352
model(**inputs)
44-
self.assertEqual((126190824, 31547706), (data["size"], data["n_weights"]))
53+
self.assertEqual((8642088, 2160522), (data["size"], data["n_weights"]))
4554

4655
@hide_stdout()
4756
def test_get_untrained_model_with_inputs_tiny_gpt_neo(self):
4857
mid = "hf-internal-testing/tiny-random-GPTNeoXForCausalLM"
4958
data = get_untrained_model_with_inputs(mid, verbose=1)
5059
model, inputs = data["model"], data["inputs"]
5160
model(**inputs)
52-
self.assertEqual((4291141632, 1072785408), (data["size"], data["n_weights"]))
61+
self.assertEqual((316712, 79178), (data["size"], data["n_weights"]))
5362

5463
@hide_stdout()
5564
def test_get_untrained_model_with_inputs_phi_2(self):
@@ -60,7 +69,7 @@ def test_get_untrained_model_with_inputs_phi_2(self):
6069
# different expected value for different version of transformers
6170
self.assertIn(
6271
(data["size"], data["n_weights"]),
63-
[(1040293888, 260073472), (1040498688, 260124672)],
72+
[(453330944, 113332736)],
6473
)
6574

6675
@hide_stdout()
@@ -70,7 +79,7 @@ def test_get_untrained_model_with_inputs_beit(self):
7079
model, inputs = data["model"], data["inputs"]
7180
model(**inputs)
7281
# different expected value for different version of transformers
73-
self.assertIn((data["size"], data["n_weights"]), [(30732296, 7683074)])
82+
self.assertIn((data["size"], data["n_weights"]), [(111448, 27862)])
7483

7584
@hide_stdout()
7685
@long_test()
@@ -81,7 +90,7 @@ def _diff(c1, c2):
8190
if isinstance(v, (str, dict, list, tuple, int, float)) and v != getattr(
8291
c2, k, None
8392
):
84-
rows.append(f"{k} :: -- {v} ++ {getattr(c2, k, "MISS")}")
93+
rows.append(f"{k} :: -- {v} ++ {getattr(c2, k, 'MISS')}")
8594
return "\n".join(rows)
8695

8796
# UNHIDE=1 LONGTEST=1 python _unittests/ut_torch_models/test_hghub_model.py -k L -f

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from .hub_data import __date__, __data_tasks__, load_architecture_task
66

77

8-
def get_pretrained_config(model_id: str) -> str:
8+
def get_pretrained_config(model_id: str, trust_remote_code: bool = True) -> str:
99
"""Returns the config for a model_id."""
10-
return transformers.AutoConfig.from_pretrained(model_id)
10+
return transformers.AutoConfig.from_pretrained(
11+
model_id, trust_remote_code=trust_remote_code
12+
)
1113

1214

1315
def get_model_info(model_id) -> str:
@@ -16,11 +18,12 @@ def get_model_info(model_id) -> str:
1618

1719

1820
@functools.cache
19-
def task_from_arch(arch: str) -> str:
21+
def task_from_arch(arch: str, default_value: Optional[str] = None) -> str:
2022
"""
2123
This function relies on stored information. That information needs to be refresh.
2224
2325
:param arch: architecture name
26+
:param default_value: default value in case the task cannot be determined
2427
:return: task
2528
2629
.. runpython::
@@ -33,17 +36,24 @@ def task_from_arch(arch: str) -> str:
3336
<onnx_diagnostic.torch_models.hghub.hub_data.load_architecture_task>`.
3437
"""
3538
data = load_architecture_task()
39+
if default_value is not None:
40+
return data.get(arch, default_value)
3641
assert arch in data, f"Architecture {arch!r} is unknown, last refresh in {__date__}"
3742
return data[arch]
3843

3944

4045
def task_from_id(
41-
model_id: str, pretrained: bool = False, fall_back_to_pretrained: bool = True
46+
model_id: str,
47+
default_value: Optional[str] = None,
48+
pretrained: bool = False,
49+
fall_back_to_pretrained: bool = True,
4250
) -> str:
4351
"""
4452
Returns the task attached to a model id.
4553
4654
:param model_id: model id
55+
:param default_value: if specified, the function returns this value
56+
if the task cannot be determined
4757
:param pretrained: uses the config
4858
:param fall_back_to_pretrained: balls back to pretrained config
4959
:return: task
@@ -62,7 +72,7 @@ def task_from_id(
6272
f"Cannot return the task of {model_id!r}, pipeline_tag is not setup, "
6373
f"architectures={config.architectures} in config={config}"
6474
)
65-
return task_from_arch(config.architectures[0])
75+
return task_from_arch(config.architectures[0], default_value=default_value)
6676

6777

6878
def task_from_tags(tags: Union[str, List[str]]) -> str:

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 113 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import importlib
33
import inspect
44
import re
5-
from typing import Any, Dict, Optional, Tuple
5+
from typing import Any, Callable, Dict, Optional, Tuple
66
import torch
77
import transformers
88
from ...cache_helpers import make_dynamic_cache
@@ -46,6 +46,104 @@ def _update_config(config: Any, kwargs: Dict[str, Any]):
4646
setattr(config, k, v)
4747

4848

49+
def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
50+
"""Reduces a model size."""
51+
if task == "text-generation":
52+
kwargs = dict(
53+
head_dim=getattr(
54+
config, "head_dim", config.hidden_size // config.num_attention_heads
55+
),
56+
num_hidden_layers=min(config.num_hidden_layers, 2),
57+
num_key_value_heads=(
58+
config.num_key_value_heads
59+
if hasattr(config, "num_key_value_heads")
60+
else config.num_attention_heads
61+
),
62+
intermediate_size=(
63+
min(config.intermediate_size, 24576 // 4)
64+
if config.intermediate_size % 4 == 0
65+
else config.intermediate_size
66+
),
67+
hidden_size=(
68+
min(config.hidden_size, 3072 // 4)
69+
if config.hidden_size % 4 == 0
70+
else config.hidden_size
71+
),
72+
)
73+
elif task == "image-classification":
74+
if isinstance(config.image_size, int):
75+
kwargs = dict(
76+
batch_size=2,
77+
input_width=config.image_size,
78+
input_height=config.image_size,
79+
input_channels=config.num_channels,
80+
)
81+
else:
82+
kwargs = dict(
83+
batch_size=2,
84+
input_width=config.image_size[0],
85+
input_height=config.image_size[1],
86+
input_channels=config.num_channels,
87+
)
88+
else:
89+
raise NotImplementedError(f"Input generation for task {task!r} not implemented yet.")
90+
91+
for k, v in kwargs.items():
92+
setattr(config, k, v)
93+
return kwargs
94+
95+
96+
def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
97+
"""Inputs kwargs"""
98+
if task == "text-generation":
99+
kwargs = dict(
100+
batch_size=2,
101+
sequence_length=30,
102+
sequence_length2=3,
103+
head_dim=getattr(
104+
config, "head_dim", config.hidden_size // config.num_attention_heads
105+
),
106+
dummy_max_token_id=config.vocab_size - 1,
107+
num_hidden_layers=min(config.num_hidden_layers, 2),
108+
num_key_value_heads=(
109+
config.num_key_value_heads
110+
if hasattr(config, "num_key_value_heads")
111+
else config.num_attention_heads
112+
),
113+
intermediate_size=(
114+
min(config.intermediate_size, 24576 // 4)
115+
if config.intermediate_size % 4 == 0
116+
else config.intermediate_size
117+
),
118+
hidden_size=(
119+
min(config.hidden_size, 3072 // 4)
120+
if config.hidden_size % 4 == 0
121+
else config.hidden_size
122+
),
123+
)
124+
fct = get_inputs_for_text_generation
125+
elif task == "image-classification":
126+
if isinstance(config.image_size, int):
127+
kwargs = dict(
128+
batch_size=2,
129+
input_width=config.image_size,
130+
input_height=config.image_size,
131+
input_channels=config.num_channels,
132+
)
133+
else:
134+
kwargs = dict(
135+
batch_size=2,
136+
input_width=config.image_size[0],
137+
input_height=config.image_size[1],
138+
input_channels=config.num_channels,
139+
)
140+
fct = get_inputs_for_image_classification # type: ignore
141+
else:
142+
raise NotImplementedError(f"Input generation for task {task!r} not implemented yet.")
143+
144+
return kwargs, fct
145+
146+
49147
def get_untrained_model_with_inputs(
50148
model_id: str,
51149
config: Optional[Any] = None,
@@ -114,63 +212,26 @@ def get_untrained_model_with_inputs(
114212
config.rope_scaling = (
115213
{"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None
116214
)
215+
216+
# updating the configuration
217+
if not same_as_pretrained:
218+
mkwargs = reduce_model_config(config, task)
219+
else:
220+
mkwargs = {}
117221
if model_kwargs:
118222
for k, v in model_kwargs.items():
119223
setattr(config, k, v)
120-
121-
if task == "text-generation":
122-
kwargs = dict(
123-
batch_size=2,
124-
sequence_length=30,
125-
sequence_length2=3,
126-
head_dim=getattr(
127-
config, "head_dim", config.hidden_size // config.num_attention_heads
128-
),
129-
dummy_max_token_id=config.vocab_size - 1,
130-
num_hidden_layers=min(config.num_hidden_layers, 2),
131-
num_key_value_heads=(
132-
config.num_key_value_heads
133-
if hasattr(config, "num_key_value_heads")
134-
else config.num_attention_heads
135-
),
136-
intermediate_size=(
137-
min(config.intermediate_size, 24576 // 4)
138-
if config.intermediate_size % 4 == 0
139-
else config.intermediate_size
140-
),
141-
hidden_size=(
142-
min(config.hidden_size, 3072 // 4)
143-
if config.hidden_size % 4 == 0
144-
else config.hidden_size
145-
),
146-
)
147-
148-
fct = get_inputs_for_text_generation
149-
elif task == "image-classification":
150-
if isinstance(config.image_size, int):
151-
kwargs = dict(
152-
batch_size=2,
153-
input_width=config.image_size,
154-
input_height=config.image_size,
155-
input_channels=config.num_channels,
156-
)
157-
else:
158-
kwargs = dict(
159-
batch_size=2,
160-
input_width=config.image_size[0],
161-
input_height=config.image_size[1],
162-
input_channels=config.num_channels,
163-
)
164-
fct = get_inputs_for_image_classification
165-
else:
166-
raise NotImplementedError(f"Input generation for task {task!r} not implemented yet.")
167-
224+
mkwargs[k] = v
225+
# input kwargs
226+
kwargs, fct = random_input_kwargs(config, task)
168227
if inputs_kwargs:
169228
kwargs.update(inputs_kwargs)
170-
true_kwargs = (inputs_kwargs or {}) if same_as_pretrained else kwargs
171-
_update_config(config, true_kwargs)
229+
172230
model = getattr(transformers, arch)(config)
173-
return fct(model, config, **true_kwargs)
231+
res = fct(model, config, **kwargs)
232+
res["input_kwargs"] = kwargs
233+
res["model_kwargs"] = mkwargs
234+
return res
174235

175236

176237
def compute_model_size(model: torch.nn.Module) -> Tuple[int, int]:

0 commit comments

Comments
 (0)