@@ -1934,6 +1934,51 @@ def stream_response(
19341934 raise NotImplementedError ()
19351935
19361936
1937+ @workflow .defn
1938+ class AlternateModelAgent :
1939+ @workflow .run
1940+ async def run (self , prompt : str ) -> str :
1941+ agent = Agent [None ](
1942+ name = "Assistant" ,
1943+ instructions = "You only respond in haikus." ,
1944+ model = "test_model" ,
1945+ )
1946+ result = await Runner .run (starting_agent = agent , input = prompt )
1947+ return result .final_output
1948+
1949+
1950+ class CheckModelNameProvider (ModelProvider ):
1951+ def get_model (self , model_name : Optional [str ]) -> Model :
1952+ assert model_name == "test_model"
1953+ return TestHelloModel ()
1954+
1955+
1956+ async def test_alternative_model (client : Client ):
1957+ new_config = client .config ()
1958+ new_config ["plugins" ] = [
1959+ openai_agents .OpenAIAgentsPlugin (
1960+ model_params = ModelActivityParameters (
1961+ start_to_close_timeout = timedelta (seconds = 30 )
1962+ ),
1963+ model_provider = CheckModelNameProvider (),
1964+ )
1965+ ]
1966+ client = Client (** new_config )
1967+
1968+ async with new_worker (
1969+ client ,
1970+ AlternateModelAgent ,
1971+ ) as worker :
1972+ workflow_handle = await client .start_workflow (
1973+ AlternateModelAgent .run ,
1974+ "Hello" ,
1975+ id = f"alternative-model-{ uuid .uuid4 ()} " ,
1976+ task_queue = worker .task_queue ,
1977+ execution_timeout = timedelta (seconds = 10 ),
1978+ )
1979+ await workflow_handle .result ()
1980+
1981+
19371982async def test_heartbeat (client : Client , env : WorkflowEnvironment ):
19381983 if env .supports_time_skipping :
19391984 pytest .skip ("Relies on real timing, skip." )
0 commit comments