Skip to content

Commit ddb8c90

Browse files
committed
fix ocnfig
1 parent 7d65218 commit ddb8c90

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

onnx_diagnostic/export/validate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def _get(a):
7575
begin = time.perf_counter()
7676
print(
7777
f"[compare_modules] check ep with "
78-
f"args={string_type(args, with_shape=True)}, "
79-
f"kwargs={string_type(kwargs, with_shape=True)}..."
78+
f"args={string_type(args, with_shape=True, with_device=True)}, "
79+
f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}..."
8080
)
8181
got = modep(*_get(args), **_get(kwargs))
8282
if verbose:

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,9 @@ def get_inputs_for_text_generation(
477477
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
478478
:return: dictionary
479479
"""
480+
if head_dim is None:
481+
assert config, "head_dim is None, the value cannot be set without a configuration"
482+
head_dim = config.hidden_size // config.num_attention_heads
480483
batch = torch.export.Dim("batch", min=1, max=1024)
481484
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
482485
cache_length = torch.export.Dim("cache_length", min=1, max=4096)

0 commit comments

Comments
 (0)