Skip to content

Commit 396bf4b

Browse files
add rag service
1 parent 6835e9b commit 396bf4b

File tree

3 files changed

+119
-8
lines changed

3 files changed

+119
-8
lines changed

src/agentscope_runtime/engine/services/context_manager.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,19 @@
44

55
from .manager import ServiceManager
66
from .memory_service import MemoryService, InMemoryMemoryService
7+
from .rag_service import RAGService
78
from .session_history_service import (
89
SessionHistoryService,
910
Session,
1011
InMemorySessionHistoryService,
1112
)
12-
from ..schemas.agent_schemas import Message
13+
from ..schemas.agent_schemas import (
14+
Message,
15+
MessageType,
16+
Role,
17+
TextContent,
18+
ContentType,
19+
)
1320

1421

1522
class ContextComposer:
@@ -19,6 +26,7 @@ async def compose(
1926
session: Session, # session
2027
memory_service: MemoryService = None,
2128
session_history_service: SessionHistoryService = None,
29+
rag_service: RAGService = None,
2230
):
2331
# session
2432
if session_history_service:
@@ -42,6 +50,18 @@ async def compose(
4250
)
4351
session.messages = memories + session.messages
4452

53+
# rag
54+
if rag_service:
55+
query = await rag_service.get_query_text(request_input[-1])
56+
docs = await rag_service.retrieve(query=query, k=5)
57+
cooked_doc = "\n".join(docs)
58+
message = Message(
59+
type=MessageType.MESSAGE,
60+
role=Role.SYSTEM,
61+
content=[TextContent(type=ContentType.TEXT, text=cooked_doc)],
62+
)
63+
session.messages.append(message)
64+
4565

