7878class DocToolsLLM_class :
7979 "This docstring is dynamically replaced by the content of DocToolsLLM/docs/USAGE.md"
8080
81- VERSION : str = "0.37 "
81+ VERSION : str = "0.38 "
8282
8383 #@optional_typecheck
8484 @typechecked
@@ -1064,6 +1064,10 @@ def evaluate_doc_chain(
10641064 ) -> List [str ]:
10651065 if "n" in self .eval_llm_params or self .query_eval_check_number == 1 :
10661066 out = self .eval_llm ._generate_with_cache (PR_EVALUATE_DOC .format_messages (** inputs ))
1067+ reasons = [gen .generation_info ["finish_reason" ] for gen in out .generations ]
1068+ # don't crash if finish_reason is not stop, because it can sometimes still be parsed.
1069+ if not all (r == "stop" for r in reasons ):
1070+ red (f"Unexpected generation finish_reason: '{ reasons } '" )
10671071 outputs = [gen .text for gen in out .generations ]
10681072 assert outputs , "No generations found by query eval llm"
10691073 outputs = [parse_eval_output (o ) for o in outputs ]
@@ -1085,12 +1089,15 @@ async def eval(inputs):
10851089 ]
10861090 try :
10871091 loop = asyncio .get_event_loop ()
1088- except :
1092+ except RuntimeError :
10891093 loop = asyncio .new_event_loop ()
10901094 asyncio .set_event_loop (loop )
10911095 outs = loop .run_until_complete (asyncio .gather (* outs ))
10921096 for out in outs :
10931097 assert len (out .generations ) == 1 , f"Query eval llm produced more than 1 evaluations: '{ out .generations } '"
1098+ finish_reason = out .generations [0 ].generation_info ["finish_reason" ]
1099+ if not finish_reason == "stop" :
1100+ red (f"Unexpected finish_reason: '{ finish_reason } '" )
10941101 outputs .append (out .generations [0 ].text )
10951102 if out .llm_output :
10961103 new_p += out .llm_output ["token_usage" ]["prompt_tokens" ]
@@ -1219,67 +1226,6 @@ def retrieve_documents(inputs):
12191226 | StrOutputParser ()
12201227 }
12211228
1222- # the eval doc chain needs its own caching
1223- if self .no_llm_cache :
1224- def eval_cache_wrapper (func ): return func
1225- else :
1226- eval_cache_wrapper = doc_eval_cache .cache
1227-
1228- @chain
1229- @optional_typecheck
1230- @eval_cache_wrapper
1231- def evaluate_doc_chain (
1232- inputs : dict ,
1233- query_nb : int = self .query_eval_check_number ,
1234- eval_model_name : str = self .query_eval_modelname ,
1235- ) -> List [str ]:
1236- if "n" in self .eval_llm_params or self .query_eval_check_number == 1 :
1237- out = self .eval_llm ._generate_with_cache (PR_EVALUATE_DOC .format_messages (** inputs ))
1238- reasons = [gen .generation_info ["finish_reason" ] for gen in out .generations ]
1239- assert all (r == "stop" for r in reasons ), f"Unexpected generation finish_reason: '{ reasons } '"
1240- outputs = [gen .text for gen in out .generations ]
1241- assert outputs , "No generations found by query eval llm"
1242- outputs = [parse_eval_output (o ) for o in outputs ]
1243- if out .llm_output :
1244- new_p = out .llm_output ["token_usage" ]["prompt_tokens" ]
1245- new_c = out .llm_output ["token_usage" ]["completion_tokens" ]
1246- else :
1247- new_p = 0
1248- new_c = 0
1249- else :
1250- outputs = []
1251- new_p = 0
1252- new_c = 0
1253- async def eval (inputs ):
1254- return await self .eval_llm ._agenerate_with_cache (PR_EVALUATE_DOC .format_messages (** inputs ))
1255- outs = [
1256- eval (inputs )
1257- for i in range (self .query_eval_check_number )
1258- ]
1259- try :
1260- loop = asyncio .get_event_loop ()
1261- except RuntimeError :
1262- loop = asyncio .new_event_loop ()
1263- asyncio .set_event_loop (loop )
1264- outs = loop .run_until_complete (asyncio .gather (* outs ))
1265- for out in outs :
1266- assert len (out .generations ) == 1 , f"Query eval llm produced more than 1 evaluations: '{ out .generations } '"
1267- outputs .append (out .generations [0 ].text )
1268- finish_reason = out .generations [0 ].generation_info ["finish_reason" ]
1269- assert finish_reason == "stop" , f"unexpected finish_reason: '{ finish_reason } '"
1270- if out .llm_output :
1271- new_p += out .llm_output ["token_usage" ]["prompt_tokens" ]
1272- new_c += out .llm_output ["token_usage" ]["completion_tokens" ]
1273- assert outputs , "No generations found by query eval llm"
1274- outputs = [parse_eval_output (o ) for o in outputs ]
1275-
1276- assert len (outputs ) == self .query_eval_check_number , f"query eval model failed to produce { self .query_eval_check_number } outputs"
1277-
1278- self .eval_llm .callbacks [0 ].prompt_tokens += new_p
1279- self .eval_llm .callbacks [0 ].completion_tokens += new_c
1280- self .eval_llm .callbacks [0 ].total_tokens += new_p + new_c
1281- return outputs
1282-
12831229 # for some reason I needed to have at least one chain object otherwise rag_chain is a dict
12841230 @chain
12851231 def retrieve_documents (inputs ):
0 commit comments