|
38 | 38 | patch_is_initialized = pv.Version(transformers.__version__) > pv.Version("4.56.99") |
39 | 39 |
|
40 | 40 |
|
41 | | -def _get_is_initialized(self): |
42 | | - return self.keys is not None |
43 | | - |
44 | | - |
45 | | -def _set_is_initialized(self, value): |
46 | | - assert (value and self.keys is not None) or (not value and self.keys is None), ( |
47 | | - f"The patch does not set is_initialized but checks the it is consistent with " |
48 | | - f"``self.keys is not None``, value={value}, " |
49 | | - f"self.keys is not None={self.keys is not None}" |
50 | | - ) |
51 | | - |
52 | | - |
53 | | -def apply_patch_for_is_initialized(): |
54 | | - """ |
55 | | - Fixes export issues introduced by PR `40791 <https://github.com/huggingface/transformers/pull/40791>`_. |
56 | | - The attribute is_initialized does not seem to be captured by :func:`torch.export.export`. |
57 | | - """ |
58 | | - if patch_is_initialized: |
59 | | - transformers.cache_utils.CacheLayerMixin.is_initialized = property( |
60 | | - _get_is_initialized, _set_is_initialized |
61 | | - ) |
62 | | - |
63 | | - |
64 | | -def disable_patch_for_is_initialized(): |
65 | | - """Disables the patch applied by function :func:`applies_patch_for_is_initialized`.""" |
66 | | - if patch_is_initialized: |
67 | | - delattr(transformers.cache_utils.CacheLayerMixin, "is_initialized") |
68 | | - |
69 | | - |
70 | 41 | if patch_masking_utils: |
71 | 42 | # Introduced in 4.52 |
72 | 43 | from transformers.masking_utils import ( |
@@ -245,6 +216,8 @@ def lazy_initialization(self, key_states: torch.Tensor): |
245 | 216 | new_shape[-2] = 0 |
246 | 217 | self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device) |
247 | 218 | self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device) |
| 219 | + if patch_is_initialized: |
| 220 | + self.is_initialized = True |
248 | 221 |
|
249 | 222 |
|
250 | 223 | def _patch_make_causal_mask( |
|
0 commit comments