4666
class ContextManager(ServiceManager):
4767
"""
@@ -53,10 +73,12 @@ def __init__(
5373
context_composer_cls=ContextComposer,
5474
session_history_service: SessionHistoryService = None,
5575
memory_service: MemoryService = None,
76+
rag_service: RAGService = None,
5677
):
5778
self._context_composer_cls = context_composer_cls
5879
self._session_history_service = session_history_service
5980
self._memory_service = memory_service
81+
self._rag_service = rag_service
6082
super().__init__()
6183

6284
def _register_default_services(self):
@@ -68,6 +90,7 @@ def _register_default_services(self):
6890

6991
self.register_service("session", self._session_history_service)
7092
self.register_service("memory", self._memory_service)
93+
self.register_service("rag", self._rag_service)
7194

7295
async def compose_context(
7396
self,
@@ -77,6 +100,7 @@ async def compose_context(
77100
await self._context_composer_cls.compose(
78101
memory_service=self._memory_service,
79102
session_history_service=self._session_history_service,
103+
rag_service=self._rag_service,
80104
session=session,
81105
request_input=request_input,
82106
)
@@ -119,10 +143,12 @@ async def append(self, session: Session, event_output: List[Message]):
119143
async def create_context_manager(
120144
memory_service: MemoryService = None,
121145
session_history_service: SessionHistoryService = None,
146+
rag_service: RAGService = None,
122147
):
123148
manager = ContextManager(
124149
memory_service=memory_service,
125150
session_history_service=session_history_service,
151+
rag_service=rag_service,
126152
)
127153

128154
async with manager:

src/agentscope_runtime/engine/services/rag_service.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,34 @@
44
from langchain_milvus import Milvus
55

66
from .base import ServiceWithLifecycleManager
7+
from ..schemas.agent_schemas import Message, MessageType
78

89

910
class RAGService(ServiceWithLifecycleManager):
1011
"""
1112
RAG Service
1213
"""
1314

15+
async def get_query_text(self, message: Message) -> str:
16+
"""
17+
Gets the query text from the messages.
18+
19+
Args:
20+
message: A list of messages.
21+
22+
Returns:
23+
The query text.
24+
"""
25+
if message:
26+
if message.type == MessageType.MESSAGE:
27+
for content in message.content:
28+
if content.type == "text":
29+
return content.text
30+
return ""
31+
32+
async def retrieve(self, query: str, k: int = 1) -> list[str]:
33+
raise NotImplementedError
34+
1435

1536
DEFAULT_URI = "milvus_demo.db"
1637

@@ -31,7 +52,9 @@ def __init__(self, uri=None, docs=None):
3152
self.uri = DEFAULT_URI
3253
self.from_docs(docs)
3354
else:
34-
raise ValueError("Either uri or docs must be provided.")
55+
docs = []
56+
self.uri = DEFAULT_URI
57+
self.from_docs(docs)
3558

3659
def from_docs(self, docs=None):
3760
if docs is None:
@@ -53,17 +76,16 @@ def from_db(self):
5376
index_params={"index_type": "FLAT", "metric_type": "L2"},
5477
)
5578

56-
async def retrieve(self, query: str, k: int = 1) -> list:
79+
async def retrieve(self, query: str, k: int = 1) -> list[str]:
5780
if self.vectorstore is None:
5881
raise ValueError(
5982
"Vector store not initialized. Call build_index first.",
6083
)
61-
return self.vectorstore.similarity_search(query, k=k)
84+
docs = self.vectorstore.similarity_search(query, k=k)
85+
return [doc.page_content for doc in docs]
6286

6387
async def start(self) -> None:
6488
"""Starts the service."""
65-
self.embeddings = DashScopeEmbeddings()
66-
self.vectorstore = None
6789

6890
async def stop(self) -> None:
6991
"""Stops the service."""

tests/unit/test_rag_service.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,17 @@
44
import pytest
55
from dotenv import load_dotenv
66

7+
from agentscope_runtime.engine import Runner
8+
from agentscope_runtime.engine.agents.llm_agent import LLMAgent
9+
from agentscope_runtime.engine.llms import QwenLLM
10+
from agentscope_runtime.engine.schemas.agent_schemas import (
11+
MessageType,
12+
AgentRequest,
13+
RunStatus,
14+
)
15+
from agentscope_runtime.engine.services.context_manager import (
16+
create_context_manager,
17+
)
718
from agentscope_runtime.engine.services.rag_service import LangChainRAGService
819

920
if os.path.exists("../../.env"):
@@ -46,7 +57,7 @@ async def test_from_docs():
4657
"What is self-reflection of an AI Agent?",
4758
)
4859
assert len(ret_docs) == 1
49-
assert ret_docs[0].page_content.startswith("Self-Reflection")
60+
assert ret_docs[0].startswith("Self-Reflection")
5061

5162

5263
@pytest.mark.asyncio
@@ -56,4 +67,56 @@ async def test_from_db():
5667
"What is self-reflection of an AI Agent?",
5768
)
5869
assert len(ret_docs) == 1
59-
assert ret_docs[0].page_content.startswith("Self-Reflection")
70+
assert ret_docs[0].startswith("Self-Reflection")
71+
72+
73+
@pytest.mark.asyncio
74+
async def test_rag():
75+
rag_service = LangChainRAGService(uri="./assets/milvus_demo.db")
76+
USER_ID = "user2"
77+
SESSION_ID = "session1"
78+
query = "What is self-reflection of an AI Agent?"
79+
80+
llm_agent = LLMAgent(
81+
model=QwenLLM(),
82+
name="llm_agent",
83+
description="A simple LLM agent",
84+
)
85+
86+
async with create_context_manager(
87+
rag_service=rag_service,
88+
) as context_manager:
89+
runner = Runner(
90+
agent=llm_agent,
91+
context_manager=context_manager,
92+
environment_manager=None,
93+
)
94+
95+
all_result = ""
96+
# print("\n")
97+
request = AgentRequest(
98+
input=[
99+
{
100+
"role": "user",
101+
"content": [
102+
{
103+
"type": "text",
104+
"text": query,
105+
},
106+
],
107+
},
108+
],
109+
session_id=SESSION_ID,
110+
)
111+
112+
async for message in runner.stream_query(
113+
user_id=USER_ID,
114+
request=request,
115+
):
116+
if (
117+
message.object == "message"
118+
and MessageType.MESSAGE == message.type
119+
and RunStatus.Completed == message.status
120+
):
121+
all_result = message.content[0].text
122+
print(all_result)

0 commit comments

Comments
 (0)