Skip to content

Commit d2b75fe

Browse files
committed
chore: organize the code
1 parent 92b170a commit d2b75fe

File tree

1 file changed

+78
-50
lines changed

1 file changed

+78
-50
lines changed

veadk/models/ark_llm.py

Lines changed: 78 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,77 @@ def __init__(self):
6060
async def acompletion(
6161
self, model, messages, tools, **kwargs
6262
) -> Union[ModelResponse, CustomStreamWrapper]:
63-
# 1.1. Get optional_params using get_optional_params function
63+
# 1 Modify messages
64+
# Keep the header system-prompt and the user's messages
65+
messages = messages[:1] + messages[-1:]
66+
67+
# 2 Get request params
68+
(
69+
request_data,
70+
optional_params,
71+
litellm_params,
72+
logging_obj,
73+
custom_llm_provider,
74+
) = self._get_request_data(model, messages, tools, **kwargs)
75+
76+
# 3. Call litellm.aresponses with the transformed request data
77+
raw_response = await aresponses(
78+
**request_data,
79+
)
80+
# 4. Transform ResponsesAPIResponse
81+
# 4.1 Create model_response object
82+
model_response = ModelResponse()
83+
setattr(model_response, "usage", litellm.Usage())
84+
85+
# 4.2 Transform ResponsesAPIResponse to ModelResponses
86+
if isinstance(raw_response, ResponsesAPIResponse):
87+
response = self.transformation_handler.transform_response(
88+
model=model,
89+
raw_response=raw_response,
90+
model_response=model_response,
91+
logging_obj=logging_obj,
92+
request_data=request_data,
93+
messages=messages,
94+
optional_params=optional_params,
95+
litellm_params=litellm_params,
96+
encoding=kwargs.get("encoding"),
97+
api_key=kwargs.get("api_key"),
98+
json_mode=kwargs.get("json_mode"),
99+
)
100+
# 4.2.1 Modify ModelResponse id
101+
if raw_response and hasattr(raw_response, "id"):
102+
response.id = raw_response.id
103+
return response
104+
105+
else:
106+
completion_stream = self.transformation_handler.get_model_response_iterator(
107+
streaming_response=raw_response, # type: ignore
108+
sync_stream=True,
109+
json_mode=kwargs.get("json_mode"),
110+
)
111+
streamwrapper = CustomStreamWrapper(
112+
completion_stream=completion_stream,
113+
model=model,
114+
custom_llm_provider=custom_llm_provider,
115+
logging_obj=logging_obj,
116+
)
117+
return streamwrapper
118+
119+
def _get_request_data(self, model, messages, tools, **kwargs) -> tuple:
120+
# 1. Get optional_params using get_optional_params function
64121
optional_params = get_optional_params(model=model, tools=tools, **kwargs)
65122

66-
# 1.2. Get litellm_params using get_litellm_params function
123+
# 2. Get litellm_params using get_litellm_params function
67124
litellm_params = get_litellm_params(**kwargs)
68125

69-
# 1.3. Get headers by merging kwargs headers and extra_headers
126+
# 3. Get headers by merging kwargs headers and extra_headers
70127
headers = kwargs.get("headers", None) or kwargs.get("extra_headers", None)
71128
if headers is None:
72129
headers = {}
73130
if kwargs.get("extra_headers") is not None:
74131
headers.update(kwargs.get("extra_headers"))
75132

76-
# 1.4. Get logging_obj from kwargs or create new LiteLLMLoggingObj
133+
# 4. Get logging_obj from kwargs or create new LiteLLMLoggingObj
77134
logging_obj = kwargs.get("litellm_logging_obj", None)
78135
if logging_obj is None:
79136
logging_obj = Logging(
@@ -86,7 +143,7 @@ async def acompletion(
86143
start_time=datetime.now(),
87144
kwargs=kwargs,
88145
)
89-
# 1.5. Convert Message to `llm_provider` format
146+
# 4. Convert Message to `llm_provider` format
90147
_, custom_llm_provider, _, _ = get_llm_provider(model=model)
91148
if custom_llm_provider is not None and custom_llm_provider in [
92149
provider.value for provider in LlmProviders
@@ -98,10 +155,8 @@ async def acompletion(
98155
messages = provider_config.translate_developer_role_to_system_role(
99156
messages=messages
100157
)
101-
# 1.6 Add response_id to llm_response
102-
# Keep the header system-prompt and the user's messages
103-
messages = messages[:1] + messages[-1:]
104-
# 1.7 Transform request to responses api format
158+
159+
# 5 Transform request to responses api format
105160
request_data = self.transformation_handler.transform_request(
106161
model=model,
107162
messages=messages,
@@ -112,49 +167,22 @@ async def acompletion(
112167
client=kwargs.get("client"),
113168
)
114169

115-
# 2. Call litellm.aresponses with the transformed request data
116-
raw_response = await aresponses(
117-
**request_data,
170+
# 6 handler Missing field supply
171+
if "extra_body" not in request_data and kwargs.get("extra_body"):
172+
request_data["extra_body"] = kwargs.get("extra_body")
173+
if "extra_query" not in request_data and kwargs.get("extra_query"):
174+
request_data["extra_query"] = kwargs.get("extra_query")
175+
if "extra_headers" not in request_data and kwargs.get("extra_headers"):
176+
request_data["extra_headers"] = kwargs.get("extra_headers")
177+
178+
return (
179+
request_data,
180+
optional_params,
181+
litellm_params,
182+
logging_obj,
183+
custom_llm_provider,
118184
)
119185

120-
# 3.1 Create model_response object
121-
model_response = ModelResponse()
122-
setattr(model_response, "usage", litellm.Usage())
123-
124-
# 3.2 Transform ResponsesAPIResponse to ModelResponses
125-
if isinstance(raw_response, ResponsesAPIResponse):
126-
response = self.transformation_handler.transform_response(
127-
model=model,
128-
raw_response=raw_response,
129-
model_response=model_response,
130-
logging_obj=logging_obj,
131-
request_data=request_data,
132-
messages=messages,
133-
optional_params=optional_params,
134-
litellm_params=litellm_params,
135-
encoding=kwargs.get("encoding"),
136-
api_key=kwargs.get("api_key"),
137-
json_mode=kwargs.get("json_mode"),
138-
)
139-
# 3.2.1 Modify ModelResponse id
140-
if raw_response and hasattr(raw_response, "id"):
141-
response.id = raw_response.id
142-
return response
143-
144-
else:
145-
completion_stream = self.transformation_handler.get_model_response_iterator(
146-
streaming_response=raw_response, # type: ignore
147-
sync_stream=True,
148-
json_mode=kwargs.get("json_mode"),
149-
)
150-
streamwrapper = CustomStreamWrapper(
151-
completion_stream=completion_stream,
152-
model=model,
153-
custom_llm_provider=custom_llm_provider,
154-
logging_obj=logging_obj,
155-
)
156-
return streamwrapper
157-
158186

159187
class ArkLlm(LiteLlm):
160188
llm_client: ArkLlmClient = Field(default_factory=ArkLlmClient)

0 commit comments

Comments
 (0)