Skip to content

Commit 6c4a237

Browse files
committed
Use interceptors instead of forking open ai agents plugin
1 parent 1a4e622 commit 6c4a237

File tree

12 files changed

+489
-1341
lines changed

12 files changed

+489
-1341
lines changed

src/agentex/lib/core/clients/temporal/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from temporalio.runtime import Runtime, TelemetryConfig, OpenTelemetryConfig
77
from temporalio.contrib.pydantic import pydantic_data_converter
88
from temporalio.contrib.openai_agents import OpenAIAgentsPlugin
9+
from temporalio.worker import Interceptor
910

1011
# class DateTimeJSONEncoder(AdvancedJSONEncoder):
1112
# def default(self, o: Any) -> Any:
@@ -61,6 +62,24 @@ def validate_client_plugins(plugins: list[Any]) -> None:
6162
)
6263

6364

65+
def validate_worker_interceptors(interceptors: list[Any]) -> None:
66+
"""
67+
Validate that all items in the interceptors list are valid Temporal worker interceptors.
68+
69+
Args:
70+
interceptors: List of interceptors to validate
71+
72+
Raises:
73+
TypeError: If any interceptor is not a valid Interceptor instance
74+
"""
75+
for i, interceptor in enumerate(interceptors):
76+
if not isinstance(interceptor, Interceptor):
77+
raise TypeError(
78+
f"Interceptor at index {i} must be an instance of temporalio.worker.Interceptor, "
79+
f"got {type(interceptor).__name__}"
80+
)
81+
82+
6483
async def get_temporal_client(temporal_address: str, metrics_url: str | None = None, plugins: list[Any] = []) -> Client:
6584
"""
6685
Create a Temporal client with plugin integration.
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""OpenAI Agents SDK Temporal Plugin with Streaming Support.
2+
3+
This module provides streaming capabilities for the OpenAI Agents SDK in Temporal
4+
using interceptors to thread task_id through workflows to activities.
5+
6+
The streaming implementation works by:
7+
1. Using Temporal interceptors to thread task_id through the execution
8+
2. Streaming LLM responses to Redis in real-time from activities
9+
3. Returning complete responses to maintain Temporal determinism
10+
11+
Example:
12+
>>> from agentex.lib.core.temporal.plugins.openai_agents import (
13+
... StreamingModelProvider,
14+
... StreamingInterceptor,
15+
... )
16+
>>> from temporalio.contrib.openai_agents import OpenAIAgentsPlugin, ModelActivityParameters
17+
>>> from datetime import timedelta
18+
>>>
19+
>>> # Create streaming model provider
20+
>>> model_provider = StreamingModelProvider()
21+
>>>
22+
>>> # Create STANDARD plugin with streaming model provider
23+
>>> plugin = OpenAIAgentsPlugin(
24+
... model_params=ModelActivityParameters(
25+
... start_to_close_timeout=timedelta(seconds=120),
26+
... ),
27+
... model_provider=model_provider,
28+
... )
29+
>>>
30+
>>> # Register interceptor with worker
31+
>>> interceptor = StreamingInterceptor()
32+
>>> # Add interceptor to worker configuration
33+
"""
34+
35+
from agentex.lib.core.temporal.plugins.openai_agents import (
36+
StreamingModel,
37+
StreamingModelProvider,
38+
StreamingInterceptor,
39+
streaming_task_id,
40+
streaming_trace_id,
41+
streaming_parent_span_id,
42+
TemporalStreamingHooks,
43+
stream_lifecycle_content,
44+
)
45+
46+
__all__ = [
47+
"StreamingModel",
48+
"StreamingModelProvider",
49+
"StreamingInterceptor",
50+
"streaming_task_id",
51+
"streaming_trace_id",
52+
"streaming_parent_span_id",
53+
"TemporalStreamingHooks",
54+
"stream_lifecycle_content",
55+
]

src/agentex/lib/core/temporal/plugins/openai_agents/README.md

