Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/snowflake/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.util.compat import string_types

from snowflake.sqlalchemy.custom_types import OBJECT

from .custom_commands import AWSBucket, AzureContainer, ExternalStage

RESERVED_WORDS = frozenset(
Expand Down Expand Up @@ -150,6 +152,36 @@ def _split_schema_by_dot(self, schema):


class SnowflakeCompiler(compiler.SQLCompiler):
def visit_insert(self, stmt, **kw):
# https://github.com/sqlalchemy/sqlalchemy/discussions/7894#discussioncomment-2520337
insert_sql = super().visit_insert(stmt, **kw)

columns = self.column_keys
if columns is None:
columns = stmt.table.columns.keys()

# look in the columns being inserted, see if there's
# JSON being inserted. also can just look at the INSERT string
# and look for the json function

use_json = any(isinstance(stmt.table.c[key].type, OBJECT) for key in columns)

if not use_json:
return insert_sql

stmt_reg = re.match(
r"^INSERT INTO (.+?) \((.+?)\) VALUES \((.+)\)$", insert_sql
)
if not stmt_reg:
return insert_sql

# rewrite INSERT as per
# https://docs.snowflake.com/en/sql-reference/sql/insert.html#usage-notes
return (
f"INSERT INTO {stmt_reg.group(1)} "
f"({stmt_reg.group(2)}) SELECT {stmt_reg.group(3)}"
)

def visit_sequence(self, sequence, **kw):
return self.dialect.identifier_preparer.format_sequence(sequence) + ".nextval"

Expand Down