Skip to content

Commit 66ffe4f

Browse files
committed
fix: fix todo
1 parent f7f25cb commit 66ffe4f

File tree

6 files changed

+111
-63
lines changed

6 files changed

+111
-63
lines changed

veadk/database/database_adapter.py

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import re
1515
import time
1616
from typing import BinaryIO, TextIO
1717

@@ -120,8 +120,19 @@ class VectorDatabaseAdapter(BaseModel):
120120
client: OpenSearchVectorDatabase
121121

122122
def _validate_index(self, index: str):
123-
# TODO
124-
pass
123+
"""
124+
Verify whether the string conforms to the naming rules of index_name in OpenSearch.
125+
https://docs.opensearch.org/2.8/api-reference/index-apis/create-index/
126+
"""
127+
if not (
128+
isinstance(index, str)
129+
and not index.startswith(("_", "-"))
130+
and index.islower()
131+
and re.match(r"^[a-z0-9_\-.]+$", index)
132+
):
133+
raise ValueError(
134+
"The index name does not conform to the naming rules of OpenSearch"
135+
)
125136

126137
def add(self, data: list[str], index: str):
127138
self._validate_index(index)
@@ -133,9 +144,6 @@ def add(self, data: list[str], index: str):
133144
self.client.add(data, collection_name=index)
134145

135146
def query(self, query: str, index: str, top_k: int) -> list[str]:
136-
# FIXME: confirm
137-
self._validate_index(index)
138-
139147
logger.debug(
140148
f"Querying vector database: collection_name={index} query={query} top_k={top_k}"
141149
)
@@ -153,19 +161,34 @@ class VikingDatabaseAdapter(BaseModel):
153161
client: VikingDatabase
154162

155163
def _validate_index(self, index: str):
156-
# TODO
157-
pass
164+
"""
165+
Only English letters, numbers, and underscores (_) are allowed.
166+
It must start with an English letter and cannot be empty. Length requirement: [1, 128].
167+
For details, please see: https://www.volcengine.com/docs/84313/1254542?lang=zh
168+
"""
169+
if not (
170+
isinstance(index, str)
171+
and 0 < len(index) <= 128
172+
and re.fullmatch(r"^[a-zA-Z][a-zA-Z0-9_]*$", index)
173+
):
174+
raise ValueError(
175+
"The index name does not conform to the rules: it must start with an English letter, contain only letters, numbers, and underscores, and have a length of 1-128."
176+
)
158177

159178
def get_or_create_collection(self, collection_name: str):
160179
if not self.client.collection_exists(collection_name):
180+
logger.warning(
181+
f"Collection {collection_name} does not exist, creating a new collection."
182+
)
161183
self.client.create_collection(collection_name)
162184

163-
# FIXME
185+
# After creation, it is necessary to wait for a while.
164186
count = 0
165187
while not self.client.collection_exists(collection_name):
188+
print("here")
166189
time.sleep(1)
167190
count += 1
168-
if count > 50:
191+
if count > 60:
169192
raise TimeoutError(
170193
f"Collection {collection_name} not created after 50 seconds"
171194
)
@@ -185,9 +208,8 @@ def query(self, query: str, index: str, top_k: int) -> list[str]:
185208

186209
logger.debug(f"Querying Viking database: collection_name={index} query={query}")
187210

188-
# FIXME(): maybe do not raise, but just return []
189211
if not self.client.collection_exists(index):
190-
raise ValueError(f"Collection {index} does not exist")
212+
return []
191213

192214
return self.client.query(query, collection_name=index, top_k=top_k)
193215

@@ -198,8 +220,14 @@ class VikingMemoryDatabaseAdapter(BaseModel):
198220
client: VikingMemoryDatabase
199221

200222
def _validate_index(self, index: str):
201-
# TODO
202-
pass
223+
if not (
224+
isinstance(index, str)
225+
and 1 <= len(index) <= 128
226+
and re.fullmatch(r"^[a-zA-Z][a-zA-Z0-9_]*$", index)
227+
):
228+
raise ValueError(
229+
"The index name does not conform to the rules: it must start with an English letter, contain only letters, numbers, and underscores, and have a length of 1-128."
230+
)
203231

