Skip to content

Commit ecdc6f6

Browse files
committed
fix: send response.id usage matadata
1 parent c668164 commit ecdc6f6

File tree

2 files changed

+76
-33
lines changed

2 files changed

+76
-33
lines changed

veadk/models/ark_llm.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
TextChunk,
1515
_message_to_generate_content_response,
1616
UsageMetadataChunk,
17-
_model_response_to_generate_content_response,
1817
)
1918
from google.genai import types
2019
from litellm import ChatCompletionAssistantMessage
@@ -91,14 +90,14 @@ async def generate_content_async(
9190
previous_response_id = None
9291
if llm_request.cache_metadata and llm_request.cache_metadata.cache_name:
9392
previous_response_id = llm_request.cache_metadata.cache_name
94-
# ------------------------------------------------------ #
9593
completion_args = {
9694
"model": self.model,
9795
"messages": messages,
9896
"tools": tools,
9997
"response_format": response_format,
10098
"previous_response_id": previous_response_id, # supply previous_response_id
10199
}
100+
# ------------------------------------------------------ #
102101
completion_args.update(self._additional_args)
103102

104103
if generation_params:
@@ -117,6 +116,7 @@ async def generate_content_async(
117116
raw_response = await self.llm_client.aresponse(**response_args)
118117
async for part in raw_response:
119118
for (
119+
model_response,
120120
chunk,
121121
finish_reason,
122122
) in self.transform_handler.stream_event_to_chunk(
@@ -158,6 +158,14 @@ async def generate_content_async(
158158
candidates_token_count=chunk.completion_tokens,
159159
total_token_count=chunk.total_tokens,
160160
)
161+
# ------------------------------------------------------ #
162+
if model_response.get("usage", {}).get("prompt_tokens_details"):
163+
usage_metadata.cached_content_token_count = (
164+
model_response.get("usage", {})
165+
.get("prompt_tokens_details")
166+
.cached_tokens
167+
)
168+
# ------------------------------------------------------ #
161169

162170
if (
163171
finish_reason == "tool_calls" or finish_reason == "stop"
@@ -185,6 +193,11 @@ async def generate_content_async(
185193
)
186194
)
187195
)
196+
self.transform_handler.adapt_responses_api(
197+
model_response,
198+
aggregated_llm_response_with_tool_call,
199+
stream=True,
200+
)
188201
text = ""
189202
function_calls.clear()
190203
elif finish_reason == "stop" and text:
@@ -193,6 +206,9 @@ async def generate_content_async(
193206
role="assistant", content=text
194207
)
195208
)
209+
self.transform_handler.adapt_responses_api(
210+
model_response, aggregated_llm_response, stream=True
211+
)
196212
text = ""
197213

198214
# waiting until streaming ends to yield the llm_response as litellm tends
@@ -213,32 +229,9 @@ async def generate_content_async(
213229

214230
else:
215231
raw_response = await self.llm_client.aresponse(**response_args)
216-
yield self._openai_response_to_generate_content_response(raw_response)
217-
218-
def _openai_response_to_generate_content_response(
219-
self, raw_response: OpenAITypeResponse
220-
) -> LlmResponse:
221-
"""
222-
OpenAITypeResponse -> litellm.ModelResponse -> LlmResponse
223-
"""
224-
model_response = self.transform_handler.transform_response(
225-
openai_response=raw_response, stream=False
226-
)
227-
llm_response = _model_response_to_generate_content_response(model_response)
228-
229-
if not model_response.id.startswith("chatcmpl"):
230-
if llm_response.custom_metadata is None:
231-
llm_response.custom_metadata = {}
232-
llm_response.custom_metadata["response_id"] = model_response["id"]
233-
# add responses cache data
234-
if model_response.get("usage", {}).get("prompt_tokens_details"):
235-
if llm_response.usage_metadata:
236-
llm_response.usage_metadata.cached_content_token_count = (
237-
model_response.get("usage", {})
238-
.get("prompt_tokens_details")
239-
.cached_tokens
240-
)
241-
return llm_response
232+
yield self.transform_handler.openai_response_to_generate_content_response(
233+
raw_response
234+
)
242235

243236

244237
# before_model_callback

veadk/models/ark_transform.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
from typing import Any, Dict, Optional, cast, List, Generator, Tuple, Union
33

44
import litellm
5+
from google.adk.models import LlmResponse
56
from google.adk.models.lite_llm import (
67
TextChunk,
78
FunctionChunk,
89
UsageMetadataChunk,
910
_model_response_to_chunk,
11+
_model_response_to_generate_content_response,
1012
)
1113
from openai.types.responses import (
1214
Response as OpenAITypeResponse,
@@ -192,16 +194,64 @@ def transform_response(
192194
response.id = raw_response.id
193195
return response
194196

197+
def openai_response_to_generate_content_response(
198+
self, raw_response: OpenAITypeResponse
199+
) -> LlmResponse:
200+
"""
201+
OpenAITypeResponse -> litellm.ModelResponse -> LlmResponse
202+
instead of `_model_response_to_generate_content_response`,
203+
"""
204+
# no stream response
205+
model_response = self.transform_response(
206+
openai_response=raw_response, stream=False
207+
)
208+
llm_response = _model_response_to_generate_content_response(model_response)
209+
210+
llm_response = self.adapt_responses_api(
211+
model_response,
212+
llm_response,
213+
)
214+
return llm_response
215+
216+
def adapt_responses_api(
217+
self,
218+
model_response: ModelResponse,
219+
llm_response: LlmResponse,
220+
stream: bool = False,
221+
):
222+
"""
223+
Adapt responses api.
224+
"""
225+
if not model_response.id.startswith("chatcmpl"):
226+
if llm_response.custom_metadata is None:
227+
llm_response.custom_metadata = {}
228+
llm_response.custom_metadata["response_id"] = model_response["id"]
229+
# add responses cache data
230+
if not stream:
231+
if model_response.get("usage", {}).get("prompt_tokens_details"):
232+
if llm_response.usage_metadata:
233+
llm_response.usage_metadata.cached_content_token_count = (
234+
model_response.get("usage", {})
235+
.get("prompt_tokens_details")
236+
.cached_tokens
237+
)
238+
return llm_response
239+
195240
def stream_event_to_chunk(
196241
self, event: ResponseStreamEvent, model: str
197242
) -> Generator[
198243
Tuple[
244+
ModelResponse,
199245
Optional[Union[TextChunk, FunctionChunk, UsageMetadataChunk]],
200246
Optional[str],
201247
],
202248
None,
203249
None,
204250
]:
251+
"""
252+
instead of using `_model_response_to_chunk`,
253+
we use our own implementation to support the responses api.
254+
"""
205255
choices = []
206256
model_response = None
207257

@@ -214,21 +264,21 @@ def stream_event_to_chunk(
214264
stream=True, choices=choices, model=model, id=str(uuid.uuid4())
215265
)
216266
elif isinstance(event, ResponseCompletedEvent):
217-
pass
218267
response = event.response
219268
model_response = self.transform_response(response, stream=True)
220-
model_response = fix_response(model_response)
269+
model_response = fix_model_response(model_response)
221270
else:
222271
# Ignore other event types like ResponseOutputItemAddedEvent, etc.
223272
pass
224273

225274
if model_response:
226-
yield from _model_response_to_chunk(model_response)
275+
for chunk, finish_reason in _model_response_to_chunk(model_response):
276+
yield model_response, chunk, finish_reason
227277

228278

229-
def fix_response(model_response: ModelResponse) -> ModelResponse:
279+
def fix_model_response(model_response: ModelResponse) -> ModelResponse:
230280
"""
231-
Fix the response to ensure some fields that cannot be transferred through direct conversion.
281+
fix: tool_call has no attribute `index` in `_model_response_to_chunk`
232282
"""
233283
for i, choice in enumerate(model_response.choices):
234284
if choice.message.tool_calls:

0 commit comments

Comments
 (0)