Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitleaks.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@ description = "Empty environment variables with KEY pattern"
regex = '''os\.environ\[".*?KEY"\]\s*=\s*".+"'''

[allowlist]
paths = ["requirements.txt"]
paths = ["requirements.txt", "tests"]
12 changes: 4 additions & 8 deletions docs/content/6.memory/3.long-term-memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ navigation:

## 使用方法

VeADK 的长期记忆通常存储在数据库中,通过如下方式定义一个长期记忆:
VeADK 的长期记忆通常存储在数据库中,你需要在初始化长期记忆时定义 `index` 来指定后端索引名称。通过如下方式定义一个长期记忆:

```python
from veadk.memory.long_term_memory import LongTermMemory

# 由于长期记忆需要构建索引,因此你必须在初始化长期记忆时定义 `app_name` 以及 `user_id`
long_term_memory = LongTermMemory(app_name="my_app_name", user_id="user_id")
#
long_term_memory = LongTermMemory(index="my_index")
```

通过如下例子说明长期记忆:
Expand All @@ -32,7 +32,7 @@ user_id = "temp_user"
teaching_session_id = "teaching_session"
student_session_id = "student_session"

long_term_memory = LongTermMemory(backend="local", app_name=app_name, user_id=user_id)
long_term_memory = LongTermMemory(backend="local", index=app_name)

agent = Agent(long_term_memory=long_term_memory)

Expand Down Expand Up @@ -90,8 +90,4 @@ print(response)
::field{name="app_name" type="string"}
Agent 应用名称,用于多应用区分。默认空字符串。
::

::field{name="user_id" type="string"}
Agent 用户 ID,用于区分不同用户的长期记忆。默认空字符串。
::
::
2 changes: 1 addition & 1 deletion docs/content/7.knowledgebase/1.knowledgebase.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ navigation:

VeADK 基于 Llama-index 作为知识库的主要处理入口。开发者可上传文本、文件、目录,我们会为您进行自动切片。

创建知识库时,您必须要提供您的 `app_name`(将会用来自动构建索引名称),或指定一个知识库的索引
创建知识库时,您必须要提供您的知识库后端索引名称 `index`,或指定 `app_name` 来作为索引名称

