Skip to content

Commit cd11f78

Browse files
upgrade more tests
1 parent fd8f26a commit cd11f78

8 files changed

+279
-196
lines changed

tests/openai_agents/basic/test_agent_lifecycle_workflow.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,40 @@
22
from concurrent.futures import ThreadPoolExecutor
33

44
from temporalio.client import Client
5+
from temporalio.contrib.openai_agents.testing import AgentEnvironment, ResponseBuilders, TestModel
56
from temporalio.worker import Worker
67

78
from openai_agents.basic.workflows.agent_lifecycle_workflow import (
89
AgentLifecycleWorkflow,
910
)
1011

1112

13+
def agent_lifecycle_test_model():
14+
return TestModel.returning_responses(
15+
[ResponseBuilders.output_message('{"number": 10}')]
16+
)
17+
18+
1219
async def test_execute_workflow(client: Client):
1320
task_queue_name = str(uuid.uuid4())
1421

15-
async with Worker(
16-
client,
17-
task_queue=task_queue_name,
18-
workflows=[AgentLifecycleWorkflow],
19-
activity_executor=ThreadPoolExecutor(5),
20-
):
21-
result = await client.execute_workflow(
22-
AgentLifecycleWorkflow.run,
23-
10, # max_number parameter
24-
id=str(uuid.uuid4()),
22+
async with AgentEnvironment(model=agent_lifecycle_test_model()) as agent_env:
23+
client = agent_env.applied_on_client(client)
24+
async with Worker(
25+
client,
2526
task_queue=task_queue_name,
26-
)
27+
workflows=[AgentLifecycleWorkflow],
28+
activity_executor=ThreadPoolExecutor(5),
29+
):
30+
result = await client.execute_workflow(
31+
AgentLifecycleWorkflow.run,
32+
10, # max_number parameter
33+
id=str(uuid.uuid4()),
34+
task_queue=task_queue_name,
35+
)
2736

28-
# Verify the result has the expected structure
29-
assert isinstance(result.number, int)
30-
assert (
31-
0 <= result.number <= 20
32-
) # Should be between 0 and max*2 due to multiply operation
37+
# Verify the result has the expected structure
38+
assert isinstance(result.number, int)
39+
assert (
40+
0 <= result.number <= 20
41+
) # Should be between 0 and max*2 due to multiply operation

tests/openai_agents/basic/test_dynamic_system_prompt_workflow.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,51 +2,62 @@
22
from concurrent.futures import ThreadPoolExecutor
33

44
from temporalio.client import Client
5+
from temporalio.contrib.openai_agents.testing import AgentEnvironment, ResponseBuilders, TestModel
56
from temporalio.worker import Worker
67

78
from openai_agents.basic.workflows.dynamic_system_prompt_workflow import (
89
DynamicSystemPromptWorkflow,
910
)
1011

1112

13+
def dynamic_system_prompt_test_model():
14+
return TestModel.returning_responses(
15+
[ResponseBuilders.output_message("Style: haiku\nResponse: The weather is cloudy with a chance of meatballs.")]
16+
)
17+
18+
1219
async def test_execute_workflow_with_random_style(client: Client):
1320
task_queue_name = str(uuid.uuid4())
1421

15-
async with Worker(
16-
client,
17-
task_queue=task_queue_name,
18-
workflows=[DynamicSystemPromptWorkflow],
19-
activity_executor=ThreadPoolExecutor(5),
20-
):
21-
result = await client.execute_workflow(
22-
DynamicSystemPromptWorkflow.run,
23-
"Tell me about the weather today.",
24-
id=str(uuid.uuid4()),
22+
async with AgentEnvironment(model=dynamic_system_prompt_test_model()) as agent_env:
23+
client = agent_env.applied_on_client(client)
24+
async with Worker(
25+
client,
2526
task_queue=task_queue_name,
26-
)
27-
28-
# Verify the result has the expected format
29-
assert "Style:" in result
30-
assert "Response:" in result
31-
assert any(style in result for style in ["haiku", "pirate", "robot"])
27+
workflows=[DynamicSystemPromptWorkflow],
28+
activity_executor=ThreadPoolExecutor(5),
29+
):
30+
result = await client.execute_workflow(
31+
DynamicSystemPromptWorkflow.run,
32+
"Tell me about the weather today.",
33+
id=str(uuid.uuid4()),
34+
task_queue=task_queue_name,
35+
)
36+
37+
# Verify the result has the expected format
38+
assert "Style:" in result
39+
assert "Response:" in result
40+
assert any(style in result for style in ["haiku", "pirate", "robot"])
3241

3342

3443
async def test_execute_workflow_with_specific_style(client: Client):
3544
task_queue_name = str(uuid.uuid4())
3645

