Skip to content

Commit ac1c60d

Browse files
authored
feat(beta): support responses api (#297)
* feat: ark litellm client * feat: use the patch method to pass it * chore: organize the code * feat: openai sdk for responses api * feat: use callback instead patch * feat: format patches * feat: system-prompt * fix: async openai * fix: acompletion to aresponse * fix: stream type done but response.id * fix: send response.id usage matadata * fix: add license header * fix: finish reason and chunk send * fix: without instruction bug * fix: update google-adk version >=1.18 * fix: back * fix: version and cache metadata * fix: multi-agent and multi llm_response scenario * fix: transport response_id by session state * fix: multi-agent bug * fix: remove before_model_callback * fix: clarify the transmission of response_id * feat: enable persistent short-term memory to pass `response-id` * feat: add package * fix: check litellm version * chore: litellm version for pyproject
1 parent 368cdb3 commit ac1c60d

File tree

6 files changed

+762
-9
lines changed

6 files changed

+762
-9
lines changed

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ dependencies = [
1616
"a2a-sdk==0.3.7", # For Google Agent2Agent protocol
1717
"deprecated==1.2.18",
1818
"google-adk==1.19.0", # For basic agent architecture
19-
"litellm==1.74.3", # For model inference
19+
"litellm>=1.74.3", # For model inference
2020
"loguru==0.7.3", # For better logging
2121
"opentelemetry-exporter-otlp==1.37.0",
2222
"opentelemetry-instrumentation-logging>=0.56b0",
@@ -73,6 +73,9 @@ dev = [
7373
"pytest-xdist>=3.8.0",
7474
]
7575

76+
responses = [
77+
"litellm>=1.79.3"
78+
]
7679

7780
[dependency-groups]
7881
dev = [

veadk/agent.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
import os
18-
from typing import Optional, Union
18+
from typing import Optional, Union, AsyncGenerator
1919

2020
# If user didn't set LITELLM_LOCAL_MODEL_COST_MAP, set it to True
2121
# to enable local model cost map.
@@ -24,10 +24,12 @@
2424
if not os.getenv("LITELLM_LOCAL_MODEL_COST_MAP"):
2525
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
2626

27-
from google.adk.agents import LlmAgent, RunConfig
27+
from google.adk.agents import LlmAgent, RunConfig, InvocationContext
2828
from google.adk.agents.base_agent import BaseAgent
29+
from google.adk.agents.context_cache_config import ContextCacheConfig
2930
from google.adk.agents.llm_agent import InstructionProvider, ToolUnion
3031
from google.adk.agents.run_config import StreamingMode
32+
from google.adk.events import Event, EventActions
3133
from google.adk.models.lite_llm import LiteLlm
3234
from google.adk.runners import Runner
3335
from google.genai import types
@@ -52,6 +54,7 @@
5254
from veadk.tracing.base_tracer import BaseTracer
5355
from veadk.utils.logger import get_logger
5456
from veadk.utils.patches import patch_asyncio, patch_tracer
57+
from veadk.utils.misc import check_litellm_version
5558
from veadk.version import VERSION
5659

5760
patch_tracer()
@@ -109,6 +112,10 @@ class Agent(LlmAgent):
109112

110113
tracers: list[BaseTracer] = []
111114

115+
enable_responses: bool = False
116+
117+
context_cache_config: Optional[ContextCacheConfig] = None
118+
112119
run_processor: Optional[BaseRunProcessor] = Field(default=None, exclude=True)
113120
"""Optional run processor for intercepting and processing agent execution flows.
114121
@@ -157,12 +164,31 @@ def model_post_init(self, __context: Any) -> None:
157164
logger.info(f"Model extra config: {self.model_extra_config}")
158165

159166
if not self.model:
160-
self.model = LiteLlm(
161-
model=f"{self.model_provider}/{self.model_name}",
162-
api_key=self.model_api_key,
163-
api_base=self.model_api_base,
164-
**self.model_extra_config,
165-
)
167+
if self.enable_responses:
168+
min_version = "1.79.3"
169+
check_litellm_version(min_version)
170+
171+
from veadk.models.ark_llm import ArkLlm
172+
173+
self.model = ArkLlm(
174+
model=f"{self.model_provider}/{self.model_name}",
175+
api_key=self.model_api_key,
176+
api_base=self.model_api_base,
177+
**self.model_extra_config,
178+
)
179+
if not self.context_cache_config:
180+
self.context_cache_config = ContextCacheConfig(
181+
cache_intervals=100, # maximum number
182+
ttl_seconds=315360000,
183+
min_tokens=0,
184+
)
185+
else:
186+
self.model = LiteLlm(
187+
model=f"{self.model_provider}/{self.model_name}",
188+
api_key=self.model_api_key,
189+
api_base=self.model_api_base,
190+
**self.model_extra_config,
191+
)
166192
logger.debug(
167193
f"LiteLLM client created with config: {self.model_extra_config}"
168194
)
@@ -218,6 +244,28 @@ def model_post_init(self, __context: Any) -> None:
218244
f"Agent: {self.model_dump(include={'name', 'model_name', 'model_api_base', 'tools'})}"
219245
)
220246

247+
async def _run_async_impl(
248+
self, ctx: InvocationContext
249+
) -> AsyncGenerator[Event, None]:
250+
if self.enable_responses:
251+
if not ctx.context_cache_config:
252+
ctx.context_cache_config = self.context_cache_config
253+
254+
async for event in super()._run_async_impl(ctx):
255+
yield event
256+
if self.enable_responses and event.cache_metadata:
257+
# for persistent short-term memory with response api
258+
session_state_event = Event(
259+
invocation_id=event.invocation_id,
260+
author=event.author,
261+
actions=EventActions(
262+
state_delta={
263+
"response_id": event.cache_metadata.cache_name,
264+
}
265+
),
266+
)
267+
yield session_state_event
268+
221269
async def _run(
222270
self,
223271
runner,

veadk/memory/short_term_memory.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from veadk.memory.short_term_memory_backends.sqlite_backend import (
3333
SQLiteSTMBackend,
3434
)
35+
from veadk.models.ark_transform import build_cache_metadata
3536
from veadk.utils.logger import get_logger
3637

3738
logger = get_logger(__name__)
@@ -49,6 +50,21 @@ async def wrapper(*args, **kwargs):
4950
setattr(obj, "get_session", wrapper)
5051

5152

53+
def enable_responses_api_for_session_service(result, *args, **kwargs):
54+
if result and isinstance(result, Session):
55+
if result.events:
56+
for event in result.events:
57+
if (
58+
event.actions
59+
and event.actions.state_delta
60+
and not event.cache_metadata
61+
and "response_id" in event.actions.state_delta
62+
):
63+
event.cache_metadata = build_cache_metadata(
64+
response_id=event.actions.state_delta.get("response_id"),
65+
)
66+
67+
5268
class ShortTermMemory(BaseModel):
5369
"""Short term memory for agent execution.
5470
@@ -170,6 +186,11 @@ def model_post_init(self, __context: Any) -> None:
170186
db_kwargs=self.db_kwargs, **self.backend_configs
171187
).session_service
172188

189+
if self.backend != "local":
190+
wrap_get_session_with_callbacks(
191+
self._session_service, enable_responses_api_for_session_service
192+
)
193+
173194
if self.after_load_memory_callback:
174195
wrap_get_session_with_callbacks(
175196
self._session_service, self.after_load_memory_callback

0 commit comments

Comments
 (0)