|
1 | 1 | import inspect |
2 | | -import sys |
3 | 2 | from dataclasses import dataclass |
4 | 3 | from typing import Any, Dict, List, Optional, Tuple |
5 | 4 | import torch |
@@ -44,56 +43,32 @@ def _patch_make_causal_mask( |
44 | 43 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) |
45 | 44 |
|
46 | 45 |
|
47 | | -if sys.version_info[:2] <= (3, 11): |
48 | | - |
49 | | - @dataclass |
50 | | - class patched_AttentionMaskConverter: |
51 | | - """ |
52 | | - Patches |
53 | | - ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. |
54 | | - """ |
55 | | - |
56 | | - _PATCHES_ = ["_make_causal_mask"] |
57 | | - _PATCHED_CLASS_ = AttentionMaskConverter |
58 | | - |
59 | | - @staticmethod |
60 | | - def _make_causal_mask( |
61 | | - input_ids_shape: torch.Size, |
62 | | - dtype: torch.dtype, |
63 | | - device: torch.device, |
64 | | - past_key_values_length: int = 0, |
65 | | - sliding_window: Optional[int] = None, |
66 | | - ): |
67 | | - """Patched method.""" |
68 | | - return _patch_make_causal_mask( |
69 | | - input_ids_shape, dtype, device, past_key_values_length, sliding_window |
70 | | - ) |
71 | | - |
72 | | -else: |
| 46 | +@dataclass |
| 47 | +class patched_AttentionMaskConverter: |
| 48 | + """ |
| 49 | + Patches |
| 50 | + ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. |
| 51 | + """ |
73 | 52 |
|
74 | | - @dataclass |
75 | | - class patched_AttentionMaskConverter: |
76 | | - """ |
77 | | - Patches |
78 | | - ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. |
79 | | - """ |
| 53 | + _PATCHES_ = ["_make_causal_mask"] |
| 54 | + _PATCHED_CLASS_ = AttentionMaskConverter |
80 | 55 |
|
81 | | - _PATCHES_ = ["_make_causal_mask"] |
82 | | - _PATCHED_CLASS_ = AttentionMaskConverter |
83 | | - |
84 | | - @staticmethod |
85 | | - def _make_causal_mask( |
86 | | - self, |
87 | | - input_ids_shape: torch.Size, |
88 | | - dtype: torch.dtype, |
89 | | - device: torch.device, |
90 | | - past_key_values_length: int = 0, |
91 | | - sliding_window: Optional[int] = None, |
92 | | - ): |
93 | | - """Patched method.""" |
94 | | - return _patch_make_causal_mask( |
95 | | - input_ids_shape, dtype, device, past_key_values_length, sliding_window |
96 | | - ) |
| 56 | + @staticmethod |
| 57 | + def _make_causal_mask( |
| 58 | + input_ids_shape: torch.Size, |
| 59 | + dtype: torch.dtype, |
| 60 | + device: torch.device, |
| 61 | + past_key_values_length: int = 0, |
| 62 | + sliding_window: Optional[int] = None, |
| 63 | + ): |
| 64 | + """Patched method.""" |
| 65 | + return _patch_make_causal_mask( |
| 66 | + input_ids_shape=input_ids_shape, |
| 67 | + dtype=dtype, |
| 68 | + device=device, |
| 69 | + past_key_values_length=past_key_values_length, |
| 70 | + sliding_window=sliding_window, |
| 71 | + ) |
97 | 72 |
|
98 | 73 |
|
99 | 74 | class patched_DynamicCache: |
|
0 commit comments