Skip to content

Commit 670d5f7

Browse files
hovaescohashhar
authored andcommitted
Add support for ROW and ARRAY in TrinoTypeCompiler
1 parent bb35f1c commit 670d5f7

File tree

2 files changed

+111
-1
lines changed

2 files changed

+111
-1
lines changed

tests/integration/test_sqlalchemy_integration.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
import pytest
1717
import sqlalchemy as sqla
1818
from sqlalchemy.sql import and_, not_, or_
19+
from sqlalchemy.types import ARRAY
1920

2021
from tests.integration.conftest import trino_version
2122
from tests.unit.conftest import sqlalchemy_version
22-
from trino.sqlalchemy.datatype import JSON, MAP
23+
from trino.sqlalchemy.datatype import JSON, MAP, ROW
2324

2425

2526
@pytest.fixture
@@ -528,6 +529,109 @@ def test_map_column(trino_connection, map_object, sqla_type):
528529
metadata.drop_all(engine)
529530

530531

532+
@pytest.mark.skipif(
533+
sqlalchemy_version() < "1.4",
534+
reason="columns argument to select() must be a Python list or other iterable"
535+
)
536+
@pytest.mark.parametrize(
537+
'trino_connection,array_object,sqla_type',
538+
[
539+
('memory', None, ARRAY(sqla.sql.sqltypes.String)),
540+
('memory', [], ARRAY(sqla.sql.sqltypes.String)),
541+
('memory', [True, False, True], ARRAY(sqla.sql.sqltypes.Boolean)),
542+
('memory', [1, 2, None], ARRAY(sqla.sql.sqltypes.Integer)),
543+
('memory', [1.4, 2.3, math.inf], ARRAY(sqla.sql.sqltypes.Float)),
544+
('memory', [Decimal("1.2"), Decimal("2.3")], ARRAY(sqla.sql.sqltypes.DECIMAL(2, 1))),
545+
('memory', ["hello", "world"], ARRAY(sqla.sql.sqltypes.String)),
546+
('memory', ["a ", "null"], ARRAY(sqla.sql.sqltypes.CHAR(4))),
547+
('memory', [b'eh?', None, b'\x00'], ARRAY(sqla.sql.sqltypes.BINARY)),
548+
],
549+
indirect=['trino_connection']
550+
)
551+
def test_array_column(trino_connection, array_object, sqla_type):
552+
engine, conn = trino_connection
553+
554+
if not engine.dialect.has_schema(conn, "test"):
555+
with engine.begin() as connection:
556+
connection.execute(sqla.schema.CreateSchema("test"))
557+
metadata = sqla.MetaData()
558+
559+
try:
560+
table_with_array = sqla.Table(
561+
'table_with_array',
562+
metadata,
563+
sqla.Column('id', sqla.Integer),
564+
sqla.Column('array_column', sqla_type),
565+
schema="test"
566+
)
567+
metadata.create_all(engine)
568+
ins = table_with_array.insert()
569+
conn.execute(ins, {"id": 1, "array_column": array_object})
570+
query = sqla.select(table_with_array)
571+
result = conn.execute(query)
572+
rows = result.fetchall()
573+
assert len(rows) == 1
574+
assert rows[0] == (1, array_object)
575+
finally:
576+
metadata.drop_all(engine)
577+
578+
579+
@pytest.mark.skipif(
580+
sqlalchemy_version() < "1.4",
581+
reason="columns argument to select() must be a Python list or other iterable"
582+
)
583+
@pytest.mark.parametrize(
584+
'trino_connection,row_object,sqla_type',
585+
[
586+
('memory', None, ROW([('field1', sqla.sql.sqltypes.String),
587+
('field2', sqla.sql.sqltypes.String)])),
588+
('memory', ('hello', 'world'), ROW([('field1', sqla.sql.sqltypes.String),
589+
('field2', sqla.sql.sqltypes.String)])),
590+
('memory', (True, False), ROW([('field1', sqla.sql.sqltypes.Boolean),
591+
('field2', sqla.sql.sqltypes.Boolean)])),
592+
('memory', (1, 2), ROW([('field1', sqla.sql.sqltypes.Integer),
593+
('field2', sqla.sql.sqltypes.Integer)])),
594+
('memory', (1.4, float('inf')), ROW([('field1', sqla.sql.sqltypes.Float),
595+
('field2', sqla.sql.sqltypes.Float)])),
596+
('memory', (Decimal("1.2"), Decimal("2.3")), ROW([('field1', sqla.sql.sqltypes.DECIMAL(2, 1)),
597+
('field2', sqla.sql.sqltypes.DECIMAL(3, 1))])),
598+
('memory', ("hello", "world"), ROW([('field1', sqla.sql.sqltypes.String),
599+
('field2', sqla.sql.sqltypes.String)])),
600+
('memory', ("a ", "null"), ROW([('field1', sqla.sql.sqltypes.CHAR(4)),
601+
('field2', sqla.sql.sqltypes.CHAR(4))])),
602+
('memory', (b'eh?', b'oh?'), ROW([('field1', sqla.sql.sqltypes.BINARY),
603+
('field2', sqla.sql.sqltypes.BINARY)])),
604+
],
605+
indirect=['trino_connection']
606+
)
607+
def test_row_column(trino_connection, row_object, sqla_type):
608+
engine, conn = trino_connection
609+
610+
if not engine.dialect.has_schema(conn, "test"):
611+
with engine.begin() as connection:
612+
connection.execute(sqla.schema.CreateSchema("test"))
613+
metadata = sqla.MetaData()
614+
615+
try:
616+
table_with_row = sqla.Table(
617+
'table_with_row',
618+
metadata,
619+
sqla.Column('id', sqla.Integer),
620+
sqla.Column('row_column', sqla_type),
621+
schema="test"
622+
)
623+
metadata.create_all(engine)
624+
ins = table_with_row.insert()
625+
conn.execute(ins, {"id": 1, "row_column": row_object})
626+
query = sqla.select(table_with_row)
627+
result = conn.execute(query)
628+
rows = result.fetchall()
629+
assert len(rows) == 1
630+
assert rows[0] == (1, row_object)
631+
finally:
632+
metadata.drop_all(engine)
633+
634+
531635
@pytest.mark.parametrize('trino_connection', ['system'], indirect=True)
532636
def test_get_catalog_names(trino_connection):
533637
engine, conn = trino_connection

trino/sqlalchemy/compiler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,12 @@ def visit_MAP(self, type_, **kw):
253253
value_type = self.process(type_.value_type, **kw)
254254
return f'MAP({key_type}, {value_type})'
255255

256+
def visit_ARRAY(self, type_, **kw):
257+
return f'ARRAY({self.process(type_.item_type, **kw)})'
258+
259+
def visit_ROW(self, type_, **kw):
260+
return f'ROW({", ".join(f"{name} {self.process(attr_type, **kw)}" for name, attr_type in type_.attr_types)})'
261+
256262

257263
class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
258264
reserved_words = RESERVED_WORDS

0 commit comments

Comments
 (0)