Skip to content

Commit 3a14c6b

Browse files
committed
refine(tests): refine tests for agent
1 parent 59f6da7 commit 3a14c6b

File tree

3 files changed

+187
-8
lines changed

3 files changed

+187
-8
lines changed

tests/test_agent.py

Lines changed: 170 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from unittest.mock import Mock, patch
16+
17+
from google.adk.agents.llm_agent import LlmAgent
18+
from google.adk.models.lite_llm import LiteLlm
1519
from google.adk.tools import load_memory
1620

1721
from veadk import Agent
18-
from veadk.consts import DEFAULT_MODEL_EXTRA_CONFIG
22+
from veadk.consts import (
23+
DEFAULT_AGENT_NAME,
24+
DEFAULT_MODEL_AGENT_API_BASE,
25+
DEFAULT_MODEL_AGENT_NAME,
26+
DEFAULT_MODEL_AGENT_PROVIDER,
27+
DEFAULT_MODEL_EXTRA_CONFIG,
28+
)
1929
from veadk.knowledgebase import KnowledgeBase
2030
from veadk.memory.long_term_memory import LongTermMemory
2131
from veadk.tools import load_knowledgebase_tool
@@ -46,7 +56,7 @@ def test_agent():
4656
serve_url="",
4757
)
4858

49-
assert agent.model.model == f"{agent.model_provider}/{agent.model_name}"
59+
assert agent.model.model == f"{agent.model_provider}/{agent.model_name}" # type: ignore
5060

5161
expected_config = DEFAULT_MODEL_EXTRA_CONFIG.copy()
5262
expected_config["extra_headers"] |= extra_config["extra_headers"]
@@ -55,9 +65,165 @@ def test_agent():
5565
assert agent.model_extra_config == expected_config
5666

5767
assert agent.knowledgebase == knowledgebase
58-
assert agent.knowledgebase.backend == "local"
68+
assert agent.knowledgebase.backend == "local" # type: ignore
5969
assert load_knowledgebase_tool.knowledgebase == agent.knowledgebase
6070
assert load_knowledgebase_tool.load_knowledgebase_tool in agent.tools
6171

62-
assert agent.long_term_memory.backend == "local"
72+
assert agent.long_term_memory.backend == "local" # type: ignore
6373
assert load_memory in agent.tools
74+
75+
76+
@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
77+
def test_agent_default_values():
78+
agent = Agent()
79+
80+
assert agent.name == DEFAULT_AGENT_NAME
81+
82+
assert agent.model_name == DEFAULT_MODEL_AGENT_NAME
83+
assert agent.model_provider == DEFAULT_MODEL_AGENT_PROVIDER
84+
assert agent.model_api_base == DEFAULT_MODEL_AGENT_API_BASE
85+
86+
assert agent.tools == []
87+
assert agent.sub_agents == []
88+
assert agent.knowledgebase is None
89+
assert agent.long_term_memory is None
90+
assert agent.tracers == []
91+
92+
assert agent.serve_url == ""
93+
94+
95+
@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
96+
def test_agent_without_knowledgebase():
97+
agent = Agent()
98+
99+
assert agent.knowledgebase is None
100+
assert load_knowledgebase_tool.load_knowledgebase_tool not in agent.tools
101+
102+
103+
@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
104+
def test_agent_without_long_term_memory():
105+
agent = Agent()
106+
107+
assert agent.long_term_memory is None
108+
assert load_memory not in agent.tools
109+
110+
111+
@patch("veadk.agent.LiteLlm")
112+
def test_agent_model_creation(mock_lite_llm):
113+
mock_model = Mock()
114+
mock_lite_llm.return_value = mock_model
115+
116+
agent = Agent(
117+
model_name="test_model",
118+
model_provider="test_provider",
119+
model_api_key="test_key",
120+
model_api_base="test_base",
121+
)
122+
123+
mock_lite_llm.assert_called_once()
124+
assert agent.model == mock_model
125+
126+
127+
@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
128+
def test_agent_with_existing_model():
129+
existing_model = LiteLlm(model="test_model")
130+
agent = Agent(model=existing_model)
131+
132+
assert agent.model == existing_model
133+
134+
135+
@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
136+
def test_agent_model_extra_config_merge():
137+
user_config = {
138+
"extra_headers": {"custom": "header"},
139+
"extra_body": {"custom": "body"},
140+
"other_param": "value",
141+
}
142+
143+
agent = Agent(model_extra_config=user_config)
144+
145+
expected_headers = DEFAULT_MODEL_EXTRA_CONFIG["extra_headers"].copy()
146+
expected_headers["custom"] = "header"
147+
148+
expected_body = DEFAULT_MODEL_EXTRA_CONFIG["extra_body"].copy()
149+
expected_body["custom"] = "body"
150+
151+
assert agent.model_extra_config["extra_headers"] == expected_headers
152+
assert agent.model_extra_config["extra_body"] == expected_body
153+
assert agent.model_extra_config["other_param"] == "value"
154+
155+
156+
@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
157+
def test_agent_empty_model_extra_config():
158+
agent = Agent(model_extra_config={})
159+
160+
assert (
161+
agent.model_extra_config["extra_headers"]
162+
== DEFAULT_MODEL_EXTRA_CONFIG["extra_headers"]
163+
)
164+
assert (
165+
agent.model_extra_config["extra_body"]
166+
== DEFAULT_MODEL_EXTRA_CONFIG["extra_body"]
167+
)
168+
169+
170+
@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
171+
def test_agent_with_tools():
172+
mock_tool = Mock()
173+
agent = Agent(tools=[mock_tool])
174+
175+
assert mock_tool in agent.tools
176+
177+
178+
@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
179+
def test_agent_with_sub_agents():
180+
adk_agent = LlmAgent(name="agent")
181+
veadk_agent = Agent(name="agent")
182+
agent = Agent(sub_agents=[adk_agent, veadk_agent])
183+
184+
assert adk_agent in agent.sub_agents
185+
assert veadk_agent in agent.sub_agents
186+
assert adk_agent.parent_agent == agent
187+
assert veadk_agent.parent_agent == agent
188+
189+
190+
@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
191+
def test_agent_with_tracers():
192+
tracer1 = OpentelemetryTracer()
193+
tracer2 = OpentelemetryTracer()
194+
195+
agent = Agent(tracers=[tracer1, tracer2])
196+
197+
assert len(agent.tracers) == 2
198+
assert tracer1 in agent.tracers
199+
assert tracer2 in agent.tracers
200+
201+
202+
@patch.dict(
203+
"os.environ",
204+
{"MODEL_AGENT_NAME": "env_model_name", "MODEL_AGENT_API_KEY": "mock_api_key"},
205+
clear=True,
206+
)
207+
def test_agent_environment_variables():
208+
agent = Agent()
209+
print(agent)
210+
assert agent.model_name == "env_model_name"
211+
212+
213+
@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
214+
def test_agent_custom_name_and_description():
215+
custom_name = "CustomAgent"
216+
custom_description = "A custom agent for testing"
217+
218+
agent = Agent(name=custom_name, description=custom_description)
219+
220+
assert agent.name == custom_name
221+
assert agent.description == custom_description
222+
223+
224+
@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
225+
def test_agent_serve_url():
226+
serve_url = "http://localhost:8080"
227+
agent = Agent(serve_url=serve_url)
228+
229+
assert agent.serve_url == serve_url

