Skip to content

Commit ecc27d7

Browse files
feat(builder): support building agent from yaml file (#146)
* feat(builder): support building agent from yaml file
1 parent 31a9f0f commit ecc27d7

File tree

5 files changed

+168
-43
lines changed

5 files changed

+168
-43
lines changed

docs/docs/agent.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,41 @@ planner_agent = Agent(
115115
runner = Runner(agent=planner_agent, short_term_memory=ShortTermMemory())
116116
response = await runner.run(messages=prompt, session_id=session_id)
117117
```
118+
119+
## 从 Agent 配置文件构建
120+
121+
你可以通过一个 Agent 配置文件来构建 Agent 运行时实例,例如:
122+
123+
```yaml
124+
root_agent:
125+
type: Agent # Agent | SequencialAgent | LoopAgent | ParallelAgent
126+
name: test
127+
description: A test agent
128+
instruction: A test instruction
129+
long_term_memory:
130+
backend: local
131+
knowledgebase:
132+
backend: opensearch
133+
sub_agents:
134+
- ${sub_agent_1}
135+
136+
sub_agent_1:
137+
type: Agent
138+
name: agent1
139+
```
140+
141+
其中,每个`agent`的`type`负责指定 Agent 的类名。
142+
143+
可以通过如下代码来实例化这个 Agent:
144+
145+
```python
146+
from veadk.agent_builder import AgentBuilder
147+
148+
agent = AgentBuilder().build(path="./agent.yaml")
149+
```
150+
151+
函数`build`接收3个参数:
152+
153+
- `path`:配置文件路径
154+
- `root_agent_identifier`:配置文件中主 Agent 的名称,默认为`root_agent`
155+
- `tools`:主 agent 挂载的工具列表(子 Agent 工具列表暂未推出)

pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ dependencies = [
2323
"wrapt>=1.17.2", # For patching built-in functions
2424
"openai<1.100", # For fix https://github.com/BerriAI/litellm/issues/13710
2525
"volcengine-python-sdk==4.0.3", # For Volcengine API
26-
"agent-pilot-sdk>=0.0.9", # Prompt optimization by Volcengine AgentPilot/PromptPilot toolkits
27-
"fastmcp>=2.11.3", # For running MCP
28-
"cookiecutter>=2.6.0", # For cloud deploy
29-
"opensearch-py==2.8.0" # For OpenSearch database
26+
"agent-pilot-sdk>=0.0.9", # Prompt optimization by Volcengine AgentPilot/PromptPilot toolkits
27+
"fastmcp>=2.11.3", # For running MCP
28+
"cookiecutter>=2.6.0", # For cloud deploy # For OpenSearch database
29+
"opensearch-py==2.8.0",
30+
"omegaconf>=2.3.0", # For agent builder
3031
]
3132

3233
[project.scripts]

veadk/agent_builder.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.adk.agents import BaseAgent
16+
from google.adk.agents.llm_agent import ToolUnion
17+
from omegaconf import OmegaConf
18+
19+
from veadk.a2a.remote_ve_agent import RemoteVeAgent
20+
from veadk.agent import Agent
21+
from veadk.agents.loop_agent import LoopAgent
22+
from veadk.agents.parallel_agent import ParallelAgent
23+
from veadk.agents.sequential_agent import SequentialAgent
24+
from veadk.utils.logger import get_logger
25+
26+
logger = get_logger(__name__)
27+
28+
AGENT_TYPES = {
29+
"Agent": Agent,
30+
"SequentialAgent": SequentialAgent,
31+
"ParallelAgent": ParallelAgent,
32+
"LoopAgent": LoopAgent,
33+
"RemoteVeAgent": RemoteVeAgent,
34+
}
35+
36+
37+
class AgentBuilder:
38+
def __init__(self) -> None:
39+
pass
40+
41+
def _build(self, agent_config: dict) -> BaseAgent:
42+
logger.info(f"Building agent with config: {agent_config}")
43+
44+
sub_agents = []
45+
if agent_config.get("sub_agents", None):
46+
for sub_agent_config in agent_config["sub_agents"]:
47+
agent = self._build(sub_agent_config)
48+
sub_agents.append(agent)
49+
agent_config.pop("sub_agents")
50+
51+
agent_cls = AGENT_TYPES[agent_config["type"]]
52+
agent = agent_cls(**agent_config, sub_agents=sub_agents)
53+
54+
logger.debug("Build agent done.")
55+
56+
return agent
57+
58+
def _read_config(self, path: str) -> dict:
59+
"""Read config file (from `path`) to a in-memory dict."""
60+
assert path.endswith(".yaml"), "Agent config file must be a `.yaml` file."
61+
62+
config = OmegaConf.load(path)
63+
config_dict = OmegaConf.to_container(config, resolve=True)
64+
65+
assert isinstance(config_dict, dict), (
66+
"Parsed config must in `dict` format. Pls check your building file format."
67+
)
68+
69+
return config_dict
70+
71+
def build(
72+
self,
73+
path: str,
74+
root_agent_identifier: str = "root_agent",
75+
tools: list[ToolUnion] | None = None,
76+
) -> BaseAgent:
77+
config = self._read_config(path)
78+
79+
agent_config = config[root_agent_identifier]
80+
agent = self._build(agent_config)
81+
82+
if tools and isinstance(agent, Agent):
83+
agent.tools = tools
84+
85+
return agent

veadk/knowledgebase/knowledgebase.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import BinaryIO, Literal, TextIO
15+
from typing import Any, BinaryIO, Literal, TextIO
16+
17+
from pydantic import BaseModel
1618

1719
from veadk.database.database_adapter import get_knowledgebase_database_adapter
1820
from veadk.database.database_factory import DatabaseFactory
@@ -25,23 +27,23 @@ def build_knowledgebase_index(app_name: str):
2527
return f"veadk_kb_{app_name}"
2628

2729

28-
class KnowledgeBase:
29-
def __init__(
30-
self,
31-
backend: Literal["local", "opensearch", "viking", "redis", "mysql"] = "local",
32-
top_k: int = 10,
33-
db_config=None,
34-
):
35-
logger.info(f"Initializing knowledgebase: backend={backend} top_k={top_k}")
30+
class KnowledgeBase(BaseModel):
31+
backend: Literal["local", "opensearch", "viking", "redis", "mysql"] = "local"
32+
top_k: int = 10
33+
db_config: Any | None = None
3634

37-
self.backend = backend
38-
self.top_k = top_k
35+
def model_post_init(self, __context: Any) -> None:
36+
logger.info(
37+
f"Initializing knowledgebase: backend={self.backend} top_k={self.top_k}"
38+
)
3939

40-
self.db_client = DatabaseFactory.create(backend=backend, config=db_config)
41-
self.adapter = get_knowledgebase_database_adapter(self.db_client)
40+
self._db_client = DatabaseFactory.create(
41+
backend=self.backend, config=self.db_config
42+
)
43+
self._adapter = get_knowledgebase_database_adapter(self._db_client)
4244

4345
logger.info(
44-
f"Initialized knowledgebase: db_client={self.db_client.__class__.__name__} adapter={self.adapter}"
46+
f"Initialized knowledgebase: db_client={self._db_client.__class__.__name__} adapter={self._adapter}"
4547
)
4648

4749
def add(
@@ -67,7 +69,7 @@ def add(
6769

6870
logger.info(f"Adding documents to knowledgebase: index={index}")
6971

70-
self.adapter.add(data=data, index=index)
72+
self._adapter.add(data=data, index=index)
7173

7274
def search(self, query: str, app_name: str, top_k: int | None = None) -> list[str]:
7375
top_k = self.top_k if top_k is None else top_k
@@ -76,7 +78,7 @@ def search(self, query: str, app_name: str, top_k: int | None = None) -> list[st
7678
f"Searching knowledgebase: app_name={app_name} query={query} top_k={top_k}"
7779
)
7880
index = build_knowledgebase_index(app_name)
79-
result = self.adapter.query(query=query, index=index, top_k=top_k)
81+
result = self._adapter.query(query=query, index=index, top_k=top_k)
8082
if len(result) == 0:
8183
logger.warning(f"No documents found in knowledgebase. Query: {query}")
8284
return result
@@ -87,8 +89,8 @@ def delete(self, app_name: str) -> bool:
8789

8890
def delete_doc(self, app_name: str, id: str) -> bool:
8991
index = build_knowledgebase_index(app_name)
90-
return self.adapter.delete_doc(index=index, id=id)
92+
return self._adapter.delete_doc(index=index, id=id)
9193

9294
def list_docs(self, app_name: str, offset: int = 0, limit: int = 100) -> list[dict]:
9395
index = build_knowledgebase_index(app_name)
94-
return self.adapter.list_docs(index=index, offset=offset, limit=limit)
96+
return self._adapter.list_docs(index=index, offset=offset, limit=limit)

veadk/memory/long_term_memory.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# limitations under the License.
1414

1515
# adapted from Google ADK memory service adk-python/src/google/adk/memory/vertex_ai_memory_bank_service.py at 0a9e67dbca67789247e882d16b139dbdc76a329a · google/adk-python
16+
1617
import json
17-
from typing import Literal
18+
from typing import Any, Literal
1819

1920
from google.adk.events.event import Event
2021
from google.adk.memory.base_memory_service import (
@@ -24,6 +25,7 @@
2425
from google.adk.memory.memory_entry import MemoryEntry
2526
from google.adk.sessions import Session
2627
from google.genai import types
28+
from pydantic import BaseModel
2729
from typing_extensions import override
2830

2931
from veadk.database import DatabaseFactory
@@ -37,33 +39,30 @@ def build_long_term_memory_index(app_name: str, user_id: str):
3739
return f"{app_name}_{user_id}"
3840

3941

40-
class LongTermMemory(BaseMemoryService):
41-
def __init__(
42-
self,
43-
backend: Literal[
44-
"local", "opensearch", "redis", "mysql", "viking", "viking_mem"
45-
] = "opensearch",
46-
top_k: int = 5,
47-
):
48-
if backend == "viking":
42+
class LongTermMemory(BaseMemoryService, BaseModel):
43+
backend: Literal[
44+
"local", "opensearch", "redis", "mysql", "viking", "viking_mem"
45+
] = "opensearch"
46+
top_k: int = 5
47+
48+
def model_post_init(self, __context: Any) -> None:
49+
if self.backend == "viking":
4950
logger.warning(
5051
"`viking` backend is deprecated, switching to `viking_mem` backend."
5152
)
52-
backend = "viking_mem"
53-
self.top_k = top_k
54-
self.backend = backend
53+
self.backend = "viking_mem"
5554

5655
logger.info(
5756
f"Initializing long term memory: backend={self.backend} top_k={self.top_k}"
5857
)
5958

60-
self.db_client = DatabaseFactory.create(
61-
backend=backend,
59+
self._db_client = DatabaseFactory.create(
60+
backend=self.backend,
6261
)
63-
self.adapter = get_long_term_memory_database_adapter(self.db_client)
62+
self._adapter = get_long_term_memory_database_adapter(self._db_client)
6463

6564
logger.info(
66-
f"Initialized long term memory: db_client={self.db_client.__class__.__name__} adapter={self.adapter}"
65+
f"Initialized long term memory: db_client={self._db_client.__class__.__name__} adapter={self._adapter}"
6766
)
6867

6968
def _filter_and_convert_events(self, events: list[Event]) -> list[str]:
@@ -101,9 +100,9 @@ async def add_session_to_memory(
101100

102101
# check if viking memory database, should give a user id: if/else
103102
if self.backend == "viking_mem":
104-
self.adapter.add(data=event_strings, index=index, user_id=session.user_id)
103+
self._adapter.add(data=event_strings, index=index, user_id=session.user_id)
105104
else:
106-
self.adapter.add(data=event_strings, index=index)
105+
self._adapter.add(data=event_strings, index=index)
107106

108107
logger.info(
109108
f"Added {len(event_strings)} events to long term memory: index={index}"
@@ -119,11 +118,11 @@ async def search_memory(self, *, app_name: str, user_id: str, query: str):
119118

120119
# user id if viking memory db
121120
if self.backend == "viking_mem":
122-
memory_chunks = self.adapter.query(
121+
memory_chunks = self._adapter.query(
123122
query=query, index=index, top_k=self.top_k, user_id=user_id
124123
)
125124
else:
126-
memory_chunks = self.adapter.query(
125+
memory_chunks = self._adapter.query(
127126
query=query, index=index, top_k=self.top_k
128127
)
129128

0 commit comments

Comments
 (0)