|
1 | 1 | import json |
2 | 2 | import uuid |
3 | 3 | from datetime import datetime |
4 | | -from typing import Any, Dict, Union, AsyncGenerator |
| 4 | +from typing import Any, Dict, Union, AsyncGenerator, Optional |
5 | 5 |
|
6 | 6 | import litellm |
| 7 | +from google.adk.agents.callback_context import CallbackContext |
| 8 | +from google.adk.models.cache_metadata import CacheMetadata |
7 | 9 | from openai import OpenAI |
8 | 10 | from google.adk.models import LlmRequest, LlmResponse |
9 | 11 | from google.adk.models.lite_llm import ( |
|
77 | 79 | ] |
78 | 80 |
|
79 | 81 |
|
80 | | -def _add_response_id_to_llm_response( |
| 82 | +def _add_response_data_to_llm_response( |
81 | 83 | llm_response: LlmResponse, response: ModelResponse |
82 | 84 | ) -> LlmResponse: |
| 85 | + # add responses id |
83 | 86 | if not response.id.startswith("chatcmpl"): |
84 | 87 | if llm_response.custom_metadata is None: |
85 | 88 | llm_response.custom_metadata = {} |
86 | 89 | llm_response.custom_metadata["response_id"] = response["id"] |
| 90 | + # add responses cache data |
| 91 | + if response.get("usage", {}).get("prompt_tokens_details"): |
| 92 | + if llm_response.usage_metadata: |
| 93 | + llm_response.usage_metadata.cached_content_token_count = ( |
| 94 | + response.get("usage", {}).get("prompt_tokens_details").cached_tokens |
| 95 | + ) |
87 | 96 | return llm_response |
88 | 97 |
|
89 | 98 |
|
@@ -423,5 +432,27 @@ async def generate_content_async( |
423 | 432 | # Transport response id |
424 | 433 | # yield _model_response_to_generate_content_response(response) |
425 | 434 | llm_response = _model_response_to_generate_content_response(response) |
426 | | - yield _add_response_id_to_llm_response(llm_response, response) |
| 435 | + yield _add_response_data_to_llm_response(llm_response, response) |
427 | 436 | # ------------------------------------------------------ # |
| 437 | + |
| 438 | + |
| 439 | +def add_previous_response_id( |
| 440 | + callback_context: CallbackContext, llm_request: LlmRequest |
| 441 | +) -> Optional[LlmResponse]: |
| 442 | + invocation_context = callback_context._invocation_context |
| 443 | + events = invocation_context.session.events |
| 444 | + if ( |
| 445 | + events |
| 446 | + and len(events) >= 2 |
| 447 | + and events[-2].custom_metadata |
| 448 | + and "response_id" in events[-2].custom_metadata |
| 449 | + ): |
| 450 | + previous_response_id = events[-2].custom_metadata["response_id"] |
| 451 | + llm_request.cache_metadata = CacheMetadata( |
| 452 | + cache_name=previous_response_id, |
| 453 | + expire_time=0, |
| 454 | + fingerprint="", |
| 455 | + invocations_used=0, |
| 456 | + cached_contents_count=0, |
| 457 | + ) |
| 458 | + return |
0 commit comments