Skip to content

Commit 088ecc6

Browse files
Expose some agent testing utils to users
1 parent 162cff7 commit 088ecc6

File tree

5 files changed

+198
-187
lines changed

5 files changed

+198
-187
lines changed

temporalio/contrib/openai_agents/__init__.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,13 @@
2121
from temporalio.contrib.openai_agents._temporal_openai_agents import (
2222
OpenAIAgentsPlugin,
2323
OpenAIPayloadConverter,
24-
TestModel,
25-
TestModelProvider,
2624
)
2725
from temporalio.contrib.openai_agents._trace_interceptor import (
2826
OpenAIAgentsTracingInterceptor,
2927
)
3028
from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError
3129

32-
from . import workflow
30+
from . import test, workflow
3331

3432
__all__ = [
3533
"AgentsWorkflowError",
@@ -38,7 +36,6 @@
3836
"OpenAIPayloadConverter",
3937
"StatelessMCPServerProvider",
4038
"StatefulMCPServerProvider",
41-
"TestModel",
42-
"TestModelProvider",
39+
"test",
4340
"workflow",
4441
]

temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 4 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,23 @@
66
from datetime import timedelta
77
from typing import AsyncIterator, Callable, Optional, Sequence, Union
88

9-
from agents import (
10-
AgentOutputSchemaBase,
11-
Handoff,
12-
Model,
13-
ModelProvider,
14-
ModelResponse,
15-
ModelSettings,
16-
ModelTracing,
17-
Tool,
18-
TResponseInputItem,
19-
set_trace_provider,
20-
)
21-
from agents.items import TResponseStreamEvent
9+
from agents import ModelProvider, set_trace_provider
2210
from agents.run import get_default_agent_runner, set_default_agent_runner
2311
from agents.tracing import get_trace_provider
2412
from agents.tracing.provider import DefaultTraceProvider
2513