204232
def add(self, data: list[str], index: str, **kwargs):
205233
self._validate_index(index)
@@ -208,17 +236,16 @@ def add(self, data: list[str], index: str, **kwargs):
208236
f"Adding documents to Viking database memory: collection_name={index} data_len={len(data)}"
209237
)
210238

211-
# TODO: parse user_id
212-
self.client.add(data, collection_name=index)
239+
self.client.add(data, collection_name=index, **kwargs)
213240

214-
def query(self, query: str, index: str, top_k: int):
241+
def query(self, query: str, index: str, top_k: int, **kwargs):
215242
self._validate_index(index)
216243

217244
logger.debug(
218245
f"Querying Viking database memory: collection_name={index} query={query} top_k={top_k}"
219246
)
220247

221-
result = self.client.query(query, collection_name=index, top_k=top_k)
248+
result = self.client.query(query, collection_name=index, top_k=top_k, **kwargs)
222249
return result
223250

224251

@@ -245,8 +272,8 @@ def query(self, query: str, **kwargs):
245272

246273

247274
def get_knowledgebase_database_adapter(database_client: BaseDatabase):
248-
return MAPPING[type(database_client)](database_client=database_client)
275+
return MAPPING[type(database_client)](client=database_client)
249276

250277

251278
def get_long_term_memory_database_adapter(database_client: BaseDatabase):
252-
return MAPPING[type(database_client)](database_client=database_client)
279+
return MAPPING[type(database_client)](client=database_client)

veadk/database/database_factory.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,12 @@ def create(backend: str, config=None) -> BaseDatabase:
6969
return VikingDatabase() if config is None else VikingDatabase(config=config)
7070

7171
if backend == DatabaseBackend.VIKING_MEM:
72-
from .viking.viking_memory_db import VikingDatabaseMemory
72+
from .viking.viking_memory_db import VikingMemoryDatabase
7373

7474
return (
75-
VikingDatabaseMemory()
75+
VikingMemoryDatabase()
7676
if config is None
77-
else VikingDatabaseMemory(config=config)
77+
else VikingMemoryDatabase(config=config)
7878
)
7979
else:
8080
raise ValueError(f"Unsupported database type: {backend}")

veadk/database/viking/viking_database.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
get_collections_path = "/api/knowledge/collection/info"
4141
doc_add_path = "/api/knowledge/doc/add"
4242
doc_info_path = "/api/knowledge/doc/info"
43+
doc_del_path = "/api/collection/drop"
4344

4445

4546
class VolcengineTOSConfig(BaseModel):
@@ -246,9 +247,23 @@ def add(
246247
}
247248

248249
def delete(self, **kwargs: Any):
249-
# collection_name = kwargs.get("collection_name")
250-
# todo: delete vikingdb
251-
...
250+
collection_name = kwargs.get("collection_name")
251+
resource_id = kwargs.get("resource_id")
252+
request_param = {"collection_name": collection_name, "resource_id": resource_id}
253+
doc_del_req = prepare_request(
254+
method="POST", path=doc_del_path, config=self.config, data=request_param
255+
)
256+
rsp = requests.request(
257+
method=doc_del_req.method,
258+
url="http://{}{}".format(g_knowledge_base_domain, doc_del_req.path),
259+
headers=doc_del_req.headers,
260+
data=doc_del_req.body,
261+
)
262+
result = rsp.json()
263+
if result["code"] != 0:
264+
logger.error(f"Error in add_doc: {result['message']}")
265+
return {"error": result["message"]}
266+
return {}
252267

