Skip to content

Commit e510980

Browse files
committed
fix(openai): broken traces when using openai-guardrails
1 parent 5bd93d9 commit e510980

File tree

6 files changed

+1793
-98
lines changed

6 files changed

+1793
-98
lines changed

packages/opentelemetry-instrumentation-openai/opentelemetry/instrumentation/openai/v1/responses_wrappers.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,12 @@ class TracedData(pydantic.BaseModel):
139139
request_service_tier: Optional[str] = pydantic.Field(default=None)
140140
response_service_tier: Optional[str] = pydantic.Field(default=None)
141141

142+
# Trace context - to maintain trace continuity across async operations
143+
trace_context: Any = pydantic.Field(default=None)
144+
145+
class Config:
146+
arbitrary_types_allowed = True
147+
142148

143149
responses: dict[str, TracedData] = {}
144150

@@ -509,16 +515,22 @@ def responses_get_or_create_wrapper(tracer: Tracer, wrapped, instance, args, kwa
509515
response_reasoning_effort=kwargs.get("reasoning", {}).get("effort"),
510516
request_service_tier=kwargs.get("service_tier"),
511517
response_service_tier=existing_data.get("response_service_tier"),
518+
# Capture trace context to maintain continuity
519+
trace_context=existing_data.get("trace_context", context_api.get_current()),
512520
)
513521
except Exception:
514522
traced_data = None
515523

524+
# Restore the original trace context to maintain trace continuity
525+
ctx = (traced_data.trace_context if traced_data and traced_data.trace_context
526+
else context_api.get_current())
516527
span = tracer.start_span(
517528
SPAN_NAME,
518529
kind=SpanKind.CLIENT,
519530
start_time=(
520531
start_time if traced_data is None else int(traced_data.start_time)
521532
),
533+
context=ctx,
522534
)
523535
_set_request_attributes(span, prepare_kwargs_for_shared_attributes(kwargs), instance)
524536
span.set_attribute(ERROR_TYPE, e.__class__.__name__)
@@ -575,16 +587,21 @@ def responses_get_or_create_wrapper(tracer: Tracer, wrapped, instance, args, kwa
575587
response_reasoning_effort=kwargs.get("reasoning", {}).get("effort"),
576588
request_service_tier=existing_data.get("request_service_tier", kwargs.get("service_tier")),
577589
response_service_tier=existing_data.get("response_service_tier", parsed_response.service_tier),
590+
# Capture trace context to maintain continuity across async operations
591+
trace_context=existing_data.get("trace_context", context_api.get_current()),
578592
)
579593
responses[parsed_response.id] = traced_data
580594
except Exception:
581595
return response
582596

583597
if parsed_response.status == "completed":
598+
# Restore the original trace context to maintain trace continuity
599+
ctx = traced_data.trace_context if traced_data.trace_context else context_api.get_current()
584600
span = tracer.start_span(
585601
SPAN_NAME,
586602
kind=SpanKind.CLIENT,
587603
start_time=int(traced_data.start_time),
604+
context=ctx,
588605
)
589606
_set_request_attributes(span, prepare_kwargs_for_shared_attributes(kwargs), instance)
590607
set_data_attributes(traced_data, span)
@@ -654,16 +671,22 @@ async def async_responses_get_or_create_wrapper(
654671
response_reasoning_effort=kwargs.get("reasoning", {}).get("effort"),
655672
request_service_tier=kwargs.get("service_tier"),
656673
response_service_tier=existing_data.get("response_service_tier"),
674+
# Capture trace context to maintain continuity
675+
trace_context=existing_data.get("trace_context", context_api.get_current()),
657676
)
658677
except Exception:
659678
traced_data = None
660679

680+
# Restore the original trace context to maintain trace continuity
681+
ctx = (traced_data.trace_context if traced_data and traced_data.trace_context
682+
else context_api.get_current())
661683
span = tracer.start_span(
662684
SPAN_NAME,
663685
kind=SpanKind.CLIENT,
664686
start_time=(
665687
start_time if traced_data is None else int(traced_data.start_time)
666688
),
689+
context=ctx,
667690
)
668691
_set_request_attributes(span, prepare_kwargs_for_shared_attributes(kwargs), instance)
669692
span.set_attribute(ERROR_TYPE, e.__class__.__name__)
@@ -721,16 +744,21 @@ async def async_responses_get_or_create_wrapper(
721744
response_reasoning_effort=kwargs.get("reasoning", {}).get("effort"),
722745
request_service_tier=existing_data.get("request_service_tier", kwargs.get("service_tier")),
723746
response_service_tier=existing_data.get("response_service_tier", parsed_response.service_tier),
747+
# Capture trace context to maintain continuity across async operations
748+
trace_context=existing_data.get("trace_context", context_api.get_current()),
724749
)
725750
responses[parsed_response.id] = traced_data
726751
except Exception:
727752
return response
728753

