Skip to content

Commit 3cbc1e6

Browse files
committed
remove _seen_tokens from the patched code
1 parent ebecb67 commit 3cbc1e6

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ def update(
255255
"""
256256
# Update the number of seen tokens
257257
if layer_idx == 0:
258-
self._seen_tokens += key_states.shape[-2]
258+
if hasattr(self, "_seen_tokens"):
259+
self._seen_tokens += key_states.shape[-2]
259260

260261
# Update the cache
261262
if key_states is not None:
@@ -294,7 +295,8 @@ def crop(self, max_length: int):
294295
if self.get_seq_length() <= max_length:
295296
return
296297

297-
self._seen_tokens = max_length
298+
if hasattr(self, "_seen_tokens"):
299+
self._seen_tokens = max_length
298300
for idx in range(len(self.key_cache)):
299301
if self.key_cache[idx].numel():
300302
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]

0 commit comments

Comments
 (0)