253268
def query(self, query: str, **kwargs: Any) -> list[str]:
254269
"""

veadk/database/viking/viking_memory_db.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
logger = get_logger(__name__)
3535

3636

37-
# FIXME
3837
class VikingMemConfig(BaseModel):
3938
volcengine_ak: Optional[str] = Field(
4039
default=getenv("VOLCENGINE_ACCESS_KEY"),
@@ -54,8 +53,8 @@ class VikingMemConfig(BaseModel):
5453
)
5554

5655

57-
# ======= adapted from ... =======
58-
class VikingDBMemoryException(Exception):
56+
# ======= adapted from https://github.com/volcengine/mcp-server/blob/main/server/mcp_server_vikingdb_memory/src/mcp_server_vikingdb_memory/common/memory_client.py =======
57+
class VikingMemoryException(Exception):
5958
def __init__(self, code, request_id, message=None):
6059
self.code = code
6160
self.request_id = request_id
@@ -67,15 +66,15 @@ def __str__(self):
6766
return self.message
6867

6968

70-
class VikingDBMemoryService(Service):
69+
class VikingMemoryService(Service):
7170
_instance_lock = threading.Lock()
7271

7372
def __new__(cls, *args, **kwargs):
74-
if not hasattr(VikingDBMemoryService, "_instance"):
75-
with VikingDBMemoryService._instance_lock:
76-
if not hasattr(VikingDBMemoryService, "_instance"):
77-
VikingDBMemoryService._instance = object.__new__(cls)
78-
return VikingDBMemoryService._instance
73+
if not hasattr(VikingMemoryService, "_instance"):
74+
with VikingMemoryService._instance_lock:
75+
if not hasattr(VikingMemoryService, "_instance"):
76+
VikingMemoryService._instance = object.__new__(cls)
77+
return VikingMemoryService._instance
7978

8079
def __init__(
8180
self,
@@ -88,11 +87,11 @@ def __init__(
8887
connection_timeout=30,
8988
socket_timeout=30,
9089
):
91-
self.service_info = VikingDBMemoryService.get_service_info(
90+
self.service_info = VikingMemoryService.get_service_info(
9291
host, region, scheme, connection_timeout, socket_timeout
9392
)
94-
self.api_info = VikingDBMemoryService.get_api_info()
95-
super(VikingDBMemoryService, self).__init__(self.service_info, self.api_info)
93+
self.api_info = VikingMemoryService.get_api_info()
94+
super(VikingMemoryService, self).__init__(self.service_info, self.api_info)
9695
if ak:
9796
self.set_ak(ak)
9897
if sk:
@@ -102,12 +101,12 @@ def __init__(
102101
try:
103102
self.get_body("Ping", {}, json.dumps({}))
104103
except Exception as e:
105-
raise VikingDBMemoryException(
104+
raise VikingMemoryException(
106105
1000028, "missed", "host or region is incorrect: {}".format(str(e))
107106
) from None
108107

109108
def setHeader(self, header):
110-
api_info = VikingDBMemoryService.get_api_info()
109+
api_info = VikingMemoryService.get_api_info()
111110
for key in api_info:
112111
for item in header:
113112
api_info[key].header[item] = header[item]
@@ -213,17 +212,17 @@ def get_body_exception(self, api, params, body):
213212
try:
214213
res_json = json.loads(e.args[0].decode("utf-8"))
215214
except Exception:
216-
raise VikingDBMemoryException(
215+
raise VikingMemoryException(
217216
1000028, "missed", "json load res error, res:{}".format(str(e))
218217
) from None
219218
code = res_json.get("code", 1000028)
220219
request_id = res_json.get("request_id", 1000028)
221220
message = res_json.get("message", None)
222221

223-
raise VikingDBMemoryException(code, request_id, message)
222+
raise VikingMemoryException(code, request_id, message)
224223

225224
if res == "":
226-
raise VikingDBMemoryException(
225+
raise VikingMemoryException(
227226
1000028,
228227
"missed",
229228
"empty response due to unknown error, please contact customer service",
@@ -237,15 +236,15 @@ def get_exception(self, api, params):
237236
try:
238237
res_json = json.loads(e.args[0].decode("utf-8"))
239238
except Exception:
240-
raise VikingDBMemoryException(
239+
raise VikingMemoryException(
241240
1000028, "missed", "json load res error, res:{}".format(str(e))
242241
) from None
243242
code = res_json.get("code", 1000028)
244243
request_id = res_json.get("request_id", 1000028)
245244
message = res_json.get("message", None)
246-
raise VikingDBMemoryException(code, request_id, message)
245+
raise VikingMemoryException(code, request_id, message)
247246
if res == "":
248-
raise VikingDBMemoryException(
247+
raise VikingMemoryException(
249248
1000028,
250249
"missed",
251250
"empty response due to unknown error, please contact customer service",
@@ -365,7 +364,7 @@ def format_milliseconds(timestamp_ms):
365364
return dt.strftime("%Y%m%d %H:%M:%S")
366365

367366

368-
# ======= adapted from ... =======
367+
# ======= adapted from https://github.com/volcengine/mcp-server/blob/main/server/mcp_server_vikingdb_memory/src/mcp_server_vikingdb_memory/common/memory_client.py =======
369368

370369

371370
class VikingMemoryDatabase(BaseModel, BaseDatabase):
@@ -375,7 +374,7 @@ class VikingMemoryDatabase(BaseModel, BaseDatabase):
375374
)
376375

377376
def model_post_init(self, context: Any, /) -> None:
378-
self._vm = VikingDBMemoryService(
377+
self._vm = VikingMemoryService(
379378
ak=self.config.volcengine_ak, sk=self.config.volcengine_sk
380379
)
381380

@@ -516,8 +515,8 @@ def query(self, query: str, **kwargs: Any) -> list[str]:
516515
assert collection_name is not None, "collection_name is required"
517516
user_id = kwargs.get("user_id")
518517
assert user_id is not None, "user_id is required"
519-
520-
resp = self.search_memory(collection_name, query, user_id=user_id)
518+
top_k = kwargs.get("top_k", 5)
519+
resp = self.search_memory(collection_name, query, user_id=user_id, top_k=top_k)
521520
return resp
522521

523522
def delete(self, **kwargs: Any):

veadk/knowledgebase/knowledgebase.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,12 @@ def add(
5858
In addition, if you upload data of the bytes type,
5959
for example, if you read the file stream of a pdf, then you need to pass an additional parameter file_ext = '.pdf'.
6060
"""
61-
# TODO: add check for data type
62-
...
61+
if self.backend != "viking" and not (
62+
isinstance(data, str) or isinstance(data, list)
63+
):
64+
raise ValueError(
65+
"Only vikingdb supports uploading files or file characters."
66+
)
6367

