-
Notifications
You must be signed in to change notification settings - Fork 146
Expand file tree
/
Copy pathbase_driver.py
More file actions
316 lines (286 loc) · 12.1 KB
/
base_driver.py
File metadata and controls
316 lines (286 loc) · 12.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
#
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#
from enum import Enum
import datetime
from typing import Dict, List, Callable, Any, Optional, TYPE_CHECKING
from snowflake.connector.options import pandas as pd
from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
from snowflake.snowpark._internal.data_source.datasource_typing import (
Connection,
Cursor,
)
from snowflake.snowpark._internal.utils import (
get_sorted_key_for_version,
measure_time,
random_name_for_temp_object,
TempObjectType,
)
from snowflake.snowpark.exceptions import SnowparkDataframeReaderException
from snowflake.snowpark.types import (
StructType,
StructField,
VariantType,
TimestampType,
IntegerType,
BinaryType,
DateType,
BooleanType,
)
import snowflake.snowpark
import logging
PARTITION_TABLE_COLUMN_NAME = "partition"
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from snowflake.snowpark.session import Session
from snowflake.snowpark.dataframe import DataFrame
class BaseDriver:
def __init__(
self,
create_connection: Callable[..., "Connection"],
dbms_type: Enum,
connection_parameters: Optional[dict] = None,
) -> None:
self.create_connection = create_connection
self.dbms_type = dbms_type
self.connection_parameters = connection_parameters
self.raw_schema = None
def _call_create_connection(self) -> "Connection":
"""Call create_connection with connection_parameters if provided."""
if self.connection_parameters:
return self.create_connection(**self.connection_parameters)
return self.create_connection()
def to_snow_type(self, schema: List[Any]) -> StructType:
raise NotImplementedError(
f"{self.__class__.__name__} has not implemented to_snow_type function"
)
def non_retryable_error_checker(self, error: Exception) -> bool:
return False
@staticmethod
def prepare_connection(
conn: "Connection",
query_timeout: int = 0,
) -> "Connection":
return conn
@staticmethod
def generate_infer_schema_sql(
table_or_query: str, is_query: bool, query_input_alias: str
):
return (
f"SELECT * FROM ({table_or_query}) {query_input_alias} WHERE 1 = 0"
if is_query
else f"SELECT * FROM {table_or_query} WHERE 1 = 0"
)
def get_raw_schema(
self,
table_or_query: str,
cursor: "Cursor",
is_query: bool,
query_input_alias: str,
) -> None:
cursor.execute(
self.generate_infer_schema_sql(table_or_query, is_query, query_input_alias)
)
self.raw_schema = cursor.description
def infer_schema_from_description(
self,
table_or_query: str,
cursor: "Cursor",
is_query: bool,
query_input_alias: str,
) -> StructType:
self.get_raw_schema(table_or_query, cursor, is_query, query_input_alias)
return self.to_snow_type(self.raw_schema)
def infer_schema_from_description_with_error_control(
self, table_or_query: str, is_query: bool, query_input_alias: str
) -> StructType:
conn = self._call_create_connection()
cursor = conn.cursor()
try:
return self.infer_schema_from_description(
table_or_query, cursor, is_query, query_input_alias
)
except Exception as exc:
raise SnowparkDataframeReaderException(
"Auto infer schema failure:"
f"{exc!r}."
"A query:"
f"{self.generate_infer_schema_sql(table_or_query, is_query, query_input_alias)}"
"is used to infer Snowpark DataFrame schema from"
f"{table_or_query}"
"But it failed with above exception"
) from exc
finally:
# Best effort to close cursor and connection; failures are non-critical and can be ignored.
try:
cursor.close()
except BaseException as exc:
logger.debug(
f"Failed to close cursor after inferring schema from description due to error: {exc!r}"
)
try:
conn.close()
except BaseException as exc:
logger.debug(
f"Failed to close connection after inferring schema from description due to error: {exc!r}"
)
def udtf_ingestion(
self,
session: "snowflake.snowpark.Session",
schema: StructType,
partition_table: str,
external_access_integrations: str,
fetch_size: int = 1000,
imports: Optional[List[str]] = None,
packages: Optional[List[str]] = None,
session_init_statement: Optional[List[str]] = None,
query_timeout: Optional[int] = 0,
statement_params: Optional[Dict[str, str]] = None,
_emit_ast: bool = True,
) -> "snowflake.snowpark.DataFrame":
from snowflake.snowpark._internal.data_source.utils import UDTF_PACKAGE_MAP
udtf_name = random_name_for_temp_object(TempObjectType.FUNCTION)
with measure_time() as udtf_register_time:
session.udtf.register(
self.udtf_class_builder(
fetch_size=fetch_size,
schema=schema,
session_init_statement=session_init_statement,
query_timeout=query_timeout,
),
name=udtf_name,
output_schema=StructType(
[
StructField(field.name, VariantType(), field.nullable)
for field in schema.fields
]
),
external_access_integrations=[external_access_integrations],
packages=packages or UDTF_PACKAGE_MAP.get(self.dbms_type),
imports=imports,
statement_params=statement_params,
)
logger.debug(f"register ingestion udtf takes: {udtf_register_time()} seconds")
call_udtf_sql = f"""
select * from {partition_table}, table({udtf_name}({PARTITION_TABLE_COLUMN_NAME}))
"""
res = session.sql(call_udtf_sql, _emit_ast=_emit_ast)
return self.to_result_snowpark_df_udtf(res, schema, _emit_ast=_emit_ast)
def udtf_class_builder(
self,
fetch_size: int = 1000,
schema: StructType = None,
session_init_statement: List[str] = None,
query_timeout: int = 0,
) -> type:
create_connection = self.create_connection
prepare_connection = self.prepare_connection
connection_parameters = self.connection_parameters
class UDTFIngestion:
def process(self, query: str):
conn_result = (
create_connection(**connection_parameters)
if connection_parameters
else create_connection()
)
conn = prepare_connection(conn_result, query_timeout)
cursor = conn.cursor()
if session_init_statement is not None:
for statement in session_init_statement:
cursor.execute(statement)
cursor.execute(query)
while True:
rows = cursor.fetchmany(fetch_size)
if not rows:
break
yield from rows
return UDTFIngestion
@staticmethod
def validate_numeric_precision_scale(
precision: Optional[int], scale: Optional[int]
) -> bool:
if precision is not None:
if not (0 <= precision <= 38):
return False
if scale is not None and not (0 <= scale <= precision):
return False
elif scale is not None:
return False
return True
# convert timestamp and date to string to work around SNOW-1911989
# https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.map.html
# 'map' is introduced in pandas 2.1.0, before that it is 'applymap'
@staticmethod
def df_map_method(pandas_df):
return (
pandas_df.applymap
if get_sorted_key_for_version(str(pd.__version__)) < (2, 1, 0)
else pandas_df.map
)
@staticmethod
def data_source_data_to_pandas_df(
data: List[Any], schema: StructType
) -> "pd.DataFrame":
# unquote column name because double quotes stored in parquet file create column mismatch during copy into table
columns = [unquote_if_quoted(col.name) for col in schema.fields]
# this way handles both list of object and list of tuples and avoid implicit pandas type conversion
df = pd.DataFrame([list(row) for row in data], columns=columns, dtype=object)
for field in schema.fields:
name = unquote_if_quoted(field.name)
if isinstance(field.datatype, IntegerType):
# 'Int64' is a pandas dtype while 'int64' is a numpy dtype, as stated here:
# https://github.com/pandas-dev/pandas/issues/27731
# https://pandas.pydata.org/docs/reference/api/pandas.Int64Dtype.html
# https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.int64
df[name] = df[name].astype("Int64")
elif isinstance(field.datatype, (TimestampType, DateType)):
df[name] = df[name].map(
lambda x: x.isoformat()
if isinstance(x, (datetime.datetime, datetime.date))
else x
)
# astype below is meant to address copy into failure when the column contain only None value,
# pandas would infer wrong type for that column in that situation, thus we convert them to corresponding type.
elif isinstance(field.datatype, BinaryType):
# we convert all binary to hex, so it is safe to astype to string
df[name] = (
df[name]
.map(lambda x: x.hex() if isinstance(x, (bytearray, bytes)) else x)
.astype("string")
)
elif isinstance(field.datatype, BooleanType):
df[name] = df[name].astype("boolean")
return df
@staticmethod
def to_result_snowpark_df(
session: "Session", table_name: str, schema: StructType, _emit_ast: bool = True
) -> "DataFrame":
return session.table(table_name, _emit_ast=_emit_ast)
@staticmethod
def to_result_snowpark_df_udtf(
res_df: "DataFrame",
schema: StructType,
_emit_ast: bool = True,
):
cols = [
res_df[field.name].cast(field.datatype).alias(field.name)
for field in schema.fields
]
selected_df = res_df.select(cols, _emit_ast=_emit_ast)
for attr, source_field in zip(selected_df._plan.attributes, schema.fields):
attr.nullable = source_field.nullable
return selected_df
def get_server_cursor_if_supported(self, conn: "Connection") -> "Cursor":
"""
This method is used to get a server cursor if the driver and the DBMS supports it.
It can be overridden by the driver to return a server cursor if supported.
Otherwise, it will return the default cursor supported by the driver and the DBMS.
- databricks-sql-connector: no concept of client/server cursor, no need to override
- python-oracledb: default to the server cursor, no need to override
- psycopg2: default to the client cursor which needs to be overridden to return the server cursor
- pymysql: default to the client cursor which needs to be overridden to return the server cursor
TODO:
- pyodbc: This is a Python wrapper on top of ODBC drivers, the ODBC driver and the DBMS may or may not support server cursor
and if they do support, the way to get the server cursor may vary across different DBMS. we need to document pyodbc.
"""
return conn.cursor()