Skip to content

Commit eab1d81

Browse files
Merge branch 'dev'
2 parents bb9ecf1 + 16df89f commit eab1d81

File tree

5 files changed

+78
-77
lines changed

5 files changed

+78
-77
lines changed

DocToolsLLM/DocToolsLLM.py

Lines changed: 53 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
class DocToolsLLM_class:
7979
"This docstring is dynamically replaced by the content of DocToolsLLM/docs/USAGE.md"
8080

81-
VERSION: str = "0.36"
81+
VERSION: str = "0.37"
8282

8383
#@optional_typecheck
8484
@typechecked
@@ -282,6 +282,8 @@ def handle_exception(exc_type, exc_value, exc_traceback):
282282
else:
283283
self.llm_cache = SQLiteCache(database_path=(cache_dir / "private_langchain.db").resolve().absolute())
284284
set_llm_cache(self.llm_cache)
285+
else:
286+
self.llm_cache = not no_llm_cache
285287

286288
if llms_api_bases["model"]:
287289
red(f"Disabling price computation for model because api_base was modified")
@@ -364,7 +366,7 @@ def ntfy(text: str) -> str:
364366
self.llm = load_llm(
365367
modelname=modelname,
366368
backend=self.modelbackend,
367-
llm_cache=self.llm_cache if not self.no_llm_cache else False,
369+
llm_cache=self.llm_cache,
368370
temperature=0,
369371
verbose=self.llm_verbosity,
370372
api_base=self.llms_api_bases["model"],
@@ -1021,11 +1023,37 @@ def _query(self, query: Optional[str]) -> Optional[str]:
10211023
whi(f"Question to answer: {query_an}")
10221024

10231025
# the eval doc chain needs its own caching
1024-
if not self.no_llm_cache:
1026+
if self.llm_cache:
10251027
eval_cache_wrapper = doc_eval_cache.cache
10261028
else:
10271029
def eval_cache_wrapper(func): return func
10281030

1031+
# answer 0 or 1 if the document is related
1032+
if not hasattr(self, "eval_llm"):
1033+
self.eval_llm_params = litellm.get_supported_openai_params(
1034+
model=self.query_eval_modelname,
1035+
custom_llm_provider=self.query_eval_modelbackend,
1036+
)
1037+
eval_args = {}
1038+
if "n" in self.eval_llm_params:
1039+
eval_args["n"] = self.query_eval_check_number
1040+
else:
1041+
red(f"Model {self.query_eval_modelname} does not support parameter 'n' so will be called multiple times instead. This might cost more.")
1042+
if "max_tokens" in self.eval_llm_params:
1043+
eval_args["max_tokens"] = 2
1044+
else:
1045+
red(f"Model {self.query_eval_modelname} does not support parameter 'max_token' so the result might be of less quality.")
1046+
self.eval_llm = load_llm(
1047+
modelname=self.query_eval_modelname,
1048+
backend=self.query_eval_modelbackend,
1049+
llm_cache=False, # disables caching because another caching is used on top
1050+
verbose=self.llm_verbosity,
1051+
temperature=1,
1052+
api_base=self.llms_api_bases["query_eval_model"],
1053+
private=self.private,
1054+
**eval_args,
1055+
)
1056+
10291057
@chain
10301058
@optional_typecheck
10311059
@eval_cache_wrapper
@@ -1039,8 +1067,12 @@ def evaluate_doc_chain(
10391067
outputs = [gen.text for gen in out.generations]
10401068
assert outputs, "No generations found by query eval llm"
10411069
outputs = [parse_eval_output(o) for o in outputs]
1042-
new_p = out.llm_output["token_usage"]["prompt_tokens"]
1043-
new_c = out.llm_output["token_usage"]["completion_tokens"]
1070+
if out.llm_output:
1071+
new_p = out.llm_output["token_usage"]["prompt_tokens"]
1072+
new_c = out.llm_output["token_usage"]["completion_tokens"]
1073+
else:
1074+
new_p = 0
1075+
new_c = 0
10441076
else:
10451077
outputs = []
10461078
new_p = 0
@@ -1060,8 +1092,9 @@ async def eval(inputs):
10601092
for out in outs:
10611093
assert len(out.generations) == 1, f"Query eval llm produced more than 1 evaluations: '{out.generations}'"
10621094
outputs.append(out.generations[0].text)
1063-
new_p += out.llm_output["token_usage"]["prompt_tokens"]
1064-
new_c += out.llm_output["token_usage"]["completion_tokens"]
1095+
if out.llm_output:
1096+
new_p += out.llm_output["token_usage"]["prompt_tokens"]
1097+
new_c += out.llm_output["token_usage"]["completion_tokens"]
10651098
assert outputs, "No generations found by query eval llm"
10661099
outputs = [parse_eval_output(o) for o in outputs]
10671100

@@ -1072,36 +1105,12 @@ async def eval(inputs):
10721105
self.eval_llm.callbacks[0].total_tokens += new_p + new_c
10731106
return outputs
10741107

1108+
# uses in most places to increase concurrency limit
1109+
multi = {"max_concurrency": 50 if not self.debug else 1}
1110+
10751111
if self.task == "search":
10761112
if self.query_eval_modelname:
1077-
# uses in most places to increase concurrency limit
1078-
multi = {"max_concurrency": 50 if not self.debug else 1}
1079-
1080-
# answer 0 or 1 if the document is related
1081-
if not hasattr(self, "eval_llm"):
1082-
self.eval_llm_params = litellm.get_supported_openai_params(
1083-
model=self.query_eval_modelname,
1084-
custom_llm_provider=self.query_eval_modelbackend,
1085-
)
1086-
eval_args = {}
1087-
if "n" in self.eval_llm_params:
1088-
eval_args["n"] = self.query_eval_check_number
1089-
else:
1090-
red(f"Model {self.query_eval_modelname} does not support parameter 'n' so will be called multiple times instead. This might cost more.")
1091-
if "max_tokens" in self.eval_llm_params:
1092-
eval_args["max_tokens"] = 2
1093-
else:
1094-
red(f"Model {self.query_eval_modelname} does not support parameter 'max_token' so the result might be of less quality.")
1095-
self.eval_llm = load_llm(
1096-
modelname=self.query_eval_modelname,
1097-
backend=self.query_eval_modelbackend,
1098-
llm_cache=self.llm_cache if not self.no_llm_cache else False,
1099-
verbose=self.llm_verbosity,
1100-
temperature=1,
1101-
api_base=self.llms_api_bases["query_eval_model"],
1102-
private=self.private,
1103-
**eval_args,
1104-
)
1113+
11051114

11061115
# for some reason I needed to have at least one chain object otherwise rag_chain is a dict
11071116
@chain
@@ -1210,35 +1219,6 @@ def retrieve_documents(inputs):
12101219
| StrOutputParser()
12111220
}
12121221

1213-
# uses in most places to increase concurrency limit
1214-
multi = {"max_concurrency": 50 if not self.debug else 1}
1215-
1216-
# answer 0 or 1 if the document is related
1217-
if not hasattr(self, "eval_llm"):
1218-
self.eval_llm_params = litellm.get_supported_openai_params(
1219-
model=self.query_eval_modelname,
1220-
custom_llm_provider=self.query_eval_modelbackend,
1221-
)
1222-
eval_args = {}
1223-
if "n" in self.eval_llm_params:
1224-
eval_args["n"] = self.query_eval_check_number
1225-
else:
1226-
red(f"Model {self.query_eval_modelname} does not support parameter 'n' so will be called multiple times instead. This might cost more.")
1227-
if "max_tokens" in self.eval_llm_params:
1228-
eval_args["max_tokens"] = 2
1229-
else:
1230-
red(f"Model {self.query_eval_modelname} does not support parameter 'max_token' so the result might be of less quality.")
1231-
self.eval_llm = load_llm(
1232-
modelname=self.query_eval_modelname,
1233-
backend=self.query_eval_modelbackend,
1234-
llm_cache=self.llm_cache if not self.no_llm_cache else False,
1235-
verbose=self.llm_verbosity,
1236-
temperature=1,
1237-
api_base=self.llms_api_bases["query_eval_model"],
1238-
private=self.private,
1239-
**eval_args,
1240-
)
1241-
12421222
# the eval doc chain needs its own caching
12431223
if self.no_llm_cache:
12441224
def eval_cache_wrapper(func): return func
@@ -1260,8 +1240,12 @@ def evaluate_doc_chain(
12601240
outputs = [gen.text for gen in out.generations]
12611241
assert outputs, "No generations found by query eval llm"
12621242
outputs = [parse_eval_output(o) for o in outputs]
1263-
new_p = out.llm_output["token_usage"]["prompt_tokens"]
1264-
new_c = out.llm_output["token_usage"]["completion_tokens"]
1243+
if out.llm_output:
1244+
new_p = out.llm_output["token_usage"]["prompt_tokens"]
1245+
new_c = out.llm_output["token_usage"]["completion_tokens"]
1246+
else:
1247+
new_p = 0
1248+
new_c = 0
12651249
else:
12661250
outputs = []
12671251
new_p = 0
@@ -1283,8 +1267,9 @@ async def eval(inputs):
12831267
outputs.append(out.generations[0].text)
12841268
finish_reason = out.generations[0].generation_info["finish_reason"]
12851269
assert finish_reason == "stop", f"unexpected finish_reason: '{finish_reason}'"
1286-
new_p += out.llm_output["token_usage"]["prompt_tokens"]
1287-
new_c += out.llm_output["token_usage"]["completion_tokens"]
1270+
if out.llm_output:
1271+
new_p += out.llm_output["token_usage"]["prompt_tokens"]
1272+
new_c += out.llm_output["token_usage"]["completion_tokens"]
12881273
assert outputs, "No generations found by query eval llm"
12891274
outputs = [parse_eval_output(o) for o in outputs]
12901275

DocToolsLLM/utils/llm.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def load_llm(
3939
modelname: str,
4040
backend: str,
4141
verbose: bool,
42-
llm_cache: Union[bool, SQLiteCache],
42+
llm_cache: Union[None, bool, SQLiteCache],
4343
api_base: Optional[str],
4444
private: bool,
4545
**extra_model_args,
@@ -81,7 +81,7 @@ def load_llm(
8181
else:
8282
assert os.environ["DOCTOOLS_PRIVATEMODE"] == "false"
8383

84-
if not private and backend == "openai" and api_base is None and llm_cache is not False:
84+
if not private and backend == "openai" and api_base is None:
8585
red("Using ChatOpenAI instead of litellm because calling openai server anyway and the caching has a bug on langchain side :( The caching works on ChatOpenAI though. More at https://github.com/langchain-ai/langchain/issues/22389")
8686
max_tokens = litellm.get_model_info(modelname)["max_tokens"]
8787
if "max_tokens" not in extra_model_args:
@@ -98,7 +98,7 @@ def load_llm(
9898
max_tokens = litellm.get_model_info(modelname)["max_tokens"]
9999
if "max_tokens" not in extra_model_args:
100100
extra_model_args["max_tokens"] = max_tokens
101-
if llm_cache is not False:
101+
if llm_cache is not None:
102102
red(f"Reminder: caching is disabled for non openai models until langchain approves the fix.")
103103
llm = ChatLiteLLM(
104104
model_name=modelname,
@@ -111,6 +111,12 @@ def load_llm(
111111
if private:
112112
assert llm.api_base, "private is set but no api_base for llm were found"
113113
assert llm.api_base == api_base, "private is set but found unexpected llm.api_base value: '{litellm.api_base}'"
114+
115+
# fix: the SQLiteCache's str appearance is cancelling its own cache lookup!
116+
if llm.cache:
117+
cur = str(llm.cache)
118+
llm.cache.__class__.__repr__ = lambda: cur.split(" at ")[0]
119+
llm.cache.__class__.__str__ = lambda: cur.split(" at ")[0]
114120
return llm
115121

116122

@@ -133,6 +139,14 @@ def __init__(self, verbose, *args, **kwargs):
133139
"on_chain_error",
134140
]
135141

142+
def __repr__(self) -> str:
143+
# setting __repr__ and __str__ is important because it can
144+
# maybe be used for caching?
145+
return "PriceCountingCallback"
146+
147+
def __str__(self) -> str:
148+
return "PriceCountingCallback"
149+
136150
def _check_methods_called(self) -> None:
137151
assert all(meth in dir(self) for meth in self.methods_called), (
138152
"unexpected method names!")

DocToolsLLM/utils/tasks/summary.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ def do_summarize(
3535
assert "[PROGRESS]" in metadata
3636
for ird, rd in tqdm(enumerate(docs), desc="Summarising splits", total=len(docs)):
3737
fixed_index = f"{ird + 1}/{len(docs)}"
38-
if ird > 0:
39-
assert llm.callbacks[0].total_tokens > 0
4038

4139
messages = BASE_SUMMARY_PROMPT.format_messages(
4240
text=rd.page_content,
@@ -50,8 +48,12 @@ def do_summarize(
5048
assert finish == "stop", f"Unexpected finish_reason: '{finish}'"
5149
assert len(output.generations) == 1
5250
out = output.generations[0].text
53-
new_p = output.llm_output["token_usage"]["prompt_tokens"]
54-
new_c = output.llm_output["token_usage"]["completion_tokens"]
51+
if output.llm_output: # only present if not caching
52+
new_p = output.llm_output["token_usage"]["prompt_tokens"]
53+
new_c = output.llm_output["token_usage"]["completion_tokens"]
54+
else:
55+
new_p = 0
56+
new_c = 0
5557
total_tokens += new_p + new_c
5658
total_cost += (new_p * llm_price[0] + new_c + llm_price[1]) / 1e6
5759

bumpver.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpver]
2-
current_version = "0.36"
2+
current_version = "0.37"
33
version_pattern = "MAJOR.MINOR"
44
commit_message = "bump version {old_version} -> {new_version}"
55
tag_message = "{new_version}"

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def run(self):
2828

2929
setup(
3030
name="DocToolsLLM",
31-
version="0.36",
31+
version="0.37",
3232
description="A perfect RAG and AI summary setup for my needs. Supports all LLM, virt. any filetypes (epub, youtube_playlist, pdf, mp3, etc)",
3333
long_description=long_description,
3434
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)