Skip to content

Commit 15f1af9

Browse files
authored
SNOW-619615: SQLAlchemy dialect compliance (#314)
1 parent 725e926 commit 15f1af9

File tree

9 files changed

+605
-424
lines changed

9 files changed

+605
-424
lines changed

src/snowflake/sqlalchemy/base.py

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,13 +359,50 @@ def update_from_clause(
359359
for t in extra_froms
360360
)
361361

362+
def _get_regexp_args(self, binary, kw):
363+
string = self.process(binary.left, **kw)
364+
pattern = self.process(binary.right, **kw)
365+
flags = binary.modifiers["flags"]
366+
if flags is not None:
367+
flags = self.process(flags, **kw)
368+
return string, pattern, flags
369+
370+
def visit_regexp_match_op_binary(self, binary, operator, **kw):
371+
string, pattern, flags = self._get_regexp_args(binary, kw)
372+
if flags is None:
373+
return f"REGEXP_LIKE({string}, {pattern})"
374+
else:
375+
return f"REGEXP_LIKE({string}, {pattern}, {flags})"
376+
377+
def visit_regexp_replace_op_binary(self, binary, operator, **kw):
378+
string, pattern, flags = self._get_regexp_args(binary, kw)
379+
replacement = self.process(binary.modifiers["replacement"], **kw)
380+
if flags is None:
381+
return "REGEXP_REPLACE({}, {}, {})".format(
382+
string,
383+
pattern,
384+
replacement,
385+
)
386+
else:
387+
return "REGEXP_REPLACE({}, {}, {}, {})".format(
388+
string,
389+
pattern,
390+
replacement,
391+
flags,
392+
)
393+
394+
def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
395+
return f"NOT {self.visit_regexp_match_op_binary(binary, operator, **kw)}"
396+
397+
def render_literal_value(self, value, type_):
398+
# escape backslash
399+
return super().render_literal_value(value, type_).replace("\\", "\\\\")
400+
362401

363402
class SnowflakeExecutionContext(default.DefaultExecutionContext):
364403
def fire_sequence(self, seq, type_):
365404
return self._execute_scalar(
366-
"SELECT "
367-
+ self.dialect.identifier_preparer.format_sequence(seq)
368-
+ ".nextval",
405+
f"SELECT {self.identifier_preparer.format_sequence(seq)}.nextval",
369406
type_,
370407
)
371408

@@ -387,6 +424,35 @@ def should_autocommit(self):
387424
else:
388425
return autocommit and not self.isddl
389426

427+
def pre_exec(self):
428+
if self.compiled:
429+
# for compiled statements, percent is doubled for escape, we turn on _interpolate_empty_sequences
430+
if hasattr(self._dbapi_connection, "driver_connection"):
431+
# _dbapi_connection is a _ConnectionFairy which proxies raw SnowflakeConnection
432+
self._dbapi_connection.driver_connection._interpolate_empty_sequences = (
433+
True
434+
)
435+
else:
436+
# _dbapi_connection is a raw SnowflakeConnection
437+
self._dbapi_connection._interpolate_empty_sequences = True
438+
439+
def post_exec(self):
440+
if self.compiled:
441+
# for compiled statements, percent is doubled for escapeafter execution
442+
# we reset _interpolate_empty_sequences to false which is turned on in pre_exec
443+
if hasattr(self._dbapi_connection, "driver_connection"):
444+
# _dbapi_connection is a _ConnectionFairy which proxies raw SnowflakeConnection
445+
self._dbapi_connection.driver_connection._interpolate_empty_sequences = (
446+
False
447+
)
448+
else:
449+
# _dbapi_connection is a raw SnowflakeConnection
450+
self._dbapi_connection._interpolate_empty_sequences = False
451+
452+
@property
453+
def rowcount(self):
454+
return self.cursor.rowcount
455+
390456

391457
class SnowflakeDDLCompiler(compiler.DDLCompiler):
392458
def denormalize_column_name(self, name):
@@ -406,6 +472,10 @@ def get_column_specification(self, column, **kwargs):
406472
self.dialect.type_compiler.process(column.type, type_expression=column),
407473
]
408474

475+
has_identity = (
476+
column.identity is not None and self.dialect.supports_identity_columns
477+
)
478+
409479
if not column.nullable:
410480
colspec.append("NOT NULL")
411481

@@ -422,10 +492,15 @@ def get_column_specification(self, column, **kwargs):
422492
and column.server_default is None
423493
):
424494
if isinstance(column.default, Sequence):
425-
colspec.append(f"DEFAULT {column.default.name}.nextval")
495+
colspec.append(
496+
f"DEFAULT {self.dialect.identifier_preparer.format_sequence(column.default)}.nextval"
497+
)
426498
else:
427499
colspec.append("AUTOINCREMENT")
428500

501+
if has_identity:
502+
colspec.append(self.process(column.identity))
503+
429504
return " ".join(colspec)
430505

431506
def post_create_table(self, table):
@@ -512,6 +587,14 @@ def visit_drop_column_comment(self, drop, **kw):
512587
self.preparer.format_column(drop.element),
513588
)
514589

