|
6 | 6 | from datetime import timedelta |
7 | 7 | from typing import AsyncIterator, Callable, Optional, Sequence, Union |
8 | 8 |
|
9 | | -from agents import ( |
10 | | - AgentOutputSchemaBase, |
11 | | - Handoff, |
12 | | - Model, |
13 | | - ModelProvider, |
14 | | - ModelResponse, |
15 | | - ModelSettings, |
16 | | - ModelTracing, |
17 | | - Tool, |
18 | | - TResponseInputItem, |
19 | | - set_trace_provider, |
20 | | -) |
21 | | -from agents.items import TResponseStreamEvent |
| 9 | +from agents import ModelProvider, set_trace_provider |
22 | 10 | from agents.run import get_default_agent_runner, set_default_agent_runner |
23 | 11 | from agents.tracing import get_trace_provider |
24 | 12 | from agents.tracing.provider import DefaultTraceProvider |
@@ -97,58 +85,6 @@ def set_open_ai_agent_temporal_overrides( |
97 | 85 | set_trace_provider(previous_trace_provider or DefaultTraceProvider()) |
98 | 86 |
|
99 | 87 |
|
100 | | -class TestModelProvider(ModelProvider): |
101 | | - """Test model provider which simply returns the given module.""" |
102 | | - |
103 | | - __test__ = False |
104 | | - |
105 | | - def __init__(self, model: Model): |
106 | | - """Initialize a test model provider with a model.""" |
107 | | - self._model = model |
108 | | - |
109 | | - def get_model(self, model_name: Union[str, None]) -> Model: |
110 | | - """Get a model from the model provider.""" |
111 | | - return self._model |
112 | | - |
113 | | - |
114 | | -class TestModel(Model): |
115 | | - """Test model for use mocking model responses.""" |
116 | | - |
117 | | - __test__ = False |
118 | | - |
119 | | - def __init__(self, fn: Callable[[], ModelResponse]) -> None: |
120 | | - """Initialize a test model with a callable.""" |
121 | | - self.fn = fn |
122 | | - |
123 | | - async def get_response( |
124 | | - self, |
125 | | - system_instructions: Union[str, None], |
126 | | - input: Union[str, list[TResponseInputItem]], |
127 | | - model_settings: ModelSettings, |
128 | | - tools: list[Tool], |
129 | | - output_schema: Union[AgentOutputSchemaBase, None], |
130 | | - handoffs: list[Handoff], |
131 | | - tracing: ModelTracing, |
132 | | - **kwargs, |
133 | | - ) -> ModelResponse: |
134 | | - """Get a response from the model.""" |
135 | | - return self.fn() |
136 | | - |
137 | | - def stream_response( |
138 | | - self, |
139 | | - system_instructions: Optional[str], |
140 | | - input: Union[str, list[TResponseInputItem]], |
141 | | - model_settings: ModelSettings, |
142 | | - tools: list[Tool], |
143 | | - output_schema: Optional[AgentOutputSchemaBase], |
144 | | - handoffs: list[Handoff], |
145 | | - tracing: ModelTracing, |
146 | | - **kwargs, |
147 | | - ) -> AsyncIterator[TResponseStreamEvent]: |
148 | | - """Get a streamed response from the model. Unimplemented.""" |
149 | | - raise NotImplementedError() |
150 | | - |
151 | | - |
152 | 88 | class OpenAIPayloadConverter(PydanticPayloadConverter): |
153 | 89 | """PayloadConverter for OpenAI agents.""" |
154 | 90 |
|
|
0 commit comments