Skip to content

Commit 05f3e40

Browse files
committed
fix: fix database dpendency issues
1 parent 66ffe4f commit 05f3e40

File tree

1 file changed

+35
-35
lines changed

1 file changed

+35
-35
lines changed

veadk/database/database_adapter.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,18 @@
1414
import re
1515
import time
1616
from typing import BinaryIO, TextIO
17-
18-
from pydantic import BaseModel, ConfigDict
19-
2017
from veadk.database.base_database import BaseDatabase
21-
from veadk.database.kv.redis_database import RedisDatabase
22-
from veadk.database.local_database import LocalDataBase
23-
from veadk.database.relational.mysql_database import MysqlDatabase
24-
from veadk.database.vector.opensearch_vector_database import OpenSearchVectorDatabase
25-
from veadk.database.viking.viking_database import VikingDatabase
26-
from veadk.database.viking.viking_memory_db import VikingMemoryDatabase
18+
2719
from veadk.utils.logger import get_logger
2820

2921
logger = get_logger(__name__)
3022

3123

32-
class KVDatabaseAdapter(BaseModel):
33-
model_config = ConfigDict(arbitrary_types_allowed=True)
24+
class KVDatabaseAdapter:
25+
def __init__(self, client):
26+
from veadk.database.kv.redis_database import RedisDatabase
3427

35-
client: RedisDatabase
28+
self.client: RedisDatabase = client
3629

3730
def add(self, data: list[str], index: str):
3831
logger.debug(f"Adding documents to Redis database: index={index}")
@@ -61,10 +54,11 @@ def query(self, query: str, index: str, top_k: int = 0) -> list[str]:
6154
raise e
6255

6356

64-
class RelationalDatabaseAdapter(BaseModel):
65-
model_config = ConfigDict(arbitrary_types_allowed=True)
57+
class RelationalDatabaseAdapter:
58+
def __init__(self, client):
59+
from veadk.database.relational.mysql_database import MysqlDatabase
6660

67-
client: MysqlDatabase
61+
self.client: MysqlDatabase = client
6862

6963
def create_table(self, table_name: str):
7064
logger.debug(f"Creating table for SQL database: table_name={table_name}")
@@ -114,10 +108,13 @@ def query(self, query: str, index: str, top_k: int) -> list[str]:
114108
return [item["data"] for item in results]
115109

116110

117-
class VectorDatabaseAdapter(BaseModel):
118-
model_config = ConfigDict(arbitrary_types_allowed=True)
111+
class VectorDatabaseAdapter:
112+
def __init__(self, client):
113+
from veadk.database.vector.opensearch_vector_database import (
114+
OpenSearchVectorDatabase,
115+
)
119116

120-
client: OpenSearchVectorDatabase
117+
self.client: OpenSearchVectorDatabase = client
121118

122119
def _validate_index(self, index: str):
123120
"""
@@ -155,10 +152,11 @@ def query(self, query: str, index: str, top_k: int) -> list[str]:
155152
)
156153

157154

158-
class VikingDatabaseAdapter(BaseModel):
159-
model_config = ConfigDict(arbitrary_types_allowed=True)
155+
class VikingDatabaseAdapter:
156+
def __init__(self, client):
157+
from veadk.database.viking.viking_database import VikingDatabase
160158

161-
client: VikingDatabase
159+
self.client: VikingDatabase = client
162160

163161
def _validate_index(self, index: str):
164162
"""
@@ -214,10 +212,11 @@ def query(self, query: str, index: str, top_k: int) -> list[str]:
214212
return self.client.query(query, collection_name=index, top_k=top_k)
215213

216214

217-
class VikingMemoryDatabaseAdapter(BaseModel):
218-
model_config = ConfigDict(arbitrary_types_allowed=True)
215+
class VikingMemoryDatabaseAdapter:
216+
def __init__(self, client):
217+
from veadk.database.viking.viking_memory_db import VikingMemoryDatabase
219218

220-
client: VikingMemoryDatabase
219+
self.client: VikingMemoryDatabase = client
221220

222221
def _validate_index(self, index: str):
223222
if not (
@@ -249,10 +248,11 @@ def query(self, query: str, index: str, top_k: int, **kwargs):
249248
return result
250249

251250

252-
class LocalDatabaseAdapter(BaseModel):
253-
model_config = ConfigDict(arbitrary_types_allowed=True)
251+
class LocalDatabaseAdapter:
252+
def __init__(self, client):
253+
from veadk.database.local_database import LocalDataBase
254254

255-
client: LocalDataBase
255+
self.client: LocalDataBase = client
256256

257257
def add(self, data: list[str], **kwargs):
258258
self.client.add(data)
@@ -262,18 +262,18 @@ def query(self, query: str, **kwargs):
262262

263263

264264
MAPPING = {
265-
RedisDatabase: KVDatabaseAdapter,
266-
MysqlDatabase: RelationalDatabaseAdapter,
267-
LocalDataBase: LocalDatabaseAdapter,
268-
VikingDatabase: VikingDatabaseAdapter,
269-
OpenSearchVectorDatabase: VectorDatabaseAdapter,
270-
VikingMemoryDatabase: VikingMemoryDatabaseAdapter,
265+
"RedisDatabase": KVDatabaseAdapter,
266+
"MysqlDatabase": RelationalDatabaseAdapter,
267+
"LocalDataBase": LocalDatabaseAdapter,
268+
"VikingDatabase": VikingDatabaseAdapter,
269+
"OpenSearchVectorDatabase": VectorDatabaseAdapter,
270+
"VikingMemoryDatabase": VikingMemoryDatabaseAdapter,
271271
}
272272

273273

274274
def get_knowledgebase_database_adapter(database_client: BaseDatabase):
275-
return MAPPING[type(database_client)](client=database_client)
275+
return MAPPING[type(database_client).__name__](client=database_client)
276276

277277

278278
def get_long_term_memory_database_adapter(database_client: BaseDatabase):
279-
return MAPPING[type(database_client)](client=database_client)
279+
return MAPPING[type(database_client).__name__](client=database_client)

0 commit comments

Comments
 (0)