Skip to content

Commit 2c61ecf

Browse files
authored
Add an optional extension module for Google ADK integration (#161)
1 parent 2261661 commit 2c61ecf

File tree

6 files changed

+2447
-47
lines changed

6 files changed

+2447
-47
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ lint = ["mypy>=1.11.2", "pyright>=1.1.390", "ruff>=0.6.9"]
2626
harness = ["testcontainers", "hypercorn", "httpx"]
2727
serde = ["dacite", "pydantic", "msgspec"]
2828
client = ["httpx[http2]"]
29+
adk = ["google-adk>=1.20.0"]
2930

3031
[build-system]
3132
requires = ["maturin>=1.6,<2.0"]

python/restate/ext/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#
2+
# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH
3+
#
4+
# This file is part of the Restate SDK for Python,
5+
# which is released under the MIT license.
6+
#
7+
# You can find a copy of the license in file LICENSE in the root
8+
# directory of this repository or package, or at
9+
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
10+
#

python/restate/ext/adk/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#
2+
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
3+
#
4+
# This file is part of the Restate SDK for Python,
5+
# which is released under the MIT license.
6+
#
7+
# You can find a copy of the license in file LICENSE in the root
8+
# directory of this repository or package, or at
9+
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
10+
#
11+
12+
from .session import RestateSessionService
13+
from .plugin import RestatePlugin
14+
15+
__all__ = [
16+
"RestateSessionService",
17+
"RestatePlugin",
18+
]

python/restate/ext/adk/plugin.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
#
2+
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
3+
#
4+
# This file is part of the Restate SDK for Python,
5+
# which is released under the MIT license.
6+
#
7+
# You can find a copy of the license in file LICENSE in the root
8+
# directory of this repository or package, or at
9+
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
10+
#
11+
"""
12+
ADK plugin implementation for restate.
13+
"""
14+
15+
import asyncio
16+
import restate
17+
18+
from datetime import timedelta
19+
from typing import Optional, Any, cast
20+
21+
from google.genai import types
22+
23+
from google.adk.agents import BaseAgent, LlmAgent
24+
from google.adk.agents.callback_context import CallbackContext
25+
from google.adk.plugins import BasePlugin
26+
from google.adk.tools.base_tool import BaseTool
27+
from google.adk.tools.tool_context import ToolContext
28+
from google.adk.models.llm_request import LlmRequest
29+
from google.adk.models.llm_response import LlmResponse
30+
from google.adk.models import LLMRegistry
31+
from google.adk.models.base_llm import BaseLlm
32+
from google.adk.flows.llm_flows.functions import generate_client_function_call_id
33+
34+
35+
from restate.extensions import current_context
36+
37+
from .session import flush_session_state
38+
39+
40+
class RestatePlugin(BasePlugin):
41+
"""A plugin to integrate Restate with the ADK framework."""
42+
43+
_models: dict[str, BaseLlm]
44+
_locks: dict[str, asyncio.Lock]
45+
46+
def __init__(self, *, max_model_call_retries: int = 10):
47+
super().__init__(name="restate_plugin")
48+
self._models = {}
49+
self._locks = {}
50+
self._max_model_call_retries = max_model_call_retries
51+
52+
async def before_agent_callback(
53+
self, *, agent: BaseAgent, callback_context: CallbackContext
54+
) -> Optional[types.Content]:
55+
if not isinstance(agent, LlmAgent):
56+
raise restate.TerminalError("RestatePlugin only supports LlmAgent agents.")
57+
ctx = current_context() # Ensure we have a Restate context
58+
if ctx is None:
59+
raise restate.TerminalError(
60+
"""No Restate context found for RestatePlugin.
61+
Ensure that the agent is invoked within a restate handler and,
62+
using a ```with restate_overrides(ctx):``` block. around your agent use."""
63+
)
64+
model = agent.model if isinstance(agent.model, BaseLlm) else LLMRegistry.new_llm(agent.model)
65+
self._models[callback_context.invocation_id] = model
66+
self._locks[callback_context.invocation_id] = asyncio.Lock()
67+
68+
id = callback_context.invocation_id
69+
event = ctx.request().attempt_finished_event
70+
71+
async def release_task():
72+
"""make sure to release resources when the agent finishes"""
73+
try:
74+
await event.wait()
75+
finally:
76+
self._models.pop(id, None)
77+
self._locks.pop(id, None)
78+
79+
_ = asyncio.create_task(release_task())
80+
return None
81+
82+
async def after_agent_callback(
83+
self, *, agent: BaseAgent, callback_context: CallbackContext
84+
) -> Optional[types.Content]:
85+
self._models.pop(callback_context.invocation_id, None)
86+
self._locks.pop(callback_context.invocation_id, None)
87+
88+
ctx = cast(restate.ObjectContext, current_context())
89+
await flush_session_state(ctx, callback_context.session)
90+
91+
return None
92+
93+
async def before_model_callback(
94+
self, *, callback_context: CallbackContext, llm_request: LlmRequest
95+
) -> Optional[LlmResponse]:
96+
model = self._models[callback_context.invocation_id]
97+
ctx = current_context()
98+
if ctx is None:
99+
raise RuntimeError(
100+
"No Restate context found, the restate plugin must be used from within a restate handler."
101+
)
102+
response = await _generate_content_async(ctx, self._max_model_call_retries, model, llm_request)
103+
return response
104+
105+
async def before_tool_callback(
106+
self,
107+
*,
108+
tool: BaseTool,
109+
tool_args: dict[str, Any],
110+
tool_context: ToolContext,
111+
) -> Optional[dict]:
112+
tool_context.session.state["restate_context"] = current_context()
113+
lock = self._locks[tool_context.invocation_id]
114+
await lock.acquire()
115+
# TODO: if we want we can also automatically wrap tools with ctx.run_typed here
116+
return None
117+
118+
async def after_tool_callback(
119+
self,
120+
*,
121+
tool: BaseTool,
122+
tool_args: dict[str, Any],
123+
tool_context: ToolContext,
124+
result: dict,
125+
) -> Optional[dict]:
126+
lock = self._locks[tool_context.invocation_id]
127+
lock.release()
128+
tool_context.session.state.pop("restate_context", None)
129+
return None
130+
131+
async def on_tool_error_callback(
132+
self,
133+
*,
134+
tool: BaseTool,
135+
tool_args: dict[str, Any],
136+
tool_context: ToolContext,
137+
error: Exception,
138+
) -> Optional[dict]:
139+
lock = self._locks[tool_context.invocation_id]
140+
lock.release()
141+
tool_context.session.state.pop("restate_context", None)
142+
return None
143+
144+
async def close(self):
145+
self._models.clear()
146+
self._locks.clear()
147+
148+
149+
def _generate_client_function_call_id(s: LlmResponse) -> None:
150+
"""Generate client function call IDs for function calls in the LlmResponse.
151+
It is important for the function call IDs to be stable across retries, as they
152+
are used to correlate function call results with their invocations.
153+
"""
154+
if s.content and s.content.parts:
155+
for part in s.content.parts:
156+
if part.function_call:
157+
if not part.function_call.id:
158+
id = generate_client_function_call_id()
159+
part.function_call.id = id
160+
161+
162+
async def _generate_content_async(
163+
ctx: restate.Context, max_attempts: int, model: BaseLlm, llm_request: LlmRequest
164+
) -> LlmResponse:
165+
"""Generate content using Restate's context."""
166+
167+
async def call_llm() -> LlmResponse:
168+
a_gen = model.generate_content_async(llm_request, stream=False)
169+
try:
170+
result = await anext(a_gen)
171+
_generate_client_function_call_id(result)
172+
return result
173+
finally:
174+
await a_gen.aclose()
175+
176+
return await ctx.run_typed(
177+
"call LLM",
178+
call_llm,
179+
restate.RunOptions(max_attempts=max_attempts, initial_retry_interval=timedelta(seconds=1)),
180+
)

python/restate/ext/adk/session.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#
2+
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
3+
#
4+
# This file is part of the Restate SDK for Python,
5+
# which is released under the MIT license.
6+
#
7+
# You can find a copy of the license in file LICENSE in the root
8+
# directory of this repository or package, or at
9+
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
10+
#
11+
"""
12+
ADK session service implementation using Restate Virtual Objects as the backing store.
13+
"""
14+
15+
import restate
16+
17+
from typing import Optional, Any, cast
18+
19+
from google.adk.sessions import Session
20+
from google.adk.events.event import Event
21+
from google.adk.sessions.base_session_service import (
22+
BaseSessionService,
23+
ListSessionsResponse,
24+
GetSessionConfig,
25+
)
26+
27+
from restate.extensions import current_context
28+
29+
30+
class RestateSessionService(BaseSessionService):
31+
def ctx(self) -> restate.ObjectContext:
32+
return cast(restate.ObjectContext, current_context())
33+
34+
async def create_session(
35+
self,
36+
*,
37+
app_name: str,
38+
user_id: str,
39+
state: Optional[dict[str, Any]] = None,
40+
session_id: Optional[str] = None,
41+
) -> Session:
42+
if session_id is None:
43+
session_id = str(self.ctx().uuid())
44+
45+
session = await self.ctx().get(f"session_store::{session_id}", type_hint=Session) or Session(
46+
app_name=app_name,
47+
user_id=user_id,
48+
id=session_id,
49+
state=state or {},
50+
)
51+
self.ctx().set(f"session_store::{session_id}", session)
52+
return session
53+
54+
async def get_session(
55+
self,
56+
*,
57+
app_name: str,
58+
user_id: str,
59+
session_id: str,
60+
config: Optional[GetSessionConfig] = None,
61+
) -> Optional[Session]:
62+
# TODO : Handle config options
63+
return await self.ctx().get(f"session_store::{session_id}", type_hint=Session) or Session(
64+
app_name=app_name,
65+
user_id=user_id,
66+
id=session_id,
67+
)
68+
69+
async def list_sessions(self, *, app_name: str, user_id: Optional[str] = None) -> ListSessionsResponse:
70+
state_keys = await self.ctx().state_keys()
71+
sessions = []
72+
for key in state_keys:
73+
if key.startswith("session_store::"):
74+
session = await self.ctx().get(key, type_hint=Session)
75+
if session is not None:
76+
sessions.append(session)
77+
return ListSessionsResponse(sessions=sessions)
78+
79+
async def delete_session(self, *, app_name: str, user_id: str, session_id: str) -> None:
80+
self.ctx().clear(f"session_store::{session_id}")
81+
82+
async def append_event(self, session: Session, event: Event) -> Event:
83+
"""Appends an event to a session object."""
84+
if event.partial:
85+
return event
86+
# For now, we also store temp state
87+
event = self._trim_temp_delta_state(event)
88+
self._update_session_state(session, event)
89+
session.events.append(event)
90+
return event
91+
92+
93+
async def flush_session_state(ctx: restate.ObjectContext, session: Session):
94+
session_to_store = session.model_copy()
95+
# Remove restate-specific context that got added by the plugin before storing
96+
session_to_store.state.pop("restate_context", None)
97+
deterministic_session = await ctx.run_typed(
98+
"store session", lambda: session_to_store, restate.RunOptions(type_hint=Session)
99+
)
100+
ctx.set(f"session_store::{session.id}", deterministic_session)

0 commit comments

Comments
 (0)