Skip to content

Commit d633adc

Browse files
committed
Small fixes for ModelBuilder
1 parent 77a6db1 commit d633adc

File tree

6 files changed

+45
-13
lines changed

6 files changed

+45
-13
lines changed

CHANGELOGS.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Change Logs
22
===========
33

4+
0.6.1
5+
+++++
6+
7+
* :pr:`112`: fixes a couple of issues with ModelBuilder
8+
49
0.6.0
510
+++++
611

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ The function replaces dynamic dimensions defined as strings by
206206
Older versions
207207
++++++++++++++
208208

209+
* `0.6.1 <../v0.6.1/index.html>`_
209210
* `0.6.0 <../v0.6.0/index.html>`_
210211
* `0.5.0 <../v0.5.0/index.html>`_
211212
* `0.4.4 <../v0.4.4/index.html>`_

onnx_diagnostic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
Functions, classes to dig into a model when this one is right, slow, wrong...
44
"""
55

6-
__version__ = "0.6.0"
6+
__version__ = "0.6.1"
77
__author__ = "Xavier Dupré"

onnx_diagnostic/helpers/model_builder_helper.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -237,18 +237,24 @@ def create_model_builder(
237237
"OlmoForCausalLM": builder.OLMoModel,
238238
"PhiForCausalLM": builder.PhiModel,
239239
"Phi3ForCausalLM": (
240-
lambda config, *_: (
241-
builder.Phi3MiniModel
242-
if config.max_position_embeddings == config.original_max_position_embeddings
243-
else builder.Phi3MiniLongRoPEModel
240+
lambda config, *args: (
241+
(
242+
builder.Phi3MiniModel
243+
if config.max_position_embeddings
244+
== config.original_max_position_embeddings
245+
else builder.Phi3MiniLongRoPEModel
246+
)(config, *args)
244247
)
245248
),
246249
"PhiMoEForCausalLM": builder.Phi3MoELongRoPEModel,
247250
"Phi3SmallForCausalLM": (
248-
lambda config, *_: (
249-
builder.Phi3SmallModel
250-
if config.max_position_embeddings == config.original_max_position_embeddings
251-
else builder.Phi3SmallLongRoPEModel
251+
lambda config, *args: (
252+
(
253+
builder.Phi3SmallModel
254+
if config.max_position_embeddings
255+
== config.original_max_position_embeddings
256+
else builder.Phi3SmallLongRoPEModel
257+
)(config, *args)
252258
)
253259
),
254260
"Phi3VForCausalLM": builder.Phi3VModel,
@@ -317,7 +323,17 @@ def _post(onnx_model):
317323
)
318324

319325
cls = arch_map[config.architectures[0]]
326+
327+
# ModelBuilder does not like None values for some parameters.
328+
remove = set()
329+
for c in ["head_dim"]:
330+
if hasattr(config, c) and getattr(config, c) is None:
331+
remove.add(c)
332+
for c in remove:
333+
delattr(config, c)
334+
320335
onnx_model = cls(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
336+
321337
if post:
322338
post(onnx_model)
323339
_make_model(onnx_model, model, verbose=verbose)

onnx_diagnostic/tasks/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,17 @@ def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
4141
"""Reduces a model size."""
4242
tasks = {mod.__TASK__: mod.reduce_model_config for mod in __TASKS__}
4343
assert task in tasks, f"Task {task!r} not found in {sorted(tasks)}"
44-
return tasks[task](config)
44+
res = tasks[task](config)
45+
if "head_dim" in res:
46+
head_size = (
47+
config.head_dim
48+
if hasattr(config, "head_dim") and config.head_dim
49+
else config.hidden_size // config.num_attention_heads
50+
)
51+
assert (
52+
head_size % 16 == 0
53+
), f"head_size should be a multiple of 16, res={res}, config=\n{config}"
54+
return res
4555

4656

4757
def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:

onnx_diagnostic/tasks/text_generation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
2727
kwargs = dict(
2828
num_hidden_layers=min(config.num_hidden_layers, 2),
2929
intermediate_size=256 if config is None else min(512, config.intermediate_size),
30-
hidden_size=256 if config is None else min(256, config.hidden_size),
30+
hidden_size=512 if config is None else min(512, config.hidden_size),
3131
cls_cache="MambaCache",
3232
state_size=8 if config is None else getattr(config, "state_size", None),
3333
conv_kernel=4 if config is None else getattr(config, "conv_kernel", None),
@@ -44,8 +44,8 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
4444
else config.num_attention_heads
4545
),
4646
hidden_size=(
47-
min(config.hidden_size, 3072 // 4)
48-
if config.hidden_size % 4 == 0
47+
min(config.hidden_size, 4096 // 4)
48+
if config.hidden_size % 64 == 0
4949
else config.hidden_size
5050
),
5151
)

0 commit comments

Comments
 (0)