Skip to content

Commit c34e7a6

Browse files
committed
Merge remote-tracking branch 'origin/main' into openai/mcp
2 parents 783eb9e + 5080b68 commit c34e7a6

File tree

5 files changed

+307
-45
lines changed

5 files changed

+307
-45
lines changed

temporalio/contrib/openai_agents/_model_parameters.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,32 @@
11
"""Parameters for configuring Temporal activity execution for model calls."""
22

3+
from abc import ABC, abstractmethod
34
from dataclasses import dataclass
45
from datetime import timedelta
5-
from typing import Optional
6+
from typing import Any, Callable, Optional, Union
7+
8+
from agents import Agent, TResponseInputItem
69

710
from temporalio.common import Priority, RetryPolicy
811
from temporalio.workflow import ActivityCancellationType, VersioningIntent
912

1013

14+
class ModelSummaryProvider(ABC):
15+
"""Abstract base class for providing model summaries. Essentially just a callable,
16+
but the arguments are sufficiently complex to benefit from names.
17+
"""
18+
19+
@abstractmethod
20+
def provide(
21+
self,
22+
agent: Optional[Agent[Any]],
23+
instructions: Optional[str],
24+
input: Union[str, list[TResponseInputItem]],
25+
) -> str:
26+
"""Given the provided information, produce a summary for the model invocation activity."""
27+
pass
28+
29+
1130
@dataclass
1231
class ModelActivityParameters:
1332
"""Parameters for configuring Temporal activity execution for model calls.
@@ -41,7 +60,12 @@ class ModelActivityParameters:
4160
versioning_intent: Optional[VersioningIntent] = None
4261
"""Versioning intent for the activity."""
4362

44-
summary_override: Optional[str] = None
63+
summary_override: Optional[
64+
Union[
65+
str,
66+
ModelSummaryProvider,
67+
]
68+
] = None
4569
"""Summary for the activity execution."""
4670

4771
priority: Priority = Priority.default

temporalio/contrib/openai_agents/_openai_runner.py

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
import dataclasses
12
import json
23
import typing
3-
import warnings
4-
from dataclasses import replace
5-
from typing import Any, Union
4+
from typing import Any, Optional, Union
65

76
from agents import (
87
Agent,
8+
Handoff,
99
RunConfig,
10+
RunContextWrapper,
1011
RunResult,
1112
RunResultStreaming,
1213
SQLiteSession,
@@ -91,26 +92,70 @@ async def run(
9192
if run_config is None:
9293
run_config = RunConfig()
9394

94-
model_name = run_config.model or starting_agent.model
95-
if model_name is not None and not isinstance(model_name, str):
96-
raise ValueError(
97-
"Temporal workflows require a model name to be a string in the run config and/or agent."
95+
if run_config.model:
96+
if not isinstance(run_config.model, str):
97+
raise ValueError(
98+
"Temporal workflows require a model name to be a string in the run config."
99+
)
100+
run_config = dataclasses.replace(
101+
run_config,
102+
model=_TemporalModelStub(
103+
run_config.model, model_params=self.model_params, agent=None
104+
),
105+
)
106+
107+
# Recursively replace models in all agents
108+
def convert_agent(agent: Agent[Any], seen: Optional[set[int]]) -> Agent[Any]:
109+
if seen is None:
110+
seen = set()
111+
112+
# Short circuit if this model was already seen to prevent looping from circular handoffs
113+
if id(agent) in seen:
114+
return agent
115+
seen.add(id(agent))
116+
117+
# This agent has already been processed in some other run
118+
if isinstance(agent.model, _TemporalModelStub):
119+
return agent
120+
121+
name = _model_name(agent)
122+
123+
new_handoffs: list[Union[Agent, Handoff]] = []
124+
for handoff in agent.handoffs:
125+
if isinstance(handoff, Agent):
126+
new_handoffs.append(convert_agent(handoff, seen))
127+
elif isinstance(handoff, Handoff):
128+
original_invoke = handoff.on_invoke_handoff
129+
130+
async def on_invoke(
131+
context: RunContextWrapper[Any], args: str
132+
) -> Agent:
133+
handoff_agent = await original_invoke(context, args)
134+
return convert_agent(handoff_agent, seen)
135+
136+
new_handoffs.append(
137+
dataclasses.replace(handoff, on_invoke_handoff=on_invoke)
138+
)
139+
else:
140+
raise ValueError(f"Unknown handoff type: {type(handoff)}")
141+
142+
return dataclasses.replace(
143+
agent,
144+
model=_TemporalModelStub(
145+
model_name=name,
146+
model_params=self.model_params,
147+
agent=agent,
148+
),
149+
handoffs=new_handoffs,
98150
)
99-
updated_run_config = replace(
100-
run_config,
101-
model=_TemporalModelStub(
102-
model_name=model_name,
103-
model_params=self.model_params,
104-
),
105-
)
106151

107152
return await self._runner.run(
108-
starting_agent=starting_agent,
153+
starting_agent=convert_agent(starting_agent, None),
109154
input=input,
110155
context=context,
111156
max_turns=max_turns,
112157
hooks=hooks,
113-
run_config=updated_run_config,
158+
run_config=run_config,
114159
previous_response_id=previous_response_id,
115160
session=session,
116161
)
@@ -144,3 +189,12 @@ def run_streamed(
144189
**kwargs,
145190
)
146191
raise RuntimeError("Temporal workflows do not support streaming.")
192+
193+
194+
def _model_name(agent: Agent[Any]) -> Optional[str]:
195+
name = agent.model
196+
if name is not None and not isinstance(name, str):
197+
raise ValueError(
198+
"Temporal workflows require a model name to be a string in the agent."
199+
)
200+
return name

temporalio/contrib/openai_agents/_temporal_model_stub.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Any, AsyncIterator, Union, cast
1212

1313
from agents import (
14+
Agent,
1415
AgentOutputSchema,
1516
AgentOutputSchemaBase,
1617
CodeInterpreterTool,
@@ -50,9 +51,11 @@ def __init__(
5051
model_name: Optional[str],
5152
*,
5253
model_params: ModelActivityParameters,
54+
agent: Optional[Agent[Any]],
5355
) -> None:
5456
self.model_name = model_name
5557
self.model_params = model_params
58+
self.agent = agent
5659

5760
async def get_response(
5861
self,
@@ -124,7 +127,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
124127
activity_input = ActivityModelInput(
125128
model_name=self.model_name,
126129
system_instructions=system_instructions,
127-
input=cast(Union[str, list[TResponseInputItem]], input),
130+
input=input,
128131
model_settings=model_settings,
129132
tools=tool_infos,
130133
output_schema=output_schema_input,
@@ -134,10 +137,25 @@ def make_tool_info(tool: Tool) -> ToolInput:
134137
prompt=prompt,
135138
)
136139

140+
if self.model_params.summary_override:
141+
summary = (
142+
self.model_params.summary_override
143+
if isinstance(self.model_params.summary_override, str)
144+
else (
145+
self.model_params.summary_override.provide(
146+
self.agent, system_instructions, input
147+
)
148+
)
149+
)
150+
elif self.agent:
151+
summary = self.agent.name
152+
else:
153+
summary = None
154+
137155
return await workflow.execute_activity_method(
138156
ModelActivity.invoke_model_activity,
139157
activity_input,
140-
summary=self.model_params.summary_override or _extract_summary(input),
158+
summary=summary,
141159
task_queue=self.model_params.task_queue,
142160
schedule_to_close_timeout=self.model_params.schedule_to_close_timeout,
143161
schedule_to_start_timeout=self.model_params.schedule_to_start_timeout,

temporalio/contrib/openai_agents/_trace_interceptor.py

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
import random
66
import uuid
77
from contextlib import contextmanager
8-
from typing import Any, Mapping, Protocol, Type
8+
from typing import Any, Mapping, Optional, Protocol, Type
99

1010
from agents import CustomSpanData, custom_span, get_current_span, trace
1111
from agents.tracing import (
1212
get_trace_provider,
1313
)
1414
from agents.tracing.scope import Scope
15-
from agents.tracing.spans import NoOpSpan
15+
from agents.tracing.spans import NoOpSpan, Span
1616

1717
import temporalio.activity
1818
import temporalio.api.common.v1
@@ -370,55 +370,78 @@ class _ContextPropagationWorkflowOutboundInterceptor(
370370
async def signal_child_workflow(
371371
self, input: temporalio.worker.SignalChildWorkflowInput
372372
) -> None:
373-
with custom_span(
374-
name="temporal:signalChildWorkflow",
375-
data={"workflowId": input.child_workflow_id},
376-
):
373+
trace = get_trace_provider().get_current_trace()
374+
if trace:
375+
with custom_span(
376+
name="temporal:signalChildWorkflow",
377+
data={"workflowId": input.child_workflow_id},
378+
):
379+
set_header_from_context(input, temporalio.workflow.payload_converter())
380+
await self.next.signal_child_workflow(input)
381+
else:
377382
set_header_from_context(input, temporalio.workflow.payload_converter())
378383
await self.next.signal_child_workflow(input)
379384

380385
async def signal_external_workflow(
381386
self, input: temporalio.worker.SignalExternalWorkflowInput
382387
) -> None:
383-
with custom_span(
384-
name="temporal:signalExternalWorkflow",
385-
data={"workflowId": input.workflow_id},
386-
):
388+
trace = get_trace_provider().get_current_trace()
389+
if trace:
390+
with custom_span(
391+
name="temporal:signalExternalWorkflow",
392+
data={"workflowId": input.workflow_id},
393+
):
394+
set_header_from_context(input, temporalio.workflow.payload_converter())
395+
await self.next.signal_external_workflow(input)
396+
else:
387397
set_header_from_context(input, temporalio.workflow.payload_converter())
388398
await self.next.signal_external_workflow(input)
389399

390400
def start_activity(
391401
self, input: temporalio.worker.StartActivityInput
392402
) -> temporalio.workflow.ActivityHandle:
393-
span = custom_span(
394-
name="temporal:startActivity", data={"activity": input.activity}
395-
)
396-
span.start(mark_as_current=True)
403+
trace = get_trace_provider().get_current_trace()
404+
span: Optional[Span] = None
405+
if trace:
406+
span = custom_span(
407+
name="temporal:startActivity", data={"activity": input.activity}
408+
)
409+
span.start(mark_as_current=True)
410+
397411
set_header_from_context(input, temporalio.workflow.payload_converter())
398412
handle = self.next.start_activity(input)
399-
handle.add_done_callback(lambda _: span.finish())
413+
if span:
414+
handle.add_done_callback(lambda _: span.finish()) # type: ignore
400415
return handle
401416

402417
async def start_child_workflow(
403418
self, input: temporalio.worker.StartChildWorkflowInput
404419
) -> temporalio.workflow.ChildWorkflowHandle:
405-
span = custom_span(
406-
name="temporal:startChildWorkflow", data={"workflow": input.workflow}
407-
)
408-
span.start(mark_as_current=True)
420+
trace = get_trace_provider().get_current_trace()
421+
span: Optional[Span] = None
422+
if trace:
423+
span = custom_span(
424+
name="temporal:startChildWorkflow", data={"workflow": input.workflow}
425+
)
426+
span.start(mark_as_current=True)
409427
set_header_from_context(input, temporalio.workflow.payload_converter())
410428
handle = await self.next.start_child_workflow(input)
411-
handle.add_done_callback(lambda _: span.finish())
429+
if span:
430+
handle.add_done_callback(lambda _: span.finish()) # type: ignore
412431
return handle
413432

414433
def start_local_activity(
415434
self, input: temporalio.worker.StartLocalActivityInput
416435
) -> temporalio.workflow.ActivityHandle:
417-
span = custom_span(
418-
name="temporal:startLocalActivity", data={"activity": input.activity}
419-
)
420-
span.start(mark_as_current=True)
436+
trace = get_trace_provider().get_current_trace()
437+
span: Optional[Span] = None
438+
if trace:
439+
span = custom_span(
440+
name="temporal:startLocalActivity", data={"activity": input.activity}
441+
)
442+
span.start(mark_as_current=True)
421443
set_header_from_context(input, temporalio.workflow.payload_converter())
422444
handle = self.next.start_local_activity(input)
423-
handle.add_done_callback(lambda _: span.finish())
445+
if span:
446+
handle.add_done_callback(lambda _: span.finish()) # type: ignore
424447
return handle

0 commit comments

Comments
 (0)