Skip to content

Commit bce93b9

Browse files
committed
chore(runner): support run config
1 parent 5401f93 commit bce93b9

File tree

1 file changed

+88
-29
lines changed

1 file changed

+88
-29
lines changed

veadk/runner.py

Lines changed: 88 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import Union
1515

1616
from google.adk.agents import RunConfig
17+
from google.adk.agents.invocation_context import LlmCallsLimitExceededError
1718
from google.adk.agents.run_config import StreamingMode
1819
from google.adk.plugins.base_plugin import BasePlugin
1920
from google.adk.runners import Runner as ADKRunner
@@ -49,20 +50,25 @@ class Runner:
4950
def __init__(
5051
self,
5152
agent: VeAgent,
52-
short_term_memory: ShortTermMemory,
53+
short_term_memory: ShortTermMemory | None = None,
5354
plugins: list[BasePlugin] | None = None,
5455
app_name: str = "veadk_default_app",
5556
user_id: str = "veadk_default_user",
5657
):
57-
# basic settings
5858
self.app_name = app_name
5959
self.user_id = user_id
6060

61-
# agent settings
6261
self.agent = agent
6362

64-
self.short_term_memory = short_term_memory
65-
self.session_service = short_term_memory.session_service
63+
if not short_term_memory:
64+
logger.info(
65+
"No short term memory provided, using a in-memory memory by default."
66+
)
67+
self.short_term_memory = ShortTermMemory()
68+
else:
69+
self.short_term_memory = short_term_memory
70+
71+
self.session_service = self.short_term_memory.session_service
6672

6773
# prevent VeRemoteAgent has no long-term memory attr
6874
if isinstance(self.agent, Agent):
@@ -114,35 +120,44 @@ async def _run(
114120
self,
115121
session_id: str,
116122
message: types.Content,
123+
run_config: RunConfig | None = None,
117124
stream: bool = False,
118125
):
119126
stream_mode = StreamingMode.SSE if stream else StreamingMode.NONE
120127

121-
async def event_generator():
122-
async for event in self.runner.run_async(
123-
user_id=self.user_id,
124-
session_id=session_id,
125-
new_message=message,
126-
run_config=RunConfig(streaming_mode=stream_mode),
127-
):
128-
if event.get_function_calls():
129-
for function_call in event.get_function_calls():
130-
logger.debug(f"Function call: {function_call}")
131-
elif (
132-
event.content is not None
133-
and event.content.parts
134-
and event.content.parts[0].text is not None
135-
and len(event.content.parts[0].text.strip()) > 0
136-
):
137-
yield event.content.parts[0].text
128+
if run_config is not None:
129+
stream_mode = run_config.streaming_mode
130+
else:
131+
run_config = RunConfig(streaming_mode=stream_mode)
132+
try:
138133

139-
final_output = ""
140-
async for chunk in event_generator():
134+
async def event_generator():
135+
async for event in self.runner.run_async(
136+
user_id=self.user_id,
137+
session_id=session_id,
138+
new_message=message,
139+
run_config=run_config,
140+
):
141+
if event.get_function_calls():
142+
for function_call in event.get_function_calls():
143+
logger.debug(f"Function call: {function_call}")
144+
elif (
145+
event.content is not None
146+
and event.content.parts
147+
and event.content.parts[0].text is not None
148+
and len(event.content.parts[0].text.strip()) > 0
149+
):
150+
yield event.content.parts[0].text
151+
152+
final_output = ""
153+
async for chunk in event_generator():
154+
if stream:
155+
print(chunk, end="", flush=True)
156+
final_output += chunk
141157
if stream:
142-
print(chunk, end="", flush=True)
143-
final_output += chunk
144-
if stream:
145-
print() # end with a new line
158+
print() # end with a new line
159+
except LlmCallsLimitExceededError as e:
160+
logger.warning(f"Max number of llm calls limit exceeded: {e}")
146161

147162
return final_output
148163

@@ -151,6 +166,7 @@ async def run(
151166
messages: RunnerMessage,
152167
session_id: str,
153168
stream: bool = False,
169+
run_config: RunConfig | None = None,
154170
save_tracing_data: bool = False,
155171
):
156172
converted_messages: list = self._convert_messages(messages)
@@ -163,7 +179,9 @@ async def run(
163179

164180
final_output = ""
165181
for converted_message in converted_messages:
166-
final_output = await self._run(session_id, converted_message, stream)
182+
final_output = await self._run(
183+
session_id, converted_message, run_config, stream
184+
)
167185

168186
# try to save tracing file
169187
if save_tracing_data:
@@ -193,6 +211,47 @@ def get_trace_id(self) -> str:
193211
logger.warning(f"Get tracer id failed as {e}")
194212
return "<unknown_trace_id>"
195213

214+
async def run_with_raw_message(
215+
self,
216+
message: types.Content,
217+
session_id: str,
218+
run_config: RunConfig | None = None,
219+
):
220+
run_config = RunConfig() if not run_config else run_config
221+
222+
await self.short_term_memory.create_session(
223+
app_name=self.app_name, user_id=self.user_id, session_id=session_id
224+
)
225+
226+
try:
227+
228+
async def event_generator():
229+
async for event in self.runner.run_async(
230+
user_id=self.user_id,
231+
session_id=session_id,
232+
new_message=message,
233+
run_config=run_config,
234+
):
235+
if event.get_function_calls():
236+
for function_call in event.get_function_calls():
237+
logger.debug(f"Function call: {function_call}")
238+
elif (
239+
event.content is not None
240+
and event.content.parts
241+
and event.content.parts[0].text is not None
242+
and len(event.content.parts[0].text.strip()) > 0
243+
):
244+
yield event.content.parts[0].text
245+
246+
final_output = ""
247+
248+
async for chunk in event_generator():
249+
final_output += chunk
250+
except LlmCallsLimitExceededError as e:
251+
logger.warning(f"Max number of llm calls limit exceeded: {e}")
252+
253+
return final_output
254+
196255
def _print_trace_id(self) -> None:
197256
if not isinstance(self.agent, Agent):
198257
logger.warning(

0 commit comments

Comments
 (0)