Skip to content

Commit 81cd6b7

Browse files
committed
Dynamic gen: Return held output with last results
1 parent 304e021 commit 81cd6b7

File tree

1 file changed

+31
-7
lines changed

1 file changed

+31
-7
lines changed

exllamav2/generator/dynamic.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1856,7 +1856,8 @@ def emit(
18561856
suppressed_text = None,
18571857
suppressed_tokens = None,
18581858
stop_token: int = None,
1859-
stop_string: str = None
1859+
stop_string: str = None,
1860+
rem_held_text: str = None
18601861
):
18611862
r = {
18621863
"job": self,
@@ -1919,18 +1920,29 @@ def emit(
19191920
"accepted_draft_tokens": self.accepted_draft_tokens,
19201921
"rejected_draft_tokens": self.rejected_draft_tokens
19211922
})
1923+
if eos_reason == "stop_string":
1924+
self.held_text = rem_held_text
1925+
rh = {}
1926+
if self.held_text:
1927+
rh.update({ "text": self.held_text })
1928+
if self.held_tokens:
1929+
rh.update({ "token_ids": self.held_tokens.torch().clone() })
1930+
if self.held_probs:
1931+
rh.update({ "token_probs": self.held_probs.torch().clone() })
1932+
if self.held_k_tokens:
1933+
rh.update({ "top_k_tokens": self.held_k_tokens.torch().clone() })
1934+
rh.update({ "top_k_probs": self.held_k_probs.torch().clone() })
1935+
if self.held_logits:
1936+
rh.update({ "logits": self.held_logits.torch().clone() })
1937+
if rh:
1938+
r.update({ "held": rh })
19221939

19231940
if self.identifier is not None:
19241941
r.update({ "identifier": self.identifier })
19251942

19261943
results.append(r)
19271944
return emit_eos, next_token
19281945

1929-
# End on stop tokens
1930-
1931-
if next_token.item() in self.stop_tokens:
1932-
return emit(results, emit_eos = True, eos_reason = "stop_token", stop_token = next_token.item())
1933-
19341946
# Decode and buffer output
19351947

19361948
id_to_piece = self.generator.tokenizer.get_id_to_piece_list(self.decode_special_tokens)
@@ -1950,6 +1962,11 @@ def emit(
19501962
if self.return_logits:
19511963
self.held_logits.append(logits[:1, :, :])
19521964

1965+
# End on stop tokens
1966+
1967+
if next_token.item() in self.stop_tokens:
1968+
return emit(results, emit_eos = True, eos_reason = "stop_token", stop_token = next_token.item())
1969+
19531970
# Stop if we reach max_new_tokens
19541971

19551972
if self.new_tokens >= self.max_new_tokens - self.generator.num_draft_tokens:
@@ -2052,7 +2069,14 @@ def rewind_checkpoint():
20522069
self.held_text = self.held_text[:match]
20532070
for s in self.stop_strings:
20542071
if held.startswith(s):
2055-
return emit(results, emit_eos = True, emit_held = True, eos_reason = "stop_string", stop_string = s)
2072+
return emit(
2073+
results,
2074+
emit_eos = True,
2075+
emit_held = True,
2076+
eos_reason = "stop_string",
2077+
stop_string = s,
2078+
rem_held_text = held
2079+
)
20562080
assert False, "Detected stop string but couldn't identify it (logic error)"
20572081
if match == -2:
20582082
return emit(results)

0 commit comments

Comments
 (0)