@@ -60,20 +60,77 @@ def __init__(self):
6060 async def acompletion (
6161 self , model , messages , tools , ** kwargs
6262 ) -> Union [ModelResponse , CustomStreamWrapper ]:
63- # 1.1. Get optional_params using get_optional_params function
63+ # 1 Modify messages
64+ # Keep the header system-prompt and the user's messages
65+ messages = messages [:1 ] + messages [- 1 :]
66+
67+ # 2 Get request params
68+ (
69+ request_data ,
70+ optional_params ,
71+ litellm_params ,
72+ logging_obj ,
73+ custom_llm_provider ,
74+ ) = self ._get_request_data (model , messages , tools , ** kwargs )
75+
76+ # 3. Call litellm.aresponses with the transformed request data
77+ raw_response = await aresponses (
78+ ** request_data ,
79+ )
80+ # 4. Transform ResponsesAPIResponse
81+ # 4.1 Create model_response object
82+ model_response = ModelResponse ()
83+ setattr (model_response , "usage" , litellm .Usage ())
84+
85+ # 4.2 Transform ResponsesAPIResponse to ModelResponses
86+ if isinstance (raw_response , ResponsesAPIResponse ):
87+ response = self .transformation_handler .transform_response (
88+ model = model ,
89+ raw_response = raw_response ,
90+ model_response = model_response ,
91+ logging_obj = logging_obj ,
92+ request_data = request_data ,
93+ messages = messages ,
94+ optional_params = optional_params ,
95+ litellm_params = litellm_params ,
96+ encoding = kwargs .get ("encoding" ),
97+ api_key = kwargs .get ("api_key" ),
98+ json_mode = kwargs .get ("json_mode" ),
99+ )
100+ # 4.2.1 Modify ModelResponse id
101+ if raw_response and hasattr (raw_response , "id" ):
102+ response .id = raw_response .id
103+ return response
104+
105+ else :
106+ completion_stream = self .transformation_handler .get_model_response_iterator (
107+ streaming_response = raw_response , # type: ignore
108+ sync_stream = True ,
109+ json_mode = kwargs .get ("json_mode" ),
110+ )
111+ streamwrapper = CustomStreamWrapper (
112+ completion_stream = completion_stream ,
113+ model = model ,
114+ custom_llm_provider = custom_llm_provider ,
115+ logging_obj = logging_obj ,
116+ )
117+ return streamwrapper
118+
119+ def _get_request_data (self , model , messages , tools , ** kwargs ) -> tuple :
120+ # 1. Get optional_params using get_optional_params function
64121 optional_params = get_optional_params (model = model , tools = tools , ** kwargs )
65122
66- # 1. 2. Get litellm_params using get_litellm_params function
123+ # 2. Get litellm_params using get_litellm_params function
67124 litellm_params = get_litellm_params (** kwargs )
68125
69- # 1. 3. Get headers by merging kwargs headers and extra_headers
126+ # 3. Get headers by merging kwargs headers and extra_headers
70127 headers = kwargs .get ("headers" , None ) or kwargs .get ("extra_headers" , None )
71128 if headers is None :
72129 headers = {}
73130 if kwargs .get ("extra_headers" ) is not None :
74131 headers .update (kwargs .get ("extra_headers" ))
75132
76- # 1. 4. Get logging_obj from kwargs or create new LiteLLMLoggingObj
133+ # 4. Get logging_obj from kwargs or create new LiteLLMLoggingObj
77134 logging_obj = kwargs .get ("litellm_logging_obj" , None )
78135 if logging_obj is None :
79136 logging_obj = Logging (
@@ -86,7 +143,7 @@ async def acompletion(
86143 start_time = datetime .now (),
87144 kwargs = kwargs ,
88145 )
89- # 1.5 . Convert Message to `llm_provider` format
146+ # 4 . Convert Message to `llm_provider` format
90147 _ , custom_llm_provider , _ , _ = get_llm_provider (model = model )
91148 if custom_llm_provider is not None and custom_llm_provider in [
92149 provider .value for provider in LlmProviders
@@ -98,10 +155,8 @@ async def acompletion(
98155 messages = provider_config .translate_developer_role_to_system_role (
99156 messages = messages
100157 )
101- # 1.6 Add response_id to llm_response
102- # Keep the header system-prompt and the user's messages
103- messages = messages [:1 ] + messages [- 1 :]
104- # 1.7 Transform request to responses api format
158+
159+ # 5 Transform request to responses api format
105160 request_data = self .transformation_handler .transform_request (
106161 model = model ,
107162 messages = messages ,
@@ -112,49 +167,22 @@ async def acompletion(
112167 client = kwargs .get ("client" ),
113168 )
114169
115- # 2. Call litellm.aresponses with the transformed request data
116- raw_response = await aresponses (
117- ** request_data ,
170+ # 6 handler Missing field supply
171+ if "extra_body" not in request_data and kwargs .get ("extra_body" ):
172+ request_data ["extra_body" ] = kwargs .get ("extra_body" )
173+ if "extra_query" not in request_data and kwargs .get ("extra_query" ):
174+ request_data ["extra_query" ] = kwargs .get ("extra_query" )
175+ if "extra_headers" not in request_data and kwargs .get ("extra_headers" ):
176+ request_data ["extra_headers" ] = kwargs .get ("extra_headers" )
177+
178+ return (
179+ request_data ,
180+ optional_params ,
181+ litellm_params ,
182+ logging_obj ,
183+ custom_llm_provider ,
118184 )
119185
120- # 3.1 Create model_response object
121- model_response = ModelResponse ()
122- setattr (model_response , "usage" , litellm .Usage ())
123-
124- # 3.2 Transform ResponsesAPIResponse to ModelResponses
125- if isinstance (raw_response , ResponsesAPIResponse ):
126- response = self .transformation_handler .transform_response (
127- model = model ,
128- raw_response = raw_response ,
129- model_response = model_response ,
130- logging_obj = logging_obj ,
131- request_data = request_data ,
132- messages = messages ,
133- optional_params = optional_params ,
134- litellm_params = litellm_params ,
135- encoding = kwargs .get ("encoding" ),
136- api_key = kwargs .get ("api_key" ),
137- json_mode = kwargs .get ("json_mode" ),
138- )
139- # 3.2.1 Modify ModelResponse id
140- if raw_response and hasattr (raw_response , "id" ):
141- response .id = raw_response .id
142- return response
143-
144- else :
145- completion_stream = self .transformation_handler .get_model_response_iterator (
146- streaming_response = raw_response , # type: ignore
147- sync_stream = True ,
148- json_mode = kwargs .get ("json_mode" ),
149- )
150- streamwrapper = CustomStreamWrapper (
151- completion_stream = completion_stream ,
152- model = model ,
153- custom_llm_provider = custom_llm_provider ,
154- logging_obj = logging_obj ,
155- )
156- return streamwrapper
157-
158186
159187class ArkLlm (LiteLlm ):
160188 llm_client : ArkLlmClient = Field (default_factory = ArkLlmClient )
0 commit comments