|
16 | 16 | import pytest
|
17 | 17 | import sqlalchemy as sqla
|
18 | 18 | from sqlalchemy.sql import and_, not_, or_
|
| 19 | +from sqlalchemy.types import ARRAY |
19 | 20 |
|
20 | 21 | from tests.integration.conftest import trino_version
|
21 | 22 | from tests.unit.conftest import sqlalchemy_version
|
22 |
| -from trino.sqlalchemy.datatype import JSON, MAP |
| 23 | +from trino.sqlalchemy.datatype import JSON, MAP, ROW |
23 | 24 |
|
24 | 25 |
|
25 | 26 | @pytest.fixture
|
@@ -528,6 +529,109 @@ def test_map_column(trino_connection, map_object, sqla_type):
|
528 | 529 | metadata.drop_all(engine)
|
529 | 530 |
|
530 | 531 |
|
| 532 | +@pytest.mark.skipif( |
| 533 | + sqlalchemy_version() < "1.4", |
| 534 | + reason="columns argument to select() must be a Python list or other iterable" |
| 535 | +) |
| 536 | +@pytest.mark.parametrize( |
| 537 | + 'trino_connection,array_object,sqla_type', |
| 538 | + [ |
| 539 | + ('memory', None, ARRAY(sqla.sql.sqltypes.String)), |
| 540 | + ('memory', [], ARRAY(sqla.sql.sqltypes.String)), |
| 541 | + ('memory', [True, False, True], ARRAY(sqla.sql.sqltypes.Boolean)), |
| 542 | + ('memory', [1, 2, None], ARRAY(sqla.sql.sqltypes.Integer)), |
| 543 | + ('memory', [1.4, 2.3, math.inf], ARRAY(sqla.sql.sqltypes.Float)), |
| 544 | + ('memory', [Decimal("1.2"), Decimal("2.3")], ARRAY(sqla.sql.sqltypes.DECIMAL(2, 1))), |
| 545 | + ('memory', ["hello", "world"], ARRAY(sqla.sql.sqltypes.String)), |
| 546 | + ('memory', ["a ", "null"], ARRAY(sqla.sql.sqltypes.CHAR(4))), |
| 547 | + ('memory', [b'eh?', None, b'\x00'], ARRAY(sqla.sql.sqltypes.BINARY)), |
| 548 | + ], |
| 549 | + indirect=['trino_connection'] |
| 550 | +) |
| 551 | +def test_array_column(trino_connection, array_object, sqla_type): |
| 552 | + engine, conn = trino_connection |
| 553 | + |
| 554 | + if not engine.dialect.has_schema(conn, "test"): |
| 555 | + with engine.begin() as connection: |
| 556 | + connection.execute(sqla.schema.CreateSchema("test")) |
| 557 | + metadata = sqla.MetaData() |
| 558 | + |
| 559 | + try: |
| 560 | + table_with_array = sqla.Table( |
| 561 | + 'table_with_array', |
| 562 | + metadata, |
| 563 | + sqla.Column('id', sqla.Integer), |
| 564 | + sqla.Column('array_column', sqla_type), |
| 565 | + schema="test" |
| 566 | + ) |
| 567 | + metadata.create_all(engine) |
| 568 | + ins = table_with_array.insert() |
| 569 | + conn.execute(ins, {"id": 1, "array_column": array_object}) |
| 570 | + query = sqla.select(table_with_array) |
| 571 | + result = conn.execute(query) |
| 572 | + rows = result.fetchall() |
| 573 | + assert len(rows) == 1 |
| 574 | + assert rows[0] == (1, array_object) |
| 575 | + finally: |
| 576 | + metadata.drop_all(engine) |
| 577 | + |
| 578 | + |
| 579 | +@pytest.mark.skipif( |
| 580 | + sqlalchemy_version() < "1.4", |
| 581 | + reason="columns argument to select() must be a Python list or other iterable" |
| 582 | +) |
| 583 | +@pytest.mark.parametrize( |
| 584 | + 'trino_connection,row_object,sqla_type', |
| 585 | + [ |
| 586 | + ('memory', None, ROW([('field1', sqla.sql.sqltypes.String), |
| 587 | + ('field2', sqla.sql.sqltypes.String)])), |
| 588 | + ('memory', ('hello', 'world'), ROW([('field1', sqla.sql.sqltypes.String), |
| 589 | + ('field2', sqla.sql.sqltypes.String)])), |
| 590 | + ('memory', (True, False), ROW([('field1', sqla.sql.sqltypes.Boolean), |
| 591 | + ('field2', sqla.sql.sqltypes.Boolean)])), |
| 592 | + ('memory', (1, 2), ROW([('field1', sqla.sql.sqltypes.Integer), |
| 593 | + ('field2', sqla.sql.sqltypes.Integer)])), |
| 594 | + ('memory', (1.4, float('inf')), ROW([('field1', sqla.sql.sqltypes.Float), |
| 595 | + ('field2', sqla.sql.sqltypes.Float)])), |
| 596 | + ('memory', (Decimal("1.2"), Decimal("2.3")), ROW([('field1', sqla.sql.sqltypes.DECIMAL(2, 1)), |
| 597 | + ('field2', sqla.sql.sqltypes.DECIMAL(3, 1))])), |
| 598 | + ('memory', ("hello", "world"), ROW([('field1', sqla.sql.sqltypes.String), |
| 599 | + ('field2', sqla.sql.sqltypes.String)])), |
| 600 | + ('memory', ("a ", "null"), ROW([('field1', sqla.sql.sqltypes.CHAR(4)), |
| 601 | + ('field2', sqla.sql.sqltypes.CHAR(4))])), |
| 602 | + ('memory', (b'eh?', b'oh?'), ROW([('field1', sqla.sql.sqltypes.BINARY), |
| 603 | + ('field2', sqla.sql.sqltypes.BINARY)])), |
| 604 | + ], |
| 605 | + indirect=['trino_connection'] |
| 606 | +) |
| 607 | +def test_row_column(trino_connection, row_object, sqla_type): |
| 608 | + engine, conn = trino_connection |
| 609 | + |
| 610 | + if not engine.dialect.has_schema(conn, "test"): |
| 611 | + with engine.begin() as connection: |
| 612 | + connection.execute(sqla.schema.CreateSchema("test")) |
| 613 | + metadata = sqla.MetaData() |
| 614 | + |
| 615 | + try: |
| 616 | + table_with_row = sqla.Table( |
| 617 | + 'table_with_row', |
| 618 | + metadata, |
| 619 | + sqla.Column('id', sqla.Integer), |
| 620 | + sqla.Column('row_column', sqla_type), |
| 621 | + schema="test" |
| 622 | + ) |
| 623 | + metadata.create_all(engine) |
| 624 | + ins = table_with_row.insert() |
| 625 | + conn.execute(ins, {"id": 1, "row_column": row_object}) |
| 626 | + query = sqla.select(table_with_row) |
| 627 | + result = conn.execute(query) |
| 628 | + rows = result.fetchall() |
| 629 | + assert len(rows) == 1 |
| 630 | + assert rows[0] == (1, row_object) |
| 631 | + finally: |
| 632 | + metadata.drop_all(engine) |
| 633 | + |
| 634 | + |
531 | 635 | @pytest.mark.parametrize('trino_connection', ['system'], indirect=True)
|
532 | 636 | def test_get_catalog_names(trino_connection):
|
533 | 637 | engine, conn = trino_connection
|
|
0 commit comments