Skip to content

Commit 552bfe0

Browse files
committed
fix: better token counting and fixes cache
1 parent a6a5db6 commit 552bfe0

9 files changed

Lines changed: 498 additions & 44 deletions

File tree

application/api/answer/services/compression/token_counter.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
class TokenCounter:
1313
"""Centralized token counting for conversations and messages."""
1414

15+
# Per-image token estimate. Provider tokenizers vary widely
16+
# (Gemini ~258, GPT-4o 85-1500, Claude ~1500) and the actual cost
17+
# depends on resolution/detail we can't see here. Errs slightly high
18+
# so the threshold check stays conservative.
19+
_IMAGE_PART_TOKEN_ESTIMATE = 1500
20+
1521
@staticmethod
1622
def count_message_tokens(messages: List[Dict]) -> int:
1723
"""
@@ -29,12 +35,36 @@ def count_message_tokens(messages: List[Dict]) -> int:
2935
if isinstance(content, str):
3036
total_tokens += num_tokens_from_string(content)
3137
elif isinstance(content, list):
32-
# Handle structured content (tool calls, etc.)
38+
# Handle structured content (tool calls, image parts, etc.)
3339
for item in content:
3440
if isinstance(item, dict):
35-
total_tokens += num_tokens_from_string(str(item))
41+
total_tokens += TokenCounter._count_content_part(item)
3642
return total_tokens
3743

44+
@staticmethod
45+
def _count_content_part(item: Dict) -> int:
46+
# Image/file attachments are billed by the provider per image,
47+
# not proportional to the inline bytes/base64 string.
48+
# ``str(item)`` on a 1MB image inflates the count by ~10000x,
49+
# which trips spurious compression and overflows downstream
50+
# input limits.
51+
item_type = item.get("type")
52+
53+
if "files" in item:
54+
files = item.get("files")
55+
count = len(files) if isinstance(files, list) and files else 1
56+
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE * count
57+
58+
if "image_url" in item or item_type in {
59+
"image",
60+
"image_url",
61+
"input_image",
62+
"file",
63+
}:
64+
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE
65+
66+
return num_tokens_from_string(str(item))
67+
3868
@staticmethod
3969
def count_query_tokens(
4070
queries: List[Dict[str, Any]], include_tool_calls: bool = True

application/cache.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import hashlib
12
import json
23
import logging
34
import time
@@ -10,6 +11,14 @@
1011

1112
logger = logging.getLogger(__name__)
1213

14+
15+
def _cache_default(value):
16+
# Image attachments arrive inline as bytes (see GoogleLLM.prepare_messages_with_attachments);
17+
# hash so the cache key stays bounded in size and stable across identical content.
18+
if isinstance(value, (bytes, bytearray, memoryview)):
19+
return f"<bytes:sha256:{hashlib.sha256(bytes(value)).hexdigest()}>"
20+
return repr(value)
21+
1322
_redis_instance = None
1423
_redis_creation_failed = False
1524
_instance_lock = Lock()
@@ -36,7 +45,7 @@ def get_redis_instance():
3645
def gen_cache_key(messages, model="docgpt", tools=None):
3746
if not all(isinstance(msg, dict) for msg in messages):
3847
raise ValueError("All messages must be dictionaries.")
39-
messages_str = json.dumps(messages)
48+
messages_str = json.dumps(messages, default=_cache_default)
4049
tools_str = json.dumps(str(tools)) if tools else ""
4150
combined = f"{model}_{messages_str}_{tools_str}"
4251
cache_key = get_hash(combined)

application/llm/base.py

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def decorated_method():
166166

167167
if is_stream:
168168
return self._stream_with_fallback(
169-
decorated_method, method_name, *args, **kwargs
169+
decorated_method, method_name, decorators, *args, **kwargs
170170
)
171171

172172
try:
@@ -187,14 +187,27 @@ def decorated_method():
187187
f"{fallback.model_id}. Error: {str(e)}"
188188
)
189189

190-
fallback_method = getattr(
191-
fallback, method_name.replace("_raw_", "")
192-
)
190+
# Apply decorators to fallback's raw method directly — calling
191+
# fallback.gen() would re-enter the orchestrator and recurse via
192+
# fallback.fallback_llm.
193+
fallback_method = getattr(fallback, method_name)
194+
for decorator in decorators:
195+
fallback_method = decorator(fallback_method)
193196
fallback_kwargs = {**kwargs, "model": fallback.model_id}
194-
return fallback_method(*args, **fallback_kwargs)
197+
try:
198+
return fallback_method(fallback, *args, **fallback_kwargs)
199+
except Exception as e2:
200+
if self._is_non_retriable_client_error(e2):
201+
logger.error(
202+
f"Fallback LLM failed with non-retriable client "
203+
f"error; giving up: {str(e2)}"
204+
)
205+
else:
206+
logger.error(f"Fallback LLM also failed; giving up: {str(e2)}")
207+
raise
195208

196209
def _stream_with_fallback(
197-
self, decorated_method, method_name, *args, **kwargs
210+
self, decorated_method, method_name, decorators, *args, **kwargs
198211
):
199212
"""
200213
Wrapper generator that catches mid-stream errors and falls back.
@@ -223,11 +236,37 @@ def _stream_with_fallback(
223236
f"Primary LLM failed mid-stream. Falling back to "
224237
f"{fallback.model_id}. Error: {str(e)}"
225238
)
226-
fallback_method = getattr(
227-
fallback, method_name.replace("_raw_", "")
239+
# Apply decorators to fallback's raw stream method directly —
240+
# calling fallback.gen_stream() would re-enter the orchestrator
241+
# and recurse via fallback.fallback_llm. Emit the stream-start
242+
# event manually so dashboards still see the fallback's
243+
# provider/model when the response actually comes from it.
244+
fallback._emit_stream_start_log(
245+
fallback.model_id,
246+
kwargs.get("messages"),
247+
kwargs.get("tools"),
248+
bool(
249+
kwargs.get("_usage_attachments")
250+
or kwargs.get("attachments")
251+
),
228252
)
253+
fallback_method = getattr(fallback, method_name)
254+
for decorator in decorators:
255+
fallback_method = decorator(fallback_method)
229256
fallback_kwargs = {**kwargs, "model": fallback.model_id}
230-
yield from fallback_method(*args, **fallback_kwargs)
257+
try:
258+
yield from fallback_method(fallback, *args, **fallback_kwargs)
259+
except Exception as e2:
260+
if self._is_non_retriable_client_error(e2):
261+
logger.error(
262+
f"Fallback LLM failed mid-stream with non-retriable "
263+
f"client error; giving up: {str(e2)}"
264+
)
265+
else:
266+
logger.error(
267+
f"Fallback LLM also failed mid-stream; giving up: {str(e2)}"
268+
)
269+
raise
231270

232271
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
233272
decorators = [gen_token_usage, gen_cache]
@@ -242,22 +281,29 @@ def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
242281
**kwargs,
243282
)
244283

