21
21
import itertools
22
22
from dataclasses import dataclass
23
23
# import xxhash
24
- # from line_profiler import profile
24
+ from line_profiler import profile
25
25
26
26
# TODO:
27
27
# - ExLlamaV2StreamingGenerator wrapper
@@ -893,6 +893,11 @@ def iterate(self) -> list[dict]:
893
893
"stop_string"
894
894
"max_new_tokens"
895
895
"end_filter"
896
+ optional, if "eos_reason" == "stop_token":
897
+ "eos_triggering_token_id": int
898
+ "eos_triggering_token_str": str
899
+ optional, if "eos_reason" == "stop_string":
900
+ "eos_triggering_string": str
896
901
"full_completion": str - full text completion
897
902
"new_tokens": int - number of tokens generated
898
903
"time_enqueued": float - time from job was enqueued until it started, in seconds
@@ -1849,7 +1854,9 @@ def emit(
1849
1854
eos_reason : str = None ,
1850
1855
emit_held = False ,
1851
1856
suppressed_text = None ,
1852
- suppressed_tokens = None
1857
+ suppressed_tokens = None ,
1858
+ stop_token : int = None ,
1859
+ stop_string : str = None
1853
1860
):
1854
1861
r = {
1855
1862
"job" : self ,
@@ -1860,6 +1867,15 @@ def emit(
1860
1867
1861
1868
if eos_reason is not None :
1862
1869
r .update ({ "eos_reason" : eos_reason })
1870
+ if eos_reason == "stop_token" :
1871
+ id_to_piece = self .generator .tokenizer .get_id_to_piece_list (True )
1872
+ r .update ({
1873
+ "eos_triggering_token_id" : stop_token ,
1874
+ "eos_triggering_token_str" : id_to_piece [stop_token ]
1875
+ })
1876
+ pass
1877
+ if eos_reason == "stop_string" :
1878
+ r .update ({ "eos_triggering_string" : stop_string })
1863
1879
1864
1880
if emit_held :
1865
1881
if self .held_text != "" :
@@ -1913,7 +1929,7 @@ def emit(
1913
1929
# End on stop tokens
1914
1930
1915
1931
if next_token .item () in self .stop_tokens :
1916
- return emit (results , emit_eos = True , eos_reason = "stop_token" )
1932
+ return emit (results , emit_eos = True , eos_reason = "stop_token" , stop_token = next_token . item () )
1917
1933
1918
1934
# Decode and buffer output
1919
1935
@@ -2032,8 +2048,12 @@ def rewind_checkpoint():
2032
2048
self .stop_strings_utf32_buffer
2033
2049
)
2034
2050
if match >= 0 :
2051
+ held = self .held_text [match :]
2035
2052
self .held_text = self .held_text [:match ]
2036
- return emit (results , emit_eos = True , emit_held = True , eos_reason = "stop_string" )
2053
+ for s in self .stop_strings :
2054
+ if held .startswith (s ):
2055
+ return emit (results , emit_eos = True , emit_held = True , eos_reason = "stop_string" , stop_string = s )
2056
+ assert False , "Detected stop string but couldn't identify it (logic error)"
2037
2057
if match == - 2 :
2038
2058
return emit (results )
2039
2059
0 commit comments