Skip to content

Commit 168236a

Browse files
fix(langchain): ensure llm spans are created for sync cases (#3201)
Co-authored-by: Abhishek Rao <[email protected]>
1 parent 7d529be commit 168236a

File tree

13 files changed

+1251
-244
lines changed

13 files changed

+1251
-244
lines changed

packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/__init__.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import logging
44
from typing import Collection
55

6+
from opentelemetry import context as context_api
7+
8+
69
from opentelemetry._events import get_event_logger
710
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
811
from opentelemetry.instrumentation.langchain.callback_handler import (
@@ -13,7 +16,7 @@
1316
from opentelemetry.instrumentation.langchain.version import __version__
1417
from opentelemetry.instrumentation.utils import unwrap
1518
from opentelemetry.metrics import get_meter
16-
from opentelemetry.semconv_ai import Meters
19+
from opentelemetry.semconv_ai import Meters, SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY
1720
from opentelemetry.trace import get_tracer
1821
from opentelemetry.trace.propagation import set_span_in_context
1922
from opentelemetry.trace.propagation.tracecontext import (
@@ -183,8 +186,8 @@ def _uninstrument(self, **kwargs):
183186

184187

185188
class _BaseCallbackManagerInitWrapper:
186-
def __init__(self, callback_manager: "TraceloopCallbackHandler"):
187-
self._callback_manager = callback_manager
189+
def __init__(self, callback_handler: "TraceloopCallbackHandler"):
190+
self._callback_handler = callback_handler
188191

189192
def __call__(
190193
self,
@@ -195,10 +198,14 @@ def __call__(
195198
) -> None:
196199
wrapped(*args, **kwargs)
197200
for handler in instance.inheritable_handlers:
198-
if isinstance(handler, type(self._callback_manager)):
201+
if isinstance(handler, type(self._callback_handler)):
199202
break
200203
else:
201-
instance.add_handler(self._callback_manager, True)
204+
# Add a property to the handler which indicates the CallbackManager instance.
205+
# Since the CallbackHandler only propagates context for sync callbacks,
206+
# we need a way to determine the type of CallbackManager being wrapped.
207+
self._callback_handler._callback_manager = instance
208+
instance.add_handler(self._callback_handler, True)
202209

203210

204211
# This class wraps a function call to inject tracing information (trace headers) into
@@ -233,4 +240,10 @@ def __call__(
233240
# Update kwargs to include the modified headers
234241
kwargs["extra_headers"] = extra_headers
235242

243+
# In legacy chains like LLMChain, suppressing model instrumentations
244+
# within create_llm_span doesn't work, so this should helps as a fallback
245+
context_api.attach(
246+
context_api.set_value(SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY, True)
247+
)
248+
236249
return wrapped(*args, **kwargs)

packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/callback_handler.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from langchain_core.callbacks import (
77
BaseCallbackHandler,
8+
CallbackManager,
9+
AsyncCallbackManager,
810
)
911
from langchain_core.messages import (
1012
AIMessage,
@@ -85,19 +87,6 @@ def _extract_class_name_from_serialized(serialized: Optional[dict[str, Any]]) ->
8587
return ""
8688

8789

88-
def _message_type_to_role(message_type: str) -> str:
89-
if message_type == "human":
90-
return "user"
91-
elif message_type == "system":
92-
return "system"
93-
elif message_type == "ai":
94-
return "assistant"
95-
elif message_type == "tool":
96-
return "tool"
97-
else:
98-
return "unknown"
99-
100-
10190
def _sanitize_metadata_value(value: Any) -> Any:
10291
"""Convert metadata values to OpenTelemetry-compatible types."""
10392
if value is None:
@@ -163,6 +152,7 @@ def __init__(
163152
self.token_histogram = token_histogram
164153
self.spans: dict[UUID, SpanHolder] = {}
165154
self.run_inline = True
155+
self._callback_manager: CallbackManager | AsyncCallbackManager = None
166156

167157
@staticmethod
168158
def _get_name_from_callback(
@@ -192,6 +182,9 @@ def _end_span(self, span: Span, run_id: UUID) -> None:
192182
if child_span.end_time is None: # avoid warning on ended spans
193183
child_span.end()
194184
span.end()
185+
token = self.spans[run_id].token
186+
if token:
187+
context_api.detach(token)
195188

196189
def _create_span(
197190
self,
@@ -230,13 +223,17 @@ def _create_span(
230223
else:
231224
span = self.tracer.start_span(span_name, kind=kind)
232225

226+
token = None
227+
# TODO: make this unconditional once attach/detach works properly with async callbacks.
228+
# Currently, it doesn't work due to this - https://github.com/langchain-ai/langchain/issues/31398
229+
# As a sidenote, OTel Python users also report similar issues -
230+
# https://github.com/open-telemetry/opentelemetry-python/issues/2606
231+
if self._callback_manager and not self._callback_manager.is_async:
232+
token = context_api.attach(set_span_in_context(span))
233+
233234
_set_span_attribute(span, SpanAttributes.TRACELOOP_WORKFLOW_NAME, workflow_name)
234235
_set_span_attribute(span, SpanAttributes.TRACELOOP_ENTITY_PATH, entity_path)
235236

236-
token = context_api.attach(
237-
context_api.set_value(SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY, True)
238-
)
239-
240237
self.spans[run_id] = SpanHolder(
241238
span, token, None, [], workflow_name, entity_name, entity_path
242239
)
@@ -300,6 +297,16 @@ def _create_llm_span(
300297
_set_span_attribute(span, SpanAttributes.LLM_SYSTEM, vendor)
301298
_set_span_attribute(span, SpanAttributes.LLM_REQUEST_TYPE, request_type.value)
302299

300+
# we already have an LLM span by this point,
301+
# so skip any downstream instrumentation from here
302+
token = context_api.attach(
303+
context_api.set_value(SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY, True)
304+
)
305+
306+
self.spans[run_id] = SpanHolder(
307+
span, token, None, [], workflow_name, None, entity_path
308+
)
309+
303310
return span
304311

305312
@dont_throw
@@ -464,7 +471,7 @@ def on_llm_end(
464471
"model_name"
465472
) or response.llm_output.get("model_id")
466473
if model_name is not None:
467-
_set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, model_name)
474+
_set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, model_name or "unknown")
468475

469476
if self.spans[run_id].request_model is None:
470477
_set_span_attribute(

0 commit comments

Comments
 (0)