Skip to content

Commit ea8b65b

Browse files
Merge branch 'dev'
2 parents e38ef9e + 3705a6c commit ea8b65b

File tree

7 files changed

+343
-70
lines changed

7 files changed

+343
-70
lines changed

DocToolsLLM/DocToolsLLM.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
ankiconnect, debug_chain, model_name_matcher,
3131
average_word_length, wpm, get_splitter,
3232
check_docs_tkn_length, get_tkn_length,
33-
extra_args_keys, disable_internet, loaders_temp_dir_file)
33+
extra_args_keys, disable_internet)
3434
from .utils.prompts import PR_CONDENSE_QUESTION, PR_EVALUATE_DOC, PR_ANSWER_ONE_DOC, PR_COMBINE_INTERMEDIATE_ANSWERS
3535
from .utils.tasks.query import format_chat_history, refilter_docs, check_intermediate_answer, parse_eval_output, query_eval_cache
3636

@@ -75,7 +75,7 @@
7575
class DocToolsLLM_class:
7676
"This docstring is dynamically replaced by the content of DocToolsLLM/docs/USAGE.md"
7777

78-
VERSION: str = "0.45"
78+
VERSION: str = "0.49"
7979

8080
#@optional_typecheck
8181
@typechecked
@@ -154,9 +154,6 @@ def p(message: str) -> None:
154154
red(pyfiglet.figlet_format("DocToolsLLM"))
155155
log.info("Starting DocToolsLLM")
156156

157-
# erases content that links to the loaders temporary files at startup
158-
loaders_temp_dir_file.write_text("")
159-
160157
# make sure the extra args are valid
161158
for k in cli_kwargs:
162159
if k not in self.allowed_extra_keys:
@@ -857,6 +854,13 @@ def prepare_query_task(self) -> None:
857854
# parse filters as callable for faiss filtering
858855
if "filter_metadata" in self.cli_kwargs or "filter_content" in self.cli_kwargs:
859856
if "filter_metadata" in self.cli_kwargs:
857+
# get the list of all metadata to see if a filter was not misspelled
858+
all_metadata_keys = set()
859+
for doc in tqdm(self.loaded_embeddings.docstore._dict.values(), desc="gathering metadata keys", unit="doc"):
860+
for k in doc.metadata.keys():
861+
all_metadata_keys.add(k)
862+
assert all_metadata_keys, "No metadata keys found in any metadata, something went wrong!"
863+
860864
if isinstance(self.cli_kwargs["filter_metadata"], str):
861865
filter_metadata = self.cli_kwargs["filter_metadata"].split(",")
862866
else:
@@ -921,6 +925,10 @@ def prepare_query_task(self) -> None:
921925
filters_b_minus_keys = tuple(filters_b_minus_keys)
922926
filters_b_minus_values = tuple(filters_b_minus_values)
923927

928+
# check that all key filter indeed match metadata keys
929+
for k in filters_k_plus + filters_k_minus + filters_b_plus_keys + filters_b_minus_keys:
930+
assert any(k.match(key) for key in all_metadata_keys), f"Key {k} didn't match any key in the metadata"
931+
924932
def filter_meta(meta: dict) -> bool:
925933
# match keys
926934
for inc in filters_k_plus:
@@ -1025,7 +1033,7 @@ def filter_cont(cont: str) -> bool:
10251033
self.unfiltered_docstore = self.loaded_embeddings.serialize_to_bytes()
10261034
status = self.loaded_embeddings.delete(ids_to_del)
10271035

1028-
# checking deletiong want well
1036+
# checking deletions went well
10291037
if status is False:
10301038
raise Exception("Vectorstore filtering failed")
10311039
elif status is None:
@@ -1132,10 +1140,22 @@ def query_task(self, query: Optional[str]) -> Optional[str]:
11321140

