@@ -916,16 +916,6 @@ class CustomerServiceModel(StaticTestModel):
916916 ]
917917
918918
919- class AssertDifferentModelProvider (ModelProvider ):
920- model_names = set ()
921-
922- def __init__ (self , model : Model ):
923- self ._model = model
924-
925- def get_model (self , model_name : Union [str , None ]) -> Model :
926- self .model_names .add (model_name )
927- return self ._model
928-
929919
930920@workflow .defn
931921class CustomerServiceWorkflow :
@@ -999,7 +989,7 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
999989 if not use_local_model and not os .environ .get ("OPENAI_API_KEY" ):
1000990 pytest .skip ("No openai API key" )
1001991 new_config = client .config ()
1002- provider = AssertDifferentModelProvider (CustomerServiceModel ())
992+ provider = TestModelProvider (CustomerServiceModel ())
1003993 new_config ["plugins" ] = [
1004994 openai_agents .OpenAIAgentsPlugin (
1005995 model_params = ModelActivityParameters (
@@ -1089,7 +1079,6 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
10891079 .data .decode ()
10901080 )
10911081
1092- assert len (provider .model_names ) == 2
10931082
10941083class InputGuardrailModel (OpenAIResponsesModel ):
10951084 __test__ = False
@@ -2055,3 +2044,71 @@ async def test_hosted_mcp_tool(client: Client, use_local_model):
20552044 result = await workflow_handle .result ()
20562045 if use_local_model :
20572046 assert result == "Some language"
2047+
2048+
2049+ class AssertDifferentModelProvider (ModelProvider ):
2050+ model_names = set ()
2051+
2052+ def __init__ (self , model : Model ):
2053+ self ._model = model
2054+
2055+ def get_model (self , model_name : Union [str , None ]) -> Model :
2056+ self .model_names .add (model_name )
2057+ return self ._model
2058+
2059+
2060+ class MultipleModelsModel (StaticTestModel ):
2061+ responses = [
2062+ ResponseBuilders .tool_call ("{}" , "transfer_to_underling" ),
2063+ ResponseBuilders .output_message (
2064+ "I'm here to help! Was there a specific task you needed assistance with regarding the storeroom?"
2065+ ),
2066+ ]
2067+
2068+ @workflow .defn
2069+ class MultipleModelWorkflow :
2070+ @workflow .run
2071+ async def run (self ):
2072+ underling = Agent [None ](
2073+ name = "Underling" ,
2074+ instructions = "You do all the work you are told." ,
2075+ )
2076+
2077+ starting_agent = Agent [None ](
2078+ name = "Lazy Assistant" ,
2079+ model = "gpt-4o-mini" ,
2080+ instructions = "You delegate all your work to another agent." ,
2081+ handoffs = [underling ]
2082+ )
2083+ result = await Runner .run (
2084+ starting_agent = starting_agent ,
2085+ input = "Have you cleaned the store room yet?" ,
2086+ )
2087+ return result .final_output
2088+
2089+
2090+ async def test_multiple_models (client : Client ):
2091+ provider = AssertDifferentModelProvider (MultipleModelsModel ())
2092+ new_config = client .config ()
2093+ new_config ["plugins" ] = [
2094+ openai_agents .OpenAIAgentsPlugin (
2095+ model_params = ModelActivityParameters (
2096+ start_to_close_timeout = timedelta (seconds = 120 )
2097+ ),
2098+ model_provider = provider
2099+ )
2100+ ]
2101+ client = Client (** new_config )
2102+
2103+ async with new_worker (
2104+ client ,
2105+ MultipleModelWorkflow ,
2106+ ) as worker :
2107+ workflow_handle = await client .start_workflow (
2108+ MultipleModelWorkflow .run ,
2109+ id = f"multiple-model-{ uuid .uuid4 ()} " ,
2110+ task_queue = worker .task_queue ,
2111+ execution_timeout = timedelta (seconds = 120 ),
2112+ )
2113+ result = await workflow_handle .result ()
2114+ assert provider .model_names == {None , "gpt-4o-mini" }
0 commit comments