Skip to content

Commit f6595ca

Browse files
hacking
1 parent 9b1004a commit f6595ca

File tree

2 files changed

+45
-34
lines changed

2 files changed

+45
-34
lines changed

tests/openai_agents/basic/test_hello_world_workflow.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,43 @@
88
from temporalio.worker import Worker
99

1010
from openai_agents.basic.workflows.hello_world_workflow import HelloWorldAgent
11+
from tests.openai_agents.conftest import sequential_test_model
1112

1213

1314
@pytest.fixture
14-
def mocked_model(mocker):
15-
mock = mocker.AsyncMock()
16-
mock.get_response.side_effect = [
17-
ModelResponse(
18-
output=[
19-
ResponseOutputMessage(
20-
id="1",
21-
content=[
22-
ResponseOutputText(
23-
annotations=[],
24-
text="This is a haiku (not really)",
25-
type="output_text",
26-
)
27-
],
28-
role="assistant",
29-
status="completed",
30-
type="message",
31-
)
32-
],
33-
usage=Usage(requests=1, input_tokens=1, output_tokens=1, total_tokens=1),
34-
response_id="1",
35-
)
36-
]
15+
def test_model():
16+
return sequential_test_model(
17+
[
18+
ModelResponse(
19+
output=[
20+
ResponseOutputMessage(
21+
id="1",
22+
content=[
23+
ResponseOutputText(
24+
annotations=[],
25+
text="This is a haiku (not really)",
26+
type="output_text",
27+
)
28+
],
29+
role="assistant",
30+
status="completed",
31+
type="message",
32+
)
33+
],
34+
usage=Usage(
35+
requests=1, input_tokens=1, output_tokens=1, total_tokens=1
36+
),
37+
response_id="1",
38+
)
39+
]
40+
)
41+
3742

38-
return mock
43+
@pytest.fixture
44+
def test_model():
45+
return TestModel.returning_responses(
46+
[ResponseBuilders.output_message("This is a haiku (not really)")]
47+
)
3948

4049

4150
async def test_execute_workflow(client: Client):

tests/openai_agents/conftest.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,23 @@
22
from typing import Optional
33

44
import pytest
5-
from agents import Model, ModelProvider
6-
from temporalio.contrib.openai_agents import ModelActivityParameters, OpenAIAgentsPlugin
5+
from agents import ModelProvider, ModelResponse
6+
from temporalio.contrib.openai_agents import (
7+
ModelActivityParameters,
8+
OpenAIAgentsPlugin,
9+
TestModel,
10+
TestModelProvider,
11+
)
712

813

9-
class MockedModelProvider(ModelProvider):
10-
def __init__(self, mocked_model):
11-
self.mocked_model = mocked_model
12-
13-
def get_model(self, model_name: str | None) -> Model:
14-
return self.mocked_model
14+
def sequential_test_model(responses: list[ModelResponse]) -> TestModel:
15+
responses = iter(responses)
16+
return TestModel(lambda: next(responses))
1517

1618

1719
@pytest.fixture
18-
def model_provider(mocked_model):
19-
return MockedModelProvider(mocked_model)
20+
def model_provider(test_model):
21+
return TestModelProvider(test_model)
2022

2123

2224
@pytest.fixture

0 commit comments

Comments
 (0)