245-
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
246-
# Attachments arrive as ``_usage_attachments`` from ``Agent._llm_gen``;
247-
# the ``stream_token_usage`` decorator pops that key, but the log
248-
# fires before the decorator runs so it's still in ``kwargs`` here.
284+
def _emit_stream_start_log(self, model, messages, tools, has_attachments):
285+
# Stamped with ``self.provider_name`` so dashboards can group calls
286+
# by vendor; the fallback path emits its own copy on the fallback
287+
# instance so the actual responding provider is recorded.
249288
logging.info(
250289
"llm_stream_start",
251290
extra={
252291
"model": model,
253292
"provider": self.provider_name,
254293
"message_count": len(messages) if messages is not None else 0,
255-
"has_attachments": bool(
256-
kwargs.get("_usage_attachments") or kwargs.get("attachments")
257-
),
294+
"has_attachments": bool(has_attachments),
258295
"has_tools": bool(tools),
259296
},
260297
)
298+
299+
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
300+
# Attachments arrive as ``_usage_attachments`` from ``Agent._llm_gen``;
301+
# the ``stream_token_usage`` decorator pops that key, but the log
302+
# fires before the decorator runs so it's still in ``kwargs`` here.
303+
has_attachments = bool(
304+
kwargs.get("_usage_attachments") or kwargs.get("attachments")
305+
)
306+
self._emit_stream_start_log(model, messages, tools, has_attachments)
261307
decorators = [stream_cache, stream_token_usage]
262308
return self._execute_with_fallback(
263309
"_raw_gen_stream",

application/llm/handlers/base.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,26 @@ def _extract_text_from_content(self, content: Any) -> str:
280280
# Keep serialized function calls/responses so the compressor sees actions
281281
parts_text.append(str(item))
282282
elif "files" in item:
283-
parts_text.append(str(item))
283+
# Image attachments arrive with raw bytes / base64
284+
# inline (see GoogleLLM.prepare_messages_with_attachments).
285+
# ``str(item)`` would dump the whole byte/base64
286+
# blob into the compression prompt and bust the
287+
# compression LLM's input limit.
288+
files = item.get("files") or []
289+
descriptors = []
290+
if isinstance(files, list):
291+
for f in files:
292+
if isinstance(f, dict):
293+
descriptors.append(
294+
f.get("mime_type") or "file"
295+
)
296+
elif isinstance(f, str):
297+
descriptors.append(f)
298+
if not descriptors:
299+
descriptors = ["file"]
300+
parts_text.append(
301+
f"[attachment: {', '.join(descriptors)}]"
302+
)
284303
return "\n".join(parts_text)
285304
return ""
286305

tests/llm/handlers/test_llm_handlers.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,38 @@ def test_list_with_files(self):
360360
handler = ConcreteHandler()
361361
content = [{"files": ["/tmp/a.txt"]}]
362362
result = handler._extract_text_from_content(content)
363-
assert "files" in result
363+
assert result == "[attachment: /tmp/a.txt]"
364+
365+
def test_list_with_inline_image_bytes(self):
366+
# Google attaches images as inline bytes; stringifying them into
367+
# the compression prompt would bust the compression LLM's input
368+
# limit. The placeholder must describe the attachment without
369+
# embedding the bytes.
370+
handler = ConcreteHandler()
371+
content = [
372+
{
373+
"files": [
374+
{"file_bytes": b"\x89PNG" + b"\x00" * 1000, "mime_type": "image/png"}
375+
]
376+
}
377+
]
378+
result = handler._extract_text_from_content(content)
379+
assert result == "[attachment: image/png]"
380+
assert "PNG" not in result
381+
assert "\\x" not in result
382+
383+
def test_list_with_multiple_files(self):
384+
handler = ConcreteHandler()
385+
content = [
386+
{
387+
"files": [
388+
{"file_bytes": b"a", "mime_type": "image/png"},
389+
{"file_uri": "https://x", "mime_type": "image/jpeg"},
390+
]
391+
}
392+
]
393+
result = handler._extract_text_from_content(content)
394+
assert result == "[attachment: image/png, image/jpeg]"
364395

365396
def test_list_with_none_text(self):
366397
handler = ConcreteHandler()

tests/llm/test_base.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ def _raw_gen_stream(self, baseself, model, messages, stream=True, tools=None, **
4949

5050

5151
class FallbackLLM(BaseLLM):
52+
# _execute_with_fallback applies decorators to the fallback's raw method
53+
# directly and never calls .gen() / .gen_stream() on it, so
54+
# tracking lives on the raw methods.
5255
def __init__(self, **kwargs):
5356
super().__init__(**kwargs)
5457
self.gen_called = False
@@ -62,14 +65,6 @@ def _raw_gen_stream(self, baseself, model, messages, stream=True, tools=None, **
6265
self.gen_stream_called = True
6366
yield "fallback_chunk"
6467

65-
def gen(self, *args, **kwargs):
66-
self.gen_called = True
67-
return "fallback_gen_result"
68-
69-
def gen_stream(self, *args, **kwargs):
70-
self.gen_stream_called = True
71-
yield "fallback_stream_chunk"
72-
7368

7469
# ---------------------------------------------------------------------------
7570
# gen / gen_stream decorator application
@@ -230,7 +225,7 @@ def test_fallback_called_on_failure(self):
230225
llm._fallback_llm = fallback
231226

232227
result = llm.gen(model="m", messages=[])
233-
assert result == "fallback_gen_result"
228+
assert result == "fallback_result"
234229
assert fallback.gen_called
235230

236231

@@ -257,7 +252,7 @@ def test_fallback_called_on_stream_failure(self):
257252
llm._fallback_llm = fallback
258253

259254
result = list(llm.gen_stream(model="m", messages=[]))
260-
assert "fallback_stream_chunk" in result
255+
assert "fallback_chunk" in result
261256
assert fallback.gen_stream_called
262257

263258

@@ -344,7 +339,7 @@ def test_5xx_still_falls_back(self):
344339
llm._fallback_llm = fallback
345340

346341
result = llm.gen(model="m", messages=[])
347-
assert result == "fallback_gen_result"
342+
assert result == "fallback_result"
348343
assert fallback.gen_called
349344

350345

0 commit comments

Comments
 (0)