Skip to content

Commit 1ed180a

Browse files
committed
simplifies patch
1 parent b70f957 commit 1ed180a

File tree

1 file changed

+24
-49
lines changed

1 file changed

+24
-49
lines changed

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 24 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import inspect
2-
import sys
32
from dataclasses import dataclass
43
from typing import Any, Dict, List, Optional, Tuple
54
import torch
@@ -44,56 +43,32 @@ def _patch_make_causal_mask(
4443
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
4544

4645

47-
if sys.version_info[:2] <= (3, 11):
48-
49-
@dataclass
50-
class patched_AttentionMaskConverter:
51-
"""
52-
Patches
53-
``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
54-
"""
55-
56-
_PATCHES_ = ["_make_causal_mask"]
57-
_PATCHED_CLASS_ = AttentionMaskConverter
58-
59-
@staticmethod
60-
def _make_causal_mask(
61-
input_ids_shape: torch.Size,
62-
dtype: torch.dtype,
63-
device: torch.device,
64-
past_key_values_length: int = 0,
65-
sliding_window: Optional[int] = None,
66-
):
67-
"""Patched method."""
68-
return _patch_make_causal_mask(
69-
input_ids_shape, dtype, device, past_key_values_length, sliding_window
70-
)
71-
72-
else:
46+
@dataclass
47+
class patched_AttentionMaskConverter:
48+
"""
49+
Patches
50+
``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
51+
"""
7352

74-
@dataclass
75-
class patched_AttentionMaskConverter:
76-
"""
77-
Patches
78-
``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
79-
"""
53+
_PATCHES_ = ["_make_causal_mask"]
54+
_PATCHED_CLASS_ = AttentionMaskConverter
8055

81-
_PATCHES_ = ["_make_causal_mask"]
82-
_PATCHED_CLASS_ = AttentionMaskConverter
83-
84-
@staticmethod
85-
def _make_causal_mask(
86-
self,
87-
input_ids_shape: torch.Size,
88-
dtype: torch.dtype,
89-
device: torch.device,
90-
past_key_values_length: int = 0,
91-
sliding_window: Optional[int] = None,
92-
):
93-
"""Patched method."""
94-
return _patch_make_causal_mask(
95-
input_ids_shape, dtype, device, past_key_values_length, sliding_window
96-
)
56+
@staticmethod
57+
def _make_causal_mask(
58+
input_ids_shape: torch.Size,
59+
dtype: torch.dtype,
60+
device: torch.device,
61+
past_key_values_length: int = 0,
62+
sliding_window: Optional[int] = None,
63+
):
64+
"""Patched method."""
65+
return _patch_make_causal_mask(
66+
input_ids_shape=input_ids_shape,
67+
dtype=dtype,
68+
device=device,
69+
past_key_values_length=past_key_values_length,
70+
sliding_window=sliding_window,
71+
)
9772

9873

9974
class patched_DynamicCache:

0 commit comments

Comments
 (0)