Skip to content

Commit 9b1004a

Browse files
Figure out mocks for one test
1 parent 10082e1 commit 9b1004a

File tree

7 files changed

+103
-6
lines changed

7 files changed

+103
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ dev = [
2626
"types-pyyaml>=6.0.12.20241230,<7",
2727
"pytest-pretty>=1.3.0",
2828
"poethepoet>=0.36.0",
29+
"pytest-mock>=3.15.1",
2930
]
3031
bedrock = ["boto3>=1.34.92,<2"]
3132
dsl = [

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@ def event_loop():
3737
loop.close()
3838

3939

40-
@pytest.fixture(scope="session")
40+
@pytest.fixture
4141
def plugins():
4242
# By default, no plugins.
4343
# Other tests can override this fixture, such as in tests/openai_agents/conftest.py
4444
return []
4545

4646

47-
@pytest_asyncio.fixture(scope="session")
47+
@pytest_asyncio.fixture
4848
async def env(request, plugins) -> AsyncGenerator[WorkflowEnvironment, None]:
4949
env_type = request.config.getoption("--workflow-environment")
5050
if env_type == "local":

tests/openai_agents/basic/test_hello_world_workflow.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,43 @@
11
import uuid
22
from concurrent.futures import ThreadPoolExecutor
33

4+
import pytest
5+
from agents import ModelResponse, Usage
6+
from openai.types.responses import ResponseOutputMessage, ResponseOutputText
47
from temporalio.client import Client
58
from temporalio.worker import Worker
69

710
from openai_agents.basic.workflows.hello_world_workflow import HelloWorldAgent
811

912

13+
@pytest.fixture
14+
def mocked_model(mocker):
15+
mock = mocker.AsyncMock()
16+
mock.get_response.side_effect = [
17+
ModelResponse(
18+
output=[
19+
ResponseOutputMessage(
20+
id="1",
21+
content=[
22+
ResponseOutputText(
23+
annotations=[],
24+
text="This is a haiku (not really)",
25+
type="output_text",
26+
)
27+
],
28+
role="assistant",
29+
status="completed",
30+
type="message",
31+
)
32+
],
33+
usage=Usage(requests=1, input_tokens=1, output_tokens=1, total_tokens=1),
34+
response_id="1",
35+
)
36+
]
37+
38+
return mock
39+
40+
1041
async def test_execute_workflow(client: Client):
1142
task_queue_name = str(uuid.uuid4())
1243

tests/openai_agents/conftest.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,31 @@
11
from datetime import timedelta
2+
from typing import Optional
23

34
import pytest
5+
from agents import Model, ModelProvider
46
from temporalio.contrib.openai_agents import ModelActivityParameters, OpenAIAgentsPlugin
57

68

7-
@pytest.fixture(scope="session")
8-
def plugins():
9+
class MockedModelProvider(ModelProvider):
10+
def __init__(self, mocked_model):
11+
self.mocked_model = mocked_model
12+
13+
def get_model(self, model_name: str | None) -> Model:
14+
return self.mocked_model
15+
16+
17+
@pytest.fixture
18+
def model_provider(mocked_model):
19+
return MockedModelProvider(mocked_model)
20+
21+
22+
@pytest.fixture
23+
def plugins(model_provider: Optional[ModelProvider]):
924
return [
1025
OpenAIAgentsPlugin(
1126
model_params=ModelActivityParameters(
1227
start_to_close_timeout=timedelta(seconds=30)
13-
)
28+
),
29+
model_provider=model_provider,
1430
)
1531
]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from datetime import timedelta
2+
3+
import pytest
4+
from temporalio.contrib.openai_agents import ModelActivityParameters, OpenAIAgentsPlugin
5+
6+
7+
@pytest.fixture
8+
def model_provider():
9+
return None
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import uuid
2+
from concurrent.futures import ThreadPoolExecutor
3+
4+
from temporalio.client import Client
5+
from temporalio.worker import Worker
6+
7+
from openai_agents.basic.workflows.hello_world_workflow import HelloWorldAgent
8+
9+
10+
async def test_execute_workflow(client: Client):
11+
task_queue_name = str(uuid.uuid4())
12+
13+
async with Worker(
14+
client,
15+
task_queue=task_queue_name,
16+
workflows=[HelloWorldAgent],
17+
activity_executor=ThreadPoolExecutor(5),
18+
):
19+
result = await client.execute_workflow(
20+
HelloWorldAgent.run,
21+
"Write a recursive haiku about recursive haikus.",
22+
id=str(uuid.uuid4()),
23+
task_queue=task_queue_name,
24+
)
25+
assert isinstance(result, str)
26+
assert len(result) > 0

uv.lock

Lines changed: 15 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)