Lines changed: 189 additions & 378 deletions
Large diffs are not rendered by default.
Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,65 @@
11
"""OpenAI Agents SDK Temporal Plugin with Streaming Support.
22
3-
This module provides a custom Temporal plugin for the OpenAI Agents SDK that adds
4-
real-time streaming capabilities while maintaining Temporal's durability guarantees.
3+
This module provides streaming capabilities for the OpenAI Agents SDK in Temporal
4+
using interceptors to thread task_id through workflows to activities.
55
66
The streaming implementation works by:
7-
1. Threading a task_id through the workflow execution
7+
1. Using Temporal interceptors to thread task_id through the execution
88
2. Streaming LLM responses to Redis in real-time from activities
99
3. Streaming lifecycle events (tool calls, handoffs) via hooks and activities
1010
4. Returning complete responses to maintain Temporal determinism
1111
1212
Example - Complete Setup:
1313
>>> from agentex.lib.core.temporal.plugins.openai_agents import (
14-
... CustomStreamingOpenAIAgentsPlugin,
1514
... StreamingModelProvider,
1615
... TemporalStreamingHooks,
16+
... StreamingInterceptor,
1717
... )
18-
>>> from temporalio.contrib.openai_agents import ModelActivityParameters
18+
>>> from temporalio.contrib.openai_agents import OpenAIAgentsPlugin, ModelActivityParameters
1919
>>> from datetime import timedelta
2020
>>> from agents import Agent, Runner
2121
>>>
2222
>>> # 1. Create streaming model provider
2323
>>> model_provider = StreamingModelProvider()
2424
>>>
25-
>>> # 2. Create plugin with streaming support
26-
>>> plugin = CustomStreamingOpenAIAgentsPlugin(
25+
>>> # 2. Create STANDARD plugin with streaming model provider
26+
>>> plugin = OpenAIAgentsPlugin(
2727
... model_params=ModelActivityParameters(
2828
... start_to_close_timeout=timedelta(seconds=120),
2929
... ),
3030
... model_provider=model_provider,
3131
... )
3232
>>>
33-
>>> # 3. In workflow, create hooks for streaming lifecycle events
34-
>>> # The hooks automatically use the built-in streaming activity
33+
>>> # 3. Register interceptor with worker
34+
>>> interceptor = StreamingInterceptor()
35+
>>> # Add interceptor to worker configuration
36+
>>>
37+
>>> # 4. In workflow, store task_id in instance variable
38+
>>> self._task_id = params.task.id
39+
>>>
40+
>>> # 5. Create hooks for streaming lifecycle events
3541
>>> hooks = TemporalStreamingHooks(task_id="your-task-id")
3642
>>>
37-
>>> # 4. Run agent with streaming context and hooks
38-
>>> context = {"task_id": "your-task-id"}
39-
>>> result = await Runner.run(agent, input, context=context, hooks=hooks)
43+
>>> # 6. Run agent - interceptor handles task_id threading automatically
44+
>>> result = await Runner.run(agent, input, hooks=hooks)
4045
4146
This gives you:
42-
- Real-time streaming of LLM responses (via StreamingModel)
47+
- Real-time streaming of LLM responses (via StreamingModel + interceptors)
4348
- Real-time streaming of tool calls (via TemporalStreamingHooks)
4449
- Real-time streaming of agent handoffs (via TemporalStreamingHooks)
4550
- Full Temporal durability and observability
51+
- No forked plugin required - uses standard OpenAIAgentsPlugin
4652
"""
4753

