Skip to content

Commit 974d132

Browse files
committed
Add dedicated test that models are configured correctly
1 parent 6d6aba4 commit 974d132

File tree

2 files changed

+70
-12
lines changed

2 files changed

+70
-12
lines changed

temporalio/contrib/openai_agents/_temporal_model_stub.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ async def get_response(
6969
previous_response_id: Optional[str],
7070
prompt: Optional[ResponsePromptParam],
7171
) -> ModelResponse:
72+
print("Model stub invocation:", self.model_name)
7273
def make_tool_info(tool: Tool) -> ToolInput:
7374
if isinstance(
7475
tool,

tests/contrib/openai_agents/test_openai.py

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -916,16 +916,6 @@ class CustomerServiceModel(StaticTestModel):
916916
]
917917

918918

919-
class AssertDifferentModelProvider(ModelProvider):
920-
model_names = set()
921-
922-
def __init__(self, model: Model):
923-
self._model = model
924-
925-
def get_model(self, model_name: Union[str, None]) -> Model:
926-
self.model_names.add(model_name)
927-
return self._model
928-
929919

930920
@workflow.defn
931921
class CustomerServiceWorkflow:
@@ -999,7 +989,7 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
999989
if not use_local_model and not os.environ.get("OPENAI_API_KEY"):
1000990
pytest.skip("No openai API key")
1001991
new_config = client.config()
1002-
provider = AssertDifferentModelProvider(CustomerServiceModel())
992+
provider = TestModelProvider(CustomerServiceModel())
1003993
new_config["plugins"] = [
1004994
openai_agents.OpenAIAgentsPlugin(
1005995
model_params=ModelActivityParameters(
@@ -1089,7 +1079,6 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
10891079
.data.decode()
10901080
)
10911081

1092-
assert len(provider.model_names) == 2
10931082

10941083
class InputGuardrailModel(OpenAIResponsesModel):
10951084
__test__ = False
@@ -2055,3 +2044,71 @@ async def test_hosted_mcp_tool(client: Client, use_local_model):
20552044
result = await workflow_handle.result()
20562045
if use_local_model:
20572046
assert result == "Some language"
2047+
2048+
2049+
class AssertDifferentModelProvider(ModelProvider):
2050+
model_names = set()
2051+
2052+
def __init__(self, model: Model):
2053+
self._model = model
2054+
2055+
def get_model(self, model_name: Union[str, None]) -> Model:
2056+
self.model_names.add(model_name)
2057+
return self._model
2058+
2059+
2060+
class MultipleModelsModel(StaticTestModel):
2061+
responses = [
2062+
ResponseBuilders.tool_call("{}", "transfer_to_underling"),
2063+
ResponseBuilders.output_message(
2064+
"I'm here to help! Was there a specific task you needed assistance with regarding the storeroom?"
2065+
),
2066+
]
2067+
2068+
@workflow.defn
2069+
class MultipleModelWorkflow:
2070+
@workflow.run
2071+
async def run(self):
2072+
underling = Agent[None](
2073+
name="Underling",
2074+
instructions="You do all the work you are told.",
2075+
)
2076+
2077+
starting_agent = Agent[None](
2078+
name="Lazy Assistant",
2079+
model="gpt-4o-mini",
2080+
instructions="You delegate all your work to another agent.",
2081+
handoffs=[underling]
2082+
)
2083+
result = await Runner.run(
2084+
starting_agent=starting_agent,
2085+
input="Have you cleaned the store room yet?",
2086+
)
2087+
return result.final_output
2088+
2089+
2090+
async def test_multiple_models(client: Client):
2091+
provider = AssertDifferentModelProvider(MultipleModelsModel())
2092+
new_config = client.config()
2093+
new_config["plugins"] = [
2094+
openai_agents.OpenAIAgentsPlugin(
2095+
model_params=ModelActivityParameters(
2096+
start_to_close_timeout=timedelta(seconds=120)
2097+
),
2098+
model_provider=provider
2099+
)
2100+
]
2101+
client = Client(**new_config)
2102+
2103+
async with new_worker(
2104+
client,
2105+
MultipleModelWorkflow,
2106+
) as worker:
2107+
workflow_handle = await client.start_workflow(
2108+
MultipleModelWorkflow.run,
2109+
id=f"multiple-model-{uuid.uuid4()}",
2110+
task_queue=worker.task_queue,
2111+
execution_timeout=timedelta(seconds=120),
2112+
)
2113+
result = await workflow_handle.result()
2114+
assert provider.model_names == {None, "gpt-4o-mini"}

0 commit comments

Comments
 (0)