Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

_logger = logging.getLogger(__name__)

_INSTRUMENT_TAGS_KEY = "instrument_tags"

# ContextVar for managing active instrument tags
active_instrument_tags: ContextVar[Dict[str, Any]] = ContextVar(
"instrument_tags", default={}
Expand Down Expand Up @@ -115,6 +117,15 @@ def root(self) -> "Dispatcher":
assert self.manager is not None
return self.manager.dispatchers[self.root_name]

def _walk_span_handlers(self) -> Generator[BaseSpanHandler, None, None]:
"""Yield every span handler reachable via the propagation chain."""
c: Optional[Dispatcher] = self
while c:
yield from c.span_handlers
if not c.propagate:
break
c = c.parent

def add_event_handler(self, handler: BaseEventHandler) -> None:
"""Add handler to set of handlers."""
self.event_handlers += [handler]
Expand Down Expand Up @@ -261,6 +272,73 @@ def span_exit(
else:
c = c.parent

def capture_propagation_context(self) -> Dict[str, Any]:
"""
Capture trace propagation context from all registered span handlers.

Each span handler namespaces its data under its own key. The Dispatcher
also captures active instrument_tags. The returned dict can be serialized
and passed to restore_propagation_context() in another process.
"""
result: Dict[str, Any] = {}
for h in self._walk_span_handlers():
try:
result.update(h.capture_propagation_context())
except BaseException:
_logger.warning("Error capturing propagation context", exc_info=True)
tags = active_instrument_tags.get()
if tags:
result[_INSTRUMENT_TAGS_KEY] = dict(tags)
return result

def restore_propagation_context(self, context: Dict[str, Any]) -> None:
"""
Restore trace propagation context on all registered span handlers.

Also restores instrument_tags so that subsequent spans see them.
"""
for h in self._walk_span_handlers():
try:
h.restore_propagation_context(context)
except BaseException:
_logger.warning("Error restoring propagation context", exc_info=True)
tags = context.get(_INSTRUMENT_TAGS_KEY)
if tags:
active_instrument_tags.set(dict(tags))

def shutdown(self) -> None:
"""
Drop all open spans and close all handlers.

Walks the dispatcher parent chain (same as other span methods),
drops every open span on every handler, then calls close() on
each handler. Exceptions are swallowed to match existing convention.
"""
_synthetic_bound_args = inspect.signature(lambda: None).bind()
_shutdown_err = RuntimeError("dispatcher shutdown")

for h in self._walk_span_handlers():
# Drop all open spans — snapshot keys since span_drop mutates the dict
for span_id in list(h.open_spans.keys()):
try:
h.span_drop(
id_=span_id,
bound_args=_synthetic_bound_args,
instance=None,
err=_shutdown_err,
)
except BaseException:
_logger.debug(
"Error dropping span %s during shutdown",
span_id,
exc_info=True,
)
# Close the handler
try:
h.close()
except BaseException:
_logger.warning("Error closing handler %s", h, exc_info=True)

def span(self, func: Callable[..., _R]) -> Callable[..., _R]:
# The `span` decorator should be idempotent.
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,30 @@ def span_drop(
with self.lock:
del self.open_spans[id_]

def capture_propagation_context(self) -> Dict[str, Any]:
"""
Capture trace propagation context for serialization across process boundaries.

Returns a dict that can be serialized and passed to restore_propagation_context()
in another process to re-establish trace continuity.
"""
return {}

def restore_propagation_context(self, context: Dict[str, Any]) -> None:
"""
Restore trace propagation context received from another process.

Should be called BEFORE span_enter so that new spans parent correctly.
"""

def close(self) -> None:
"""
Optional cleanup hook called during dispatcher shutdown.

Subclasses can override to flush buffers, release resources, etc.
Default is a no-op.
"""

@abstractmethod
def new_span(
self,
Expand Down
186 changes: 186 additions & 0 deletions llama-index-instrumentation/tests/test_propagation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
from typing import Any, Dict

from llama_index_instrumentation.dispatcher import (
Dispatcher,
Manager,
_INSTRUMENT_TAGS_KEY,
active_instrument_tags,
)
from llama_index_instrumentation.span_handlers.simple import SimpleSpanHandler


class PropagatingHandler(SimpleSpanHandler):
"""Handler that captures/restores a fake trace context."""

def capture_propagation_context(self) -> Dict[str, Any]:
return {"test_handler": {"trace_id": "abc123", "span_id": "def456"}}

def restore_propagation_context(self, context: Dict[str, Any]) -> None:
self._restored_context = context


def test_capture_propagation_context_basic():
handler = PropagatingHandler()
d = Dispatcher(span_handlers=[handler], propagate=False)

ctx = d.capture_propagation_context()

assert ctx["test_handler"]["trace_id"] == "abc123"
assert ctx["test_handler"]["span_id"] == "def456"


def test_capture_includes_instrument_tags():
handler = SimpleSpanHandler()
d = Dispatcher(span_handlers=[handler], propagate=False)

token = active_instrument_tags.set({"user_id": "u1", "session": "s1"})
try:
ctx = d.capture_propagation_context()
finally:
active_instrument_tags.reset(token)

assert ctx[_INSTRUMENT_TAGS_KEY] == {"user_id": "u1", "session": "s1"}


def test_capture_omits_instrument_tags_when_empty():
handler = SimpleSpanHandler()
d = Dispatcher(span_handlers=[handler], propagate=False)

ctx = d.capture_propagation_context()

assert _INSTRUMENT_TAGS_KEY not in ctx


def test_restore_propagation_context_basic():
handler = PropagatingHandler()
d = Dispatcher(span_handlers=[handler], propagate=False)

context = {"test_handler": {"trace_id": "abc123"}}
d.restore_propagation_context(context)

assert handler._restored_context == context


def test_restore_sets_instrument_tags():
handler = SimpleSpanHandler()
d = Dispatcher(span_handlers=[handler], propagate=False)

d.restore_propagation_context({_INSTRUMENT_TAGS_KEY: {"user_id": "u1"}})

assert active_instrument_tags.get() == {"user_id": "u1"}
# cleanup
active_instrument_tags.set({})


def test_capture_walks_parent_chain():
parent_handler = PropagatingHandler()
child_handler = SimpleSpanHandler()

parent = Dispatcher(name="parent", span_handlers=[parent_handler], propagate=False)
child = Dispatcher(
name="child",
span_handlers=[child_handler],
propagate=True,
parent_name="parent",
)
manager = Manager(parent)
manager.add_dispatcher(child)
child.manager = manager
parent.manager = manager

ctx = child.capture_propagation_context()

# Should include parent handler's context via propagation
assert "test_handler" in ctx


def test_restore_walks_parent_chain():
parent_handler = PropagatingHandler()
child_handler = PropagatingHandler()

parent = Dispatcher(name="parent", span_handlers=[parent_handler], propagate=False)
child = Dispatcher(
name="child",
span_handlers=[child_handler],
propagate=True,
parent_name="parent",
)
manager = Manager(parent)
manager.add_dispatcher(child)
child.manager = manager
parent.manager = manager

context = {"test_handler": {"trace_id": "xyz"}}
child.restore_propagation_context(context)

assert child_handler._restored_context == context
assert parent_handler._restored_context == context


def test_capture_stops_at_propagate_false():
parent_handler = PropagatingHandler()
child_handler = SimpleSpanHandler()

parent = Dispatcher(name="parent", span_handlers=[parent_handler], propagate=False)
child = Dispatcher(
name="child",
span_handlers=[child_handler],
propagate=False, # does NOT propagate
parent_name="parent",
)
manager = Manager(parent)
manager.add_dispatcher(child)
child.manager = manager
parent.manager = manager

ctx = child.capture_propagation_context()

# Should NOT include parent handler's context
assert "test_handler" not in ctx


def test_roundtrip_capture_restore():
"""Capture from one dispatcher, restore on another — simulates cross-process."""
source_handler = PropagatingHandler()
source = Dispatcher(span_handlers=[source_handler], propagate=False)

token = active_instrument_tags.set({"env": "prod"})
try:
ctx = source.capture_propagation_context()
finally:
active_instrument_tags.reset(token)

dest_handler = PropagatingHandler()
dest = Dispatcher(span_handlers=[dest_handler], propagate=False)

dest.restore_propagation_context(ctx)

assert dest_handler._restored_context == ctx
assert active_instrument_tags.get() == {"env": "prod"}
# cleanup
active_instrument_tags.set({})


def test_capture_swallows_handler_exceptions():
class BrokenHandler(SimpleSpanHandler):
def capture_propagation_context(self) -> Dict[str, Any]:
raise RuntimeError("boom")

handler = BrokenHandler()
d = Dispatcher(span_handlers=[handler], propagate=False)

# Should not raise
ctx = d.capture_propagation_context()
assert isinstance(ctx, dict)


def test_restore_swallows_handler_exceptions():
class BrokenHandler(SimpleSpanHandler):
def restore_propagation_context(self, context: Dict[str, Any]) -> None:
raise RuntimeError("boom")

handler = BrokenHandler()
d = Dispatcher(span_handlers=[handler], propagate=False)

# Should not raise
d.restore_propagation_context({"some": "data"})
Loading