Skip to content

Commit 098a973

Browse files
committed
feat: ark litellm client
1 parent 6016f10 commit 098a973

File tree

2 files changed

+156
-6
lines changed

2 files changed

+156
-6
lines changed

veadk/agent.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ class Agent(LlmAgent):
151151

152152
tracers: list[BaseTracer] = []
153153

154+
enable_responses: bool = False
155+
154156
run_processor: Optional[BaseRunProcessor] = Field(default=None, exclude=True)
155157
"""Optional run processor for intercepting and processing agent execution flows.
156158
@@ -197,12 +199,22 @@ def model_post_init(self, __context: Any) -> None:
197199
logger.info(f"Model extra config: {self.model_extra_config}")
198200

199201
if not self.model:
200-
self.model = LiteLlm(
201-
model=f"{self.model_provider}/{self.model_name}",
202-
api_key=self.model_api_key,
203-
api_base=self.model_api_base,
204-
**self.model_extra_config,
205-
)
202+
if self.enable_responses:
203+
from veadk.models.ark_llm import ArkLlm
204+
205+
self.model = ArkLlm(
206+
model=f"{self.model_provider}/{self.model_name}",
207+
api_key=self.model_api_key,
208+
api_base=self.model_api_base,
209+
**self.model_extra_config,
210+
)
211+
else:
212+
self.model = LiteLlm(
213+
model=f"{self.model_provider}/{self.model_name}",
214+
api_key=self.model_api_key,
215+
api_base=self.model_api_base,
216+
**self.model_extra_config,
217+
)
206218
logger.debug(
207219
f"LiteLLM client created with config: {self.model_extra_config}"
208220
)

veadk/models/ark_llm.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import uuid
2+
from datetime import datetime
3+
from typing import Any, Dict, Union
4+
5+
import litellm
6+
from google.adk.models.lite_llm import (
7+
LiteLlm,
8+
LiteLLMClient,
9+
)
10+
from litellm import Logging
11+
from litellm import aresponses
12+
from litellm.completion_extras.litellm_responses_transformation.transformation import (
13+
LiteLLMResponsesTransformationHandler,
14+
)
15+
from litellm.litellm_core_utils.get_litellm_params import get_litellm_params
16+
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
17+
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
18+
from litellm.types.llms.openai import ResponsesAPIResponse
19+
from litellm.types.utils import ModelResponse, LlmProviders
20+
from litellm.utils import get_optional_params, ProviderConfigManager
21+
from pydantic import Field
22+
23+
from veadk.utils.logger import get_logger
24+
25+
# This will add functions to prompts if functions are provided.
26+
litellm.add_function_to_prompt = True
27+
28+
logger = get_logger(__name__)
29+
30+
31+
class ArkLlmClient(LiteLLMClient):
32+
def __init__(self):
33+
super().__init__()
34+
self.transformation_handler = LiteLLMResponsesTransformationHandler()
35+
36+
async def acompletion(
37+
self, model, messages, tools, **kwargs
38+
) -> Union[ModelResponse, CustomStreamWrapper]:
39+
# 1.1. Get optional_params using get_optional_params function
40+
optional_params = get_optional_params(model=model, **kwargs)
41+
42+
# 1.2. Get litellm_params using get_litellm_params function
43+
litellm_params = get_litellm_params(**kwargs)
44+
45+
# 1.3. Get headers by merging kwargs headers and extra_headers
46+
headers = kwargs.get("headers", None) or kwargs.get("extra_headers", None)
47+
if headers is None:
48+
headers = {}
49+
if kwargs.get("extra_headers") is not None:
50+
headers.update(kwargs.get("extra_headers"))
51+
52+
# 1.4. Get logging_obj from kwargs or create new LiteLLMLoggingObj
53+
logging_obj = kwargs.get("litellm_logging_obj", None)
54+
if logging_obj is None:
55+
logging_obj = Logging(
56+
model=model,
57+
messages=messages,
58+
stream=kwargs.get("stream", False),
59+
call_type="acompletion",
60+
litellm_call_id=str(uuid.uuid4()),
61+
function_id=str(uuid.uuid4()),
62+
start_time=datetime.now(),
63+
kwargs=kwargs,
64+
)
65+
# 1.5. Convert Message to `llm_provider` format
66+
_, custom_llm_provider, _, _ = get_llm_provider(model=model)
67+
if custom_llm_provider is not None and custom_llm_provider in [
68+
provider.value for provider in LlmProviders
69+
]:
70+
provider_config = ProviderConfigManager.get_provider_chat_config(
71+
model=model, provider=LlmProviders(custom_llm_provider)
72+
)
73+
if provider_config is not None:
74+
messages = provider_config.translate_developer_role_to_system_role(
75+
messages=messages
76+
)
77+
78+
# 1.6 Transform request to responses api format
79+
request_data = self.transformation_handler.transform_request(
80+
model=model,
81+
messages=messages,
82+
optional_params=optional_params,
83+
litellm_params=litellm_params,
84+
headers=headers,
85+
litellm_logging_obj=logging_obj,
86+
client=kwargs.get("client"),
87+
)
88+
89+
# 2. Call litellm.aresponses with the transformed request data
90+
result = await aresponses(
91+
**request_data,
92+
)
93+
94+
# 3.1 Create model_response object
95+
model_response = ModelResponse()
96+
setattr(model_response, "usage", litellm.Usage())
97+
98+
# 3.2 Transform ResponsesAPIResponse to ModelResponses
99+
if isinstance(result, ResponsesAPIResponse):
100+
return self.transformation_handler.transform_response(
101+
model=model,
102+
raw_response=result,
103+
model_response=model_response,
104+
logging_obj=logging_obj,
105+
request_data=request_data,
106+
messages=messages,
107+
optional_params=optional_params,
108+
litellm_params=litellm_params,
109+
encoding=kwargs.get("encoding"),
110+
api_key=kwargs.get("api_key"),
111+
json_mode=kwargs.get("json_mode"),
112+
)
113+
else:
114+
completion_stream = self.transformation_handler.get_model_response_iterator(
115+
streaming_response=result, # type: ignore
116+
sync_stream=True,
117+
json_mode=kwargs.get("json_mode"),
118+
)
119+
streamwrapper = CustomStreamWrapper(
120+
completion_stream=completion_stream,
121+
model=model,
122+
custom_llm_provider=custom_llm_provider,
123+
logging_obj=logging_obj,
124+
)
125+
return streamwrapper
126+
127+
128+
class ArkLlm(LiteLlm):
129+
llm_client: ArkLlmClient = Field(default_factory=ArkLlmClient)
130+
_additional_args: Dict[str, Any] = None
131+
132+
def __init__(self, **kwargs):
133+
super().__init__(**kwargs)
134+
135+
# async def generate_content_async(
136+
# self, llm_request: LlmRequest, stream: bool = False
137+
# ) -> AsyncGenerator[LlmResponse, None]:
138+
# pass

0 commit comments

Comments
 (0)