|
1 | 1 | import inspect |
2 | | -import sys |
3 | 2 | from dataclasses import dataclass |
4 | 3 | from typing import Any, Dict, List, Optional, Tuple |
5 | 4 | import torch |
@@ -44,56 +43,47 @@ def _patch_make_causal_mask( |
44 | 43 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) |
45 | 44 |
|
46 | 45 |
|
47 | | -if sys.version_info[:2] <= (3, 11): |
48 | | - |
49 | | - @dataclass |
50 | | - class patched_AttentionMaskConverter: |
51 | | - """ |
52 | | - Patches |
53 | | - ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. |
54 | | - """ |
55 | | - |
56 | | - _PATCHES_ = ["_make_causal_mask"] |
57 | | - _PATCHED_CLASS_ = AttentionMaskConverter |
58 | | - |
59 | | - @staticmethod |
60 | | - def _make_causal_mask( |
61 | | - input_ids_shape: torch.Size, |
62 | | - dtype: torch.dtype, |
63 | | - device: torch.device, |
64 | | - past_key_values_length: int = 0, |
65 | | - sliding_window: Optional[int] = None, |
66 | | - ): |
67 | | - """Patched method.""" |
68 | | - return _patch_make_causal_mask( |
69 | | - input_ids_shape, dtype, device, past_key_values_length, sliding_window |
70 | | - ) |
| 46 | +@dataclass |
| 47 | +class patched_AttentionMaskConverter: |
| 48 | + """ |
| 49 | + Patches |
| 50 | + ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. |
| 51 | + """ |
71 | 52 |
|
72 | | -else: |
| 53 | + _PATCHES_ = ["_make_causal_mask"] |
| 54 | + _PATCHED_CLASS_ = AttentionMaskConverter |
73 | 55 |
|
74 | | - @dataclass |
75 | | - class patched_AttentionMaskConverter: |
76 | | - """ |
77 | | - Patches |
78 | | - ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. |
| 56 | + @staticmethod |
| 57 | + def _make_causal_mask( |
| 58 | + *args, |
| 59 | + **kwargs, |
| 60 | + # input_ids_shape: torch.Size, |
| 61 | + # dtype: torch.dtype, |
| 62 | + # device: torch.device, |
| 63 | + # past_key_values_length: int = 0, |
| 64 | + # sliding_window: Optional[int] = None, |
| 65 | + ): |
79 | 66 | """ |
| 67 | + Patched method. |
80 | 68 |
|
81 | | - _PATCHES_ = ["_make_causal_mask"] |
82 | | - _PATCHED_CLASS_ = AttentionMaskConverter |
83 | | - |
84 | | - @staticmethod |
85 | | - def _make_causal_mask( |
86 | | - self, |
87 | | - input_ids_shape: torch.Size, |
88 | | - dtype: torch.dtype, |
89 | | - device: torch.device, |
90 | | - past_key_values_length: int = 0, |
91 | | - sliding_window: Optional[int] = None, |
92 | | - ): |
93 | | - """Patched method.""" |
94 | | - return _patch_make_causal_mask( |
95 | | - input_ids_shape, dtype, device, past_key_values_length, sliding_window |
96 | | - ) |
| 69 | + This static method may be called with ``AttentionMaskConverter._make_causal_mask`` |
| 70 | + or ``self._make_causal_mask``. That changes this argument is receives. |
| 71 | + That should not matter but... |
| 72 | + """ |
| 73 | + if args: |
| 74 | + index = 0 if isinstance(args[0], (tuple, torch.Size)) else 1 |
| 75 | + names = [ |
| 76 | + "input_ids_shape", |
| 77 | + "dtype", |
| 78 | + "device", |
| 79 | + "past_key_values_length", |
| 80 | + "sliding_window", |
| 81 | + ] |
| 82 | + for i, a in enumerate(args): |
| 83 | + if i < index: |
| 84 | + continue |
| 85 | + kwargs[names[i - index]] = a |
| 86 | + return _patch_make_causal_mask(**kwargs) |
97 | 87 |
|
98 | 88 |
|
99 | 89 | class patched_DynamicCache: |
|
0 commit comments