|
| 1 | +import uuid |
| 2 | +from datetime import datetime |
| 3 | +from typing import Any, Dict, Union |
| 4 | + |
| 5 | +import litellm |
| 6 | +from google.adk.models.lite_llm import ( |
| 7 | + LiteLlm, |
| 8 | + LiteLLMClient, |
| 9 | +) |
| 10 | +from litellm import Logging |
| 11 | +from litellm import aresponses |
| 12 | +from litellm.completion_extras.litellm_responses_transformation.transformation import ( |
| 13 | + LiteLLMResponsesTransformationHandler, |
| 14 | +) |
| 15 | +from litellm.litellm_core_utils.get_litellm_params import get_litellm_params |
| 16 | +from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider |
| 17 | +from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper |
| 18 | +from litellm.types.llms.openai import ResponsesAPIResponse |
| 19 | +from litellm.types.utils import ModelResponse, LlmProviders |
| 20 | +from litellm.utils import get_optional_params, ProviderConfigManager |
| 21 | +from pydantic import Field |
| 22 | + |
| 23 | +from veadk.utils.logger import get_logger |
| 24 | + |
| 25 | +# This will add functions to prompts if functions are provided. |
| 26 | +litellm.add_function_to_prompt = True |
| 27 | + |
| 28 | +logger = get_logger(__name__) |
| 29 | + |
| 30 | + |
| 31 | +class ArkLlmClient(LiteLLMClient): |
| 32 | + def __init__(self): |
| 33 | + super().__init__() |
| 34 | + self.transformation_handler = LiteLLMResponsesTransformationHandler() |
| 35 | + |
| 36 | + async def acompletion( |
| 37 | + self, model, messages, tools, **kwargs |
| 38 | + ) -> Union[ModelResponse, CustomStreamWrapper]: |
| 39 | + # 1.1. Get optional_params using get_optional_params function |
| 40 | + optional_params = get_optional_params(model=model, **kwargs) |
| 41 | + |
| 42 | + # 1.2. Get litellm_params using get_litellm_params function |
| 43 | + litellm_params = get_litellm_params(**kwargs) |
| 44 | + |
| 45 | + # 1.3. Get headers by merging kwargs headers and extra_headers |
| 46 | + headers = kwargs.get("headers", None) or kwargs.get("extra_headers", None) |
| 47 | + if headers is None: |
| 48 | + headers = {} |
| 49 | + if kwargs.get("extra_headers") is not None: |
| 50 | + headers.update(kwargs.get("extra_headers")) |
| 51 | + |
| 52 | + # 1.4. Get logging_obj from kwargs or create new LiteLLMLoggingObj |
| 53 | + logging_obj = kwargs.get("litellm_logging_obj", None) |
| 54 | + if logging_obj is None: |
| 55 | + logging_obj = Logging( |
| 56 | + model=model, |
| 57 | + messages=messages, |
| 58 | + stream=kwargs.get("stream", False), |
| 59 | + call_type="acompletion", |
| 60 | + litellm_call_id=str(uuid.uuid4()), |
| 61 | + function_id=str(uuid.uuid4()), |
| 62 | + start_time=datetime.now(), |
| 63 | + kwargs=kwargs, |
| 64 | + ) |
| 65 | + # 1.5. Convert Message to `llm_provider` format |
| 66 | + _, custom_llm_provider, _, _ = get_llm_provider(model=model) |
| 67 | + if custom_llm_provider is not None and custom_llm_provider in [ |
| 68 | + provider.value for provider in LlmProviders |
| 69 | + ]: |
| 70 | + provider_config = ProviderConfigManager.get_provider_chat_config( |
| 71 | + model=model, provider=LlmProviders(custom_llm_provider) |
| 72 | + ) |
| 73 | + if provider_config is not None: |
| 74 | + messages = provider_config.translate_developer_role_to_system_role( |
| 75 | + messages=messages |
| 76 | + ) |
| 77 | + |
| 78 | + # 1.6 Transform request to responses api format |
| 79 | + request_data = self.transformation_handler.transform_request( |
| 80 | + model=model, |
| 81 | + messages=messages, |
| 82 | + optional_params=optional_params, |
| 83 | + litellm_params=litellm_params, |
| 84 | + headers=headers, |
| 85 | + litellm_logging_obj=logging_obj, |
| 86 | + client=kwargs.get("client"), |
| 87 | + ) |
| 88 | + |
| 89 | + # 2. Call litellm.aresponses with the transformed request data |
| 90 | + result = await aresponses( |
| 91 | + **request_data, |
| 92 | + ) |
| 93 | + |
| 94 | + # 3.1 Create model_response object |
| 95 | + model_response = ModelResponse() |
| 96 | + setattr(model_response, "usage", litellm.Usage()) |
| 97 | + |
| 98 | + # 3.2 Transform ResponsesAPIResponse to ModelResponses |
| 99 | + if isinstance(result, ResponsesAPIResponse): |
| 100 | + return self.transformation_handler.transform_response( |
| 101 | + model=model, |
| 102 | + raw_response=result, |
| 103 | + model_response=model_response, |
| 104 | + logging_obj=logging_obj, |
| 105 | + request_data=request_data, |
| 106 | + messages=messages, |
| 107 | + optional_params=optional_params, |
| 108 | + litellm_params=litellm_params, |
| 109 | + encoding=kwargs.get("encoding"), |
| 110 | + api_key=kwargs.get("api_key"), |
| 111 | + json_mode=kwargs.get("json_mode"), |
| 112 | + ) |
| 113 | + else: |
| 114 | + completion_stream = self.transformation_handler.get_model_response_iterator( |
| 115 | + streaming_response=result, # type: ignore |
| 116 | + sync_stream=True, |
| 117 | + json_mode=kwargs.get("json_mode"), |
| 118 | + ) |
| 119 | + streamwrapper = CustomStreamWrapper( |
| 120 | + completion_stream=completion_stream, |
| 121 | + model=model, |
| 122 | + custom_llm_provider=custom_llm_provider, |
| 123 | + logging_obj=logging_obj, |
| 124 | + ) |
| 125 | + return streamwrapper |
| 126 | + |
| 127 | + |
| 128 | +class ArkLlm(LiteLlm): |
| 129 | + llm_client: ArkLlmClient = Field(default_factory=ArkLlmClient) |
| 130 | + _additional_args: Dict[str, Any] = None |
| 131 | + |
| 132 | + def __init__(self, **kwargs): |
| 133 | + super().__init__(**kwargs) |
| 134 | + |
| 135 | + # async def generate_content_async( |
| 136 | + # self, llm_request: LlmRequest, stream: bool = False |
| 137 | + # ) -> AsyncGenerator[LlmResponse, None]: |
| 138 | + # pass |
0 commit comments