Skip to content

Commit df97368

Browse files
feature(datasources): support for accessing sqlalchemy provided data engines with a sing SQLDataSource and DatabaseReader (#189)
1 parent 26b111d commit df97368

File tree

18 files changed

+415
-213
lines changed

18 files changed

+415
-213
lines changed

.vscode/settings.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"editor.tabSize": 4,
55
"editor.defaultFormatter": "ms-python.black-formatter",
66
"editor.codeActionsOnSave": {
7-
"source.organizeImports": true
7+
"source.organizeImports": "explicit"
88
}
99
},
1010
"isort.args": ["--profile", "black"],
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from enum import StrEnum
2+
3+
4+
class DatabaseEngineType(StrEnum):
5+
POSTGRESQL = "postgresql"
6+
MYSQL = "mysql"
7+
SQLITE = "sqlite"

llmstack/common/blocks/data/store/postgres/read.py renamed to llmstack/common/blocks/data/store/database/database_reader.py

Lines changed: 33 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,23 @@
1+
import collections
2+
import datetime
13
import json
2-
from collections import defaultdict
3-
from datetime import datetime
4-
from uuid import UUID
4+
import uuid
55

6+
import sqlalchemy
67
from psycopg2.extras import Range
78

89
from llmstack.common.blocks.base.processor import ProcessorInterface
910
from llmstack.common.blocks.base.schema import BaseSchema
1011
from llmstack.common.blocks.data import DataDocument
11-
from llmstack.common.blocks.data.store.postgres import (
12-
PostgresConfiguration,
13-
PostgresOutput,
14-
get_pg_connection,
12+
from llmstack.common.blocks.data.store.database.utils import (
13+
DatabaseConfiguration,
14+
DatabaseConfigurationType,
15+
DatabaseOutput,
16+
get_database_connection,
1517
)
1618

1719

18-
class PostgresReaderInput(BaseSchema):
19-
sql: str
20-
21-
22-
types_map = {
23-
20: "integer",
24-
21: "integer",
25-
23: "integer",
26-
700: "float",
27-
1700: "float",
28-
701: "float",
29-
16: "boolean",
30-
1082: "date",
31-
1182: "date",
32-
1114: "datetime",
33-
1184: "datetime",
34-
1115: "datetime",
35-
1185: "datetime",
36-
1014: "string",
37-
1015: "string",
38-
1008: "string",
39-
1009: "string",
40-
2951: "string",
41-
1043: "string",
42-
1002: "string",
43-
1003: "string",
44-
}
45-
46-
47-
class PostgreSQLJSONEncoder(json.JSONEncoder):
20+
class DatabaseJSONEncoder(json.JSONEncoder):
4821
def default(self, o):
4922
if isinstance(o, Range):
5023
# From: https://github.com/psycopg/psycopg2/pull/779
@@ -64,20 +37,24 @@ def default(self, o):
6437
]
6538

6639
return "".join(items)
67-
elif isinstance(o, UUID):
40+
elif isinstance(o, uuid.UUID):
6841
return str(o.hex)
69-
elif isinstance(o, datetime):
42+
elif isinstance(o, (datetime.date, datetime.datetime)):
7043
return o.isoformat()
7144

72-
return super(PostgreSQLJSONEncoder, self).default(o)
45+
return super().default(o)
7346

7447

75-
class PostgresReader(
76-
ProcessorInterface[PostgresReaderInput, PostgresOutput, PostgresConfiguration],
48+
class DatabaseReaderInput(BaseSchema):
49+
sql: str
50+
51+
52+
class DatabaseReader(
53+
ProcessorInterface[DatabaseReaderInput, DatabaseOutput, DatabaseConfiguration],
7754
):
7855
def fetch_columns(self, columns):
7956
column_names = set()
80-
duplicates_counters = defaultdict(int)
57+
duplicates_counters = collections.defaultdict(int)
8158
new_columns = []
8259

8360
for col in columns:
@@ -90,35 +67,35 @@ def fetch_columns(self, columns):
9067
)
9168

9269
column_names.add(column_name)
93-
new_columns.append(
94-
{"name": column_name, "friendly_name": column_name, "type": col[1]},
95-
)
70+
new_columns.append({"name": column_name, "type": col[1]})
9671

9772
return new_columns
9873

9974
def process(
10075
self,
101-
input: PostgresReaderInput,
102-
configuration: PostgresConfiguration,
103-
) -> PostgresOutput:
104-
connection = get_pg_connection(configuration.dict())
105-
cursor = connection.cursor()
76+
input: DatabaseReaderInput,
77+
configuration: DatabaseConfigurationType,
78+
) -> DatabaseOutput:
79+
connection = get_database_connection(configuration=configuration)
10680
try:
107-
cursor.execute(input.sql)
81+
result = connection.execute(sqlalchemy.text(input.sql))
82+
cursor = result.cursor
83+
10884
if cursor.description is not None:
10985
columns = self.fetch_columns(
110-
[(i[0], types_map.get(i[1], None)) for i in cursor.description],
86+
[(i[0], None) for i in cursor.description],
11187
)
11288
rows = [dict(zip((column["name"] for column in columns), row)) for row in cursor]
11389