37-
async with Worker(
38-
client,
39-
task_queue=task_queue_name,
40-
workflows=[DynamicSystemPromptWorkflow],
41-
activity_executor=ThreadPoolExecutor(5),
42-
):
43-
result = await client.execute_workflow(
44-
DynamicSystemPromptWorkflow.run,
45-
args=["Tell me about the weather today.", "haiku"],
46-
id=str(uuid.uuid4()),
46+
async with AgentEnvironment(model=dynamic_system_prompt_test_model()) as agent_env:
47+
client = agent_env.applied_on_client(client)
48+
async with Worker(
49+
client,
4750
task_queue=task_queue_name,
48-
)
49-
50-
# Verify the result has the expected format and style
51-
assert "Style: haiku" in result
52-
assert "Response:" in result
51+
workflows=[DynamicSystemPromptWorkflow],
52+
activity_executor=ThreadPoolExecutor(5),
53+
):
54+
result = await client.execute_workflow(
55+
DynamicSystemPromptWorkflow.run,
56+
args=["Tell me about the weather today.", "haiku"],
57+
id=str(uuid.uuid4()),
58+
task_queue=task_queue_name,
59+
)
60+
61+
# Verify the result has the expected format and style
62+
assert "Style: haiku" in result
63+
assert "Response:" in result

tests/openai_agents/basic/test_lifecycle_workflow.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,38 @@
22
from concurrent.futures import ThreadPoolExecutor
33

44
from temporalio.client import Client
5+
from temporalio.contrib.openai_agents.testing import AgentEnvironment, ResponseBuilders, TestModel
56
from temporalio.worker import Worker
67

78
from openai_agents.basic.workflows.lifecycle_workflow import LifecycleWorkflow
89

910

11+
def lifecycle_test_model():
12+
return TestModel.returning_responses(
13+
[ResponseBuilders.output_message('{"number": 10}')]
14+
)
15+
16+
1017
async def test_execute_workflow(client: Client):
1118
task_queue_name = str(uuid.uuid4())
1219

13-
async with Worker(
14-
client,
15-
task_queue=task_queue_name,
16-
workflows=[LifecycleWorkflow],
17-
activity_executor=ThreadPoolExecutor(5),
18-
):
19-
result = await client.execute_workflow(
20-
LifecycleWorkflow.run,
21-
10, # max_number parameter
22-
id=str(uuid.uuid4()),
20+
async with AgentEnvironment(model=lifecycle_test_model()) as agent_env:
21+
client = agent_env.applied_on_client(client)
22+
async with Worker(
23+
client,
2324
task_queue=task_queue_name,
24-
)
25+
workflows=[LifecycleWorkflow],
26+
activity_executor=ThreadPoolExecutor(5),
27+
):
28+
result = await client.execute_workflow(
29+
LifecycleWorkflow.run,
30+
10, # max_number parameter
31+
id=str(uuid.uuid4()),
32+
task_queue=task_queue_name,
33+
)
2534

26-
# Verify the result has the expected structure
27-
assert isinstance(result.number, int)
28-
assert (
29-
0 <= result.number <= 20
30-
) # Should be between 0 and max*2 due to multiply operation
35+
# Verify the result has the expected structure
36+
assert isinstance(result.number, int)
37+
assert (
38+
0 <= result.number <= 20
39+
) # Should be between 0 and max*2 due to multiply operation

tests/openai_agents/basic/test_local_image_workflow.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,52 +2,63 @@
22
from concurrent.futures import ThreadPoolExecutor
33

44
from temporalio.client import Client
5+
from temporalio.contrib.openai_agents.testing import AgentEnvironment, ResponseBuilders, TestModel
56
from temporalio.worker import Worker
67

78
from openai_agents.basic.activities.image_activities import read_image_as_base64
89
from openai_agents.basic.workflows.local_image_workflow import LocalImageWorkflow
910

1011

12+
def local_image_test_model():
13+
return TestModel.returning_responses(
14+
[ResponseBuilders.output_message("I can see a bison in the image.")]
15+
)
16+
17+
1118
async def test_execute_workflow_default_question(client: Client):
1219
task_queue_name = str(uuid.uuid4())
1320

14-
async with Worker(
15-
client,
16-
task_queue=task_queue_name,
17-
workflows=[LocalImageWorkflow],
18-
activity_executor=ThreadPoolExecutor(5),
19-
activities=[read_image_as_base64],
20-
):
21-
result = await client.execute_workflow(
22-
LocalImageWorkflow.run,
23-
"openai_agents/basic/media/image_bison.jpg", # Path to test image
24-
id=str(uuid.uuid4()),
21+
async with AgentEnvironment(model=local_image_test_model()) as agent_env:
22+
client = agent_env.applied_on_client(client)
23+
async with Worker(
24+
client,
2525
task_queue=task_queue_name,
26-
)
26+
workflows=[LocalImageWorkflow],
27+
activity_executor=ThreadPoolExecutor(5),
28+
activities=[read_image_as_base64],
29+
):
30+
result = await client.execute_workflow(
31+
LocalImageWorkflow.run,
32+
"openai_agents/basic/media/image_bison.jpg", # Path to test image
33+
id=str(uuid.uuid4()),
34+
task_queue=task_queue_name,
35+
)
2736

