Skip to content

Commit 1a95582

Browse files
committed
last issue
1 parent 11512cb commit 1a95582

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

onnx_diagnostic/helpers/torch_test_helper.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
import numpy as np
55
import torch
66
from .helper import string_type
7-
from .cache_helper import make_dynamic_cache, make_encoder_decoder_cache
7+
from .cache_helper import (
8+
make_dynamic_cache,
9+
make_encoder_decoder_cache,
10+
make_sliding_window_cache,
11+
)
812

913

1014
def _forward_(*args, _f=None, _context=None, **kwargs):
@@ -363,6 +367,10 @@ def torch_deepcopy(value: Any) -> Any:
363367
return make_dynamic_cache(
364368
torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
365369
)
370+
if value.__class__.__name__ == "SlidingWindowCache":
371+
return make_sliding_window_cache(
372+
torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
373+
)
366374
if value.__class__.__name__ == "EncoderDecoderCache":
367375
return make_encoder_decoder_cache(
368376
torch_deepcopy(value.self_attention_cache),

0 commit comments

Comments
 (0)