Skip to content

Commit 89c5d96

Browse files
committed
fix make_static_cache
1 parent 07a0b2a commit 89c5d96

File tree

4 files changed

+32
-5
lines changed

4 files changed

+32
-5
lines changed

.github/workflows/check-release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
matrix:
1717
os: [ubuntu-latest, macOS-latest, windows-latest]
1818
python: ['3.11', '3.12']
19-
transformers: ['4.48.3', '4.52.4', 'main']
19+
transformers: ['4.48.3', '4.52.4', '4.55.2', 'main']
2020
torch: ['2.7', '2.8', 'main']
2121

2222
steps:

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
matrix:
1717
os: [ubuntu-latest]
1818
python: ['3.10', '3.11', '3.12', '3.13']
19-
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.53.3', '4.55.0', 'main']
19+
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.53.3', '4.55.2', 'main']
2020
torch: ['2.8', 'main']
2121
exclude:
2222
- python: '3.10'
@@ -28,15 +28,15 @@ jobs:
2828
- python: '3.10'
2929
transformers: '4.53.3'
3030
- python: '3.10'
31-
transformers: '4.55.0'
31+
transformers: '4.55.2'
3232
- python: '3.11'
3333
torch: 'main'
3434
- python: '3.11'
3535
transformers: '4.53.3'
3636
- python: '3.11'
3737
transformers: 'main'
3838
- python: '3.11'
39-
transformers: '4.55.0'
39+
transformers: '4.55.2'
4040
- python: '3.13'
4141
torch: '2.8'
4242
- python: '3.13'

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,33 @@ def __init__(self):
280280
max_cache_len=max_cache_len,
281281
)
282282
ca = CacheKeyValue(cache)
283+
if hasattr(cache, "layers") and len(ca.key_cache) == 0:
284+
# transformers>= 4.55.2, layers are empty
285+
for i, (key, value) in enumerate(key_value_pairs):
286+
cache.update(key, value, i)
287+
return cache
288+
289+
torch._check(
290+
len(key_value_pairs) == len(cache.layers),
291+
lambda: (
292+
f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
293+
f"len(cache.layers)={len(cache.layers)}"
294+
),
295+
)
296+
torch._check(
297+
len(key_value_pairs) == len(ca.key_cache),
298+
lambda: (
299+
f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
300+
f"len(ca.key_cache)={len(ca.key_cache)}"
301+
),
302+
)
303+
torch._check(
304+
len(key_value_pairs) == len(ca.value_cache),
305+
lambda: (
306+
f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
307+
f"len(ca.value_cache)={len(ca.value_cache)}"
308+
),
309+
)
283310
for i in range(len(key_value_pairs)):
284311
assert (
285312
key_value_pairs[i][0].shape == key_value_pairs[i][1].shape

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@ numpy
22
onnx>=1.16.0
33
onnxruntime>=1.21
44
optree
5-
torch>=2.7
5+
torch>=2.8
66
torch_geometric

0 commit comments

Comments
 (0)