3030 ankiconnect , debug_chain , model_name_matcher ,
3131 average_word_length , wpm , get_splitter ,
3232 check_docs_tkn_length , get_tkn_length ,
33- extra_args_keys , disable_internet , loaders_temp_dir_file )
33+ extra_args_keys , disable_internet )
3434from .utils .prompts import PR_CONDENSE_QUESTION , PR_EVALUATE_DOC , PR_ANSWER_ONE_DOC , PR_COMBINE_INTERMEDIATE_ANSWERS
3535from .utils .tasks .query import format_chat_history , refilter_docs , check_intermediate_answer , parse_eval_output , query_eval_cache
3636
7575class DocToolsLLM_class :
7676 "This docstring is dynamically replaced by the content of DocToolsLLM/docs/USAGE.md"
7777
78- VERSION : str = "0.45 "
78+ VERSION : str = "0.49 "
7979
8080 #@optional_typecheck
8181 @typechecked
@@ -154,9 +154,6 @@ def p(message: str) -> None:
154154 red (pyfiglet .figlet_format ("DocToolsLLM" ))
155155 log .info ("Starting DocToolsLLM" )
156156
157- # erases content that links to the loaders temporary files at startup
158- loaders_temp_dir_file .write_text ("" )
159-
160157 # make sure the extra args are valid
161158 for k in cli_kwargs :
162159 if k not in self .allowed_extra_keys :
@@ -857,6 +854,13 @@ def prepare_query_task(self) -> None:
857854 # parse filters as callable for faiss filtering
858855 if "filter_metadata" in self .cli_kwargs or "filter_content" in self .cli_kwargs :
859856 if "filter_metadata" in self .cli_kwargs :
857+ # get the list of all metadata to see if a filter was not misspelled
858+ all_metadata_keys = set ()
859+ for doc in tqdm (self .loaded_embeddings .docstore ._dict .values (), desc = "gathering metadata keys" , unit = "doc" ):
860+ for k in doc .metadata .keys ():
861+ all_metadata_keys .add (k )
862+ assert all_metadata_keys , "No metadata keys found in any metadata, something went wrong!"
863+
860864 if isinstance (self .cli_kwargs ["filter_metadata" ], str ):
861865 filter_metadata = self .cli_kwargs ["filter_metadata" ].split ("," )
862866 else :
@@ -921,6 +925,10 @@ def prepare_query_task(self) -> None:
921925 filters_b_minus_keys = tuple (filters_b_minus_keys )
922926 filters_b_minus_values = tuple (filters_b_minus_values )
923927
928+ # check that all key filter indeed match metadata keys
929+ for k in filters_k_plus + filters_k_minus + filters_b_plus_keys + filters_b_minus_keys :
930+ assert any (k .match (key ) for key in all_metadata_keys ), f"Key { k } didn't match any key in the metadata"
931+
924932 def filter_meta (meta : dict ) -> bool :
925933 # match keys
926934 for inc in filters_k_plus :
@@ -1025,7 +1033,7 @@ def filter_cont(cont: str) -> bool:
10251033 self .unfiltered_docstore = self .loaded_embeddings .serialize_to_bytes ()
10261034 status = self .loaded_embeddings .delete (ids_to_del )
10271035
1028- # checking deletiong want well
1036+ # checking deletions went well
10291037 if status is False :
10301038 raise Exception ("Vectorstore filtering failed" )
10311039 elif status is None :
@@ -1132,10 +1140,22 @@ def query_task(self, query: Optional[str]) -> Optional[str]:
11321140
11331141 # answer 0 or 1 if the document is related
11341142 if not hasattr (self , "eval_llm" ):
1135- self .eval_llm_params = litellm .get_supported_openai_params (
1136- model = self .query_eval_modelname ,
1137- custom_llm_provider = self .query_eval_modelbackend ,
1138- )
1143+ failed = False
1144+ if self .query_eval_modelbackend == "openrouter" :
1145+ try :
1146+ self .eval_llm_params = litellm .get_supported_openai_params (
1147+ model_name_matcher (
1148+ self .query_eval_modelname .split ("/" , 1 )[1 ]
1149+ )
1150+ )
1151+ except Exception as err :
1152+ failed = True
1153+ red (f"Failed to get query_eval_model parameters information bypassing openrouter: '{ err } '" )
1154+ if self .modelbackend != "openrouter" or failed :
1155+ self .eval_llm_params = litellm .get_supported_openai_params (
1156+ model = self .query_eval_modelname ,
1157+ custom_llm_provider = self .query_eval_modelbackend ,
1158+ )
11391159 eval_args = {}
11401160 if "n" in self .eval_llm_params :
11411161 eval_args ["n" ] = self .query_eval_check_number
@@ -1186,10 +1206,10 @@ def evaluate_doc_chain(
11861206 if "n" in self .eval_llm_params or self .query_eval_check_number == 1 :
11871207 out = self .eval_llm ._generate_with_cache (PR_EVALUATE_DOC .format_messages (** inputs ))
11881208 reasons = [gen .generation_info ["finish_reason" ] for gen in out .generations ]
1189- # don't crash if finish_reason is not stop, because it can sometimes still be parsed.
1190- if not all (r == "stop" for r in reasons ):
1191- red (f"Unexpected generation finish_reason: '{ reasons } '" )
11921209 outputs = [gen .text for gen in out .generations ]
1210+ # don't crash if finish_reason is not stop, because it can sometimes still be parsed.
1211+ if not all (r in ["stop" , "lenghth" ] for r in reasons ):
1212+ red (f"Unexpected generation finish_reason: '{ reasons } ' for generations: '{ outputs } '" )
11931213 assert outputs , "No generations found by query eval llm"
11941214 outputs = [parse_eval_output (o ) for o in outputs ]
11951215 if out .llm_output :
@@ -1216,17 +1236,17 @@ async def do_eval(inputs):
12161236 outs = loop .run_until_complete (asyncio .gather (* outs ))
12171237 for out in outs :
12181238 assert len (out .generations ) == 1 , f"Query eval llm produced more than 1 evaluations: '{ out .generations } '"
1219- finish_reason = out .generations [0 ].generation_info ["finish_reason" ]
1220- if not finish_reason == "stop" :
1221- red (f"Unexpected finish_reason: '{ finish_reason } '" )
12221239 outputs .append (out .generations [0 ].text )
1240+ finish_reason = out .generations [0 ].generation_info ["finish_reason" ]
1241+ if not finish_reason in ["stop" , "length" ]:
1242+ red (f"Unexpected finish_reason: '{ finish_reason } ' for generation '{ outputs [- 1 ]} '" )
12231243 if out .llm_output :
12241244 new_p += out .llm_output ["token_usage" ]["prompt_tokens" ]
12251245 new_c += out .llm_output ["token_usage" ]["completion_tokens" ]
12261246 assert outputs , "No generations found by query eval llm"
12271247 outputs = [parse_eval_output (o ) for o in outputs ]
12281248
1229- assert len (outputs ) == self .query_eval_check_number , f"query eval model failed to produce { self .query_eval_check_number } outputs"
1249+ assert len (outputs ) == self .query_eval_check_number , f"query eval model failed to produce { self .query_eval_check_number } outputs: ' { outputs } ' "
12301250
12311251 self .eval_llm .callbacks [0 ].prompt_tokens += new_p
12321252 self .eval_llm .callbacks [0 ].completion_tokens += new_c
0 commit comments