4854
from agentex.lib.core.temporal.plugins.openai_agents.streaming_model import (
4955
StreamingModel,
5056
StreamingModelProvider,
5157
)
52-
from agentex.lib.core.temporal.plugins.openai_agents.streaming_plugin import (
53-
CustomStreamingOpenAIAgentsPlugin,
54-
StreamingModelActivity,
55-
StreamingActivityModelInput,
56-
)
57-
from agentex.lib.core.temporal.plugins.openai_agents.streaming_runner import (
58-
StreamingTemporalRunner,
59-
StreamingTemporalModelStub,
58+
from agentex.lib.core.temporal.plugins.openai_agents.streaming_interceptor import (
59+
StreamingInterceptor,
60+
streaming_task_id,
61+
streaming_trace_id,
62+
streaming_parent_span_id,
6063
)
6164
from agentex.lib.core.temporal.plugins.openai_agents.hooks import (
6265
TemporalStreamingHooks,
@@ -66,13 +69,12 @@
6669
)
6770

6871
__all__ = [
69-
"CustomStreamingOpenAIAgentsPlugin",
7072
"StreamingModel",
7173
"StreamingModelProvider",
72-
"StreamingModelActivity",
73-
"StreamingActivityModelInput",
74-
"StreamingTemporalRunner",
75-
"StreamingTemporalModelStub",
74+
"StreamingInterceptor",
75+
"streaming_task_id",
76+
"streaming_trace_id",
77+
"streaming_parent_span_id",
7678
"TemporalStreamingHooks",
7779
"stream_lifecycle_content",
7880
]
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""
2+
Simple Temporal interceptors for threading task_id to enable streaming.
3+
4+
This module provides minimal interceptors to pass task_id from workflows
5+
to activities via headers, making it available to the StreamingModel.
6+
"""
7+
8+
from contextvars import ContextVar
9+
from typing import Optional, Any, Type
10+
import logging
11+
12+
from temporalio import workflow
13+
from temporalio.worker import (
14+
Interceptor,
15+
WorkflowInboundInterceptor,
16+
WorkflowOutboundInterceptor,
17+
ActivityInboundInterceptor,
18+
ExecuteWorkflowInput,
19+
StartActivityInput,
20+
ExecuteActivityInput,
21+
)
22+
from temporalio.converter import default
23+
24+
# Set up logging
25+
logger = logging.getLogger("streaming.interceptor")
26+
27+
# Global context variable that StreamingModel will read
28+
# This is thread-safe and works across async boundaries
29+
streaming_task_id: ContextVar[Optional[str]] = ContextVar('streaming_task_id', default=None)
30+
streaming_trace_id: ContextVar[Optional[str]] = ContextVar('streaming_trace_id', default=None)
31+
streaming_parent_span_id: ContextVar[Optional[str]] = ContextVar('streaming_parent_span_id', default=None)
32+
# Header key for passing task_id
33+
TASK_ID_HEADER = "streaming-task-id"
34+
TRACE_ID_HEADER = "trace-id"
35+
PARENT_SPAN_ID_HEADER = "parent-span-id"
36+
37+
class StreamingInterceptor(Interceptor):
38+
"""Main interceptor that enables task_id threading."""
39+
40+
def __init__(self):
41+
self._payload_converter = default().payload_converter
42+
logger.info("[StreamingInterceptor] Initialized")
43+
44+
def intercept_activity(self, next: ActivityInboundInterceptor) -> ActivityInboundInterceptor:
45+
"""Create activity interceptor to read task_id from headers."""
46+
return StreamingActivityInboundInterceptor(next, self._payload_converter)
47+
48+
def workflow_interceptor_class(self, input: Any) -> Optional[Type[WorkflowInboundInterceptor]]:
49+
"""Return workflow interceptor class."""
50+
return StreamingWorkflowInboundInterceptor
51+
52+
53+
class StreamingWorkflowInboundInterceptor(WorkflowInboundInterceptor):
54+
"""Workflow interceptor that creates the outbound interceptor."""
55+
56+
def __init__(self, next: WorkflowInboundInterceptor):
57+
super().__init__(next)
58+
self._payload_converter = default().payload_converter
59+
60+
async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any:
61+
"""Execute workflow - just pass through."""
62+
return await self.next.execute_workflow(input)
63+
64+
def init(self, outbound: WorkflowOutboundInterceptor) -> None:
65+
"""Initialize with our custom outbound interceptor."""
66+
self.next.init(StreamingWorkflowOutboundInterceptor(
67+
outbound, self._payload_converter
68+
))
69+
70+
71+
class StreamingWorkflowOutboundInterceptor(WorkflowOutboundInterceptor):
72+
"""Outbound interceptor that adds task_id to activity headers."""
73+
74+
def __init__(self, next, payload_converter):
75+
super().__init__(next)
76+
self._payload_converter = payload_converter
77+
78+
def start_activity(self, input: StartActivityInput) -> workflow.ActivityHandle:
79+
"""Add task_id, trace_id, and parent_span_id to headers when starting model activities."""
80+
81+
# Only add headers for invoke_model_activity calls
82+
activity_name = str(input.activity) if hasattr(input, 'activity') else ""
83+
84+
if "invoke_model_activity" in activity_name or "invoke-model-activity" in activity_name:
85+
# Get task_id, trace_id, and parent_span_id from workflow instance instead of inbound interceptor
86+
try:
87+
workflow_instance = workflow.instance()
88+
task_id = getattr(workflow_instance, '_task_id', None)
89+
trace_id = getattr(workflow_instance, '_trace_id', None)
90+
parent_span_id = getattr(workflow_instance, '_parent_span_id', None)
91+
92+
if task_id and trace_id and parent_span_id:
93+
# Initialize headers if needed
94+
if not input.headers:
95+
input.headers = {}
96+
97+
# Add task_id to headers
98+
input.headers[TASK_ID_HEADER] = self._payload_converter.to_payload(task_id)
99+
input.headers[TRACE_ID_HEADER] = self._payload_converter.to_payload(trace_id)
100+
input.headers[PARENT_SPAN_ID_HEADER] = self._payload_converter.to_payload(parent_span_id)
101+
logger.debug(f"[OutboundInterceptor] Added task_id, trace_id, and parent_span_id to activity headers: {task_id}, {trace_id}, {parent_span_id}")
102+
else:
103+
logger.warning("[OutboundInterceptor] No _task_id, _trace_id, or _parent_span_id found in workflow instance")
104+
except Exception as e:
105+
logger.error(f"[OutboundInterceptor] Failed to get task_id, trace_id, or parent_span_id from workflow instance: {e}")
106+
107+
return self.next.start_activity(input)
108+
109+
110+
class StreamingActivityInboundInterceptor(ActivityInboundInterceptor):
111+
"""Activity interceptor that extracts task_id, trace_id, and parent_span_id from headers and sets context variables."""
112+
113+
def __init__(self, next, payload_converter):
114+
super().__init__(next)
115+
self._payload_converter = payload_converter
116+
117+
async def execute_activity(self, input: ExecuteActivityInput) -> Any:
118+
"""Extract task_id, trace_id, and parent_span_id from headers and set context variables."""
119+
120+
# Extract task_id from headers if present
121+
if input.headers and TASK_ID_HEADER in input.headers:
122+
task_id_value = self._payload_converter.from_payload(
123+
input.headers[TASK_ID_HEADER], str
124+
)
125+
trace_id_value = self._payload_converter.from_payload(
126+
input.headers[TRACE_ID_HEADER], str
127+
)
128+
parent_span_id_value = self._payload_converter.from_payload(
129+
input.headers[PARENT_SPAN_ID_HEADER], str
130+
)
131+
132+
# P THIS IS THE KEY PART - Set the context variable!
133+
# This makes task_id available to StreamingModel.get_response()
134+
streaming_task_id.set(task_id_value)
135+
streaming_trace_id.set(trace_id_value)
136+
streaming_parent_span_id.set(parent_span_id_value)
137+
logger.info(f"[ActivityInterceptor] Set task_id, trace_id, and parent_span_id in context: {task_id_value}, {trace_id_value}, {parent_span_id_value}")
138+
else:
139+
logger.debug("[ActivityInterceptor] No task_id, trace_id, or parent_span_id in headers")
140+
141+
try:
142+
# Execute the activity
143+
# The StreamingModel can now read streaming_task_id.get()
144+
result = await self.next.execute_activity(input)
145+
return result
146+
finally:
147+
# Clean up context after activity
148+
streaming_task_id.set(None)
149+
streaming_trace_id.set(None)
150+
streaming_parent_span_id.set(None)
151+
logger.debug("[ActivityInterceptor] Cleared task_id, trace_id, and parent_span_id from context")
152+

src/agentex/lib/core/temporal/plugins/openai_agents/streaming_model.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@
5959
from agentex.types.task_message_delta import ReasoningSummaryDelta, TextDelta, ReasoningContentDelta
6060
from agentex.types.task_message_content import TextContent, ReasoningContent
6161
from agentex.types.task_message_update import StreamTaskMessageDelta, StreamTaskMessageFull
62-
62+
from agentex.lib.core.temporal.plugins.openai_agents.streaming_interceptor import streaming_task_id
63+
from agentex.lib.core.temporal.plugins.openai_agents.streaming_interceptor import streaming_trace_id
64+
from agentex.lib.core.temporal.plugins.openai_agents.streaming_interceptor import streaming_parent_span_id
6365
# Create logger for this module
6466
logger = logging.getLogger("agentex.temporal.streaming")
6567

@@ -383,20 +385,15 @@ async def get_response(
383385
This method is used by Temporal activities and needs to return a complete
384386
response, but we stream the response to Redis while generating it.
385387
"""
386-
# Get optional parameters from kwargs
387-
trace_id = kwargs.get('trace_id')
388-
parent_span_id = kwargs.get('parent_span_id')
389-
390-
# Use the class's tracer instance
391-
tracer = self.tracer
388+
389+
task_id = streaming_task_id.get()
390+
trace_id = streaming_trace_id.get()
391+
parent_span_id = streaming_parent_span_id.get()
392392

393-
# Use task_id as trace_id if not provided
394-
task_id = kwargs.get('task_id')
395-
if trace_id is None:
396-
trace_id = task_id
397-
logger.info(f"[StreamingModel] Using task_id as trace_id: {trace_id}")
393+
if not task_id or not trace_id or not parent_span_id:
394+
raise ValueError("task_id, trace_id, and parent_span_id are required for streaming with Responses API")
398395

399-
trace = tracer.trace(trace_id)
396+
trace = self.tracer.trace(trace_id)
400397

401398
async with trace.span(
402399
parent_id=parent_span_id,

0 commit comments

Comments
 (0)