Skip to content

Commit 5aae5a4

Browse files
committed
Add option to run models as local activities
1 parent 5994a45 commit 5aae5a4

File tree

3 files changed

+59
-14
lines changed

3 files changed

+59
-14
lines changed

temporalio/contrib/openai_agents/_model_parameters.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,5 @@ class ModelActivityParameters:
7070

7171
priority: Priority = Priority.default
7272
"""Priority for the activity execution."""
73+
74+
use_local_activity: bool = False

temporalio/contrib/openai_agents/_temporal_model_stub.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -154,20 +154,32 @@ def make_tool_info(tool: Tool) -> ToolInput:
154154
else:
155155
summary = None
156156

157-
return await workflow.execute_activity_method(
158-
ModelActivity.invoke_model_activity,
159-
activity_input,
160-
summary=summary,
161-
task_queue=self.model_params.task_queue,
162-
schedule_to_close_timeout=self.model_params.schedule_to_close_timeout,
163-
schedule_to_start_timeout=self.model_params.schedule_to_start_timeout,
164-
start_to_close_timeout=self.model_params.start_to_close_timeout,
165-
heartbeat_timeout=self.model_params.heartbeat_timeout,
166-
retry_policy=self.model_params.retry_policy,
167-
cancellation_type=self.model_params.cancellation_type,
168-
versioning_intent=self.model_params.versioning_intent,
169-
priority=self.model_params.priority,
170-
)
157+
if self.model_params.use_local_activity:
158+
return await workflow.execute_local_activity_method(
159+
ModelActivity.invoke_model_activity,
160+
activity_input,
161+
summary=summary,
162+
schedule_to_close_timeout=self.model_params.schedule_to_close_timeout,
163+
schedule_to_start_timeout=self.model_params.schedule_to_start_timeout,
164+
start_to_close_timeout=self.model_params.start_to_close_timeout,
165+
retry_policy=self.model_params.retry_policy,
166+
cancellation_type=self.model_params.cancellation_type,
167+
)
168+
else:
169+
return await workflow.execute_activity_method(
170+
ModelActivity.invoke_model_activity,
171+
activity_input,
172+
summary=summary,
173+
task_queue=self.model_params.task_queue,
174+
schedule_to_close_timeout=self.model_params.schedule_to_close_timeout,
175+
schedule_to_start_timeout=self.model_params.schedule_to_start_timeout,
176+
start_to_close_timeout=self.model_params.start_to_close_timeout,
177+
heartbeat_timeout=self.model_params.heartbeat_timeout,
178+
retry_policy=self.model_params.retry_policy,
179+
cancellation_type=self.model_params.cancellation_type,
180+
versioning_intent=self.model_params.versioning_intent,
181+
priority=self.model_params.priority,
182+
)
171183

172184
def stream_response(
173185
self,

tests/contrib/openai_agents/test_openai.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2670,3 +2670,34 @@ async def test_model_conversion_loops():
26702670
triage_agent = seat_booking_agent.handoffs[0]
26712671
assert isinstance(triage_agent, Agent)
26722672
assert isinstance(triage_agent.model, _TemporalModelStub)
2673+
2674+
2675+
async def test_local_hello_world_agent(client: Client):
2676+
new_config = client.config()
2677+
new_config["plugins"] = [
2678+
openai_agents.OpenAIAgentsPlugin(
2679+
model_params=ModelActivityParameters(
2680+
start_to_close_timeout=timedelta(seconds=30),
2681+
use_local_activity=True,
2682+
),
2683+
model_provider=TestModelProvider(TestHelloModel())
2684+
)
2685+
]
2686+
client = Client(**new_config)
2687+
2688+
async with new_worker(client, HelloWorldAgent) as worker:
2689+
handle = await client.start_workflow(
2690+
HelloWorldAgent.run,
2691+
"Tell me about recursion in programming.",
2692+
id=f"hello-workflow-{uuid.uuid4()}",
2693+
task_queue=worker.task_queue,
2694+
execution_timeout=timedelta(seconds=5),
2695+
)
2696+
result = await handle.result()
2697+
assert result == "test"
2698+
2699+
local_activity_found = False
2700+
async for e in handle.fetch_history_events():
2701+
if e.HasField("marker_recorded_event_attributes"):
2702+
local_activity_found = True
2703+
assert local_activity_found

0 commit comments

Comments
 (0)