6468
index = build_knowledgebase_index(app_name)
6569

@@ -73,7 +77,8 @@ def search(self, query: str, app_name: str, top_k: int = None) -> list[str]:
7377
logger.info(
7478
f"Searching knowledgebase: app_name={app_name} query={query} top_k={top_k}"
7579
)
76-
result = self.adapter.query(query=query, app_name=app_name, top_k=top_k)
80+
index = build_knowledgebase_index(app_name)
81+
result = self.adapter.query(query=query, index=index, top_k=top_k)
7782
if len(result) == 0:
7883
logger.warning(f"No documents found in knowledgebase. Query: {query}")
7984
return result

veadk/memory/long_term_memory.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,10 @@ async def add_session_to_memory(
9797
)
9898

9999
# check if viking memory database, should give a user id: if/else
100-
self.adapter.add(data=event_strings, index=index)
100+
if self.backend == "viking_mem":
101+
self.adapter.add(data=event_strings, index=index, user_id=session.user_id)
102+
else:
103+
self.adapter.add(data=event_strings, index=index)
101104

102105
logger.info(
103106
f"Added {len(event_strings)} events to long term memory: index={index}"
@@ -112,15 +115,14 @@ async def search_memory(self, *, app_name: str, user_id: str, query: str):
112115
)
113116

114117
# user id if viking memory db
115-
memory_chunks = self.adapter.query(query=query, index=index, top_k=self.top_k)
116-
117-
# if len(memory_chunks) == 0:
118-
# logger.info(f"Found no memory chunks for query: {query} index={index}")
119-
# return SearchMemoryResponse()
120-
121-
# logger.info(
122-
# f"Found {len(memory_chunks)} memory chunks for query: {query} index={index}"
123-
# )
118+
if self.backend == "viking_mem":
119+
memory_chunks = self.adapter.query(
120+
query=query, index=index, top_k=self.top_k, user_id=user_id
121+
)
122+
else:
123+
memory_chunks = self.adapter.query(
124+
query=query, index=index, top_k=self.top_k
125+
)
124126

125127
memory_events = []
126128
for memory in memory_chunks:

0 commit comments

Comments
 (0)