Skip to content

Commit 0cf9729

Browse files
committed
Dynamic gen: Identify specific stop condition when ending generation
1 parent 3ffcc74 commit 0cf9729

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

exllamav2/generator/dynamic.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import itertools
2222
from dataclasses import dataclass
2323
# import xxhash
24-
# from line_profiler import profile
24+
from line_profiler import profile
2525

2626
# TODO:
2727
# - ExLlamaV2StreamingGenerator wrapper
@@ -893,6 +893,11 @@ def iterate(self) -> list[dict]:
893893
"stop_string"
894894
"max_new_tokens"
895895
"end_filter"
896+
optional, if "eos_reason" == "stop_token":
897+
"eos_triggering_token_id": int
898+
"eos_triggering_token_str": str
899+
optional, if "eos_reason" == "stop_string":
900+
"eos_triggering_string": str
896901
"full_completion": str - full text completion
897902
"new_tokens": int - number of tokens generated
898903
"time_enqueued": float - time from job was enqueued until it started, in seconds
@@ -1849,7 +1854,9 @@ def emit(
18491854
eos_reason: str = None,
18501855
emit_held = False,
18511856
suppressed_text = None,
1852-
suppressed_tokens = None
1857+
suppressed_tokens = None,
1858+
stop_token: int = None,
1859+
stop_string: str = None
18531860
):
18541861
r = {
18551862
"job": self,
@@ -1860,6 +1867,15 @@ def emit(
18601867

18611868
if eos_reason is not None:
18621869
r.update({ "eos_reason": eos_reason })
1870+
if eos_reason == "stop_token":
1871+
id_to_piece = self.generator.tokenizer.get_id_to_piece_list(True)
1872+
r.update({
1873+
"eos_triggering_token_id": stop_token,
1874+
"eos_triggering_token_str": id_to_piece[stop_token]
1875+
})
1876+
pass
1877+
if eos_reason == "stop_string":
1878+
r.update({ "eos_triggering_string": stop_string })
18631879

18641880
if emit_held:
18651881
if self.held_text != "":
@@ -1913,7 +1929,7 @@ def emit(
19131929
# End on stop tokens
19141930

19151931
if next_token.item() in self.stop_tokens:
1916-
return emit(results, emit_eos = True, eos_reason = "stop_token")
1932+
return emit(results, emit_eos = True, eos_reason = "stop_token", stop_token = next_token.item())
19171933

19181934
# Decode and buffer output
19191935

@@ -2032,8 +2048,12 @@ def rewind_checkpoint():
20322048
self.stop_strings_utf32_buffer
20332049
)
20342050
if match >= 0:
2051+
held = self.held_text[match:]
20352052
self.held_text = self.held_text[:match]
2036-
return emit(results, emit_eos = True, emit_held = True, eos_reason = "stop_string")
2053+
for s in self.stop_strings:
2054+
if held.startswith(s):
2055+
return emit(results, emit_eos = True, emit_held = True, eos_reason = "stop_string", stop_string = s)
2056+
assert False, "Detected stop string but couldn't identify it (logic error)"
20372057
if match == -2:
20382058
return emit(results)
20392059

0 commit comments

Comments
 (0)