diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 4d3386ba..fa120569 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -9,7 +9,7 @@ from sqlalchemy.engine import default from sqlalchemy.schema import Sequence, Table from sqlalchemy.sql import compiler, expression -from sqlalchemy.sql.elements import quoted_name +from sqlalchemy.sql.elements import BinaryExpression, quoted_name from sqlalchemy.util.compat import string_types from .custom_commands import AWSBucket, AzureContainer, ExternalStage @@ -153,6 +153,10 @@ class SnowflakeCompiler(compiler.SQLCompiler): def visit_sequence(self, sequence, **kw): return self.dialect.identifier_preparer.format_sequence(sequence) + ".nextval" + def visit_json_getitem_op_binary(self, a: BinaryExpression, b, **kw): + """Render keys selected from OBJECTs.""" + return self.process(a.left, **kw) + "[" + self.process(a.right, **kw) + "]" + def visit_merge_into(self, merge_into, **kw): clauses = " ".join( clause._compiler_dispatch(self, **kw) for clause in merge_into.clauses diff --git a/src/snowflake/sqlalchemy/custom_types.py b/src/snowflake/sqlalchemy/custom_types.py index 938d7883..5d3c4ec6 100644 --- a/src/snowflake/sqlalchemy/custom_types.py +++ b/src/snowflake/sqlalchemy/custom_types.py @@ -3,10 +3,13 @@ # import datetime import decimal +import json import re import sqlalchemy.types as sqltypes import sqlalchemy.util as util +from sqlalchemy import sql +from sqlalchemy.sql import expression TEXT = sqltypes.VARCHAR CHARACTER = sqltypes.CHAR @@ -40,8 +43,43 @@ class VARIANT(SnowflakeType): __visit_name__ = "VARIANT" -class OBJECT(SnowflakeType): +class OBJECT(sqltypes.Indexable, SnowflakeType): __visit_name__ = "OBJECT" + comparator_factory = sqltypes.JSON.Comparator + + def bind_expression(self, bindvalue: expression.BindParameter): + """Build the SQL string compoenent when inserted into a statement. + + The OBJECT must be sent as a string and passed to the `parse_json` Snowflake + function when INSERTing or UPDATE-ing. + """ + return sql.func.parse_json(bindvalue) + + def bind_processor(self, dialect): + def process(value): + """Process data before sending to connector as the value to bind.""" + if value is not None: + value = json.dumps(value) + + return value + + return process + + def literal_processor(self, dialect): + def process(value) -> str: + """Process data when binding literal string directly into statement.""" + return f"'{self.bind_processor(dialect)(value)}'" + + return process + + def result_processor(self, dialect, coltype): + def process(value): + """Process the value recieved from the connector.""" + if value is not None: + value = json.loads(value) + return value + + return process class ARRAY(SnowflakeType):