Skip to content

Commit 0b7bd05

Browse files
authored
Update Google ADK examples to use Restate's ADK extensions (#32)
* use ADK extensions from Restate SDK * fix session ids and user ids * Make examples more concise * Update ADK examples * Move runner out of the handler * Cleanup * Cleanup * Fix broken link * Fix broken link
1 parent 06696ce commit 0b7bd05

File tree

18 files changed

+308
-1672
lines changed

18 files changed

+308
-1672
lines changed

a2a/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
**This integration is work-in-progress.**
44

5-
These examples use [Restate](https://ai.restate.dev/) to implement the [Agent2Agent (A2A) protocol](https://github.com/google/A2A).
5+
These examples use [Restate](https://ai.restate.dev/) to implement the [Agent2Agent (A2A) protocol](https://github.com/a2aproject/A2A).
66

77
Restate acts as a scalable, resilient task orchestrator that speaks the A2A protocol and gives you:
88
- 🔁 **Automatic retries** - Handles LLM API downtime, timeouts, and infrastructure failures

google-adk/example/app/chat.py

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,38 @@
11
import restate
2+
23
from google.adk import Runner
34
from google.adk.apps import App
4-
from google.adk.sessions import Session
5-
from google.genai import types as genai_types
6-
7-
from app.utils.models import ChatMessage
5+
from google.genai.types import Content, Part
86
from google.adk.agents.llm_agent import Agent
7+
from restate.ext.adk import RestateSessionService, RestatePlugin
98

10-
from middleware.restate_plugin import RestatePlugin
11-
from middleware.restate_session_service import RestateSessionService
12-
from middleware.restate_utils import restate_overrides
9+
from app.utils.models import ChatMessage
1310

1411
APP_NAME = "agents"
1512

16-
# AGENT
1713
agent = Agent(
1814
model="gemini-2.5-flash",
1915
name="assistant",
20-
description="A helpful assistant that can answer questions.",
2116
instruction="You are a helpful assistant. Be concise and helpful.",
2217
)
2318

2419
# Enables retries and recovery for model calls and tool executions
2520
app = App(name=APP_NAME, root_agent=agent, plugins=[RestatePlugin()])
26-
session_service = RestateSessionService()
21+
runner = Runner(app=app, session_service=RestateSessionService())
2722

2823
chat = restate.VirtualObject("Chat")
2924

3025

31-
# HANDLER
3226
@chat.handler()
33-
async def message(ctx: restate.ObjectContext, req: ChatMessage) -> str:
34-
session_id = ctx.key()
35-
with restate_overrides(ctx):
36-
await session_service.create_session(
37-
app_name=APP_NAME, user_id=req.user_id, session_id=session_id
38-
)
39-
40-
runner = Runner(app=app, session_service=session_service)
41-
events = runner.run_async(
42-
user_id=req.user_id,
43-
session_id=session_id,
44-
new_message=genai_types.Content(
45-
role="user", parts=[genai_types.Part.from_text(text=req.message)]
46-
),
47-
)
48-
final_response = ""
49-
async for event in events:
50-
if event.is_final_response() and event.content and event.content.parts:
51-
if event.content.parts[0].text:
52-
final_response = event.content.parts[0].text
53-
return final_response
27+
async def message(ctx: restate.ObjectContext, req: ChatMessage) -> str | None:
28+
events = runner.run_async(
29+
user_id=ctx.key(),
30+
session_id=req.session_id,
31+
new_message=Content(role="user", parts=[Part.from_text(text=req.message)]),
32+
)
33+
final_response = None
34+
async for event in events:
35+
if event.is_final_response() and event.content and event.content.parts:
36+
if event.content.parts[0].text:
37+
final_response = event.content.parts[0].text
38+
return final_response

google-adk/example/app/durable_agent.py

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,19 @@
44
from google.adk.apps import App
55
from google.adk.sessions import InMemorySessionService
66
from google.genai.types import Content, Part
7-
from app.utils.models import WeatherResponse, WeatherPrompt
8-
from app.utils.utils import call_weather_api
9-
from google.adk.tools.tool_context import ToolContext
107
from google.adk.agents.llm_agent import Agent
8+
from restate.ext.adk import RestatePlugin, restate_object_context
119

12-
from middleware.restate_plugin import RestatePlugin
13-
from middleware.restate_utils import restate_overrides
10+
from app.utils.models import WeatherResponse, WeatherPrompt
11+
from app.utils.utils import call_weather_api
1412

1513
APP_NAME = "agents"
1614

1715

18-
# TOOLS
19-
async def get_weather(tool_context: ToolContext, city: str) -> WeatherResponse:
16+
async def get_weather(city: str) -> WeatherResponse:
2017
"""Get the current weather for a given city."""
21-
restate_context = tool_context.session.state["restate_context"]
2218
# Do one or more durable steps using the Restate context
23-
return await restate_context.run_typed(
19+
return await restate_object_context().run_typed(
2420
f"Get weather {city}", call_weather_api, city=city
2521
)
2622

@@ -29,9 +25,7 @@ async def get_weather(tool_context: ToolContext, city: str) -> WeatherResponse:
2925
agent = Agent(
3026
model="gemini-2.5-flash",
3127
name="weather_agent",
32-
description="Agent that provides weather updates for cities.",
33-
instruction="""You are a helpful agent that provides weather updates.
34-
Use the get_weather tool to fetch current weather information.""",
28+
instruction="You are a helpful agent that provides weather updates.",
3529
tools=[get_weather],
3630
)
3731

@@ -41,23 +35,22 @@ async def get_weather(tool_context: ToolContext, city: str) -> WeatherResponse:
4135

4236

4337
@agent_service.handler()
44-
async def run(ctx: restate.Context, req: WeatherPrompt) -> str:
45-
session_id = str(ctx.uuid())
46-
with restate_overrides(ctx):
47-
session_service = InMemorySessionService()
48-
await session_service.create_session(
49-
app_name=APP_NAME, user_id=req.user_id, session_id=session_id
50-
)
51-
runner = Runner(app=app, session_service=session_service)
52-
53-
events = runner.run_async(
54-
user_id=req.user_id,
55-
session_id=session_id,
56-
new_message=Content(role="user", parts=[Part.from_text(text=req.message)]),
57-
)
58-
final_response = ""
59-
async for event in events:
60-
if event.is_final_response() and event.content and event.content.parts:
61-
if event.content.parts[0].text:
62-
final_response = event.content.parts[0].text
63-
return final_response
38+
async def run(_ctx: restate.Context, req: WeatherPrompt) -> str | None:
39+
session_service = InMemorySessionService()
40+
await session_service.create_session(
41+
app_name=APP_NAME, user_id="user-123", session_id=req.session_id
42+
)
43+
44+
runner = Runner(app=app, session_service=session_service)
45+
events = runner.run_async(
46+
user_id="user-123",
47+
session_id=req.session_id,
48+
new_message=Content(role="user", parts=[Part.from_text(text=req.message)]),
49+
)
50+
51+
final_response = None
52+
async for event in events:
53+
if event.is_final_response() and event.content and event.content.parts:
54+
if event.content.parts[0].text:
55+
final_response = event.content.parts[0].text
56+
return final_response
Lines changed: 23 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,49 @@
1-
from datetime import timedelta
2-
31
import restate
2+
43
from google.adk import Runner
54
from google.adk.apps import App
65
from google.genai.types import Content, Part
7-
from app.utils.models import WeatherResponse, WeatherPrompt
8-
from app.utils.utils import call_weather_api
9-
from google.adk.tools.tool_context import ToolContext
106
from google.adk.agents.llm_agent import Agent
7+
from restate.ext.adk import RestateSessionService, RestatePlugin, restate_object_context
118

12-
from middleware.restate_plugin import RestatePlugin
13-
from middleware.restate_session_service import RestateSessionService
14-
from middleware.restate_utils import restate_overrides
9+
from app.utils.models import WeatherResponse, WeatherPrompt
10+
from app.utils.utils import call_weather_api
1511

1612
APP_NAME = "agents"
1713

1814

19-
# TOOLS
20-
async def get_weather(tool_context: ToolContext, city: str) -> WeatherResponse:
15+
async def get_weather(city: str) -> WeatherResponse:
2116
"""Get the current weather for a given city."""
22-
restate_context = tool_context.session.state["restate_context"]
23-
2417
# call tool wrapped as Restate durable step
25-
return await restate_context.run_typed("Get weather", call_weather_api, city=city)
18+
return await restate_object_context().run_typed(
19+
f"Get weather {city}", call_weather_api, city=city
20+
)
2621

2722

2823
agent = Agent(
2924
model="gemini-2.5-flash",
3025
name="weather_agent",
31-
description="Agent that provides weather updates for cities.",
32-
instruction="""You are a helpful agent that provides weather updates.
33-
Use the get_weather tool to fetch current weather information.""",
26+
instruction="You are a helpful agent that provides weather updates.",
3427
tools=[get_weather],
3528
)
3629

3730
app = App(name=APP_NAME, root_agent=agent, plugins=[RestatePlugin()])
38-
session_service = RestateSessionService()
39-
31+
runner = Runner(app=app, session_service=RestateSessionService())
4032

4133
agent_service = restate.VirtualObject("StatefulWeatherAgent")
4234

4335

44-
# HANDLER
4536
@agent_service.handler()
46-
async def run(ctx: restate.ObjectContext, req: WeatherPrompt) -> str:
47-
session_id = ctx.key()
48-
with restate_overrides(ctx):
49-
# Use Restate session service to persist session state in Restate
50-
await session_service.create_session(
51-
app_name=APP_NAME, user_id=req.user_id, session_id=session_id
52-
)
53-
runner = Runner(app=app, session_service=session_service)
54-
55-
events = runner.run_async(
56-
user_id=req.user_id,
57-
session_id=session_id,
58-
new_message=Content(role="user", parts=[Part.from_text(text=req.message)]),
59-
)
60-
final_response = ""
61-
async for event in events:
62-
if event.is_final_response() and event.content and event.content.parts:
63-
if event.content.parts[0].text:
64-
final_response = event.content.parts[0].text
65-
return final_response
37+
async def run(ctx: restate.ObjectContext, req: WeatherPrompt) -> str | None:
38+
events = runner.run_async(
39+
user_id=ctx.key(),
40+
session_id=req.session_id,
41+
new_message=Content(role="user", parts=[Part.from_text(text=req.message)]),
42+
)
43+
44+
final_response = None
45+
async for event in events:
46+
if event.is_final_response() and event.content and event.content.parts:
47+
if event.content.parts[0].text:
48+
final_response = event.content.parts[0].text
49+
return final_response

google-adk/example/app/human_approval_agent.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,25 @@
22
from google.adk import Runner
33
from google.adk.agents.llm_agent import Agent
44
from google.adk.apps import App
5-
from google.adk.tools.tool_context import ToolContext
65
from google.genai.types import Content, Part
6+
from restate.ext.adk import RestatePlugin, RestateSessionService, restate_object_context
77

88
from app.utils.models import ClaimPrompt, InsuranceClaim
99
from app.utils.utils import request_human_review
1010

11-
from middleware.restate_plugin import RestatePlugin
12-
from middleware.restate_session_service import RestateSessionService
13-
from middleware.restate_utils import restate_overrides
14-
1511
APP_NAME = "agents"
1612

1713

1814
# TOOLS
19-
async def human_approval(tool_context: ToolContext, claim: InsuranceClaim) -> str:
15+
async def human_approval(claim: InsuranceClaim) -> str:
2016
"""Ask for human approval for high-value claims."""
21-
restate_context = tool_context.session.state["restate_context"]
17+
ctx = restate_object_context()
2218

2319
# Create an awakeable for human approval
24-
approval_id, approval_promise = restate_context.awakeable(type_hint=str)
20+
approval_id, approval_promise = ctx.awakeable(type_hint=str)
2521

2622
# Request human review
27-
await restate_context.run_typed(
23+
await ctx.run_typed(
2824
"Request review",
2925
request_human_review,
3026
claim=claim,
@@ -39,7 +35,6 @@ async def human_approval(tool_context: ToolContext, claim: InsuranceClaim) -> st
3935
agent = Agent(
4036
model="gemini-2.5-flash",
4137
name="claim_approval_agent",
42-
description="Insurance claim evaluation agent that handles human approval workflows.",
4338
instruction="""You are an insurance claim evaluation agent. Use these rules:
4439
- if the amount is more than 1000, ask for human approval using tools;
4540
- if the amount is less than 1000, decide by yourself.""",
@@ -48,29 +43,23 @@ async def human_approval(tool_context: ToolContext, claim: InsuranceClaim) -> st
4843

4944

5045
app = App(name=APP_NAME, root_agent=agent, plugins=[RestatePlugin()])
51-
session_service = RestateSessionService()
46+
runner = Runner(app=app, session_service=RestateSessionService())
5247

5348
agent_service = restate.VirtualObject("HumanClaimApprovalAgent")
5449

5550

5651
# HANDLER
5752
@agent_service.handler()
58-
async def run(ctx: restate.ObjectContext, req: ClaimPrompt) -> str:
59-
session_id = ctx.key()
60-
with restate_overrides(ctx):
61-
await session_service.create_session(
62-
app_name=APP_NAME, user_id=req.user_id, session_id=session_id
63-
)
64-
runner = Runner(app=app, session_service=session_service)
53+
async def run(ctx: restate.ObjectContext, req: ClaimPrompt) -> str | None:
54+
events = runner.run_async(
55+
user_id=ctx.key(),
56+
session_id=req.session_id,
57+
new_message=Content(role="user", parts=[Part.from_text(text=req.message)]),
58+
)
6559

66-
events = runner.run_async(
67-
user_id=req.user_id,
68-
session_id=session_id,
69-
new_message=Content(role="user", parts=[Part.from_text(text=req.message)]),
70-
)
71-
final_response = ""
72-
async for event in events:
73-
if event.is_final_response() and event.content and event.content.parts:
74-
if event.content.parts[0].text:
75-
final_response = event.content.parts[0].text
76-
return final_response
60+
final_response = None
61+
async for event in events:
62+
if event.is_final_response() and event.content and event.content.parts:
63+
if event.content.parts[0].text:
64+
final_response = event.content.parts[0].text
65+
return final_response

0 commit comments

Comments
 (0)