Skip to content

Commit 1018374

Browse files
sdpythonxadupre
andauthored
switch to 4.57.1 in CI (#261)
* v57.1 * fix issues --------- Co-authored-by: xadupre <[email protected]>
1 parent fa91495 commit 1018374

File tree

2 files changed

+51
-4
lines changed

2 files changed

+51
-4
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
matrix:
1818
os: [ubuntu-latest]
1919
python: ['3.10', '3.11', '3.12', '3.13']
20-
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.55.4', '4.56.2', '4.57', 'main']
20+
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.55.4', '4.56.2', '4.57.1', 'main']
2121
torch: ['2.8', 'main']
2222
exclude:
2323
- python: '3.10'
@@ -31,7 +31,7 @@ jobs:
3131
- python: '3.10'
3232
transformers: '4.56.2'
3333
- python: '3.10'
34-
transformers: '4.57.0'
34+
transformers: '4.57.1'
3535
- python: '3.11'
3636
torch: 'main'
3737
- python: '3.11'
@@ -41,7 +41,7 @@ jobs:
4141
- python: '3.11'
4242
transformers: '4.56.2'
4343
- python: '3.11'
44-
transformers: '4.57.0'
44+
transformers: '4.57.1'
4545
- python: '3.13'
4646
torch: '2.8'
4747
- python: '3.13'

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,54 @@ def is_cache_dynamic_registered(fast: bool = False) -> bool:
134134
return len(cache2.key_cache) == len(cache.value_cache)
135135

136136

137-
if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
137+
if (
138+
pv.Version(transformers.__version__) > pv.Version("4.99.99999")
139+
or transformers.__version__ == "4.57.0.dev0"
140+
):
141+
142+
def make_dynamic_cache(
143+
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
144+
) -> transformers.cache_utils.DynamicCache:
145+
"""
146+
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
147+
This version is valid for ``transformers >= 4.50``.
148+
149+
:param key_value_pairs: list of pairs of (key, values)
150+
:return: :class:`transformers.cache_utils.DynamicCache`
151+
152+
Example:
153+
154+
.. runpython::
155+
:showcode:
156+
157+
import torch
158+
from onnx_diagnostic.helpers import string_type
159+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
160+
161+
n_layers = 2
162+
bsize, nheads, slen, dim = 2, 4, 3, 7
163+
164+
past_key_values = make_dynamic_cache(
165+
[
166+
(
167+
torch.randn(bsize, nheads, slen, dim),
168+
torch.randn(bsize, nheads, slen, dim),
169+
)
170+
for i in range(n_layers)
171+
]
172+
)
173+
print(string_type(past_key_values, with_shape=True))
174+
"""
175+
cache = transformers.cache_utils.DynamicCache(
176+
[(None, k, v) for k, v in key_value_pairs]
177+
)
178+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
179+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
180+
f"{len(key_value_pairs)} expected."
181+
)
182+
return finalize_cache(cache)
183+
184+
elif pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
138185

139186
def make_dynamic_cache(
140187
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],

0 commit comments

Comments
 (0)