```python
from veadk.knowledgebase import KnowledgeBase
Expand Down
12 changes: 5 additions & 7 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from unittest.mock import Mock, patch

from google.adk.agents.llm_agent import LlmAgent
Expand All @@ -33,11 +34,10 @@


def test_agent():
knowledgebase = KnowledgeBase(
index="test_index",
backend="local",
backend_config={"embedding_config": {"api_key": "test"}},
)
os.environ["MODEL_EMBEDDING_API_KEY"] = "mocked_api_key"

knowledgebase = KnowledgeBase(index="test_index", backend="local")

long_term_memory = LongTermMemory(backend="local")
tracer = OpentelemetryTracer()

Expand Down Expand Up @@ -69,8 +69,6 @@ def test_agent():

assert agent.knowledgebase == knowledgebase
assert agent.knowledgebase.backend == "local" # type: ignore
assert load_knowledgebase_tool.knowledgebase == agent.knowledgebase
assert load_knowledgebase_tool.load_knowledgebase_tool in agent.tools

assert agent.long_term_memory.backend == "local" # type: ignore
assert load_memory in agent.tools
Expand Down
9 changes: 4 additions & 5 deletions tests/test_knowledgebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest

Expand All @@ -21,11 +22,9 @@

@pytest.mark.asyncio
async def test_knowledgebase():
os.environ["MODEL_EMBEDDING_API_KEY"] = "mocked_api_key"

app_name = "kb_test_app"
kb = KnowledgeBase(
backend="local",
app_name=app_name,
backend_config={"embedding_config": {"api_key": "test"}},
)
kb = KnowledgeBase(backend="local", app_name=app_name)

assert isinstance(kb._backend, InMemoryKnowledgeBackend)
16 changes: 7 additions & 9 deletions tests/test_long_term_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,20 @@
# limitations under the License.


import os

import pytest
from google.adk.tools import load_memory

from veadk.agent import Agent
from veadk.memory.long_term_memory import LongTermMemory

app_name = "test_ltm"
user_id = "test_user"


@pytest.mark.asyncio
async def test_long_term_memory():
long_term_memory = LongTermMemory(
backend="local",
# app_name=app_name,
# user_id=user_id,
)
os.environ["MODEL_EMBEDDING_API_KEY"] = "mocked_api_key"
long_term_memory = LongTermMemory(backend="local")

agent = Agent(
name="all_name",
model_name="test_model_name",
Expand All @@ -43,7 +40,8 @@ async def test_long_term_memory():

assert load_memory in agent.tools, "load_memory tool not found in agent tools"

assert not agent.long_term_memory._backend
assert agent.long_term_memory
assert agent.long_term_memory._backend

# assert agent.long_term_memory._backend.index == build_long_term_memory_index(
# app_name, user_id
Expand Down
8 changes: 5 additions & 3 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from google.genai import types

from veadk.agent import Agent
from veadk.memory.long_term_memory import LongTermMemory
from veadk.memory.short_term_memory import ShortTermMemory
from veadk.runner import Runner


# Import the standalone function instead of accessing as class method
from veadk.runner import _convert_messages
from veadk.runner import Runner, _convert_messages


def _test_convert_messages(runner):
Expand Down Expand Up @@ -67,6 +67,8 @@ def _test_convert_messages(runner):

def test_runner():
"""Test Runner class initialization and core properties"""
os.environ["MODEL_EMBEDDING_API_KEY"] = "mocked_api_key"

short_term_memory = ShortTermMemory()
long_term_memory = LongTermMemory(backend="local")
agent = Agent(
Expand Down
10 changes: 7 additions & 3 deletions veadk/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,14 @@ def model_post_init(self, __context: Any) -> None:
)

if self.knowledgebase:
from veadk.tools import load_knowledgebase_tool
from veadk.tools.builtin_tools.load_knowledgebase import (
LoadKnowledgebaseTool,
)

load_knowledgebase_tool.knowledgebase = self.knowledgebase
self.tools.append(load_knowledgebase_tool.load_knowledgebase_tool)
load_knowledgebase_tool = LoadKnowledgebaseTool(
knowledgebase=self.knowledgebase
)
self.tools.append(load_knowledgebase_tool)

if self.long_term_memory is not None:
from google.adk.tools import load_memory
Expand Down
51 changes: 19 additions & 32 deletions veadk/knowledgebase/knowledgebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Literal
from __future__ import annotations

from typing import Any, Callable, Literal, Union

from pydantic import BaseModel, Field
from typing_extensions import Union

from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend
from veadk.knowledgebase.entry import KnowledgebaseEntry
Expand Down Expand Up @@ -54,11 +55,11 @@ def _get_backend_cls(backend: str) -> type[BaseKnowledgebaseBackend]:
raise ValueError(f"Unsupported knowledgebase backend: {backend}")


def build_knowledgebase_index(app_name: str):
return f"veadk_kb_{app_name}"
class KnowledgeBase(BaseModel):
name: str = "user_knowledgebase"

description: str = "This knowledgebase stores some user-related information."

class KnowledgeBase(BaseModel):
backend: Union[
Literal["local", "opensearch", "viking", "redis"], BaseKnowledgebaseBackend
] = "local"
Expand All @@ -73,9 +74,7 @@ class KnowledgeBase(BaseModel):
"""Configuration for the backend"""

top_k: int = 10
"""Number of top similar documents to retrieve during search.

Default is 10."""
"""Number of top similar documents to retrieve during search"""

app_name: str = ""

Expand All @@ -85,38 +84,27 @@ class KnowledgeBase(BaseModel):
def model_post_init(self, __context: Any) -> None:
if isinstance(self.backend, BaseKnowledgebaseBackend):
self._backend = self.backend
self.index = self._backend.index
logger.info(
f"Initialized knowledgebase with provided backend instance {self._backend.__class__.__name__}"
)
return

# must provide at least one of them
if not self.app_name and not self.index:
raise ValueError(
"Either `app_name` or `index` must be provided one of them."
)

# priority use index
if self.app_name and self.index:
logger.warning(
"`app_name` and `index` are both provided, using `index` as the knowledgebase index name."
)
# Once user define backend config, use it directly
if self.backend_config:
self._backend = _get_backend_cls(self.backend)(**self.backend_config)
return

# generate index name if `index` not provided but `app_name` is provided
if self.app_name and not self.index:
self.index = build_knowledgebase_index(self.app_name)
logger.info(
f"Knowledgebase index is set to {self.index} (generated by the app_name: {self.app_name})."
)
self.index = self.index or self.app_name
if not self.index:
raise ValueError("Either `index` or `app_name` must be provided.")

logger.info(
f"Initializing knowledgebase: backend={self.backend} top_k={self.top_k}"
)
self._backend = _get_backend_cls(self.backend)(
index=self.index, **self.backend_config if self.backend_config else {}
f"Initializing knowledgebase: backend={self.backend} index={self.index} top_k={self.top_k}"
)
self._backend = _get_backend_cls(self.backend)(index=self.index)
logger.info(
f"Initialized knowledgebase with backend {self._backend.__class__.__name__}"
f"Initialized knowledgebase with backend {self.backend.__class__.__name__}"
)

def add_from_directory(self, directory: str, **kwargs) -> bool:
Expand All @@ -133,8 +121,7 @@ def add_from_text(self, text: str | list[str], **kwargs) -> bool:

def search(self, query: str, top_k: int = 0, **kwargs) -> list[KnowledgebaseEntry]:
"""Search knowledge from knowledgebase"""
if top_k == 0:
top_k = self.top_k
top_k = top_k if top_k != 0 else self.top_k

_entries = self._backend.search(query=query, top_k=top_k, **kwargs)

Expand Down
Loading