Skip to content

Commit fb2a6d2

Browse files
authored
Support to config the cache storage data size (#383)
Signed-off-by: SimFG <bang.fu@zilliz.com>
1 parent f91a8da commit fb2a6d2

File tree

7 files changed

+129
-48
lines changed

7 files changed

+129
-48
lines changed

docs/_exts/docgen.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def generate(self, lib_name):
114114
f.write(t)
115115

116116
# Iterate the modules, render the function templates and write rendered output to files
117-
print("modules:", modules)
118117
for module in modules:
119118
module_name = module[0]
120119

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ sphinx_copybutton
1414
pydata-sphinx-theme
1515
m2r2
1616
sphinx_toolbox
17+
protobuf==3.20.2

gptcache/manager/scalar_data/__init__.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,26 @@ def CacheBase(name: str, **kwargs):
1414
:param name: the name of the cache storage, it is support 'sqlite', 'postgresql', 'mysql', 'mariadb', 'sqlserver' and 'oracle' now.
1515
:type name: str
1616
:param sql_url: the url of the sql database for cache, such as '<db_type>+<db_driver>://<username>:<password>@<host>:<port>/<database>',
17-
and the default value is related to the `cache_store` parameter,
18-
'sqlite:///./sqlite.db' for 'sqlite',
19-
'duckdb:///./duck.db' for 'duckdb',
20-
'postgresql+psycopg2://postgres:123456@127.0.0.1:5432/postgres' for 'postgresql',
21-
'mysql+pymysql://root:123456@127.0.0.1:3306/mysql' for 'mysql',
22-
'mariadb+pymysql://root:123456@127.0.0.1:3307/mysql' for 'mariadb',
23-
'mssql+pyodbc://sa:Strongpsw_123@127.0.0.1:1434/msdb?driver=ODBC+Driver+17+for+SQL+Server' for 'sqlserver',
24-
'oracle+cx_oracle://oracle:123456@127.0.0.1:1521/?service_name=helowin&encoding=UTF-8&nencoding=UTF-8' for 'oracle'.
17+
and the default value is related to the `cache_store` parameter,
18+
19+
- 'sqlite:///./sqlite.db' for 'sqlite',
20+
- 'duckdb:///./duck.db' for 'duckdb',
21+
- 'postgresql+psycopg2://postgres:123456@127.0.0.1:5432/postgres' for 'postgresql',
22+
- 'mysql+pymysql://root:123456@127.0.0.1:3306/mysql' for 'mysql',
23+
- 'mariadb+pymysql://root:123456@127.0.0.1:3307/mysql' for 'mariadb',
24+
- 'mssql+pyodbc://sa:Strongpsw_123@127.0.0.1:1434/msdb?driver=ODBC+Driver+17+for+SQL+Server' for 'sqlserver',
25+
- 'oracle+cx_oracle://oracle:123456@127.0.0.1:1521/?service_name=helowin&encoding=UTF-8&nencoding=UTF-8' for 'oracle'.
2526
:type sql_url: str
2627
:param table_name: the table name for sql database, defaults to 'gptcache'.
2728
:type table_name: str
29+
:param table_len_config: the table length config for sql database, defaults to {}. the key includes:
30+
31+
- 'question_question': the question column size in the question table, default to 3000.
32+
- 'answer_answer': the answer column size in the answer table, default to 3000.
33+
- 'session_id': the session id column size in the session table, default to 1000.
34+
- 'dep_name': the name column size in the dep table, default to 1000.
35+
- 'dep_data': the data column size in the dep table, default to 3000.
36+
:type table_len_config: dict
2837
2938
:return: CacheStorage.
3039

gptcache/manager/scalar_data/manager.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,27 @@ def __init__(self):
2727

2828
@staticmethod
2929
def get(name, **kwargs):
30-
if name in ["sqlite", "postgresql", "mysql", "mariadb", "sqlserver", "oracle"]:
30+
if name in [
31+
"sqlite",
32+
"duckdb",
33+
"postgresql",
34+
"mysql",
35+
"mariadb",
36+
"sqlserver",
37+
"oracle",
38+
]:
3139
from gptcache.manager.scalar_data.sql_storage import SQLStorage
3240

3341
sql_url = kwargs.get("sql_url", SQL_URL[name])
3442
table_name = kwargs.get("table_name", TABLE_NAME)
43+
table_len_config = kwargs.get("table_len_config", {})
3544
import_sql_client(name)
36-
cache_base = SQLStorage(db_type=name, url=sql_url, table_name=table_name)
45+
cache_base = SQLStorage(
46+
db_type=name,
47+
url=sql_url,
48+
table_name=table_name,
49+
table_len_config=table_len_config,
50+
)
3751
else:
3852
raise NotFoundError("cache store", name)
3953
return cache_base

gptcache/manager/scalar_data/sql_storage.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime
2-
from typing import List, Optional
2+
from typing import List, Optional, Dict
33

44
import numpy as np
55

@@ -24,8 +24,22 @@
2424
from sqlalchemy.orm import sessionmaker # pylint: disable=C0413
2525
from sqlalchemy.ext.declarative import declarative_base # pylint: disable=C0413
2626

27+
DEFAULT_LEN_DOCT = {
28+
"question_question": 3000,
29+
"answer_answer": 3000,
30+
"session_id": 1000,
31+
"dep_name": 1000,
32+
"dep_data": 3000,
33+
}
2734

28-
def get_models(table_prefix, db_type):
35+
36+
def _get_table_len(config: Dict, column_alias: str) -> int:
37+
if config and column_alias in config and config[column_alias] > 0:
38+
return config[column_alias]
39+
return DEFAULT_LEN_DOCT.get(column_alias, 1000)
40+
41+
42+
def get_models(table_prefix, db_type, table_len_config):
2943
DynamicBase = declarative_base(class_registry={}) # pylint: disable=C0103
3044

3145
class QuestionTable(DynamicBase):
@@ -41,7 +55,10 @@ class QuestionTable(DynamicBase):
4155
id = Column(Integer, question_id_seq, primary_key=True, autoincrement=True)
4256
else:
4357
id = Column(Integer, primary_key=True, autoincrement=True)
44-
question = Column(String(1000), nullable=False)
58+
question = Column(
59+
String(_get_table_len(table_len_config, "question_question")),
60+
nullable=False,
61+
)
4562
create_on = Column(DateTime, default=datetime.now)
4663
last_access = Column(DateTime, default=datetime.now)
4764
embedding_data = Column(LargeBinary, nullable=True)
@@ -61,7 +78,9 @@ class AnswerTable(DynamicBase):
6178
else:
6279
id = Column(Integer, primary_key=True, autoincrement=True)
6380
question_id = Column(Integer, nullable=False)
64-
answer = Column(String(2000), nullable=False)
81+
answer = Column(
82+
String(_get_table_len(table_len_config, "answer_answer")), nullable=False
83+
)
6584
answer_type = Column(Integer, nullable=False)
6685

6786
class SessionTable(DynamicBase):
@@ -83,8 +102,13 @@ class SessionTable(DynamicBase):
83102
else:
84103
id = Column(Integer, primary_key=True, autoincrement=True)
85104
question_id = Column(Integer, nullable=False)
86-
session_id = Column(String(500), nullable=False)
87-
session_question = Column(String(2000), nullable=False)
105+
session_id = Column(
106+
String(_get_table_len(table_len_config, "session_id")), nullable=False
107+
)
108+
session_question = Column(
109+
String(_get_table_len(table_len_config, "question_question")),
110+
nullable=False,
111+
)
88112

89113
class QuestionDepTable(DynamicBase):
90114
"""
@@ -102,8 +126,12 @@ class QuestionDepTable(DynamicBase):
102126
else:
103127
id = Column(Integer, primary_key=True, autoincrement=True)
104128
question_id = Column(Integer, nullable=False)
105-
dep_name = Column(String(255), nullable=False)
106-
dep_data = Column(String(1000), nullable=False)
129+
dep_name = Column(
130+
String(_get_table_len(table_len_config, "dep_name")), nullable=False
131+
)
132+
dep_data = Column(
133+
String(_get_table_len(table_len_config, "dep_data")), nullable=False
134+
)
107135
dep_type = Column(Integer, nullable=False)
108136

109137
return QuestionTable, AnswerTable, QuestionDepTable, SessionTable
@@ -134,10 +162,13 @@ def __init__(
134162
db_type: str = "sqlite",
135163
url: str = "sqlite:///./sqlite.db",
136164
table_name: str = "gptcache",
165+
table_len_config=None,
137166
):
167+
if table_len_config is None:
168+
table_len_config = {}
138169
self._url = url
139170
self._ques, self._answer, self._ques_dep, self._session = get_models(
140-
table_name, db_type
171+
table_name, db_type, table_len_config
141172
)
142173
self._engine = create_engine(self._url)
143174
self.Session = sessionmaker(bind=self._engine) # pylint: disable=invalid-name

tests/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,5 @@ mock
2020
pexpect
2121
spacy
2222
safetensors
23-
protobuf==3.20.0
2423
typing_extensions<4.6.0
24+
protobuf==3.20.0

tests/unit_tests/manager/test_sql_scalar.py

Lines changed: 54 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111

1212
class TestSQLStore(unittest.TestCase):
13-
1413
def test_sqlite(self):
1514
self._inner_test_normal("sqlite")
1615
self._inner_test_with_deps("sqlite")
@@ -21,26 +20,38 @@ def test_duckdb(self):
2120
self._inner_test_with_deps("duckdb")
2221

2322
def _inner_test_normal(self, db_name: str):
24-
with TemporaryDirectory(dir='./') as root:
25-
db_path = Path(root) / f'{db_name}1.db'
26-
db = SQLStorage(db_type=db_name, url=f"{db_name}:///" + str(db_path))
23+
with TemporaryDirectory(dir="./") as root:
24+
db_path = Path(root) / f"{db_name}1.db"
25+
db = SQLStorage(
26+
db_type=db_name,
27+
url=f"{db_name}:///" + str(db_path),
28+
table_len_config={"question_question": 500},
29+
)
2730
db.create()
2831
data = []
2932
for i in range(1, 10):
30-
data.append(CacheData('question_' + str(i), ['answer_' + str(i)] * i, np.random.rand(5)))
33+
data.append(
34+
CacheData(
35+
"question_" + str(i),
36+
["answer_" + str(i)] * i,
37+
np.random.rand(5),
38+
)
39+
)
3140

3241
db.batch_insert(data)
3342
data = db.get_data_by_id(1)
34-
self.assertEqual(data.question, 'question_1')
35-
self.assertEqual(data.answers[0].answer, 'answer_1')
43+
self.assertEqual(data.question, "question_1")
44+
self.assertEqual(data.answers[0].answer, "answer_1")
3645
data = db.get_data_by_id(2)
37-
self.assertEqual(data.question, 'question_2')
38-
self.assertEqual(data.answers[0].answer, 'answer_2')
39-
self.assertEqual(data.answers[1].answer, 'answer_2')
40-
q_id = db.batch_insert([CacheData('question_single', 'answer_singel', np.random.rand(5))])[0]
46+
self.assertEqual(data.question, "question_2")
47+
self.assertEqual(data.answers[0].answer, "answer_2")
48+
self.assertEqual(data.answers[1].answer, "answer_2")
49+
q_id = db.batch_insert(
50+
[CacheData("question_single", "answer_singel", np.random.rand(5))]
51+
)[0]
4152
data = db.get_data_by_id(q_id)
42-
self.assertEqual(data.question, 'question_single')
43-
self.assertEqual(data.answers[0].answer, 'answer_singel')
53+
self.assertEqual(data.question, "question_single")
54+
self.assertEqual(data.answers[0].answer, "answer_singel")
4455

4556
# test deleted
4657
self.assertEqual(len(db.get_ids(True)), 0)
@@ -54,26 +65,42 @@ def _inner_test_normal(self, db_name: str):
5465
self.assertEqual(db.count(is_all=True), 7)
5566

5667
def _inner_test_with_deps(self, db_name: str):
57-
with TemporaryDirectory(dir='./') as root:
58-
db_path = Path(root) / f'{db_name}2.db'
68+
with TemporaryDirectory(dir="./") as root:
69+
db_path = Path(root) / f"{db_name}2.db"
5970
db = SQLStorage(db_type=db_name, url=f"{db_name}:///" + str(db_path))
6071
db.create()
61-
data_id = db.batch_insert([
62-
CacheData(
63-
Question.from_dict({
64-
"content": "test_question",
65-
"deps": [
66-
{"name": "text", "data": "how many people in this picture", "dep_type": 0},
67-
{"name": "image", "data": "object_name", "dep_type": 1}
68-
]
69-
}),
70-
'test_answer', np.random.rand(5))
71-
])[0]
72+
data_id = db.batch_insert(
73+
[
74+
CacheData(
75+
Question.from_dict(
76+
{
77+
"content": "test_question",
78+
"deps": [
79+
{
80+
"name": "text",
81+
"data": "how many people in this picture",
82+
"dep_type": 0,
83+
},
84+
{
85+
"name": "image",
86+
"data": "object_name",
87+
"dep_type": 1,
88+
},
89+
],
90+
}
91+
),
92+
"test_answer",
93+
np.random.rand(5),
94+
)
95+
]
96+
)[0]
7297

7398
ret = db.get_data_by_id(data_id)
7499
self.assertEqual(ret.question.content, "test_question")
75100
self.assertEqual(ret.question.deps[0].name, "text")
76-
self.assertEqual(ret.question.deps[0].data, "how many people in this picture")
101+
self.assertEqual(
102+
ret.question.deps[0].data, "how many people in this picture"
103+
)
77104
self.assertEqual(ret.question.deps[0].dep_type, 0)
78105
self.assertEqual(ret.question.deps[1].name, "image")
79106
self.assertEqual(ret.question.deps[1].data, "object_name")

0 commit comments

Comments
 (0)