729754
if parsed_response.status == "completed":
755+
# Restore the original trace context to maintain trace continuity
756+
ctx = traced_data.trace_context if traced_data.trace_context else context_api.get_current()
730757
span = tracer.start_span(
731758
SPAN_NAME,
732759
kind=SpanKind.CLIENT,
733760
start_time=int(traced_data.start_time),
761+
context=ctx,
734762
)
735763
_set_request_attributes(span, prepare_kwargs_for_shared_attributes(kwargs), instance)
736764
set_data_attributes(traced_data, span)
@@ -751,11 +779,14 @@ def responses_cancel_wrapper(tracer: Tracer, wrapped, instance, args, kwargs):
751779
parsed_response = parse_response(response)
752780
existing_data = responses.pop(parsed_response.id, None)
753781
if existing_data is not None:
782+
# Restore the original trace context to maintain trace continuity
783+
ctx = existing_data.trace_context if existing_data.trace_context else context_api.get_current()
754784
span = tracer.start_span(
755785
SPAN_NAME,
756786
kind=SpanKind.CLIENT,
757787
start_time=existing_data.start_time,
758788
record_exception=True,
789+
context=ctx,
759790
)
760791
_set_request_attributes(span, prepare_kwargs_for_shared_attributes(kwargs), instance)
761792
span.record_exception(Exception("Response cancelled"))
@@ -778,11 +809,14 @@ async def async_responses_cancel_wrapper(
778809
parsed_response = parse_response(response)
779810
existing_data = responses.pop(parsed_response.id, None)
780811
if existing_data is not None:
812+
# Restore the original trace context to maintain trace continuity
813+
ctx = existing_data.trace_context if existing_data.trace_context else context_api.get_current()
781814
span = tracer.start_span(
782815
SPAN_NAME,
783816
kind=SpanKind.CLIENT,
784817
start_time=existing_data.start_time,
785818
record_exception=True,
819+
context=ctx,
786820
)
787821
_set_request_attributes(span, prepare_kwargs_for_shared_attributes(kwargs), instance)
788822
span.record_exception(Exception("Response cancelled"))

packages/opentelemetry-instrumentation-openai/tests/traces/test_responses.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,3 +505,83 @@ def test_response_stream_init_with_none_tools():
505505
assert stream._traced_data is not None
506506
# Tools should be an empty list, not None
507507
assert stream._traced_data.tools == [] or stream._traced_data.tools is None
508+
509+
510+
def test_responses_trace_context_propagation_unit():
511+
"""Unit test for trace context propagation in responses API.
512+
513+
This test verifies that when TracedData is created with a trace context,
514+
and later a span is created from that TracedData, the span uses the correct
515+
trace context that was captured at creation time.
516+
517+
This is critical for guardrails and other wrappers that make multiple calls
518+
across different execution contexts.
519+
520+
Note: This is a unit test that simulates what guardrails does. For integration
521+
testing with the actual openai-guardrails library, see the sample app at:
522+
packages/sample-app/sample_app/openai_guardrails_example.py
523+
"""
524+
from unittest.mock import MagicMock, Mock
525+
from opentelemetry import trace, context
526+
from opentelemetry.sdk.trace import TracerProvider
527+
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
528+
from opentelemetry.instrumentation.openai.v1.responses_wrappers import TracedData
529+
from openai.types.responses import Response, ResponseOutputItem, ResponseUsage
530+
import time
531+
532+
# Set up tracing
533+
provider = TracerProvider()
534+
exporter = InMemorySpanExporter()
535+
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
536+
provider.add_span_processor(SimpleSpanProcessor(exporter))
537+
trace.set_tracer_provider(provider)
538+
tracer = trace.get_tracer(__name__)
539+
540+
# Create a parent span and capture its trace context
541+
with tracer.start_as_current_span("parent-span") as parent_span:
542+
parent_trace_id = parent_span.get_span_context().trace_id
543+
parent_context = context.get_current()
544+
545+
# Create TracedData with the current trace context (simulating responses.create)
546+
traced_data = TracedData(
547+
start_time=time.time_ns(),
548+
response_id="test-response-id",
549+
input="What is 2+2?",
550+
instructions=None,
551+
tools=None,
552+
output_blocks={},
553+
usage=None,
554+
output_text="4",
555+
request_model="gpt-4.1-nano",
556+
response_model="gpt-4.1-nano-2025-04-14",
557+
trace_context=parent_context,
558+
)
559+
560+
# Now we're outside the parent span context
561+
# Simulate creating a span with the stored trace context (like responses.retrieve does)
562+
ctx = traced_data.trace_context
563+
span = tracer.start_span(
564+
"openai.response",
565+
context=ctx,
566+
start_time=traced_data.start_time,
567+
)
568+
span.end()
569+
570+
# Verify the span has the correct trace context
571+
spans = exporter.get_finished_spans()
572+
parent_spans = [s for s in spans if s.name == "parent-span"]
573+
openai_spans = [s for s in spans if s.name == "openai.response"]
574+
575+
assert len(parent_spans) == 1
576+
assert len(openai_spans) == 1
577+
578+
# The openai.response span should have the same trace_id as the parent
579+
assert openai_spans[0].context.trace_id == parent_trace_id, (
580+
f"openai.response span trace_id ({openai_spans[0].context.trace_id}) "
581+
f"should match parent trace_id ({parent_trace_id})"
582+
)
583+
584+
# The openai.response span should be a child of the parent span
585+
assert openai_spans[0].parent.span_id == parent_spans[0].context.span_id, (
586+
"openai.response span should be a child of parent-span"
587+
)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.9.5
1+
3.11

0 commit comments

Comments
 (0)