Skip to content

Commit d8eee07

Browse files
committed
add missing patched method
1 parent eabc4ef commit d8eee07

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,19 @@ class patched_DynamicCache:
105105
_PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits"]
106106
_PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
107107

108+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
109+
"""Returns the sequence length of the cached states.
110+
A layer index can be optionally passed."""
111+
# TODO: deprecate this function in favor of `cache_position`
112+
is_empty_layer = (
113+
len(self.key_cache) == 0 # no cache in any layer
114+
or len(self.key_cache)
115+
<= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
116+
or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
117+
)
118+
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
119+
return layer_seq_length
120+
108121
def reorder_cache(self, beam_idx: torch.LongTensor):
109122
"""Reorders the cache for beam search, given the selected beam indices."""
110123
for layer_idx in range(len(self.key_cache)):

0 commit comments

Comments
 (0)