Skip to content

Commit 809b0da

Browse files
committed
added max_total_tokens variable to class Generator, fixed type assertions
1 parent ae04741 commit 809b0da

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

exllamav3/generator/generator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __init__(
8080

8181
# Paging
8282
self.pagetable = PageTable(self, cache)
83+
self.max_total_tokens = PAGE_SIZE * self.pagetable.max_pages
8384

8485
# Draft model
8586
self.draft_model = draft_model
@@ -563,7 +564,7 @@ def generate(
563564
prompt: list[tuple] | list[str] | tuple | str,
564565
max_new_tokens: int | None = None,
565566
min_new_tokens: int = 0,
566-
seed: int or None = None,
567+
seed: int | None = None,
567568
sampler: Sampler | list[Sampler] | None = None,
568569
token_healing: bool = False,
569570
encode_special_tokens: bool = False,

exllamav3/generator/job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ..util.tensor import SeqTensor
1717

1818
# Convert list of strings to UTF32 format to pass by reference to partial matching function
19-
def _strings_to_utf32(strings: list[str]) -> (np.array, list[int]):
19+
def _strings_to_utf32(strings: list[str]) -> tuple[np.ndarray, np.ndarray] | None:
2020

2121
if not strings: return bytearray(), None
2222

0 commit comments

Comments
 (0)