2614
from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
2715
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
28-
from temporalio.contrib.openai_agents._openai_runner import (
29-
TemporalOpenAIRunner,
30-
)
16+
from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner
3117
from temporalio.contrib.openai_agents._temporal_trace_provider import (
3218
TemporalTraceProvider,
3319
)
3420
from temporalio.contrib.openai_agents._trace_interceptor import (
3521
OpenAIAgentsTracingInterceptor,
3622
)
3723
from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError
38-
from temporalio.contrib.pydantic import (
39-
PydanticPayloadConverter,
40-
ToJsonOptions,
41-
)
42-
from temporalio.converter import (
43-
DataConverter,
44-
DefaultPayloadConverter,
45-
)
24+
from temporalio.contrib.pydantic import PydanticPayloadConverter, ToJsonOptions
25+
from temporalio.converter import DataConverter, DefaultPayloadConverter
4626
from temporalio.plugin import SimplePlugin
4727
from temporalio.worker import WorkflowRunner
4828
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
@@ -103,58 +83,6 @@ def set_open_ai_agent_temporal_overrides(
10383
set_trace_provider(previous_trace_provider or DefaultTraceProvider())
10484

10585

106-
class TestModelProvider(ModelProvider):
107-
"""Test model provider which simply returns the given module."""
108-
109-
__test__ = False
110-
111-
def __init__(self, model: Model):
112-
"""Initialize a test model provider with a model."""
113-
self._model = model
114-
115-
def get_model(self, model_name: Union[str, None]) -> Model:
116-
"""Get a model from the model provider."""
117-
return self._model
118-
119-
120-
class TestModel(Model):
121-
"""Test model for use mocking model responses."""
122-
123-
__test__ = False
124-
125-
def __init__(self, fn: Callable[[], ModelResponse]) -> None:
126-
"""Initialize a test model with a callable."""
127-
self.fn = fn
128-
129-
async def get_response(
130-
self,
131-
system_instructions: Union[str, None],
132-
input: Union[str, list[TResponseInputItem]],
133-
model_settings: ModelSettings,
134-
tools: list[Tool],
135-
output_schema: Union[AgentOutputSchemaBase, None],
136-
handoffs: list[Handoff],
137-
tracing: ModelTracing,
138-
**kwargs,
139-
) -> ModelResponse:
140-
"""Get a response from the model."""
141-
return self.fn()
142-
143-
def stream_response(
144-
self,
145-
system_instructions: Optional[str],
146-
input: Union[str, list[TResponseInputItem]],
147-
model_settings: ModelSettings,
148-
tools: list[Tool],
149-
output_schema: Optional[AgentOutputSchemaBase],
150-
handoffs: list[Handoff],
151-
tracing: ModelTracing,
152-
**kwargs,
153-
) -> AsyncIterator[TResponseStreamEvent]:
154-
"""Get a streamed response from the model. Unimplemented."""
155-
raise NotImplementedError()
156-
157-
15886
class OpenAIPayloadConverter(PydanticPayloadConverter):
15987
"""PayloadConverter for OpenAI agents."""
16088

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from typing import AsyncIterator, Callable, Optional, Union
2+
3+
from agents import (
4+
AgentOutputSchemaBase,
5+
Handoff,
6+
Model,
7+
ModelProvider,
8+
ModelResponse,
9+
ModelSettings,
10+
ModelTracing,
11+
Tool,
12+
TResponseInputItem,
13+
Usage,
14+
)
15+
from agents.items import TResponseOutputItem, TResponseStreamEvent
16+
from openai.types.responses import (
17+
ResponseFunctionToolCall,
18+
ResponseOutputMessage,
19+
ResponseOutputText,
20+
)
21+
22+
23+
class ResponseBuilders:
24+
"""
25+
Builders for creating model responses for testing.
26+
"""
27+
28+
@staticmethod
29+
def model_response(output: TResponseOutputItem) -> ModelResponse:
30+
return ModelResponse(
31+
output=[output],
32+
usage=Usage(),
33+
response_id=None,
34+
)
35+
36+
@staticmethod
37+
def response_output_message(text: str) -> ResponseOutputMessage:
38+
return ResponseOutputMessage(
39+
id="",
40+
content=[
41+
ResponseOutputText(
42+
text=text,
43+
annotations=[],
44+
type="output_text",
45+
)
46+
],
47+
role="assistant",
48+
status="completed",
49+
type="message",
50+
)
51+
52+
@staticmethod
53+
def tool_call(arguments: str, name: str) -> ModelResponse:
54+
return ResponseBuilders.model_response(
55+
ResponseFunctionToolCall(
56+
arguments=arguments,
57+
call_id="call",
58+
name=name,
59+
type="function_call",
60+
id="id",
61+
status="completed",
62+
)
63+
)
64+
65+
@staticmethod
66+
def output_message(text: str) -> ModelResponse:
67+
return ResponseBuilders.model_response(
68+
ResponseBuilders.response_output_message(text)
69+
)
70+
71+
72+
class TestModelProvider(ModelProvider):
73+
"""Test model provider which simply returns the given module."""
74+
75+
__test__ = False
76+
77+
def __init__(self, model: Model):
78+
"""Initialize a test model provider with a model."""
79+
self._model = model
80+
81+
def get_model(self, model_name: Union[str, None]) -> Model:
82+
"""Get a model from the model provider."""
83+
return self._model
84+
85+
86+
class TestModel(Model):
87+
"""Test model for use mocking model responses."""
88+
89+
__test__ = False
90+
91+
def __init__(self, fn: Callable[[], ModelResponse]) -> None:
92+
"""Initialize a test model with a callable."""
93+
self.fn = fn
94+
95+
async def get_response(
96+
self,
97+
system_instructions: Union[str, None],
98+
input: Union[str, list[TResponseInputItem]],
99+
model_settings: ModelSettings,
100+
tools: list[Tool],
101+
output_schema: Union[AgentOutputSchemaBase, None],
102+
handoffs: list[Handoff],
103+
tracing: ModelTracing,
104+
**kwargs,
105+
) -> ModelResponse:
106+
"""Get a response from the model."""
107+
return self.fn()
108+
109+
def stream_response(
110+
self,
111+
system_instructions: Optional[str],
112+
input: Union[str, list[TResponseInputItem]],
113+
model_settings: ModelSettings,
114+
tools: list[Tool],
115+
output_schema: Optional[AgentOutputSchemaBase],
116+
handoffs: list[Handoff],
117+
tracing: ModelTracing,
118+
**kwargs,
119+
) -> AsyncIterator[TResponseStreamEvent]:
120+
"""Get a streamed response from the model. Unimplemented."""
121+
raise NotImplementedError()
122+
123+
124+
class StaticTestModel(TestModel):
125+
"""Static test model for use mocking model responses.
126+
Set a responses attribute to a list of model responses, which will be returned sequentially.
127+
"""
128+
129+
__test__ = False
130+
responses: list[ModelResponse] = []
131+
132+
def __init__(
133+
self,
134+
) -> None:
135+
self._responses = iter(self.responses)
136+
super().__init__(lambda: next(self._responses))

0 commit comments

Comments
 (0)