Skip to content

Commit f1c6529

Browse files
committed
refactor
1 parent 677510a commit f1c6529

File tree

2 files changed

+65
-39
lines changed

2 files changed

+65
-39
lines changed

_unittests/ut_torch_models/test_hghub_model.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pprint
12
import unittest
23
import transformers
34
from onnx_diagnostic.ext_test_case import (
@@ -11,6 +12,7 @@
1112
config_class_from_architecture,
1213
get_untrained_model_with_inputs,
1314
)
15+
from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config
1416
from onnx_diagnostic.torch_models.hghub.hub_data import load_models_testing
1517

1618

@@ -73,12 +75,30 @@ def test_get_untrained_model_with_inputs_beit(self):
7375
@hide_stdout()
7476
@long_test()
7577
def test_get_untrained_model_Ltesting_models(self):
78+
def _diff(c1, c2):
79+
rows = [f"types {c1.__class__.__name__} <> {c2.__class__.__name__}"]
80+
for k, v in c1.__dict__.items():
81+
if isinstance(v, (str, dict, list, tuple, int, float)) and v != getattr(
82+
c2, k, None
83+
):
84+
rows.append(f"{k} :: -- {v} ++ {getattr(c2, k, "MISS")}")
85+
return "\n".join(rows)
86+
7687
# UNHIDE=1 LONGTEST=1 python _unittests/ut_torch_models/test_hghub_model.py -k L -f
7788
for mid in load_models_testing():
7889
with self.subTest(mid=mid):
7990
data = get_untrained_model_with_inputs(mid, verbose=1)
8091
model, inputs = data["model"], data["inputs"]
81-
model(**inputs)
92+
try:
93+
model(**inputs)
94+
except Exception as e:
95+
diff = _diff(get_pretrained_config(mid), data["configuration"])
96+
raise AssertionError(
97+
f"Computation failed due to {e}.\n--- pretrained\n"
98+
f"{pprint.pformat(get_pretrained_config(mid))}\n"
99+
f"--- modified\n{data['configuration']}\n"
100+
f"--- diff\n{diff}"
101+
) from e
82102
# different expected value for different version of transformers
83103
if data["size"] > 2**30:
84104
raise AssertionError(

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -102,31 +102,21 @@ def get_untrained_model_with_inputs(
102102
arch = archs[0]
103103
if verbose:
104104
print(f"[get_untrained_model_with_inputs] architecture={arch!r}")
105-
cls = config_class_from_architecture(arch, exc=False)
106-
if cls is None:
107-
if verbose:
108-
print(
109-
"[get_untrained_model_with_inputs] no found config name in the code, loads it"
110-
)
111-
config = get_pretrained_config(model_id)
112-
cls = config.__class__
105+
config = get_pretrained_config(model_id)
106+
if verbose:
107+
print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
108+
task = task_from_arch(arch)
113109
if verbose:
114-
print(f"[get_untrained_model_with_inputs] cls={cls.__name__!r}")
110+
print(f"[get_untrained_model_with_inputs] task={task!r}")
115111

116-
# model creation
117-
kwargs: Dict[str, Any] = dict(
118-
num_hidden_layers=1,
119-
)
112+
# model kwagrs
120113
if dynamic_rope is not None:
121-
kwargs["rope_scaling"] = (
114+
config.rope_scaling = (
122115
{"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None
123116
)
124117
if model_kwargs:
125-
kwargs.update(model_kwargs)
126-
config = cls(**kwargs)
127-
task = task_from_arch(arch)
128-
if verbose:
129-
print(f"[get_untrained_model_with_inputs] task={task!r}")
118+
for k, v in model_kwargs.items():
119+
setattr(config, k, v)
130120

131121
if task == "text-generation":
132122
kwargs = dict(
@@ -136,7 +126,7 @@ def get_untrained_model_with_inputs(
136126
head_dim=getattr(
137127
config, "head_dim", config.hidden_size // config.num_attention_heads
138128
),
139-
max_token_id=config.vocab_size - 1,
129+
dummy_max_token_id=config.vocab_size - 1,
140130
num_hidden_layers=min(config.num_hidden_layers, 2),
141131
num_key_value_heads=(
142132
config.num_key_value_heads
@@ -154,25 +144,29 @@ def get_untrained_model_with_inputs(
154144
else config.hidden_size
155145
),
156146
)
157-
if inputs_kwargs:
158-
kwargs.update(inputs_kwargs)
159147

160-
_update_config(config, kwargs)
161-
model = getattr(transformers, arch)(config)
162148
fct = get_inputs_for_text_generation
163149
elif task == "image-classification":
164-
kwargs = dict(
165-
batch_size=2,
166-
width=config.image_size,
167-
height=config.image_size,
168-
channels=config.num_channels,
169-
)
170-
if inputs_kwargs:
171-
kwargs.update(inputs_kwargs)
150+
if isinstance(config.image_size, int):
151+
kwargs = dict(
152+
batch_size=2,
153+
input_width=config.image_size,
154+
input_height=config.image_size,
155+
input_channels=config.num_channels,
156+
)
157+
else:
158+
kwargs = dict(
159+
batch_size=2,
160+
input_width=config.image_size[0],
161+
input_height=config.image_size[1],
162+
input_channels=config.num_channels,
163+
)
172164
fct = get_inputs_for_image_classification
173165
else:
174166
raise NotImplementedError(f"Input generation for task {task!r} not implemented yet.")
175167

168+
if inputs_kwargs:
169+
kwargs.update(inputs_kwargs)
176170
true_kwargs = (inputs_kwargs or {}) if same_as_pretrained else kwargs
177171
_update_config(config, true_kwargs)
178172
model = getattr(transformers, arch)(config)
@@ -192,7 +186,7 @@ def compute_model_size(model: torch.nn.Module) -> Tuple[int, int]:
192186
def get_inputs_for_text_generation(
193187
model: torch.nn.Module,
194188
config: Optional[Any],
195-
max_token_id: int,
189+
dummy_max_token_id: int,
196190
num_key_value_heads: int,
197191
num_hidden_layers: int,
198192
head_dim: int,
@@ -208,6 +202,7 @@ def get_inputs_for_text_generation(
208202
:param model: model to get the missing information
209203
:param config: configuration used to generate the model
210204
:param head_dim: last dimension of the cache
205+
:param dummy_max_token_id: dummy max token id
211206
:param batch_size: batch size
212207
:param sequence_length: sequence length
213208
:param sequence_length2: new sequence length
@@ -235,7 +230,7 @@ def get_inputs_for_text_generation(
235230
],
236231
}
237232
inputs = dict(
238-
input_ids=torch.randint(0, max_token_id, (batch_size, sequence_length2)).to(
233+
input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
239234
torch.int64
240235
),
241236
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
@@ -268,9 +263,9 @@ def get_inputs_for_text_generation(
268263
def get_inputs_for_image_classification(
269264
model: torch.nn.Module,
270265
config: Optional[Any],
271-
width: int,
272-
height: int,
273-
channels: int,
266+
input_width: int,
267+
input_height: int,
268+
input_channels: int,
274269
batch_size: int = 2,
275270
dynamic_rope: bool = False,
276271
**kwargs,
@@ -281,9 +276,18 @@ def get_inputs_for_image_classification(
281276
:param model: model to get the missing information
282277
:param config: configuration used to generate the model
283278
:param batch_size: batch size
279+
:param input_channel: input channel
280+
:param input_width: input width
281+
:param input_height: input height
284282
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
285283
:return: dictionary
286284
"""
285+
assert isinstance(
286+
input_width, int
287+
), f"Unexpected type for input_width {type(input_width)}{config}"
288+
assert isinstance(
289+
input_width, int
290+
), f"Unexpected type for input_height {type(input_height)}{config}"
287291

288292
shapes = {
289293
"pixel_values": {
@@ -293,7 +297,9 @@ def get_inputs_for_image_classification(
293297
},
294298
}
295299
inputs = dict(
296-
pixel_values=torch.randn(batch_size, channels, width, height).clamp(-1, 1),
300+
pixel_values=torch.randn(batch_size, input_channels, input_width, input_height).clamp(
301+
-1, 1
302+
),
297303
)
298304
sizes = compute_model_size(model)
299305
return dict(

0 commit comments

Comments
 (0)