|
1 | 1 | import dataclasses |
2 | | -import json |
3 | 2 | import typing |
4 | 3 | from typing import Any, Optional, Union |
5 | 4 |
|
|
17 | 16 | TResponseInputItem, |
18 | 17 | ) |
19 | 18 | from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner |
20 | | -from pydantic_core import to_json |
21 | 19 |
|
22 | 20 | from temporalio import workflow |
23 | 21 | from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters |
24 | 22 | from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub |
25 | 23 | from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError |
26 | 24 |
|
27 | 25 |
|
| 26 | +# Recursively replace models in all agents |
| 27 | +def _convert_agent( |
| 28 | + model_params: ModelActivityParameters, |
| 29 | + agent: Agent[Any], |
| 30 | + seen: Optional[dict[int, Agent]], |
| 31 | +) -> Agent[Any]: |
| 32 | + if seen is None: |
| 33 | + seen = dict() |
| 34 | + |
| 35 | + # Short circuit if this model was already seen to prevent looping from circular handoffs |
| 36 | + if id(agent) in seen: |
| 37 | + return seen[id(agent)] |
| 38 | + |
| 39 | + # This agent has already been processed in some other run |
| 40 | + if isinstance(agent.model, _TemporalModelStub): |
| 41 | + return agent |
| 42 | + |
| 43 | + # Save the new version of the agent so that we can replace loops |
| 44 | + new_agent = dataclasses.replace(agent) |
| 45 | + seen[id(agent)] = new_agent |
| 46 | + |
| 47 | + name = _model_name(agent) |
| 48 | + |
| 49 | + new_handoffs: list[Union[Agent, Handoff]] = [] |
| 50 | + for handoff in agent.handoffs: |
| 51 | + if isinstance(handoff, Agent): |
| 52 | + new_handoffs.append(_convert_agent(model_params, handoff, seen)) |
| 53 | + elif isinstance(handoff, Handoff): |
| 54 | + original_invoke = handoff.on_invoke_handoff |
| 55 | + |
| 56 | + async def on_invoke(context: RunContextWrapper[Any], args: str) -> Agent: |
| 57 | + handoff_agent = await original_invoke(context, args) |
| 58 | + return _convert_agent(model_params, handoff_agent, seen) |
| 59 | + |
| 60 | + new_handoffs.append( |
| 61 | + dataclasses.replace(handoff, on_invoke_handoff=on_invoke) |
| 62 | + ) |
| 63 | + else: |
| 64 | + raise ValueError(f"Unknown handoff type: {type(handoff)}") |
| 65 | + |
| 66 | + new_agent.model = _TemporalModelStub( |
| 67 | + model_name=name, |
| 68 | + model_params=model_params, |
| 69 | + agent=agent, |
| 70 | + ) |
| 71 | + new_agent.handoffs = new_handoffs |
| 72 | + return new_agent |
| 73 | + |
| 74 | + |
28 | 75 | class TemporalOpenAIRunner(AgentRunner): |
29 | 76 | """Temporal Runner for OpenAI agents. |
30 | 77 |
|
@@ -101,54 +148,9 @@ async def run( |
101 | 148 | ), |
102 | 149 | ) |
103 | 150 |
|
104 | | - # Recursively replace models in all agents |
105 | | - def convert_agent(agent: Agent[Any], seen: Optional[set[int]]) -> Agent[Any]: |
106 | | - if seen is None: |
107 | | - seen = set() |
108 | | - |
109 | | - # Short circuit if this model was already seen to prevent looping from circular handoffs |
110 | | - if id(agent) in seen: |
111 | | - return agent |
112 | | - seen.add(id(agent)) |
113 | | - |
114 | | - # This agent has already been processed in some other run |
115 | | - if isinstance(agent.model, _TemporalModelStub): |
116 | | - return agent |
117 | | - |
118 | | - name = _model_name(agent) |
119 | | - |
120 | | - new_handoffs: list[Union[Agent, Handoff]] = [] |
121 | | - for handoff in agent.handoffs: |
122 | | - if isinstance(handoff, Agent): |
123 | | - new_handoffs.append(convert_agent(handoff, seen)) |
124 | | - elif isinstance(handoff, Handoff): |
125 | | - original_invoke = handoff.on_invoke_handoff |
126 | | - |
127 | | - async def on_invoke( |
128 | | - context: RunContextWrapper[Any], args: str |
129 | | - ) -> Agent: |
130 | | - handoff_agent = await original_invoke(context, args) |
131 | | - return convert_agent(handoff_agent, seen) |
132 | | - |
133 | | - new_handoffs.append( |
134 | | - dataclasses.replace(handoff, on_invoke_handoff=on_invoke) |
135 | | - ) |
136 | | - else: |
137 | | - raise ValueError(f"Unknown handoff type: {type(handoff)}") |
138 | | - |
139 | | - return dataclasses.replace( |
140 | | - agent, |
141 | | - model=_TemporalModelStub( |
142 | | - model_name=name, |
143 | | - model_params=self.model_params, |
144 | | - agent=agent, |
145 | | - ), |
146 | | - handoffs=new_handoffs, |
147 | | - ) |
148 | | - |
149 | 151 | try: |
150 | 152 | return await self._runner.run( |
151 | | - starting_agent=convert_agent(starting_agent, None), |
| 153 | + starting_agent=_convert_agent(self.model_params, starting_agent, None), |
152 | 154 | input=input, |
153 | 155 | context=context, |
154 | 156 | max_turns=max_turns, |
|
0 commit comments