28-
# Verify the result is a string response
29-
assert isinstance(result, str)
30-
assert len(result) > 0
37+
# Verify the result is a string response
38+
assert isinstance(result, str)
39+
assert len(result) > 0
3140

3241

3342
async def test_execute_workflow_custom_question(client: Client):
3443
task_queue_name = str(uuid.uuid4())
3544

36-
async with Worker(
37-
client,
38-
task_queue=task_queue_name,
39-
workflows=[LocalImageWorkflow],
40-
activity_executor=ThreadPoolExecutor(5),
41-
activities=[read_image_as_base64],
42-
):
43-
custom_question = "What animals do you see in this image?"
44-
result = await client.execute_workflow(
45-
LocalImageWorkflow.run,
46-
args=["openai_agents/basic/media/image_bison.jpg", custom_question],
47-
id=str(uuid.uuid4()),
45+
async with AgentEnvironment(model=local_image_test_model()) as agent_env:
46+
client = agent_env.applied_on_client(client)
47+
async with Worker(
48+
client,
4849
task_queue=task_queue_name,
49-
)
50+
workflows=[LocalImageWorkflow],
51+
activity_executor=ThreadPoolExecutor(5),
52+
activities=[read_image_as_base64],
53+
):
54+
custom_question = "What animals do you see in this image?"
55+
result = await client.execute_workflow(
56+
LocalImageWorkflow.run,
57+
args=["openai_agents/basic/media/image_bison.jpg", custom_question],
58+
id=str(uuid.uuid4()),
59+
task_queue=task_queue_name,
60+
)
5061

51-
# Verify the result is a string response
52-
assert isinstance(result, str)
53-
assert len(result) > 0
62+
# Verify the result is a string response
63+
assert isinstance(result, str)
64+
assert len(result) > 0

tests/openai_agents/basic/test_non_strict_output_workflow.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,51 @@
22
from concurrent.futures import ThreadPoolExecutor
33

44
from temporalio.client import Client
5+
from temporalio.contrib.openai_agents.testing import AgentEnvironment, ResponseBuilders, TestModel
56
from temporalio.worker import Worker
67

78
from openai_agents.basic.workflows.non_strict_output_workflow import (
89
NonStrictOutputWorkflow,
910
)
1011

1112

13+
def non_strict_output_test_model():
14+
return TestModel.returning_responses(
15+
[ResponseBuilders.output_message('{"jokes": {"1": "Why do programmers prefer dark mode? Because light attracts bugs!", "2": "How many programmers does it take to change a light bulb? None, that\'s a hardware problem.", "3": "Why do Java developers wear glasses? Because they can\'t C#!"}}')
16+
]
17+
)
18+
19+
1220
async def test_execute_workflow(client: Client):
1321
task_queue_name = str(uuid.uuid4())
1422

15-
async with Worker(
16-
client,
17-
task_queue=task_queue_name,
18-
workflows=[NonStrictOutputWorkflow],
19-
activity_executor=ThreadPoolExecutor(5),
20-
# No external activities needed
21-
):
22-
result = await client.execute_workflow(
23-
NonStrictOutputWorkflow.run,
24-
"Tell me 3 funny jokes about programming.",
25-
id=str(uuid.uuid4()),
23+
async with AgentEnvironment(model=non_strict_output_test_model()) as agent_env:
24+
client = agent_env.applied_on_client(client)
25+
async with Worker(
26+
client,
2627
task_queue=task_queue_name,
27-
)
28-
29-
# Verify the result has the expected structure
30-
assert isinstance(result, dict)
31-
32-
assert "strict_error" in result
33-
assert "non_strict_result" in result
34-
35-
# If there's a strict_error, it should be a string
36-
if "strict_error" in result:
37-
assert isinstance(result["strict_error"], str)
38-
assert len(result["strict_error"]) > 0
39-
40-
jokes = result["non_strict_result"]["jokes"]
41-
assert isinstance(jokes, dict)
42-
assert isinstance(jokes[list(jokes.keys())[0]], str)
28+
workflows=[NonStrictOutputWorkflow],
29+
activity_executor=ThreadPoolExecutor(5),
30+
# No external activities needed
31+
):
32+
result = await client.execute_workflow(
33+
NonStrictOutputWorkflow.run,
34+
"Tell me 3 funny jokes about programming.",
35+
id=str(uuid.uuid4()),
36+
task_queue=task_queue_name,
37+
)
38+
39+
# Verify the result has the expected structure
40+
assert isinstance(result, dict)
41+
42+
assert "strict_error" in result
43+
assert "non_strict_result" in result
44+
45+
# If there's a strict_error, it should be a string
46+
if "strict_error" in result:
47+
assert isinstance(result["strict_error"], str)
48+
assert len(result["strict_error"]) > 0
49+
50+
jokes = result["non_strict_result"]["jokes"]
51+
assert isinstance(jokes, dict)
52+
assert isinstance(jokes[list(jokes.keys())[0]], str)

0 commit comments

Comments
 (0)