Skip to content

Commit 992ea34

Browse files
## Description [Describe what this PR does and why] **Related Issue:** Fixes #[issue_number] or Relates to #[issue_number] **Security Considerations:** [If applicable, especially for sandbox changes] ## Type of Change - [ ] Bug fix - [x] New feature - [ ] Breaking change - [x] Documentation - [ ] Refactoring ## Component(s) Affected - [x] Engine - [ ] Sandbox - [x] Documentation - [x] Tests - [ ] CI/CD ## Checklist - [x] Pre-commit hooks pass - [x] Tests pass locally - [x] Documentation updated (if needed) - [x] Ready for review ## Testing [How to test these changes] ## Additional Notes [Optional: any other context] --------- Co-authored-by: Bruce <godot.lzl@alibaba-inc.com>
1 parent 84321a8 commit 992ea34

File tree

7 files changed

+277
-1
lines changed

7 files changed

+277
-1
lines changed

cookbook/en/context_manager.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,20 @@ The `MemoryService` contains the following methods:
5959

6060
Like `SessionHistoryService`, prefer using a concrete implementation such as `InMemoryMemoryService`. For details, see {ref}`here <memory-service>`
6161

62+
### RAGService
63+
64+
The `RAGService` is a basic class to provide retrieval augmented generation (RAG) capabilities.
65+
When asked by an end-user, the agent may need to retrieve relevant information from the knowledge base.
66+
The knowledge base can be a database or a collection of documents.
67+
The `RAGService` contains the following methods:
68+
- `retrieve`: retrieve relevant information from the knowledge base
69+
70+
The `LangChainRAGService` is a concrete implementation of `RAGService` that uses LangChain to retrieve relevant information from Milvus.
71+
It can be initialized by:
72+
- `uri` the Milvus URI, either a local file (`.\xxx.db`) or a remote URL (`http://localhost:19530`).
73+
- `docs` the documents to be indexed.
74+
75+
6276
## Life-cycle of a context manager
6377
The context manager can be initialized by two ways:
6478

cookbook/zh/context_manager.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ kernelspec:
5555

5656
`SessionHistoryService`一样,优先使用具体实现,如`InMemoryMemoryService`。详细信息请参见{ref}`这里 <memory-service-zh>`
5757

58+
### RAGService
59+
`RAGService` 是一个基本类,用于提供检索增强生成(RAG)功能。当最终用户提出请求时,代理可能需要从知识库中检索相关信息。知识库可以是数据库或文档集合。`RAGService` 包含以下方法:
60+
- `retrieve`:从知识库中检索相关信息。
61+
62+
`LangChainRAGService``RAGService` 的具体实现,它使用 LangChain 从 Milvus 中检索相关信息。可以通过以下方式初始化:
63+
- `uri`:Milvus 的 URI,可以是本地文件(例如 `.\xxx.db`)或远程 URL(例如 `http://localhost:19530`)。
64+
- `docs`:要索引的文档。
65+
5866
## 上下文管理器的生命周期
5967

6068
上下文管理器可以通过两种方式初始化:

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,9 @@ autogen = [
6666
"autogen-agentchat>=0.7.4",
6767
"autogen-ext[openai]>=0.7.4",
6868
]
69+
70+
langchain_rag=[
71+
"langchain>= 0.3.25",
72+
"pymilvus>=2.6.0",
73+
"langchain_milvus"
74+
]

src/agentscope_runtime/engine/services/context_manager.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,19 @@
44

55
from .manager import ServiceManager
66
from .memory_service import MemoryService, InMemoryMemoryService
7+
from .rag_service import RAGService
78
from .session_history_service import (
89
SessionHistoryService,
910
Session,
1011
InMemorySessionHistoryService,
1112
)
12-
from ..schemas.agent_schemas import Message
13+
from ..schemas.agent_schemas import (
14+
Message,
15+
MessageType,
16+
Role,
17+
TextContent,
18+
ContentType,
19+
)
1320

1421

1522
class ContextComposer:
@@ -19,6 +26,7 @@ async def compose(
1926
session: Session, # session
2027
memory_service: MemoryService = None,
2128
session_history_service: SessionHistoryService = None,
29+
rag_service: RAGService = None,
2230
):
2331
# session
2432
if session_history_service:
@@ -42,6 +50,18 @@ async def compose(
4250
)
4351
session.messages = memories + session.messages
4452

