1414from typing import Union
1515
1616from google .adk .agents import RunConfig
17+ from google .adk .agents .invocation_context import LlmCallsLimitExceededError
1718from google .adk .agents .run_config import StreamingMode
1819from google .adk .plugins .base_plugin import BasePlugin
1920from google .adk .runners import Runner as ADKRunner
@@ -49,20 +50,25 @@ class Runner:
4950 def __init__ (
5051 self ,
5152 agent : VeAgent ,
52- short_term_memory : ShortTermMemory ,
53+ short_term_memory : ShortTermMemory | None = None ,
5354 plugins : list [BasePlugin ] | None = None ,
5455 app_name : str = "veadk_default_app" ,
5556 user_id : str = "veadk_default_user" ,
5657 ):
57- # basic settings
5858 self .app_name = app_name
5959 self .user_id = user_id
6060
61- # agent settings
6261 self .agent = agent
6362
64- self .short_term_memory = short_term_memory
65- self .session_service = short_term_memory .session_service
63+ if not short_term_memory :
64+ logger .info (
65+ "No short term memory provided, using a in-memory memory by default."
66+ )
67+ self .short_term_memory = ShortTermMemory ()
68+ else :
69+ self .short_term_memory = short_term_memory
70+
71+ self .session_service = self .short_term_memory .session_service
6672
6773 # prevent VeRemoteAgent has no long-term memory attr
6874 if isinstance (self .agent , Agent ):
@@ -114,35 +120,44 @@ async def _run(
114120 self ,
115121 session_id : str ,
116122 message : types .Content ,
123+ run_config : RunConfig | None = None ,
117124 stream : bool = False ,
118125 ):
119126 stream_mode = StreamingMode .SSE if stream else StreamingMode .NONE
120127
121- async def event_generator ():
122- async for event in self .runner .run_async (
123- user_id = self .user_id ,
124- session_id = session_id ,
125- new_message = message ,
126- run_config = RunConfig (streaming_mode = stream_mode ),
127- ):
128- if event .get_function_calls ():
129- for function_call in event .get_function_calls ():
130- logger .debug (f"Function call: { function_call } " )
131- elif (
132- event .content is not None
133- and event .content .parts
134- and event .content .parts [0 ].text is not None
135- and len (event .content .parts [0 ].text .strip ()) > 0
136- ):
137- yield event .content .parts [0 ].text
128+ if run_config is not None :
129+ stream_mode = run_config .streaming_mode
130+ else :
131+ run_config = RunConfig (streaming_mode = stream_mode )
132+ try :
138133
139- final_output = ""
140- async for chunk in event_generator ():
134+ async def event_generator ():
135+ async for event in self .runner .run_async (
136+ user_id = self .user_id ,
137+ session_id = session_id ,
138+ new_message = message ,
139+ run_config = run_config ,
140+ ):
141+ if event .get_function_calls ():
142+ for function_call in event .get_function_calls ():
143+ logger .debug (f"Function call: { function_call } " )
144+ elif (
145+ event .content is not None
146+ and event .content .parts
147+ and event .content .parts [0 ].text is not None
148+ and len (event .content .parts [0 ].text .strip ()) > 0
149+ ):
150+ yield event .content .parts [0 ].text
151+
152+ final_output = ""
153+ async for chunk in event_generator ():
154+ if stream :
155+ print (chunk , end = "" , flush = True )
156+ final_output += chunk
141157 if stream :
142- print (chunk , end = "" , flush = True )
143- final_output += chunk
144- if stream :
145- print () # end with a new line
158+ print () # end with a new line
159+ except LlmCallsLimitExceededError as e :
160+ logger .warning (f"Max number of llm calls limit exceeded: { e } " )
146161
147162 return final_output
148163
@@ -151,6 +166,7 @@ async def run(
151166 messages : RunnerMessage ,
152167 session_id : str ,
153168 stream : bool = False ,
169+ run_config : RunConfig | None = None ,
154170 save_tracing_data : bool = False ,
155171 ):
156172 converted_messages : list = self ._convert_messages (messages )
@@ -163,7 +179,9 @@ async def run(
163179
164180 final_output = ""
165181 for converted_message in converted_messages :
166- final_output = await self ._run (session_id , converted_message , stream )
182+ final_output = await self ._run (
183+ session_id , converted_message , run_config , stream
184+ )
167185
168186 # try to save tracing file
169187 if save_tracing_data :
@@ -193,6 +211,47 @@ def get_trace_id(self) -> str:
193211 logger .warning (f"Get tracer id failed as { e } " )
194212 return "<unknown_trace_id>"
195213
214+ async def run_with_raw_message (
215+ self ,
216+ message : types .Content ,
217+ session_id : str ,
218+ run_config : RunConfig | None = None ,
219+ ):
220+ run_config = RunConfig () if not run_config else run_config
221+
222+ await self .short_term_memory .create_session (
223+ app_name = self .app_name , user_id = self .user_id , session_id = session_id
224+ )
225+
226+ try :
227+
228+ async def event_generator ():
229+ async for event in self .runner .run_async (
230+ user_id = self .user_id ,
231+ session_id = session_id ,
232+ new_message = message ,
233+ run_config = run_config ,
234+ ):
235+ if event .get_function_calls ():
236+ for function_call in event .get_function_calls ():
237+ logger .debug (f"Function call: { function_call } " )
238+ elif (
239+ event .content is not None
240+ and event .content .parts
241+ and event .content .parts [0 ].text is not None
242+ and len (event .content .parts [0 ].text .strip ()) > 0
243+ ):
244+ yield event .content .parts [0 ].text
245+
246+ final_output = ""
247+
248+ async for chunk in event_generator ():
249+ final_output += chunk
250+ except LlmCallsLimitExceededError as e :
251+ logger .warning (f"Max number of llm calls limit exceeded: { e } " )
252+
253+ return final_output
254+
196255 def _print_trace_id (self ) -> None :
197256 if not isinstance (self .agent , Agent ):
198257 logger .warning (
0 commit comments