590+
def visit_identity_column(self, identity, **kw):
591+
text = " IDENTITY"
592+
if identity.start is not None or identity.increment is not None:
593+
start = 1 if identity.start is None else identity.start
594+
increment = 1 if identity.increment is None else identity.increment
595+
text += f"({start},{increment})"
596+
return text
597+
515598

516599
class SnowflakeTypeCompiler(compiler.GenericTypeCompiler):
517600
def visit_BYTEINT(self, type_, **kw):

src/snowflake/sqlalchemy/custom_types.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#
22
# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved.
33
#
4+
import datetime
5+
import decimal
6+
import re
47

58
import sqlalchemy.types as sqltypes
69
import sqlalchemy.util as util
@@ -17,6 +20,16 @@
1720
VARBINARY = sqltypes.BINARY
1821

1922

23+
def _process_float(value):
24+
if value == float("inf"):
25+
return "inf"
26+
elif value == float("-inf"):
27+
return "-inf"
28+
elif value is not None:
29+
return float(value)
30+
return value
31+
32+
2033
class SnowflakeType(sqltypes.TypeEngine):
2134
def _default_dialect(self):
2235
# Get around circular import
@@ -51,7 +64,98 @@ class GEOGRAPHY(SnowflakeType):
5164
__visit_name__ = "GEOGRAPHY"
5265

5366

67+
class _CUSTOM_Date(SnowflakeType, sqltypes.Date):
68+
def literal_processor(self, dialect):
69+
def process(value):
70+
if value is not None:
71+
return f"'{value.isoformat()}'"
72+
73+
return process
74+
75+
_reg = re.compile(r"(\d+)-(\d+)-(\d+)")
76+
77+
def result_processor(self, dialect, coltype):
78+
def process(value):
79+
if isinstance(value, str):
80+
m = self._reg.match(value)
81+
if not m:
82+
raise ValueError(f"could not parse {value!r} as a date value")
83+
return datetime.date(*[int(x or 0) for x in m.groups()])
84+
else:
85+
return value
86+
87+
return process
88+
89+
90+
class _CUSTOM_DateTime(SnowflakeType, sqltypes.DateTime):
91+
def literal_processor(self, dialect):
92+
def process(value):
93+
if value is not None:
94+
datetime_str = value.isoformat(" ", timespec="microseconds")
95+
return f"'{datetime_str}'"
96+
97+
return process
98+
99+
_reg = re.compile(r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d{0,6}))?")
100+
101+
def result_processor(self, dialect, coltype):
102+
def process(value):
103+
if isinstance(value, str):
104+
m = self._reg.match(value)
105+
if not m:
106+
raise ValueError(f"could not parse {value!r} as a datetime value")
107+
return datetime.datetime(*[int(x or 0) for x in m.groups()])
108+
else:
109+
return value
110+
111+
return process
112+
113+
114+
class _CUSTOM_Time(SnowflakeType, sqltypes.Time):
115+
def literal_processor(self, dialect):
116+
def process(value):
117+
if value is not None:
118+
time_str = value.isoformat(timespec="microseconds")
119+
return f"'{time_str}'"
120+
121+
return process
122+
123+
_reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d{0,6}))?")
124+
125+
def result_processor(self, dialect, coltype):
126+
def process(value):
127+
if isinstance(value, str):
128+
m = self._reg.match(value)
129+
if not m:
130+
raise ValueError(f"could not parse {value!r} as a time value")
131+
return datetime.time(*[int(x or 0) for x in m.groups()])
132+
else:
133+
return value
134+
135+
return process
136+
137+
138+
class _CUSTOM_Float(SnowflakeType, sqltypes.Float):
139+
def bind_processor(self, dialect):
140+
return _process_float
141+
142+
54143
class _CUSTOM_DECIMAL(SnowflakeType, sqltypes.DECIMAL):
55144
@util.memoized_property
56145
def _type_affinity(self):
57146
return sqltypes.INTEGER if self.scale == 0 else sqltypes.DECIMAL
147+
148+
149+
class _CUSTOM_Numeric(SnowflakeType, sqltypes.Numeric):
150+
def result_processor(self, dialect, coltype):
151+
if self.asdecimal:
152+
153+
def process(value):
154+
if value:
155+
return decimal.Decimal(value)
156+
else:
157+
return None
158+
159+
return process
160+
else:
161+
return _process_float

src/snowflake/sqlalchemy/provision.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#
2+
# Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved.
3+
#
4+
from sqlalchemy.testing.provision import set_default_schema_on_connection
5+
6+
7+
# This is only for test purpose required by Requirement "default_schema_name_switch"
8+
@set_default_schema_on_connection.for_db("snowflake")
9+
def _snowflake_set_default_schema_on_connection(cfg, dbapi_connection, schema_name):
10+
cursor = dbapi_connection.cursor()
11+
cursor.execute(f"USE SCHEMA {dbapi_connection.database}.{schema_name};")
12+
cursor.close()

0 commit comments

Comments
 (0)