veadk/agent.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from veadk.config import getenv
3030
from veadk.consts import (
31+
DEFAULT_AGENT_NAME,
3132
DEFAULT_MODEL_AGENT_API_BASE,
3233
DEFAULT_MODEL_AGENT_NAME,
3334
DEFAULT_MODEL_AGENT_PROVIDER,
@@ -53,7 +54,7 @@ class Agent(LlmAgent):
5354
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
5455
"""The model config"""
5556

56-
name: str = "veAgent"
57+
name: str = DEFAULT_AGENT_NAME
5758
"""The name of the agent."""
5859

5960
description: str = DEFAULT_DESCRIPTION
@@ -62,13 +63,23 @@ class Agent(LlmAgent):
6263
instruction: str = DEFAULT_INSTRUCTION
6364
"""The instruction for the agent, such as principles of function calling."""
6465

65-
model_name: str = getenv("MODEL_AGENT_NAME", DEFAULT_MODEL_AGENT_NAME)
66+
model_name: str = Field(
67+
default_factory=lambda: getenv("MODEL_AGENT_NAME", DEFAULT_MODEL_AGENT_NAME)
68+
)
6669
"""The name of the model for agent running."""
6770

68-
model_provider: str = getenv("MODEL_AGENT_PROVIDER", DEFAULT_MODEL_AGENT_PROVIDER)
71+
model_provider: str = Field(
72+
default_factory=lambda: getenv(
73+
"MODEL_AGENT_PROVIDER", DEFAULT_MODEL_AGENT_PROVIDER
74+
)
75+
)
6976
"""The provider of the model for agent running."""
7077

71-
model_api_base: str = getenv("MODEL_AGENT_API_BASE", DEFAULT_MODEL_AGENT_API_BASE)
78+
model_api_base: str = Field(
79+
default_factory=lambda: getenv(
80+
"MODEL_AGENT_API_BASE", DEFAULT_MODEL_AGENT_API_BASE
81+
)
82+
)
7283
"""The api base of the model for agent running."""
7384

7485
model_api_key: str = Field(default_factory=lambda: getenv("MODEL_AGENT_API_KEY"))

veadk/consts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from veadk.config import getenv
1818
from veadk.version import VERSION
1919

20+
DEFAULT_AGENT_NAME = "veAgent"
21+
2022
DEFAULT_MODEL_AGENT_NAME = "doubao-seed-1-6-250615"
2123
DEFAULT_MODEL_AGENT_PROVIDER = "openai"
2224
DEFAULT_MODEL_AGENT_API_BASE = "https://ark.cn-beijing.volces.com/api/v3/"

0 commit comments

Comments
 (0)