Skip to content

Commit 4118998

Browse files
fix(web): fix short-term memory and long-term memory support in veadk web (#259)
1 parent 8fc7b9a commit 4118998

File tree

2 files changed

+63
-155
lines changed

2 files changed

+63
-155
lines changed

docs/content/90.cli/2.commands.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,14 @@ veadk deploy
6363

6464
## 本地调试
6565

66-
可以通过`adk web``veadk studio`来启动Web页面,运行智能体:
66+
可以通过`adk web``veadk web`来启动Web页面,运行智能体:
6767

6868
```bash
6969
# basic usage:
7070
adk web
7171

72-
# if you need to use long-term memory, you should use `veadk web`.
73-
# if the `session_service_uri` is not set, it will use `opensearch` as your long-term memory backend
74-
veadk web --session_service_uri="mysql+pymysql://{user}:{password}@{host}/{database}"
72+
# or
73+
veadk web
7574
```
7675

7776
它们能够自动读取执行命令目录中的`agent.py`文件,并加载`root_agent`全局变量。

veadk/cli/cli_web.py

Lines changed: 60 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -12,99 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional
15+
from functools import wraps
1616

1717
import click
1818

19-
from veadk.memory.long_term_memory import LongTermMemory
20-
from veadk.memory.short_term_memory import ShortTermMemory
19+
from veadk.utils.logger import get_logger
2120

22-
23-
def _get_stm_from_module(module) -> ShortTermMemory:
24-
return module.agent_run_config.short_term_memory
25-
26-
27-
def _get_stm_from_env() -> ShortTermMemory:
28-
import os
29-
30-
from veadk.utils.logger import get_logger
31-
32-
logger = get_logger(__name__)
33-
34-
short_term_memory_backend = os.getenv("SHORT_TERM_MEMORY_BACKEND")
35-
if not short_term_memory_backend: # prevent None or empty string
36-
short_term_memory_backend = "local"
37-
logger.info(f"Short term memory: backend={short_term_memory_backend}")
38-
39-
return ShortTermMemory(backend=short_term_memory_backend) # type: ignore
40-
41-
42-
def _get_ltm_from_module(module) -> LongTermMemory | None:
43-
agent = module.agent_run_config.agent
44-
45-
if not hasattr(agent, "long_term_memory"):
46-
return None
47-
else:
48-
return agent.long_term_memory
49-
50-
51-
def _get_ltm_from_env() -> LongTermMemory | None:
52-
import os
53-
54-
from veadk.utils.logger import get_logger
55-
56-
logger = get_logger(__name__)
57-
58-
long_term_memory_backend = os.getenv("LONG_TERM_MEMORY_BACKEND")
59-
app_name = os.getenv("VEADK_WEB_APP_NAME", "")
60-
user_id = os.getenv("VEADK_WEB_USER_ID", "")
61-
62-
if long_term_memory_backend:
63-
logger.info(f"Long term memory: backend={long_term_memory_backend}")
64-
return LongTermMemory(
65-
backend=long_term_memory_backend, app_name=app_name, user_id=user_id
66-
) # type: ignore
67-
else:
68-
logger.warning("No long term memory backend settings detected.")
69-
return None
70-
71-
72-
def _get_memory(
73-
module_path: str,
74-
) -> tuple[ShortTermMemory, LongTermMemory | None]:
75-
from veadk.utils.logger import get_logger
76-
from veadk.utils.misc import load_module_from_file
77-
78-
logger = get_logger(__name__)
79-
80-
# 1. load user module
81-
try:
82-
module_file_path = module_path
83-
module = load_module_from_file(
84-
module_name="agent_and_mem", file_path=f"{module_file_path}/agent.py"
85-
)
86-
except Exception as e:
87-
logger.error(
88-
f"Failed to get memory config from `agent.py`: {e}. Fallback to get memory from environment variables."
89-
)
90-
return _get_stm_from_env(), _get_ltm_from_env()
91-
92-
if not hasattr(module, "agent_run_config"):
93-
logger.error(
94-
"You must export `agent_run_config` as a global variable in `agent.py`. Fallback to get memory from environment variables."
95-
)
96-
return _get_stm_from_env(), _get_ltm_from_env()
97-
98-
# 2. try to get short term memory
99-
# short term memory must exist in user code, as we use `default_factory` to init it
100-
short_term_memory = _get_stm_from_module(module)
101-
102-
# 3. try to get long term memory
103-
long_term_memory = _get_ltm_from_module(module)
104-
if not long_term_memory:
105-
long_term_memory = _get_ltm_from_env()
106-
107-
return short_term_memory, long_term_memory
21+
logger = get_logger(__name__)
10822

10923

11024
def patch_adkwebserver_disable_openapi():
@@ -133,71 +47,66 @@ def wrapped_get_fast_api(self, *args, **kwargs):
13347
google.adk.cli.adk_web_server.AdkWebServer.get_fast_api_app = wrapped_get_fast_api
13448

13549

136-
@click.command()
137-
@click.option("--host", default="127.0.0.1", help="Host to run the web server on")
138-
@click.option(
139-
"--app_name", default="", help="The `app_name` for initializing long term memory"
140-
)
141-
@click.option(
142-
"--user_id", default="", help="The `user_id` for initializing long term memory"
50+
@click.command(
51+
context_settings=dict(ignore_unknown_options=True, allow_extra_args=True)
14352
)
144-
def web(host: str, app_name: str, user_id: str) -> None:
53+
@click.pass_context
54+
def web(ctx, *args, **kwargs) -> None:
14555
"""Launch web with long term and short term memory."""
146-
import os
147-
from typing import Any
148-
149-
from google.adk.cli.utils.shared_value import SharedValue
150-
151-
from veadk.utils.logger import get_logger
152-
153-
logger = get_logger(__name__)
154-
155-
def init_for_veadk(
156-
self,
157-
*,
158-
agent_loader: Any,
159-
session_service: Any,
160-
memory_service: Any,
161-
artifact_service: Any,
162-
credential_service: Any,
163-
eval_sets_manager: Any,
164-
eval_set_results_manager: Any,
165-
agents_dir: str,
166-
extra_plugins: Optional[list[str]] = None,
167-
**kwargs: Any,
168-
):
169-
self.agent_loader = agent_loader
170-
self.artifact_service = artifact_service
171-
self.credential_service = credential_service
172-
self.eval_sets_manager = eval_sets_manager
173-
self.eval_set_results_manager = eval_set_results_manager
174-
self.agents_dir = agents_dir
175-
self.runners_to_clean = set()
176-
self.current_app_name_ref = SharedValue(value="")
177-
self.runner_dict = {}
178-
self.extra_plugins = extra_plugins or []
179-
180-
for key, value in kwargs.items():
181-
setattr(self, key, value)
182-
183-
# parse VeADK memories
184-
short_term_memory, long_term_memory = _get_memory(module_path=agents_dir)
185-
self.session_service = short_term_memory.session_service
186-
self.memory_service = long_term_memory
187-
188-
os.environ["VEADK_WEB_APP_NAME"] = app_name
189-
os.environ["VEADK_WEB_USER_ID"] = user_id
190-
191-
import google.adk.cli.adk_web_server
56+
from google.adk.cli import adk_web_server
57+
from google.adk.runners import Runner as ADKRunner
58+
59+
from veadk import Agent
60+
from veadk.agents.loop_agent import LoopAgent
61+
from veadk.agents.parallel_agent import ParallelAgent
62+
from veadk.agents.sequential_agent import SequentialAgent
63+
64+
def before_get_runner_async(func):
65+
logger.info("Hook before `get_runner_async`")
66+
67+
@wraps(func)
68+
async def wrapper(*args, **kwargs) -> ADKRunner:
69+
self: adk_web_server.AdkWebServer = args[0]
70+
app_name: str = args[1]
71+
"""Returns the cached runner for the given app."""
72+
agent_or_app = self.agent_loader.load_agent(app_name)
73+
74+
if isinstance(agent_or_app, (SequentialAgent, LoopAgent, ParallelAgent)):
75+
logger.warning(
76+
"Detect VeADK workflow agent, the short-term memory and long-term memory of each sub agent are useless."
77+
)
78+
79+
if isinstance(agent_or_app, Agent):
80+
logger.info("Detect VeADK Agent.")
81+
82+
if agent_or_app.short_term_memory:
83+
self.session_service = (
84+
agent_or_app.short_term_memory.session_service
85+
)
86+
87+
if agent_or_app.long_term_memory:
88+
self.memory_service = agent_or_app.long_term_memory
89+
logger.info(
90+
f"Long term memory backend is {self.memory_service.backend}"
91+
)
92+
93+
logger.info(
94+
f"Current session_service={self.session_service.__class__.__name__}, memory_service={self.memory_service.__class__.__name__}"
95+
)
96+
97+
runner = await func(*args, **kwargs)
98+
return runner
99+
100+
return wrapper
101+
102+
adk_web_server.AdkWebServer.get_runner_async = before_get_runner_async(
103+
adk_web_server.AdkWebServer.get_runner_async
104+
)
192105

193-
google.adk.cli.adk_web_server.AdkWebServer.__init__ = init_for_veadk
194106
patch_adkwebserver_disable_openapi()
195107

196-
import google.adk.cli.cli_tools_click as cli_tools_click
197-
198-
agents_dir = os.getcwd()
199-
logger.info(f"Load agents from {agents_dir}")
108+
from google.adk.cli.cli_tools_click import cli_web
200109

201-
cli_tools_click.cli_web.main(
202-
args=[agents_dir, "--host", host, "--log_level", "ERROR"]
203-
)
110+
extra_args = ctx.args
111+
logger.debug(f"User args: {ctx.args}")
112+
cli_web.main(args=extra_args, standalone_mode=False)

0 commit comments

Comments
 (0)