@@ -1856,7 +1856,8 @@ def emit(
1856
1856
suppressed_text = None ,
1857
1857
suppressed_tokens = None ,
1858
1858
stop_token : int = None ,
1859
- stop_string : str = None
1859
+ stop_string : str = None ,
1860
+ rem_held_text : str = None
1860
1861
):
1861
1862
r = {
1862
1863
"job" : self ,
@@ -1919,18 +1920,29 @@ def emit(
1919
1920
"accepted_draft_tokens" : self .accepted_draft_tokens ,
1920
1921
"rejected_draft_tokens" : self .rejected_draft_tokens
1921
1922
})
1923
+ if eos_reason == "stop_string" :
1924
+ self .held_text = rem_held_text
1925
+ rh = {}
1926
+ if self .held_text :
1927
+ rh .update ({ "text" : self .held_text })
1928
+ if self .held_tokens :
1929
+ rh .update ({ "token_ids" : self .held_tokens .torch ().clone () })
1930
+ if self .held_probs :
1931
+ rh .update ({ "token_probs" : self .held_probs .torch ().clone () })
1932
+ if self .held_k_tokens :
1933
+ rh .update ({ "top_k_tokens" : self .held_k_tokens .torch ().clone () })
1934
+ rh .update ({ "top_k_probs" : self .held_k_probs .torch ().clone () })
1935
+ if self .held_logits :
1936
+ rh .update ({ "logits" : self .held_logits .torch ().clone () })
1937
+ if rh :
1938
+ r .update ({ "held" : rh })
1922
1939
1923
1940
if self .identifier is not None :
1924
1941
r .update ({ "identifier" : self .identifier })
1925
1942
1926
1943
results .append (r )
1927
1944
return emit_eos , next_token
1928
1945
1929
- # End on stop tokens
1930
-
1931
- if next_token .item () in self .stop_tokens :
1932
- return emit (results , emit_eos = True , eos_reason = "stop_token" , stop_token = next_token .item ())
1933
-
1934
1946
# Decode and buffer output
1935
1947
1936
1948
id_to_piece = self .generator .tokenizer .get_id_to_piece_list (self .decode_special_tokens )
@@ -1950,6 +1962,11 @@ def emit(
1950
1962
if self .return_logits :
1951
1963
self .held_logits .append (logits [:1 , :, :])
1952
1964
1965
+ # End on stop tokens
1966
+
1967
+ if next_token .item () in self .stop_tokens :
1968
+ return emit (results , emit_eos = True , eos_reason = "stop_token" , stop_token = next_token .item ())
1969
+
1953
1970
# Stop if we reach max_new_tokens
1954
1971
1955
1972
if self .new_tokens >= self .max_new_tokens - self .generator .num_draft_tokens :
@@ -2052,7 +2069,14 @@ def rewind_checkpoint():
2052
2069
self .held_text = self .held_text [:match ]
2053
2070
for s in self .stop_strings :
2054
2071
if held .startswith (s ):
2055
- return emit (results , emit_eos = True , emit_held = True , eos_reason = "stop_string" , stop_string = s )
2072
+ return emit (
2073
+ results ,
2074
+ emit_eos = True ,
2075
+ emit_held = True ,
2076
+ eos_reason = "stop_string" ,
2077
+ stop_string = s ,
2078
+ rem_held_text = held
2079
+ )
2056
2080
assert False , "Detected stop string but couldn't identify it (logic error)"
2057
2081
if match == - 2 :
2058
2082
return emit (results )
0 commit comments