53+
# rag
54+
if rag_service:
55+
query = await rag_service.get_query_text(request_input[-1])
56+
docs = await rag_service.retrieve(query=query, k=5)
57+
cooked_doc = "\n".join(docs)
58+
message = Message(
59+
type=MessageType.MESSAGE,
60+
role=Role.SYSTEM,
61+
content=[TextContent(type=ContentType.TEXT, text=cooked_doc)],
62+
)
63+
session.messages.append(message)
64+
4565

4666
class ContextManager(ServiceManager):
4767
"""
@@ -53,10 +73,12 @@ def __init__(
5373
context_composer_cls=ContextComposer,
5474
session_history_service: SessionHistoryService = None,
5575
memory_service: MemoryService = None,
76+
rag_service: RAGService = None,
5677
):
5778
self._context_composer_cls = context_composer_cls
5879
self._session_history_service = session_history_service
5980
self._memory_service = memory_service
81+
self._rag_service = rag_service
6082
super().__init__()
6183

6284
def _register_default_services(self):
@@ -68,6 +90,7 @@ def _register_default_services(self):
6890

6991
self.register_service("session", self._session_history_service)
7092
self.register_service("memory", self._memory_service)
93+
self.register_service("rag", self._rag_service)
7194

7295
async def compose_context(
7396
self,
@@ -77,6 +100,7 @@ async def compose_context(
77100
await self._context_composer_cls.compose(
78101
memory_service=self._memory_service,
79102
session_history_service=self._session_history_service,
103+
rag_service=self._rag_service,
80104
session=session,
81105
request_input=request_input,
82106
)
@@ -119,10 +143,12 @@ async def append(self, session: Session, event_output: List[Message]):
119143
async def create_context_manager(
120144
memory_service: MemoryService = None,
121145
session_history_service: SessionHistoryService = None,
146+
rag_service: RAGService = None,
122147
):
123148
manager = ContextManager(
124149
memory_service=memory_service,
125150
session_history_service=session_history_service,
151+
rag_service=rag_service,
126152
)
127153

128154
async with manager:
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# -*- coding: utf-8 -*-
2+
from typing import Optional
3+
4+
from langchain_community.embeddings import DashScopeEmbeddings
5+
from langchain_milvus import Milvus
6+
7+
from .base import ServiceWithLifecycleManager
8+
from ..schemas.agent_schemas import Message, MessageType
9+
10+
11+
class RAGService(ServiceWithLifecycleManager):
12+
"""
13+
RAG Service
14+
"""
15+
16+
async def get_query_text(self, message: Message) -> str:
17+
"""
18+
Gets the query text from the messages.
19+
20+
Args:
21+
message: A list of messages.
22+
23+
Returns:
24+
The query text.
25+
"""
26+
if message:
27+
if message.type == MessageType.MESSAGE:
28+
for content in message.content:
29+
if content.type == "text":
30+
return content.text
31+
return ""
32+
33+
async def retrieve(self, query: str, k: int = 1) -> list[str]:
34+
raise NotImplementedError
35+
36+
37+
DEFAULT_URI = "milvus_demo.db"
38+
39+
40+
class LangChainRAGService(RAGService):
41+
"""
42+
RAG Service using LangChain
43+
"""
44+
45+
def __init__(
46+
self,
47+
uri: Optional[str] = None,
48+
docs: Optional[list[str]] = None,
49+
):
50+
self.embeddings = DashScopeEmbeddings()
51+
self.vectorstore = None
52+
53+
if uri:
54+
self.uri = uri
55+
self.from_db()
56+
elif docs:
57+
self.uri = DEFAULT_URI
58+
self.from_docs(docs)
59+
else:
60+
docs = []
61+
self.uri = DEFAULT_URI
62+
self.from_docs(docs)
63+
64+
def from_docs(self, docs=None):
65+
if docs is None:
66+
docs = []
67+
68+
self.vectorstore = Milvus.from_documents(
69+
documents=docs,
70+
embedding=self.embeddings,
71+
connection_args={
72+
"uri": self.uri,
73+
},
74+
drop_old=False,
75+
)
76+
77+
def from_db(self):
78+
self.vectorstore = Milvus(
79+
embedding_function=self.embeddings,
80+
connection_args={"uri": self.uri},
81+
index_params={"index_type": "FLAT", "metric_type": "L2"},
82+
)
83+
84+
async def retrieve(self, query: str, k: int = 1) -> list[str]:
85+
if self.vectorstore is None:
86+
raise ValueError(
87+
"Vector store not initialized. Call build_index first.",
88+
)
89+
docs = self.vectorstore.similarity_search(query, k=k)
90+
return [doc.page_content for doc in docs]
91+
92+
async def start(self) -> None:
93+
"""Starts the service."""
94+
95+
async def stop(self) -> None:
96+
"""Stops the service."""
97+
98+
async def health(self) -> bool:
99+
"""Checks the health of the service."""
100+
return True

tests/unit/assets/milvus_demo.db

928 KB
Binary file not shown.

tests/unit/test_rag_service.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# -*- coding: utf-8 -*-
2+
import os
3+
4+
import pytest
5+
from dotenv import load_dotenv
6+
7+
from agentscope_runtime.engine import Runner
8+
from agentscope_runtime.engine.agents.llm_agent import LLMAgent
9+
from agentscope_runtime.engine.llms import QwenLLM
10+
from agentscope_runtime.engine.schemas.agent_schemas import (
11+
MessageType,
12+
AgentRequest,
13+
RunStatus,
14+
)
15+
from agentscope_runtime.engine.services.context_manager import (
16+
create_context_manager,
17+
)
18+
from agentscope_runtime.engine.services.rag_service import LangChainRAGService
19+
20+
if os.path.exists("../../.env"):
21+
load_dotenv("../../.env")
22+
23+
24+
def load_docs():
25+
import bs4
26+
from langchain_community.document_loaders import WebBaseLoader
27+
from langchain_text_splitters import RecursiveCharacterTextSplitter
28+
29+
loader = WebBaseLoader(
30+
web_paths=(
31+
"https://lilianweng.github.io/posts/2023-06-23-agent/",
32+
"https://lilianweng.github.io/posts/2023-03-15-prompt"
33+
"-engineering/",
34+
),
35+
bs_kwargs={
36+
"parse_only": bs4.SoupStrainer(
37+
class_=("post-content", "post-title", "post-header"),
38+
),
39+
},
40+
)
41+
documents = loader.load()
42+
text_splitter = RecursiveCharacterTextSplitter(
43+
chunk_size=2000,
44+
chunk_overlap=200,
45+
)
46+
47+
docs = text_splitter.split_documents(documents)
48+
return docs
49+
50+
51+
@pytest.mark.asyncio
52+
async def test_from_docs():
53+
docs = load_docs()
54+
rag_service = LangChainRAGService(docs=docs)
55+
56+
ret_docs = await rag_service.retrieve(
57+
"What is self-reflection of an AI Agent?",
58+
)
59+
assert len(ret_docs) == 1
60+
assert ret_docs[0].startswith("Self-Reflection")
61+
62+
63+
@pytest.mark.asyncio
64+
async def test_from_db():
65+
rag_service = LangChainRAGService(uri="./assets/milvus_demo.db")
66+
ret_docs = await rag_service.retrieve(
67+
"What is self-reflection of an AI Agent?",
68+
)
69+
assert len(ret_docs) == 1
70+
assert ret_docs[0].startswith("Self-Reflection")
71+
72+
73+
@pytest.mark.asyncio
74+
async def test_rag():
75+
rag_service = LangChainRAGService(uri="./assets/milvus_demo.db")
76+
USER_ID = "user2"
77+
SESSION_ID = "session1"
78+
query = "What is self-reflection of an AI Agent?"
79+
80+
llm_agent = LLMAgent(
81+
model=QwenLLM(),
82+
name="llm_agent",
83+
description="A simple LLM agent",
84+
)
85+
86+
async with create_context_manager(
87+
rag_service=rag_service,
88+
) as context_manager:
89+
runner = Runner(
90+
agent=llm_agent,
91+
context_manager=context_manager,
92+
environment_manager=None,
93+
)
94+
95+
all_result = ""
96+
# print("\n")
97+
request = AgentRequest(
98+
input=[
99+
{
100+
"role": "user",
101+
"content": [
102+
{
103+
"type": "text",
104+
"text": query,
105+
},
106+
],
107+
},
108+
],
109+
session_id=SESSION_ID,
110+
)
111+
112+
async for message in runner.stream_query(
113+
user_id=USER_ID,
114+
request=request,
115+
):
116+
if (
117+
message.object == "message"
118+
and MessageType.MESSAGE == message.type
119+
and RunStatus.Completed == message.status
120+
):
121+
all_result = message.content[0].text
122+
print(all_result)

0 commit comments

Comments
 (0)