66import enum
77import json
88from dataclasses import dataclass
9+ from datetime import timedelta
910from typing import Any , Optional , Union , cast
1011
1112from agents import (
1718 ModelResponse ,
1819 ModelSettings ,
1920 ModelTracing ,
21+ OpenAIProvider ,
2022 RunContextWrapper ,
2123 Tool ,
2224 TResponseInputItem ,
2325 UserError ,
2426 WebSearchTool ,
2527)
2628from agents .models .multi_provider import MultiProvider
29+ from openai import (
30+ APIStatusError ,
31+ AsyncOpenAI ,
32+ AuthenticationError ,
33+ PermissionDeniedError ,
34+ )
2735from typing_extensions import Required , TypedDict
2836
2937from temporalio import activity
3038from temporalio .contrib .openai_agents ._heartbeat_decorator import _auto_heartbeater
39+ from temporalio .exceptions import ApplicationError
3140
3241
3342@dataclass
@@ -117,11 +126,15 @@ class ActivityModelInput(TypedDict, total=False):
117126
118127
119128class ModelActivity :
120- """Class wrapper for model invocation activities to allow model customization."""
129+ """Class wrapper for model invocation activities to allow model customization. By default, we use an OpenAIProvider with retries disabled.
130+ Disabling retries in your model of choice is recommended to allow activity retries to define the retry model.
131+ """
121132
122133 def __init__ (self , model_provider : Optional [ModelProvider ] = None ):
123134 """Initialize the activity with a model provider."""
124- self ._model_provider = model_provider or MultiProvider ()
135+ self ._model_provider = model_provider or OpenAIProvider (
136+ openai_client = AsyncOpenAI (max_retries = 0 )
137+ )
125138
126139 @activity .defn
127140 @_auto_heartbeater
@@ -160,7 +173,7 @@ def make_tool(tool: ToolInput) -> Tool:
160173 raise UserError (f"Unknown tool type: { tool .name } " )
161174
162175 tools = [make_tool (x ) for x in input .get ("tools" , [])]
163- handoffs = [
176+ handoffs : list [ Handoff [ Any , Any ]] = [
164177 Handoff (
165178 tool_name = x .tool_name ,
166179 tool_description = x .tool_description ,
@@ -171,14 +184,51 @@ def make_tool(tool: ToolInput) -> Tool:
171184 )
172185 for x in input .get ("handoffs" , [])
173186 ]
174- return await model .get_response (
175- system_instructions = input .get ("system_instructions" ),
176- input = input_input ,
177- model_settings = input ["model_settings" ],
178- tools = tools ,
179- output_schema = input .get ("output_schema" ),
180- handoffs = handoffs ,
181- tracing = ModelTracing (input ["tracing" ]),
182- previous_response_id = input .get ("previous_response_id" ),
183- prompt = input .get ("prompt" ),
184- )
187+
188+ try :
189+ return await model .get_response (
190+ system_instructions = input .get ("system_instructions" ),
191+ input = input_input ,
192+ model_settings = input ["model_settings" ],
193+ tools = tools ,
194+ output_schema = input .get ("output_schema" ),
195+ handoffs = handoffs ,
196+ tracing = ModelTracing (input ["tracing" ]),
197+ previous_response_id = input .get ("previous_response_id" ),
198+ prompt = input .get ("prompt" ),
199+ )
200+ except APIStatusError as e :
201+ # Listen to server hints
202+ retry_after = None
203+ retry_after_ms_header = e .response .headers .get ("retry-after-ms" )
204+ if retry_after_ms_header is not None :
205+ retry_after = timedelta (milliseconds = float (retry_after_ms_header ))
206+
207+ if retry_after is None :
208+ retry_after_header = e .response .headers .get ("retry-after" )
209+ if retry_after_header is not None :
210+ retry_after = timedelta (seconds = float (retry_after_header ))
211+
212+ should_retry_header = e .response .headers .get ("x-should-retry" )
213+ if should_retry_header == "true" :
214+ raise e
215+ if should_retry_header == "false" :
216+ raise ApplicationError (
217+ "Non retryable OpenAI error" ,
218+ non_retryable = True ,
219+ next_retry_delay = retry_after ,
220+ ) from e
221+
222+ # Specifically retryable status codes
223+ if e .response .status_code in [408 , 409 , 429 , 500 ]:
224+ raise ApplicationError (
225+ "Retryable OpenAI status code" ,
226+ non_retryable = False ,
227+ next_retry_delay = retry_after ,
228+ ) from e
229+
230+ raise ApplicationError (
231+ "Non retryable OpenAI status code" ,
232+ non_retryable = True ,
233+ next_retry_delay = retry_after ,
234+ ) from e
0 commit comments