Skip to content

Commit 41544ae

Browse files
committed
Fix @lrucache memory leaks
1 parent 19c7010 commit 41544ae

File tree

5 files changed

+84
-53
lines changed

5 files changed

+84
-53
lines changed

exllamav3/loader/safetensors.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import mmap
1010
from ..util import Timer, cuda_sync_active
1111
from ..ext import exllamav3_ext as ext
12-
from functools import lru_cache
12+
from functools import cached_property
1313
from fnmatch import fnmatch
1414
import time
1515

@@ -192,11 +192,13 @@ def get_tensor_size(
192192
return bytesize
193193

194194

195-
@lru_cache
196-
def get_tensor_file_map_trie(self):
195+
@cached_property
196+
def _get_tensor_file_map_trie(self):
197197
import marisa_trie
198198
trie = marisa_trie.Trie(self.tensor_file_map.keys())
199199
return trie
200+
def get_tensor_file_map_trie(self):
201+
return self._get_tensor_file_map_trie
200202

201203

202204
def list_tensors(
@@ -400,12 +402,16 @@ def close(self):
400402
self.handles[filename] = None
401403

402404

403-
@lru_cache
404-
def max_key_len(self):
405+
@cached_property
406+
def _max_key_len(self):
405407
l = max(len(k) for k in self.tensor_file_map.keys())
406408
return l
407409

408410

411+
def max_key_len(self):
412+
return self._max_key_len
413+
414+
409415
def set_new_tensors(self, new_tensors):
410416
self.new_tensors = new_tensors
411417

@@ -513,7 +519,7 @@ def __init__(
513519
):
514520
self.main = main
515521
self.stcs = []
516-
522+
self._get_tensor_sizes_cache = {}
517523

518524
def compile_star_globs(self, patterns, *, flags = 0):
519525
# Turn list of filter globs into single, compiled regex
@@ -572,17 +578,18 @@ def has_tensor_group(
572578
return True
573579

574580

575-
@lru_cache
576581
def get_tensor_sizes(
577582
self,
578583
prefix: str,
579584
):
580-
keys = [self.main.tensor_file_map.get(prefix)]
581-
if keys[0] is None:
582-
keys = []
583-
keys += self.main.get_tensor_file_map_trie().keys(prefix + ".")
584-
sizes = [self.get_tensor_size(key) for key in keys]
585-
return sizes
585+
if prefix not in self._get_tensor_sizes_cache:
586+
keys = [self.main.tensor_file_map.get(prefix)]
587+
if keys[0] is None:
588+
keys = []
589+
keys += self.main.get_tensor_file_map_trie().keys(prefix + ".")
590+
sizes = [self.get_tensor_size(key) for key in keys]
591+
self._get_tensor_sizes_cache[prefix] = sizes
592+
return self._get_tensor_sizes_cache[prefix]
586593

587594

588595
def get_tensor_size(

exllamav3/model/model.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
from functools import lru_cache
2+
from functools import cached_property
33
from typing import Callable
44
import torch
55
from .config import Config
@@ -50,14 +50,18 @@ def find_module(self, key: str):
5050
return self.modules_dict[key]
5151

5252

53-
@lru_cache
54-
def get_cache_layers(self):
53+
@cached_property
54+
def _get_cache_layers(self):
5555
return [m for m in self if m.caps.get("kv_cache")]
56+
def get_cache_layers(self):
57+
return self._get_cache_layers
5658

5759

58-
@lru_cache
59-
def get_recurrent_layers(self):
60+
@cached_property
61+
def _get_recurrent_layers(self):
6062
return [m for m in self if m.caps.get("recurrent_cache")]
63+
def get_recurrent_layers(self):
64+
return self._get_recurrent_layers
6165

6266

6367
@staticmethod

exllamav3/modules/linear.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
from functools import lru_cache
2+
from functools import cached_property
33
from typing_extensions import override
44
import torch
55
import torch.nn.functional as F
@@ -347,13 +347,15 @@ def quant_format_id(self):
347347
return None
348348

349349

350-
@lru_cache
351-
def storage_size(self):
350+
@cached_property
351+
def _storage_size(self):
352352
# alt_key is only used when loading unquantized model
353353
if self.is_exl3_storage(self.key):
354354
return sum(self.config.stc.get_tensor_sizes(prefix = self.key))
355355
else:
356356
return 2 * self.in_features * self.out_features
357+
def storage_size(self):
358+
return self._storage_size
357359

358360

359361
def recons_size(self):

exllamav3/modules/module.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
if TYPE_CHECKING:
99
from ..model.config import Config
1010
from ..model.model_tp_alloc import TPAllocation
11-
from functools import lru_cache
11+
from functools import cached_property
1212

1313
# Use host bounce when moving state from device to device in layer split
1414
no_p2p_copy = os.environ.get('EXLLAMA_NO_P2P_COPY', None)
@@ -133,9 +133,11 @@ def tp_import(local_context, plan, loaded):
133133
"""
134134
raise NotImplementedError()
135135

136-
@lru_cache
137-
def all_cache_modules(self) -> list[Module]:
136+
@cached_property
137+
def _all_cache_modules(self) -> list[Module]:
138138
return [m for m in self if m.caps.get("kv_cache")]
139+
def all_cache_modules(self):
140+
return self._all_cache_modules
139141

140142
@abstractmethod
141143
def optimizer_targets(self):

exllamav3/tokenizer/tokenizer.py

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ..util import synchronized
66
from ..util.file import maybe_read_json
77
from ..model.config import Config
8-
from functools import lru_cache
8+
from functools import lru_cache, cached_property
99
from typing import TYPE_CHECKING
1010
from ..util import profile_opt
1111
if TYPE_CHECKING:
@@ -162,7 +162,7 @@ def get_default_token_id(config_key: str, current: int | None, default: int | No
162162
self.actual_vocab_size = 1 + max(
163163
list(self.extended_id_to_piece.keys()) + \
164164
list(self.unspecial_id_to_piece.keys()) + \
165-
[self.raw_vocab_size() - 1]
165+
[self.raw_vocab_size - 1]
166166
)
167167

168168
# Useful token IDs
@@ -185,7 +185,7 @@ def get_default_token_id(config_key: str, current: int | None, default: int | No
185185
self.get_id_to_piece_list(False)
186186
self.get_piece_to_id_dict()
187187

188-
@lru_cache
188+
@cached_property
189189
def raw_vocab_size(self):
190190
"""
191191
Cache this function because it's suspiciously slow in HF Tokenizers
@@ -383,14 +383,14 @@ def decode_(self, seq, decode_special_tokens):
383383

384384
if not decode_special_tokens:
385385

386-
max_token = self.raw_vocab_size()
386+
max_token = self.raw_vocab_size
387387
seq = [t for t in seq if (t != self.pad_token_id and t < max_token and t != self.eos_token_id)]
388388
if self.eos_token_id in seq: seq = seq[:seq.index(self.eos_token_id)]
389389
return self.decode_unspecial(seq)
390390

391391
else:
392392

393-
max_token = self.raw_vocab_size()
393+
max_token = self.raw_vocab_size
394394
seq = [t for t in seq if (t != self.pad_token_id and t < max_token)]
395395
text = ""
396396
start = 0
@@ -486,11 +486,10 @@ def num_tokens(self, text):
486486

487487
# Get ordinals of single-byte tokens
488488

489-
@synchronized
490-
@lru_cache
491-
def get_id_to_ord_list(self):
489+
@cached_property
490+
def _get_id_to_ord_list(self):
492491

493-
self.id_to_ord = list(range(self.raw_vocab_size()))
492+
self.id_to_ord = list(range(self.raw_vocab_size))
494493

495494
def clean_special_chars(p):
496495
p = p.replace(self.space_char_, " ")
@@ -508,7 +507,7 @@ def piece_to_ord(p):
508507
if o <= 255: return o
509508
return -1
510509

511-
i = self.raw_vocab_size()
510+
i = self.raw_vocab_size
512511
while True:
513512
if i in self.extended_id_to_piece:
514513
self.id_to_ord.append(piece_to_ord(self.extended_id_to_piece[i]))
@@ -520,10 +519,15 @@ def piece_to_ord(p):
520519
break
521520
i += 1
522521

522+
return self.id_to_ord
523+
@synchronized
524+
def get_id_to_ord_list(self):
525+
return self._get_id_to_ord_list
526+
523527
# Copy vocabulary from model
524528

525-
@lru_cache
526-
def get_fixed_vocab(self):
529+
@cached_property
530+
def _get_fixed_vocab(self):
527531
test_enc = self.tokenizer.encode(" t", add_special_tokens = False)
528532
test_count = len(test_enc.ids)
529533
assert test_count > 0, "Tokenizer error, test string encodes to zero tokens"
@@ -532,35 +536,37 @@ def get_fixed_vocab(self):
532536

533537
if test_count == 1 and len(test_piece) == len(" t"):
534538
vocab = self.tokenizer.decode_batch(
535-
[[i] for i in range(self.raw_vocab_size())],
539+
[[i] for i in range(self.raw_vocab_size)],
536540
skip_special_tokens = False
537541
)
538542
else:
539543
prefix_id = self.tokenizer.encode(" ", add_special_tokens = False).ids[0]
540544
prefix_piece = self.tokenizer.decode([prefix_id])
541545
prefix_len = len(prefix_piece)
542546
vocab = self.tokenizer.decode_batch(
543-
[[prefix_id, i] for i in range(self.raw_vocab_size())]
547+
[[prefix_id, i] for i in range(self.raw_vocab_size)]
544548
)
545549
vocab = [v[prefix_len:] for v in vocab]
546550

547551
return vocab
552+
def get_fixed_vocab(self):
553+
return self._get_fixed_vocab
548554

549-
@synchronized
550-
@lru_cache
551-
def get_id_to_piece_list(self, include_special_tokens = False):
555+
@cached_property
556+
def _get_id_to_piece_list_spc(self, include_special_tokens = False):
552557

553-
if include_special_tokens:
554-
id_to_piece_extended = self.get_id_to_piece_list().copy()
555-
for k, v in self.extended_id_to_piece.items():
556-
id_to_piece_extended[k] = v
558+
id_to_piece_extended = self.get_id_to_piece_list().copy()
559+
for k, v in self.extended_id_to_piece.items():
560+
id_to_piece_extended[k] = v
557561

558-
self.id_to_piece_with_special = id_to_piece_extended
559-
return self.id_to_piece_with_special
562+
self.id_to_piece_with_special = id_to_piece_extended
563+
return self.id_to_piece_with_special
564+
@cached_property
565+
def _get_id_to_piece_list_nonspc(self, include_special_tokens = False):
560566

561567
self.id_to_piece = self.get_fixed_vocab()
562568

563-
i = self.raw_vocab_size()
569+
i = self.raw_vocab_size
564570
while True:
565571
if i in self.extended_id_to_piece:
566572
self.id_to_piece.append(self.extended_id_to_piece[i])
@@ -573,13 +579,21 @@ def get_id_to_piece_list(self, include_special_tokens = False):
573579
i += 1
574580

575581
return self.id_to_piece
576-
577582
@synchronized
578-
@lru_cache
579-
def get_piece_to_id_dict(self):
583+
def get_id_to_piece_list(self, include_special_tokens = False):
584+
if include_special_tokens:
585+
return self._get_id_to_piece_list_spc
586+
else:
587+
return self._get_id_to_piece_list_nonspc
588+
589+
@cached_property
590+
def _get_piece_to_id_dict(self):
580591
all_pieces = self.get_id_to_piece_list()
581592
self.piece_to_id = {piece: idx for idx, piece in enumerate(all_pieces)}
582593
return self.piece_to_id
594+
@synchronized
595+
def get_piece_to_id_dict(self):
596+
return self._get_piece_to_id_dict
583597

584598
@staticmethod
585599
def from_config(config: Config):
@@ -603,12 +617,14 @@ def get_tokens_with_prefix_id(self, prefix_id: int):
603617
prefix = id_to_piece[prefix_id]
604618
return self.get_tokens_with_prefix_string(prefix)
605619

606-
@lru_cache
607-
def get_vocab_dict(self):
620+
@cached_property
621+
def _get_vocab_dict(self):
608622
"""
609623
Return tokenizer (dictionary for Formatron)
610624
"""
611625
return {
612626
self.tokenizer.id_to_token(i): i
613627
for i in range(self.tokenizer.get_vocab_size())
614628
}
629+
def get_vocab_dict(self):
630+
return self._get_vocab_dict

0 commit comments

Comments
 (0)