Skip to content

Commit d70bea6

Browse files
fix(ollama): pre-imported funcs instrumentation failure (#2871)
Co-authored-by: Nir Gazit <[email protected]>
1 parent 0ec0877 commit d70bea6

File tree

3 files changed

+113
-54
lines changed

3 files changed

+113
-54
lines changed

packages/opentelemetry-instrumentation-ollama/opentelemetry/instrumentation/ollama/__init__.py

Lines changed: 98 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,29 @@
4848
]
4949

5050

51+
def _sanitize_copy_messages(wrapped, instance, args, kwargs):
52+
# original signature: _copy_messages(messages)
53+
messages = args[0] if args else []
54+
sanitized = []
55+
for msg in messages or []:
56+
if isinstance(msg, dict):
57+
msg_copy = dict(msg)
58+
tc_list = msg_copy.get("tool_calls")
59+
if tc_list:
60+
for tc in tc_list:
61+
func = tc.get("function")
62+
arg = func.get("arguments") if func else None
63+
if isinstance(arg, str):
64+
try:
65+
func["arguments"] = json.loads(arg)
66+
except Exception:
67+
pass
68+
sanitized.append(msg_copy)
69+
else:
70+
sanitized.append(msg)
71+
return wrapped(sanitized)
72+
73+
5174
def should_send_prompts():
5275
return (
5376
os.getenv("TRACELOOP_TRACE_CONTENT") or "true"
@@ -89,15 +112,18 @@ def _set_prompts(span, messages):
89112
f"{prefix}.tool_calls.{i}.name",
90113
function.get("name"),
91114
)
115+
# record arguments: ensure it's a JSON string for span attributes
116+
raw_args = function.get("arguments")
117+
if isinstance(raw_args, dict):
118+
arg_str = json.dumps(raw_args)
119+
else:
120+
arg_str = raw_args
92121
_set_span_attribute(
93122
span,
94123
f"{prefix}.tool_calls.{i}.arguments",
95-
function.get("arguments"),
124+
arg_str,
96125
)
97126

98-
if function.get("arguments"):
99-
function["arguments"] = json.loads(function.get("arguments"))
100-
101127

102128
def set_tools_attributes(span, tools):
103129
if not tools:
@@ -118,15 +144,15 @@ def set_tools_attributes(span, tools):
118144

119145
@dont_throw
120146
def _set_input_attributes(span, llm_request_type, kwargs):
121-
_set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, kwargs.get("model"))
147+
json_data = kwargs.get("json", {})
148+
_set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, json_data.get("model"))
122149
_set_span_attribute(
123150
span, SpanAttributes.LLM_IS_STREAMING, kwargs.get("stream") or False
124151
)
125-
126152
if should_send_prompts():
127153
if llm_request_type == LLMRequestTypeValues.CHAT:
128154
_set_span_attribute(span, f"{SpanAttributes.LLM_PROMPTS}.0.role", "user")
129-
for index, message in enumerate(kwargs.get("messages")):
155+
for index, message in enumerate(json_data.get("messages")):
130156
_set_span_attribute(
131157
span,
132158
f"{SpanAttributes.LLM_PROMPTS}.{index}.content",
@@ -137,13 +163,13 @@ def _set_input_attributes(span, llm_request_type, kwargs):
137163
f"{SpanAttributes.LLM_PROMPTS}.{index}.role",
138164
message.get("role"),
139165
)
140-
_set_prompts(span, kwargs.get("messages"))
141-
if kwargs.get("tools"):
142-
set_tools_attributes(span, kwargs.get("tools"))
166+
_set_prompts(span, json_data.get("messages"))
167+
if json_data.get("tools"):
168+
set_tools_attributes(span, json_data.get("tools"))
143169
else:
144170
_set_span_attribute(span, f"{SpanAttributes.LLM_PROMPTS}.0.role", "user")
145171
_set_span_attribute(
146-
span, f"{SpanAttributes.LLM_PROMPTS}.0.content", kwargs.get("prompt")
172+
span, f"{SpanAttributes.LLM_PROMPTS}.0.content", json_data.get("prompt")
147173
)
148174

