Skip to content

Commit d817f19

Browse files
committed
fix draft
1 parent df3ad9b commit d817f19

File tree

2 files changed

+3
-9
lines changed

2 files changed

+3
-9
lines changed

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from functools import wraps
55
from typing import Callable, List, Optional, Tuple
66
import packaging.version as pv
7-
from sklearn import logger
87
import torch
98
import transformers
109
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
@@ -1658,11 +1657,6 @@ def patched_sdpa_attention_forward(
16581657
**kwargs,
16591658
) -> tuple[torch.Tensor, None]:
16601659
"""manual patch for function ```transformers.integrations.sdpa_attention.sdpa_attention_forward```.""" # noqa: E501
1661-
if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
1662-
logger.warning_once(
1663-
"`sdpa` attention does not support `output_attentions=True` or `head_mask`."
1664-
" Please set your attention to `eager` if you want any of these features."
1665-
)
16661660
sdpa_kwargs = {}
16671661
if hasattr(module, "num_key_value_groups"):
16681662
if not use_gqa_in_sdpa(attention_mask, key):

onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
def get_tiny_llm(
66
batch_size: int = 2,
77
sequence_length: int = 30,
8-
sequence_length2: int = 3,
8+
past_sequence_length: int = 3,
99
dynamic_rope: bool = False,
1010
use_static_cache: bool = False,
1111
**kwargs,
@@ -15,7 +15,7 @@ def get_tiny_llm(
1515
1616
:param batch_size: batch size
1717
:param sequence_length: sequence length
18-
:param sequence_length2: new sequence length
18+
:param past_sequence_length: past sequence length
1919
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
2020
:param use_static_cache: use StaticCache instead of DynamicCache
2121
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
@@ -62,7 +62,7 @@ def get_tiny_llm(
6262
num_hidden_layers=config["num_hidden_layers"], # type: ignore[arg-type]
6363
batch_size=batch_size,
6464
sequence_length=sequence_length,
65-
sequence_length2=sequence_length2,
65+
past_sequence_length=past_sequence_length,
6666
dynamic_rope=dynamic_rope,
6767
num_key_value_heads=config["num_key_value_heads"], # type: ignore[arg-type]
6868
cls_cache="StaticCache" if use_static_cache else "DynamicCache",

0 commit comments

Comments
 (0)