|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -from typing import Optional |
| 15 | +from functools import wraps |
16 | 16 |
|
17 | 17 | import click |
18 | 18 |
|
19 | | -from veadk.memory.long_term_memory import LongTermMemory |
20 | | -from veadk.memory.short_term_memory import ShortTermMemory |
| 19 | +from veadk.utils.logger import get_logger |
21 | 20 |
|
22 | | - |
23 | | -def _get_stm_from_module(module) -> ShortTermMemory: |
24 | | - return module.agent_run_config.short_term_memory |
25 | | - |
26 | | - |
27 | | -def _get_stm_from_env() -> ShortTermMemory: |
28 | | - import os |
29 | | - |
30 | | - from veadk.utils.logger import get_logger |
31 | | - |
32 | | - logger = get_logger(__name__) |
33 | | - |
34 | | - short_term_memory_backend = os.getenv("SHORT_TERM_MEMORY_BACKEND") |
35 | | - if not short_term_memory_backend: # prevent None or empty string |
36 | | - short_term_memory_backend = "local" |
37 | | - logger.info(f"Short term memory: backend={short_term_memory_backend}") |
38 | | - |
39 | | - return ShortTermMemory(backend=short_term_memory_backend) # type: ignore |
40 | | - |
41 | | - |
42 | | -def _get_ltm_from_module(module) -> LongTermMemory | None: |
43 | | - agent = module.agent_run_config.agent |
44 | | - |
45 | | - if not hasattr(agent, "long_term_memory"): |
46 | | - return None |
47 | | - else: |
48 | | - return agent.long_term_memory |
49 | | - |
50 | | - |
51 | | -def _get_ltm_from_env() -> LongTermMemory | None: |
52 | | - import os |
53 | | - |
54 | | - from veadk.utils.logger import get_logger |
55 | | - |
56 | | - logger = get_logger(__name__) |
57 | | - |
58 | | - long_term_memory_backend = os.getenv("LONG_TERM_MEMORY_BACKEND") |
59 | | - app_name = os.getenv("VEADK_WEB_APP_NAME", "") |
60 | | - user_id = os.getenv("VEADK_WEB_USER_ID", "") |
61 | | - |
62 | | - if long_term_memory_backend: |
63 | | - logger.info(f"Long term memory: backend={long_term_memory_backend}") |
64 | | - return LongTermMemory( |
65 | | - backend=long_term_memory_backend, app_name=app_name, user_id=user_id |
66 | | - ) # type: ignore |
67 | | - else: |
68 | | - logger.warning("No long term memory backend settings detected.") |
69 | | - return None |
70 | | - |
71 | | - |
72 | | -def _get_memory( |
73 | | - module_path: str, |
74 | | -) -> tuple[ShortTermMemory, LongTermMemory | None]: |
75 | | - from veadk.utils.logger import get_logger |
76 | | - from veadk.utils.misc import load_module_from_file |
77 | | - |
78 | | - logger = get_logger(__name__) |
79 | | - |
80 | | - # 1. load user module |
81 | | - try: |
82 | | - module_file_path = module_path |
83 | | - module = load_module_from_file( |
84 | | - module_name="agent_and_mem", file_path=f"{module_file_path}/agent.py" |
85 | | - ) |
86 | | - except Exception as e: |
87 | | - logger.error( |
88 | | - f"Failed to get memory config from `agent.py`: {e}. Fallback to get memory from environment variables." |
89 | | - ) |
90 | | - return _get_stm_from_env(), _get_ltm_from_env() |
91 | | - |
92 | | - if not hasattr(module, "agent_run_config"): |
93 | | - logger.error( |
94 | | - "You must export `agent_run_config` as a global variable in `agent.py`. Fallback to get memory from environment variables." |
95 | | - ) |
96 | | - return _get_stm_from_env(), _get_ltm_from_env() |
97 | | - |
98 | | - # 2. try to get short term memory |
99 | | - # short term memory must exist in user code, as we use `default_factory` to init it |
100 | | - short_term_memory = _get_stm_from_module(module) |
101 | | - |
102 | | - # 3. try to get long term memory |
103 | | - long_term_memory = _get_ltm_from_module(module) |
104 | | - if not long_term_memory: |
105 | | - long_term_memory = _get_ltm_from_env() |
106 | | - |
107 | | - return short_term_memory, long_term_memory |
| 21 | +logger = get_logger(__name__) |
108 | 22 |
|
109 | 23 |
|
110 | 24 | def patch_adkwebserver_disable_openapi(): |
@@ -133,71 +47,66 @@ def wrapped_get_fast_api(self, *args, **kwargs): |
133 | 47 | google.adk.cli.adk_web_server.AdkWebServer.get_fast_api_app = wrapped_get_fast_api |
134 | 48 |
|
135 | 49 |
|
136 | | -@click.command() |
137 | | -@click.option("--host", default="127.0.0.1", help="Host to run the web server on") |
138 | | -@click.option( |
139 | | - "--app_name", default="", help="The `app_name` for initializing long term memory" |
140 | | -) |
141 | | -@click.option( |
142 | | - "--user_id", default="", help="The `user_id` for initializing long term memory" |
| 50 | +@click.command( |
| 51 | + context_settings=dict(ignore_unknown_options=True, allow_extra_args=True) |
143 | 52 | ) |
144 | | -def web(host: str, app_name: str, user_id: str) -> None: |
| 53 | +@click.pass_context |
| 54 | +def web(ctx, *args, **kwargs) -> None: |
145 | 55 | """Launch web with long term and short term memory.""" |
146 | | - import os |
147 | | - from typing import Any |
148 | | - |
149 | | - from google.adk.cli.utils.shared_value import SharedValue |
150 | | - |
151 | | - from veadk.utils.logger import get_logger |
152 | | - |
153 | | - logger = get_logger(__name__) |
154 | | - |
155 | | - def init_for_veadk( |
156 | | - self, |
157 | | - *, |
158 | | - agent_loader: Any, |
159 | | - session_service: Any, |
160 | | - memory_service: Any, |
161 | | - artifact_service: Any, |
162 | | - credential_service: Any, |
163 | | - eval_sets_manager: Any, |
164 | | - eval_set_results_manager: Any, |
165 | | - agents_dir: str, |
166 | | - extra_plugins: Optional[list[str]] = None, |
167 | | - **kwargs: Any, |
168 | | - ): |
169 | | - self.agent_loader = agent_loader |
170 | | - self.artifact_service = artifact_service |
171 | | - self.credential_service = credential_service |
172 | | - self.eval_sets_manager = eval_sets_manager |
173 | | - self.eval_set_results_manager = eval_set_results_manager |
174 | | - self.agents_dir = agents_dir |
175 | | - self.runners_to_clean = set() |
176 | | - self.current_app_name_ref = SharedValue(value="") |
177 | | - self.runner_dict = {} |
178 | | - self.extra_plugins = extra_plugins or [] |
179 | | - |
180 | | - for key, value in kwargs.items(): |
181 | | - setattr(self, key, value) |
182 | | - |
183 | | - # parse VeADK memories |
184 | | - short_term_memory, long_term_memory = _get_memory(module_path=agents_dir) |
185 | | - self.session_service = short_term_memory.session_service |
186 | | - self.memory_service = long_term_memory |
187 | | - |
188 | | - os.environ["VEADK_WEB_APP_NAME"] = app_name |
189 | | - os.environ["VEADK_WEB_USER_ID"] = user_id |
190 | | - |
191 | | - import google.adk.cli.adk_web_server |
| 56 | + from google.adk.cli import adk_web_server |
| 57 | + from google.adk.runners import Runner as ADKRunner |
| 58 | + |
| 59 | + from veadk import Agent |
| 60 | + from veadk.agents.loop_agent import LoopAgent |
| 61 | + from veadk.agents.parallel_agent import ParallelAgent |
| 62 | + from veadk.agents.sequential_agent import SequentialAgent |
| 63 | + |
| 64 | + def before_get_runner_async(func): |
| 65 | + logger.info("Hook before `get_runner_async`") |
| 66 | + |
| 67 | + @wraps(func) |
| 68 | + async def wrapper(*args, **kwargs) -> ADKRunner: |
| 69 | + self: adk_web_server.AdkWebServer = args[0] |
| 70 | + app_name: str = args[1] |
| 71 | + """Returns the cached runner for the given app.""" |
| 72 | + agent_or_app = self.agent_loader.load_agent(app_name) |
| 73 | + |
| 74 | + if isinstance(agent_or_app, (SequentialAgent, LoopAgent, ParallelAgent)): |
| 75 | + logger.warning( |
| 76 | + "Detect VeADK workflow agent, the short-term memory and long-term memory of each sub agent are useless." |
| 77 | + ) |
| 78 | + |
| 79 | + if isinstance(agent_or_app, Agent): |
| 80 | + logger.info("Detect VeADK Agent.") |
| 81 | + |
| 82 | + if agent_or_app.short_term_memory: |
| 83 | + self.session_service = ( |
| 84 | + agent_or_app.short_term_memory.session_service |
| 85 | + ) |
| 86 | + |
| 87 | + if agent_or_app.long_term_memory: |
| 88 | + self.memory_service = agent_or_app.long_term_memory |
| 89 | + logger.info( |
| 90 | + f"Long term memory backend is {self.memory_service.backend}" |
| 91 | + ) |
| 92 | + |
| 93 | + logger.info( |
| 94 | + f"Current session_service={self.session_service.__class__.__name__}, memory_service={self.memory_service.__class__.__name__}" |
| 95 | + ) |
| 96 | + |
| 97 | + runner = await func(*args, **kwargs) |
| 98 | + return runner |
| 99 | + |
| 100 | + return wrapper |
| 101 | + |
| 102 | + adk_web_server.AdkWebServer.get_runner_async = before_get_runner_async( |
| 103 | + adk_web_server.AdkWebServer.get_runner_async |
| 104 | + ) |
192 | 105 |
|
193 | | - google.adk.cli.adk_web_server.AdkWebServer.__init__ = init_for_veadk |
194 | 106 | patch_adkwebserver_disable_openapi() |
195 | 107 |
|
196 | | - import google.adk.cli.cli_tools_click as cli_tools_click |
197 | | - |
198 | | - agents_dir = os.getcwd() |
199 | | - logger.info(f"Load agents from {agents_dir}") |
| 108 | + from google.adk.cli.cli_tools_click import cli_web |
200 | 109 |
|
201 | | - cli_tools_click.cli_web.main( |
202 | | - args=[agents_dir, "--host", host, "--log_level", "ERROR"] |
203 | | - ) |
| 110 | + extra_args = ctx.args |
| 111 | + logger.debug(f"User args: {ctx.args}") |
| 112 | + cli_web.main(args=extra_args, standalone_mode=False) |
0 commit comments