Skip to content

Commit 7d8c261

Browse files
committed
keep head_dim
1 parent c4228f6 commit 7d8c261

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

onnx_diagnostic/tasks/text_generation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
3434
)
3535
else:
3636
kwargs = dict(
37+
head_dim=getattr(
38+
config, "head_dim", config.hidden_size // config.num_attention_heads
39+
),
3740
num_hidden_layers=min(config.num_hidden_layers, 2),
3841
num_key_value_heads=(
3942
config.num_key_value_heads

0 commit comments

Comments
 (0)