11490
data = {"columns": columns, "rows": rows}
115-
json_data = json.dumps(data, cls=PostgreSQLJSONEncoder)
91+
json_data = json.dumps(data, cls=DatabaseJSONEncoder)
11692
else:
11793
raise Exception("Query completed but it returned no data.")
11894
except Exception as e:
119-
connection.cancel()
95+
connection.close()
96+
connection.engine.dispose()
12097
raise e
121-
return PostgresOutput(
98+
return DatabaseOutput(
12299
documents=[
123100
DataDocument(
124101
content=json_data,
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from enum import Enum
2+
from typing import List, Optional
3+
4+
from typing_extensions import Literal
5+
6+
from llmstack.common.blocks.base.schema import BaseSchema
7+
from llmstack.common.blocks.data import DataDocument
8+
from llmstack.common.blocks.data.store.database.constants import DatabaseEngineType
9+
10+
try:
11+
import MySQLdb
12+
13+
enabled = True
14+
except ImportError:
15+
enabled = False
16+
17+
18+
class SSLMode(str, Enum):
19+
disabled = "DISABLED"
20+
preferred = "PREFERRED"
21+
required = "REQUIRED"
22+
verify_ca = "VERIFY_CA"
23+
verify_identity = "VERIFY_IDENTITY"
24+
25+
26+
class MySQLConfiguration(BaseSchema):
27+
engine: Literal[DatabaseEngineType.MYSQL] = DatabaseEngineType.MYSQL
28+
user: Optional[str]
29+
password: Optional[str]
30+
host: str = "127.0.0.1"
31+
port: int = 3306
32+
dbname: str
33+
use_ssl: bool = False
34+
sslmode: SSLMode = "preferred"
35+
ssl_ca: Optional[str] = None
36+
ssl_cert: Optional[str] = None
37+
ssl_key: Optional[str] = None
38+
39+
class Config:
40+
schema_extra = {
41+
"order": ["host", "port", "user", "password"],
42+
"required": ["dbname"],
43+
"secret": ["password", "ssl_ca", "ssl_cert", "ssl_key"],
44+
"extra_options": ["sslmode", "ssl_ca", "ssl_cert", "ssl_key"],
45+
}
46+
47+
48+
class MySQLOutput(BaseSchema):
49+
documents: List[DataDocument]
50+
51+
52+
def get_mysql_ssl_config(configuration: dict):
53+
if not configuration.get("use_ssl"):
54+
return {}
55+
56+
ssl_config = {"sslmode": configuration.get("sslmode", "prefer")}
57+
58+
if configuration.get("use_ssl"):
59+
config_map = {"ssl_mode": "preferred", "ssl_cacert": "ca", "ssl_cert": "cert", "ssl_key": "key"}
60+
for key, cfg in config_map.items():
61+
val = configuration.get(key)
62+
if val:
63+
ssl_config[cfg] = val
64+
65+
return ssl_config
66+
67+
68+
def get_mysql_connection(configuration: dict):
69+
params = dict(
70+
host=configuration.get("host"),
71+
user=configuration.get("user"),
72+
passwd=configuration.get("password"),
73+
db=configuration.get("dbname"),
74+
port=configuration.get("port", 3306),
75+
charset=configuration.get("charset", "utf8"),
76+
use_unicode=configuration.get("use_unicode", True),
77+
connect_timeout=configuration.get("connect_timeout", 60),
78+
autocommit=configuration.get("autocommit", True),
79+
)
80+
81+
ssl_options = get_mysql_ssl_config()
82+
83+
if ssl_options:
84+
params["ssl"] = ssl_options
85+
86+
connection = MySQLdb.connect(**params)
87+
88+
return connection

llmstack/common/blocks/data/store/postgres/__init__.py renamed to llmstack/common/blocks/data/store/database/postgresql.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
from typing import List, Optional
55

66
import psycopg2
7+
from typing_extensions import Literal
78

89
from llmstack.common.blocks.base.schema import BaseSchema
910
from llmstack.common.blocks.data import DataDocument
11+
from llmstack.common.blocks.data.store.database.constants import DatabaseEngineType
1012

1113

1214
class SSLMode(str, Enum):
@@ -19,6 +21,7 @@ class SSLMode(str, Enum):
1921

2022

2123
class PostgresConfiguration(BaseSchema):
24+
engine: Literal[DatabaseEngineType.POSTGRESQL] = DatabaseEngineType.POSTGRESQL
2225
user: Optional[str]
2326
password: Optional[str]
2427
host: str = "127.0.0.1"
@@ -53,7 +56,10 @@ def _create_cert_file(configuration, key, ssl_config):
5356
ssl_config[key] = cert_file.name
5457

5558

56-
def _get_ssl_config(configuration: dict):
59+
def get_pg_ssl_config(configuration: dict):
60+
if not configuration.get("use_ssl"):
61+
return {}
62+
5763
ssl_config = {"sslmode": configuration.get("sslmode", "prefer")}
5864

5965
_create_cert_file(configuration, "sslrootcert", ssl_config)
@@ -65,7 +71,7 @@ def _get_ssl_config(configuration: dict):
6571

6672
def get_pg_connection(configuration: dict):
6773
ssl_config = (
68-
_get_ssl_config(
74+
get_pg_ssl_config(
6975
configuration,
7076
)
7177
if configuration.get("use_ssl")

llmstack/common/blocks/data/store/sqlite/__init__.py renamed to llmstack/common/blocks/data/store/database/sqlite.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
from typing import List
22

3+
from typing_extensions import Literal
4+
35
from llmstack.common.blocks.base.schema import BaseSchema
46
from llmstack.common.blocks.data import DataDocument
7+
from llmstack.common.blocks.data.store.database.constants import DatabaseEngineType
58

69

710
class SQLiteConfiguration(BaseSchema):
11+
engine: Literal[DatabaseEngineType.SQLITE] = DatabaseEngineType.SQLITE
812
dbpath: str
913

1014

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from typing import List, TypeVar
2+
3+
import sqlalchemy
4+
5+
from llmstack.common.blocks.base.schema import BaseSchema
6+
from llmstack.common.blocks.data import DataDocument
7+
from llmstack.common.blocks.data.store.database.constants import DatabaseEngineType
8+
from llmstack.common.blocks.data.store.database.mysql import (
9+
MySQLConfiguration,
10+
get_mysql_ssl_config,
11+
)
12+
from llmstack.common.blocks.data.store.database.postgresql import (
13+
PostgresConfiguration,
14+
get_pg_ssl_config,
15+
)
16+
from llmstack.common.blocks.data.store.database.sqlite import SQLiteConfiguration
17+
18+
DATABASES = {
19+
DatabaseEngineType.POSTGRESQL: {
20+
"name": "PostgreSQL",
21+
"driver": "postgresql+psycopg2",
22+
},
23+
DatabaseEngineType.MYSQL: {
24+
"name": "MySQL",
25+
"driver": "mysql+mysqldb",
26+
},
27+
DatabaseEngineType.SQLITE: {
28+
"name": "SQLite",
29+
"driver": "sqlite+pysqlite",
30+
},
31+
}
32+
33+
DatabaseConfiguration = MySQLConfiguration | PostgresConfiguration | SQLiteConfiguration
34+
35+
DatabaseConfigurationType = TypeVar("DatabaseConfigurationType", bound=DatabaseConfiguration)
36+
37+
38+
class DatabaseOutput(BaseSchema):
39+
documents: List[DataDocument]
40+
41+
42+
def get_database_configuration_class(engine: DatabaseEngineType) -> DatabaseConfigurationType:
43+
if engine == DatabaseEngineType.POSTGRESQL:
44+
return PostgresConfiguration
45+
elif engine == DatabaseEngineType.MYSQL:
46+
return MySQLConfiguration
47+
elif engine == DatabaseEngineType.SQLITE:
48+
return SQLiteConfiguration
49+
else:
50+
raise ValueError(f"Unsupported database engine: {engine}")
51+
52+
53+
def get_ssl_config(configuration: DatabaseConfigurationType) -> dict:
54+
ssl_config = {}
55+
if configuration.engine == DatabaseEngineType.POSTGRESQL:
56+
ssl_config = get_pg_ssl_config(configuration.dict())
57+
elif configuration.engine == DatabaseEngineType.MYSQL:
58+
ssl_config = get_mysql_ssl_config(configuration.dict())
59+
return ssl_config
60+
61+
62+
def get_database_connection(
63+
configuration: DatabaseConfigurationType,
64+
ssl_config: dict = None,
65+
) -> sqlalchemy.engine.Connection:
66+
if configuration.engine not in DATABASES:
67+
raise ValueError(f"Unsupported database engine: {configuration.engine}")
68+
69+
if not ssl_config:
70+
ssl_config = get_ssl_config(configuration)
71+
72+
database_name = configuration.dbpath if configuration.engine == DatabaseEngineType.SQLITE else configuration.dbname
73+
74+
connect_args: dict = {}
75+
76+
if ssl_config:
77+
connect_args["ssl"] = ssl_config
78+
79+
# Create URL
80+
db_url = sqlalchemy.engine.URL.create(
81+
drivername=DATABASES[configuration.engine]["driver"],
82+
username=configuration.user if hasattr(configuration, "user") else None,
83+
password=configuration.password if hasattr(configuration, "password") else None,
84+
host=configuration.host if hasattr(configuration, "host") else None,
85+
port=configuration.port if hasattr(configuration, "port") else None,
86+
database=database_name,
87+
)
88+
89+
# Create engine
90+
engine = sqlalchemy.create_engine(db_url, connect_args=connect_args)
91+
92+
# Connect to the database
93+
connection = engine.connect()
94+
95+
return connection

0 commit comments

Comments
 (0)