Skip to content

Commit c46fe1e

Browse files
committed
mypy
1 parent 2523b0d commit c46fe1e

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

onnx_diagnostic/export/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def to_onnx(
6565
), f"output_dynamic_shapes not supported for exporter={exporter!r}"
6666
assert hasattr(mod, "config"), f"configuration is missing in model class {type(mod)}"
6767
assert not args, f"only kwargs can be defined with exporter={exporter!r}"
68-
assert list(kwargs) == ["input_ids", "attention_mask", "past_key_values"], (
68+
assert list(kwargs) == ["input_ids", "attention_mask", "past_key_values"], ( # type: ignore[arg-type]
6969
f"Only a specified set of inputs is supported for exporter={exporter!r}, "
70-
f"but it is {list(kwargs)}"
70+
f"but it is {list(kwargs)}" # type: ignore[arg-type]
7171
)
7272
flat_inputs = flatten_object(kwargs, drop_keys=True)
7373
first = flat_inputs[0]

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,19 @@ def _is_torchdynamo_exporting() -> bool:
7373
# Introduced in 4.52
7474
from transformers.masking_utils import (
7575
_ignore_causal_mask_sdpa,
76-
_ignore_bidirectional_mask_sdpa,
7776
and_masks,
7877
bidirectional_mask_function,
7978
causal_mask_function,
8079
padding_mask_function,
8180
prepare_padding_mask,
8281
)
8382

83+
try:
84+
# transformers>=5.0
85+
from transformers.masking_utils import _ignore_bidirectional_mask_sdpa
86+
except ImportError:
87+
_ignore_bidirectional_mask_sdpa = None
88+
8489
def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
8590
"""manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
8691
from ...helpers import string_type
@@ -187,7 +192,11 @@ def patched_sdpa_mask_recent_torch(
187192
padding_mask, q_length, kv_length, kv_offset, local_size
188193
):
189194
return None
190-
if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask):
195+
if (
196+
allow_is_bidirectional_skip
197+
and _ignore_bidirectional_mask_sdpa
198+
and _ignore_bidirectional_mask_sdpa(padding_mask)
199+
):
191200
return None
192201

193202
if mask_function is bidirectional_mask_function:

0 commit comments

Comments
 (0)