diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index e9125315..3663c524 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -583,6 +583,16 @@ def visit_identity_column(self, identity, **kw): class SnowflakeTypeCompiler(compiler.GenericTypeCompiler): + def _render_string_type(self, type_, name): + + text = name + if type_.length: + text += f"({type_.length})" + if type_.collation: + # note: whitespace before the statement is important here + text += f" COLLATE '{type_.collation}'" + return text + def visit_BYTEINT(self, type_, **kw): return "BYTEINT" diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 43820623..b389a8f6 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -2,10 +2,12 @@ # Copyright (c) 2012-2022 Snowflake Computing Inc. All rights reserved. # -from sqlalchemy import Integer, String, and_, select +import pytest +from sqlalchemy import Integer, MetaData, String, and_, select, Table, Column from sqlalchemy.schema import DropColumnComment, DropTableComment from sqlalchemy.sql import column, quoted_name, table from sqlalchemy.testing import AssertsCompiledSQL +from .conftest import CONNECTION_PARAMETERS table1 = table( "table1", column("id", Integer), column("name", String), column("value", Integer) @@ -99,3 +101,24 @@ def test_quoted_name_label(engine_testaccount): sel_from_tbl = select(col).group_by(col).select_from(table("abc")) compiled_result = sel_from_tbl.compile() assert str(compiled_result) == t["output"] + +@pytest.mark.parametrize("collation", ["en", "latin1"]) +def test_string_collation(engine_testaccount, collation): + # create a table with a string column with a certain collation + metadata = MetaData(bind=engine_testaccount) + table = Table(f"collation_test_table_{collation}", + metadata, + Column("chars_col", String(collation=collation)), + schema=CONNECTION_PARAMETERS["schema"] + ) + table.create(engine_testaccount) + insert_stmt = table.insert([ + {"chars_col": "a"}, + {"chars_col": "A"}, + {"chars_col": "b"}, + ]) + engine_testaccount.execute(insert_stmt) + # retrieve values and check if collation was used properly + column_type = engine_testaccount.execute(f"DESCRIBE TABLE {table.schema}.{table.name}").fetchone()["type"] + assert f"COLLATE '{collation}'" in column_type +