11331141
# answer 0 or 1 if the document is related
11341142
if not hasattr(self, "eval_llm"):
1135-
self.eval_llm_params = litellm.get_supported_openai_params(
1136-
model=self.query_eval_modelname,
1137-
custom_llm_provider=self.query_eval_modelbackend,
1138-
)
1143+
failed = False
1144+
if self.query_eval_modelbackend == "openrouter":
1145+
try:
1146+
self.eval_llm_params = litellm.get_supported_openai_params(
1147+
model_name_matcher(
1148+
self.query_eval_modelname.split("/", 1)[1]
1149+
)
1150+
)
1151+
except Exception as err:
1152+
failed = True
1153+
red(f"Failed to get query_eval_model parameters information bypassing openrouter: '{err}'")
1154+
if self.modelbackend != "openrouter" or failed:
1155+
self.eval_llm_params = litellm.get_supported_openai_params(
1156+
model=self.query_eval_modelname,
1157+
custom_llm_provider=self.query_eval_modelbackend,
1158+
)
11391159
eval_args = {}
11401160
if "n" in self.eval_llm_params:
11411161
eval_args["n"] = self.query_eval_check_number
@@ -1186,10 +1206,10 @@ def evaluate_doc_chain(
11861206
if "n" in self.eval_llm_params or self.query_eval_check_number == 1:
11871207
out = self.eval_llm._generate_with_cache(PR_EVALUATE_DOC.format_messages(**inputs))
11881208
reasons = [gen.generation_info["finish_reason"] for gen in out.generations]
1189-
# don't crash if finish_reason is not stop, because it can sometimes still be parsed.
1190-
if not all(r == "stop" for r in reasons):
1191-
red(f"Unexpected generation finish_reason: '{reasons}'")
11921209
outputs = [gen.text for gen in out.generations]
1210+
# don't crash if finish_reason is not stop, because it can sometimes still be parsed.
1211+
if not all(r in ["stop", "lenghth"] for r in reasons):
1212+
red(f"Unexpected generation finish_reason: '{reasons}' for generations: '{outputs}'")
11931213
assert outputs, "No generations found by query eval llm"
11941214
outputs = [parse_eval_output(o) for o in outputs]
11951215
if out.llm_output:
@@ -1216,17 +1236,17 @@ async def do_eval(inputs):
12161236
outs = loop.run_until_complete(asyncio.gather(*outs))
12171237
for out in outs:
12181238
assert len(out.generations) == 1, f"Query eval llm produced more than 1 evaluations: '{out.generations}'"
1219-
finish_reason = out.generations[0].generation_info["finish_reason"]
1220-
if not finish_reason == "stop":
1221-
red(f"Unexpected finish_reason: '{finish_reason}'")
12221239
outputs.append(out.generations[0].text)
1240+
finish_reason = out.generations[0].generation_info["finish_reason"]
1241+
if not finish_reason in ["stop", "length"]:
1242+
red(f"Unexpected finish_reason: '{finish_reason}' for generation '{outputs[-1]}'")
12231243
if out.llm_output:
12241244
new_p += out.llm_output["token_usage"]["prompt_tokens"]
12251245
new_c += out.llm_output["token_usage"]["completion_tokens"]
12261246
assert outputs, "No generations found by query eval llm"
12271247
outputs = [parse_eval_output(o) for o in outputs]
12281248

1229-
assert len(outputs) == self.query_eval_check_number, f"query eval model failed to produce {self.query_eval_check_number} outputs"
1249+
assert len(outputs) == self.query_eval_check_number, f"query eval model failed to produce {self.query_eval_check_number} outputs: '{outputs}'"
12301250

12311251
self.eval_llm.callbacks[0].prompt_tokens += new_p
12321252
self.eval_llm.callbacks[0].completion_tokens += new_c

DocToolsLLM/docs/USAGE.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,11 @@
248248
* `--anki_fields`: List[str]
249249
* List of fields to keep
250250
* `--anki_mode`: str
251-
* any of `window`, `concatenate`, `single_note`: (or _ separated
252-
value like `concatenate_window`). By default `single_note`
251+
* any of `window`, `concatenate`, `singlecard`: (or _ separated
252+
value like `concatenate_window`). By default `singlecard`
253253
is used.
254254
* Modes:
255-
* `single_note`: 1 document is 1 anki note.
255+
* `singlecard`: 1 document is 1 anki card.
256256
* `window`: 1 documents is 5 anki note, overlapping (so
257257
10 anki notes will result in 5 documents)
258258
* `concatenate`: 1 document is all anki notes concatenated as a
@@ -385,7 +385,7 @@
385385
* a string that will be added to the document metadata at the
386386
key `source_tag`. Useful when using filetype combination.
387387

388-
* `--loading_failure`: str, default `crash`
388+
* `--loading_failure`: str, default `warn`
389389
* either `crash` or `warn`. Determines what to do with
390390
exceptions happening when loading a document. This can be set
391391
per document if a recursive_paths filetype is used.

DocToolsLLM/utils/batch_file_loader.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from tqdm import tqdm
1313
from functools import cache as memoizer
1414
import time
15-
import os
1615
from typing import List, Tuple
1716
from functools import wraps
1817
import random
@@ -89,8 +88,8 @@ def batch_load_doc(
8988
if "path" in cli_kwargs and isinstance(cli_kwargs["path"], str):
9089
cli_kwargs["path"] = cli_kwargs["path"].strip()
9190

92-
load_failure = cli_kwargs["load_failure"] if "load_failure" in cli_kwargs else "crash"
93-
assert load_failure in ["crash", "warn"], f"load_failure must be either crash or warn. Not {load_failure}"
91+
loading_failure = cli_kwargs["loading_failure"] if "loading_failure" in cli_kwargs else "warn"
92+
assert loading_failure in ["crash", "warn"], f"loading_failure must be either crash or warn. Not {loading_failure}"
9493

9594
# expand the list of document to load as long as there are recursive types
9695
to_load = [cli_kwargs.copy()]
@@ -144,7 +143,6 @@ def batch_load_doc(
144143
if new_doc_to_load:
145144
assert to_load[ild]["filetype"] in recursive_types
146145
to_load.remove(to_load[ild])
147-
ild_done = None
148146
to_load.extend(new_doc_to_load)
149147
new_doc_to_load = []
150148
continue
@@ -176,7 +174,7 @@ def batch_load_doc(
176174
del doc[k]
177175
# filter out the usuall unexpected
178176
all_unexp_keys = [a for a in all_unexp_keys if a not in [
179-
"out_file", "file_loader_n_jobs"
177+
"out_file", "file_loader_n_jobs", "loading_failure",
180178
]]
181179
if all_unexp_keys:
182180
red(f"Found unexpected keys in doc_kwargs: '{all_unexp_keys}'")
@@ -250,12 +248,12 @@ def load_one_doc_wrapped(**doc_kwargs):
250248
except Exception as err:
251249
filetype = doc_kwargs["filetype"]
252250
red(f"Error when loading doc with filetype {filetype}: '{err}'. Arguments: {doc_kwargs}")
253-
if load_failure == "crash" or is_debug:
251+
if loading_failure == "crash" or is_debug:
254252
raise
255-
elif load_failure == "warn":
253+
elif loading_failure == "warn":
256254
return None
257255
else:
258-
raise ValueError(load_failure)
256+
raise ValueError(loading_failure)
259257

260258
if len(to_load) == 1 or is_debug:
261259
n_jobs = 1
@@ -293,12 +291,40 @@ def load_one_doc_wrapped(**doc_kwargs):
293291
colour="magenta",
294292
)
295293
)
294+
295+
# erases content that links to the loaders temporary files at startup
296+
loaders_temp_dir_file.write_text("")
297+
296298
red(f"Done loading all {len(to_load)} documents in {time.time()-t_load:.2f}s")
297-
n_failed = len([d for d in doc_lists if d is None])
298-
if n_failed:
299-
red(f"Number of failed documents: {n_failed}")
300-
[docs.extend(d) for d in doc_lists if d is not None]
299+
missing_docargs = []
300+
for idoc, d in tqdm(enumerate(doc_lists), total=len(doc_lists), desc="Concatenating results"):
301+
if d is not None:
302+
docs.extend(d)
303+
else:
304+
missing_docargs.append(to_load[idoc])
301305
assert None not in docs
306+
307+
if missing_docargs:
308+
missing_docargs = sorted(missing_docargs, key=lambda x: json.dumps(x))
309+
red(f"Number of failed documents: {len(missing_docargs)}:")
310+
missed_recur = []
311+
for imissed, missed in enumerate(missing_docargs):
312+
if len(missing_docargs) > 99:
313+
red(f"- {imissed + 1:03d}]: '{missed}'")
314+
else:
315+
red(f"- {imissed + 1:02d}]: '{missed}'")
316+
if missed["filetype"] in recursive_types:
317+
missed_recur.append(missed)
318+
319+
if missed_recur:
320+
missed_recur = sorted(missed_recur, key=lambda x: json.dumps(x))
321+
red("Crashing because some recursive filetypes failed:")
322+
for imr, mr in enumerate(missed_recur):
323+
red(f"- {imr + 1}]: '{mr}'")
324+
raise Exception(f"{len(missed_recur)} recursive filetypes failed to load.")
325+
else:
326+
red("No document failed to load!")
327+
302328
assert docs, "No documents were succesfully loaded!"
303329

304330
size = sum([get_tkn_length(d.page_content) for d in docs])

DocToolsLLM/utils/embeddings.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import faiss
1111
import random
1212
import time
13-
import copy
1413
from pathlib import Path, PosixPath
1514
from tqdm import tqdm
1615
import threading

0 commit comments

Comments
 (0)