149175

@@ -240,7 +266,8 @@ def _accumulate_streaming_response(span, token_histogram, llm_request_type, resp
240266
accumulated_response["message"]["content"] += res["message"]["content"]
241267
accumulated_response["message"]["role"] = res["message"]["role"]
242268
elif llm_request_type == LLMRequestTypeValues.COMPLETION:
243-
accumulated_response["response"] += res["response"]
269+
text = res.get("response", "")
270+
accumulated_response["response"] += text
244271

245272
response_data = res.model_dump() if hasattr(res, 'model_dump') else res
246273
_set_response_attributes(span, token_histogram, llm_request_type, response_data | accumulated_response)
@@ -260,7 +287,8 @@ async def _aaccumulate_streaming_response(span, token_histogram, llm_request_typ
260287
accumulated_response["message"]["content"] += res["message"]["content"]
261288
accumulated_response["message"]["role"] = res["message"]["role"]
262289
elif llm_request_type == LLMRequestTypeValues.COMPLETION:
263-
accumulated_response["response"] += res["response"]
290+
text = res.get("response", "")
291+
accumulated_response["response"] += text
264292

265293
response_data = res.model_dump() if hasattr(res, 'model_dump') else res
266294
_set_response_attributes(span, token_histogram, llm_request_type, response_data | accumulated_response)
@@ -336,13 +364,11 @@ def _wrap(
336364
if response:
337365
if duration_histogram:
338366
duration = end_time - start_time
339-
duration_histogram.record(
340-
duration,
341-
attributes={
342-
SpanAttributes.LLM_SYSTEM: "Ollama",
343-
SpanAttributes.LLM_RESPONSE_MODEL: kwargs.get("model"),
344-
},
345-
)
367+
attrs = {SpanAttributes.LLM_SYSTEM: "Ollama"}
368+
model = kwargs.get("model")
369+
if model is not None:
370+
attrs[SpanAttributes.LLM_RESPONSE_MODEL] = model
371+
duration_histogram.record(duration, attributes=attrs)
346372

347373
if span.is_recording():
348374
if kwargs.get("stream"):
@@ -392,13 +418,11 @@ async def _awrap(
392418
if response:
393419
if duration_histogram:
394420
duration = end_time - start_time
395-
duration_histogram.record(
396-
duration,
397-
attributes={
398-
SpanAttributes.LLM_SYSTEM: "Ollama",
399-
SpanAttributes.LLM_RESPONSE_MODEL: kwargs.get("model"),
400-
},
401-
)
421+
attrs = {SpanAttributes.LLM_SYSTEM: "Ollama"}
422+
model = kwargs.get("model")
423+
if model is not None:
424+
attrs[SpanAttributes.LLM_RESPONSE_MODEL] = model
425+
duration_histogram.record(duration, attributes=attrs)
402426

403427
if span.is_recording():
404428
if kwargs.get("stream"):
@@ -459,23 +483,23 @@ def _instrument(self, **kwargs):
459483
duration_histogram,
460484
) = (None, None)
461485

462-
for wrapped_method in WRAPPED_METHODS:
463-
wrap_method = wrapped_method.get("method")
464-
wrap_function_wrapper(
465-
"ollama._client",
466-
f"Client.{wrap_method}",
467-
_wrap(tracer, token_histogram, duration_histogram, wrapped_method),
468-
)
469-
wrap_function_wrapper(
470-
"ollama._client",
471-
f"AsyncClient.{wrap_method}",
472-
_awrap(tracer, token_histogram, duration_histogram, wrapped_method),
473-
)
474-
wrap_function_wrapper(
475-
"ollama",
476-
f"{wrap_method}",
477-
_wrap(tracer, token_histogram, duration_histogram, wrapped_method),
478-
)
486+
# Patch _copy_messages to sanitize tool_calls arguments before Pydantic validation
487+
wrap_function_wrapper(
488+
"ollama._client",
489+
"_copy_messages",
490+
_sanitize_copy_messages,
491+
)
492+
# instrument all llm methods (generate/chat/embeddings) via _request dispatch wrapper
493+
wrap_function_wrapper(
494+
"ollama._client",
495+
"Client._request",
496+
_dispatch_wrap(tracer, token_histogram, duration_histogram),
497+
)
498+
wrap_function_wrapper(
499+
"ollama._client",
500+
"AsyncClient._request",
501+
_dispatch_awrap(tracer, token_histogram, duration_histogram),
502+
)
479503

480504
def _uninstrument(self, **kwargs):
481505
for wrapped_method in WRAPPED_METHODS:
@@ -491,3 +515,33 @@ def _uninstrument(self, **kwargs):
491515
"ollama",
492516
wrapped_method.get("method"),
493517
)
518+
519+
520+
def _dispatch_wrap(tracer, token_histogram, duration_histogram):
521+
def wrapper(wrapped, instance, args, kwargs):
522+
to_wrap = None
523+
if len(args) > 2 and isinstance(args[2], str):
524+
path = args[2]
525+
op = path.rstrip('/').split('/')[-1]
526+
to_wrap = next((m for m in WRAPPED_METHODS if m.get("method") == op), None)
527+
if to_wrap:
528+
return _wrap(tracer, token_histogram, duration_histogram, to_wrap)(
529+
wrapped, instance, args, kwargs
530+
)
531+
return wrapped(*args, **kwargs)
532+
return wrapper
533+
534+
535+
def _dispatch_awrap(tracer, token_histogram, duration_histogram):
536+
async def wrapper(wrapped, instance, args, kwargs):
537+
to_wrap = None
538+
if len(args) > 2 and isinstance(args[2], str):
539+
path = args[2]
540+
op = path.rstrip('/').split('/')[-1]
541+
to_wrap = next((m for m in WRAPPED_METHODS if m.get("method") == op), None)
542+
if to_wrap:
543+
return await _awrap(tracer, token_histogram, duration_histogram, to_wrap)(
544+
wrapped, instance, args, kwargs
545+
)
546+
return await wrapped(*args, **kwargs)
547+
return wrapper

packages/opentelemetry-instrumentation-ollama/poetry.lock

Lines changed: 9 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/opentelemetry-instrumentation-ollama/tests/test_chat.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
import ollama
2+
from ollama import AsyncClient, chat
33
from opentelemetry.semconv_ai import SpanAttributes
44
from unittest.mock import MagicMock
55
from opentelemetry.instrumentation.ollama import _set_response_attributes
@@ -8,7 +8,7 @@
88

99
@pytest.mark.vcr
1010
def test_ollama_chat(exporter):
11-
response = ollama.chat(
11+
response = chat(
1212
model="llama3",
1313
messages=[
1414
{
@@ -45,7 +45,7 @@ def test_ollama_chat(exporter):
4545

4646
@pytest.mark.vcr
4747
def test_ollama_chat_tool_calls(exporter):
48-
ollama.chat(
48+
chat(
4949
model="llama3.1",
5050
messages=[
5151
{
@@ -93,7 +93,7 @@ def test_ollama_chat_tool_calls(exporter):
9393

9494
@pytest.mark.vcr
9595
def test_ollama_streaming_chat(exporter):
96-
gen = ollama.chat(
96+
gen = chat(
9797
model="llama3",
9898
messages=[
9999
{
@@ -136,7 +136,7 @@ def test_ollama_streaming_chat(exporter):
136136
@pytest.mark.vcr
137137
@pytest.mark.asyncio
138138
async def test_ollama_async_chat(exporter):
139-
client = ollama.AsyncClient()
139+
client = AsyncClient()
140140
response = await client.chat(
141141
model="llama3",
142142
messages=[
@@ -176,7 +176,7 @@ async def test_ollama_async_chat(exporter):
176176
@pytest.mark.vcr
177177
@pytest.mark.asyncio
178178
async def test_ollama_async_streaming_chat(exporter):
179-
client = ollama.AsyncClient()
179+
client = AsyncClient()
180180
gen = await client.chat(
181181
model="llama3",
182182
messages=[

0 commit comments

Comments
 (0)