From e3e06ea8ff6277171b6b717bff93a9f4775d5ce3 Mon Sep 17 00:00:00 2001 From: Simon Hewitt Date: Tue, 10 Jan 2023 13:20:28 -0800 Subject: [PATCH 01/74] call _compiler_dispatch for merge_into and copy_into clauses --- src/snowflake/sqlalchemy/base.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index e9125315..0211e713 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -158,8 +158,11 @@ def visit_merge_into(self, merge_into, **kw): clauses = " ".join( clause._compiler_dispatch(self, **kw) for clause in merge_into.clauses ) + target = merge_into.target._compiler_dispatch(self, asfrom=True, **kw) + source = merge_into.source._compiler_dispatch(self, asfrom=True, **kw) + on = merge_into.on._compiler_dispatch(self, **kw) return ( - f"MERGE INTO {merge_into.target} USING {merge_into.source} ON {merge_into.on}" + f"MERGE INTO {target} USING {source} ON {on}" + (" " + clauses if clauses else "") ) @@ -207,11 +210,7 @@ def visit_copy_into(self, copy_into, **kw): formatter = copy_into.formatter._compiler_dispatch(self, **kw) else: formatter = "" - into = ( - copy_into.into - if isinstance(copy_into.into, Table) - else copy_into.into._compiler_dispatch(self, **kw) - ) + into = copy_into.into._compiler_dispatch(self, asfrom=True, **kw) from_ = None if isinstance(copy_into.from_, Table): from_ = copy_into.from_ From 8733872fd59cb2023129492789e246cc79d091c1 Mon Sep 17 00:00:00 2001 From: Mark Keller Date: Wed, 19 Apr 2023 14:59:01 -0700 Subject: [PATCH 02/74] SNOW-518659 changelog bot public repositories (#407) Committed via https://github.com/asottile/all-repos --- .github/workflows/changelog.yml | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 .github/workflows/changelog.yml diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml new file mode 100644 index 00000000..2e197168 --- /dev/null +++ b/.github/workflows/changelog.yml @@ -0,0 +1,20 @@ +name: Changelog Check + +on: + pull_request: + types: [opened, synchronize, labeled, unlabeled] + branches: + - main + +jobs: + check_change_log: + runs-on: ubuntu-latest + if: ${{!contains(github.event.pull_request.labels.*.name, 'NO-CHANGELOG-UPDATES')}} + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Ensure DESCRIPTION.md is updated + run: git diff --name-only --diff-filter=ACMRT ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -wq "DESCRIPTION.md" From b48e420653a9b2d5022536d42c8e1f0e2dc615df Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Mon, 24 Apr 2023 17:59:23 -0700 Subject: [PATCH 03/74] Update cla_bot.yml (#409) --- .github/workflows/cla_bot.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/cla_bot.yml b/.github/workflows/cla_bot.yml index 5574667a..cbc3cc82 100644 --- a/.github/workflows/cla_bot.yml +++ b/.github/workflows/cla_bot.yml @@ -8,6 +8,11 @@ on: jobs: CLAssistant: runs-on: ubuntu-latest + permissions: + actions: write + contents: write + pull-requests: write + statuses: write steps: - name: "CLA Assistant" if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target' From c563913f9f2e7bbd46d47b1eeee52f456d5659cb Mon Sep 17 00:00:00 2001 From: Mark Keller Date: Tue, 25 Apr 2023 10:56:20 -0700 Subject: [PATCH 04/74] Update jira_issue.yml (#410) --- .github/workflows/jira_issue.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/jira_issue.yml b/.github/workflows/jira_issue.yml index 3683bbba..74e58454 100644 --- a/.github/workflows/jira_issue.yml +++ b/.github/workflows/jira_issue.yml @@ -9,6 +9,8 @@ on: jobs: create-issue: runs-on: ubuntu-latest + permissions: + issues: write if: ((github.event_name == 'issue_comment' && github.event.comment.body == 'recreate jira' && github.event.comment.user.login == 'sfc-gh-mkeller') || (github.event_name == 'issues' && github.event.pull_request.user.login != 'whitesource-for-github-com[bot]')) steps: - name: Checkout From c200f971a745eb404ef68b9f8a248c7255056dc9 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Mon, 5 Jun 2023 19:22:36 -0700 Subject: [PATCH 05/74] Test pr 418 (#419) * feat: Add temporary option to the create stage command. * update changelog * Update tests/test_create.py Co-authored-by: Adam Ling --------- Co-authored-by: DanCardin Co-authored-by: Adam Ling --- DESCRIPTION.md | 4 ++++ src/snowflake/sqlalchemy/base.py | 5 +++-- src/snowflake/sqlalchemy/custom_commands.py | 3 ++- tests/test_create.py | 10 ++++++++++ 4 files changed, 19 insertions(+), 3 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 5219de1e..b0281eca 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,6 +9,10 @@ Source code is also available at: # Release Notes +- v1.4.8(TBD) + + - Added opiton to create a temporary stage command. + - v1.4.7(Mar 22, 2023) - Re-applied the application name of driver connection `SnowflakeConnection` to `SnowflakeSQLAlchemy`. diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index d87b78c1..45a51265 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -534,11 +534,12 @@ def visit_create_stage(self, create_stage, **kw): """ This visitor will create the SQL representation for a CREATE STAGE command. """ - return "CREATE {}STAGE {}{} URL={}".format( - "OR REPLACE " if create_stage.replace_if_exists else "", + return "CREATE {or_replace}{temporary}STAGE {}{} URL={}".format( create_stage.stage.namespace, create_stage.stage.name, repr(create_stage.container), + or_replace="OR REPLACE " if create_stage.replace_if_exists else "", + temporary="TEMPORARY " if create_stage.temporary else "", ) def visit_create_file_format(self, file_format, **kw): diff --git a/src/snowflake/sqlalchemy/custom_commands.py b/src/snowflake/sqlalchemy/custom_commands.py index 9cc14389..9bb60916 100644 --- a/src/snowflake/sqlalchemy/custom_commands.py +++ b/src/snowflake/sqlalchemy/custom_commands.py @@ -482,9 +482,10 @@ class CreateStage(DDLElement): __visit_name__ = "create_stage" - def __init__(self, container, stage, replace_if_exists=False): + def __init__(self, container, stage, replace_if_exists=False, *, temporary=False): super().__init__() self.container = container + self.temporary = temporary self.stage = stage self.replace_if_exists = replace_if_exists diff --git a/tests/test_create.py b/tests/test_create.py index 5271118f..0b8b48fa 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -54,6 +54,16 @@ def test_create_stage(sql_compiler): ) assert actual == expected + create_stage = CreateStage(stage=stage, container=container, temporary=True) + # validate that the resulting SQL is as expected + actual = sql_compiler(create_stage) + expected = ( + "CREATE TEMPORARY STAGE MY_DB.MY_SCHEMA.AZURE_STAGE " + "URL='azure://myaccount.blob.core.windows.net/my-container' " + "CREDENTIALS=(AZURE_SAS_TOKEN='saas_token')" + ) + assert actual == expected + def test_create_csv_format(sql_compiler): """ From a72fa3929943adf837f6330b654e17c1333a2a5e Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Thu, 6 Jul 2023 15:10:28 -0700 Subject: [PATCH 06/74] SNOW-857216: fix compatibility issue with sqlalchemy 1.4.49 (#423) --- DESCRIPTION.md | 3 ++- src/snowflake/sqlalchemy/base.py | 9 ++++++++- tests/connector_regression | 2 +- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index b0281eca..10c0f9f1 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -11,7 +11,8 @@ Source code is also available at: - v1.4.8(TBD) - - Added opiton to create a temporary stage command. + - Added option to create a temporary stage command. + - Fixed a compatibility issue of regex expression with SQLAlchemy 1.4.49. - v1.4.7(Mar 22, 2023) diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 45a51265..9835ced2 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -376,7 +376,14 @@ def visit_regexp_match_op_binary(self, binary, operator, **kw): def visit_regexp_replace_op_binary(self, binary, operator, **kw): string, pattern, flags = self._get_regexp_args(binary, kw) - replacement = self.process(binary.modifiers["replacement"], **kw) + try: + replacement = self.process(binary.modifiers["replacement"], **kw) + except KeyError: + # in sqlalchemy 1.4.49, the internal structure of the expression is changed + # that binary.modifiers doesn't have "replacement": + # https://docs.sqlalchemy.org/en/20/changelog/changelog_14.html#change-1.4.49 + return f"REGEXP_REPLACE({string}, {pattern}{'' if flags is None else f', {flags}'})" + if flags is None: return f"REGEXP_REPLACE({string}, {pattern}, {replacement})" else: diff --git a/tests/connector_regression b/tests/connector_regression index ec95c563..6c365ab5 160000 --- a/tests/connector_regression +++ b/tests/connector_regression @@ -1 +1 @@ -Subproject commit ec95c563ded4694f69e8bde4eb2f010f92681e58 +Subproject commit 6c365ab5e4d11621e78194c723b45528eaa807e2 From e3f675e45456b68cc5de89771de72716e674aae9 Mon Sep 17 00:00:00 2001 From: Sophie Tan Date: Thu, 20 Jul 2023 19:05:44 -0400 Subject: [PATCH 07/74] Update cla_bot.yml (#422) --- .github/workflows/cla_bot.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cla_bot.yml b/.github/workflows/cla_bot.yml index cbc3cc82..2c87fc92 100644 --- a/.github/workflows/cla_bot.yml +++ b/.github/workflows/cla_bot.yml @@ -22,7 +22,7 @@ jobs: PERSONAL_ACCESS_TOKEN : ${{ secrets.CLA_BOT_TOKEN }} with: path-to-signatures: 'signatures/version1.json' - path-to-document: 'https://github.com/Snowflake-Labs/CLA/blob/main/README.md' + path-to-document: 'https://github.com/snowflakedb/CLA/blob/main/README.md' branch: 'main' allowlist: 'dependabot[bot],github-actions, sfc-gh-snyk-sca-sa' remote-organization-name: 'snowflakedb' From de68d864af14200643b74625db4480db87c92172 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Thu, 27 Jul 2023 19:34:20 -0700 Subject: [PATCH 08/74] SNOW-871115: add geometry support (#429) --- DESCRIPTION.md | 1 + src/snowflake/sqlalchemy/__init__.py | 2 + src/snowflake/sqlalchemy/base.py | 3 ++ src/snowflake/sqlalchemy/custom_types.py | 4 ++ src/snowflake/sqlalchemy/snowdialect.py | 2 + tests/test_core.py | 10 +++- tests/test_custom_types.py | 1 + tests/test_geometry.py | 67 ++++++++++++++++++++++++ tests/util.py | 2 + 9 files changed, 90 insertions(+), 2 deletions(-) create mode 100644 tests/test_geometry.py diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 10c0f9f1..2b2fc44e 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -12,6 +12,7 @@ Source code is also available at: - v1.4.8(TBD) - Added option to create a temporary stage command. + - Add support for geometry type. - Fixed a compatibility issue of regex expression with SQLAlchemy 1.4.49. - v1.4.7(Mar 22, 2023) diff --git a/src/snowflake/sqlalchemy/__init__.py b/src/snowflake/sqlalchemy/__init__.py index 063910fe..9df6aaa2 100644 --- a/src/snowflake/sqlalchemy/__init__.py +++ b/src/snowflake/sqlalchemy/__init__.py @@ -49,6 +49,7 @@ DOUBLE, FIXED, GEOGRAPHY, + GEOMETRY, NUMBER, OBJECT, STRING, @@ -90,6 +91,7 @@ "DOUBLE", "FIXED", "GEOGRAPHY", + "GEOMETRY", "OBJECT", "NUMBER", "STRING", diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 9835ced2..525522ea 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -654,5 +654,8 @@ def visit_TIMESTAMP(self, type_, **kw): def visit_GEOGRAPHY(self, type_, **kw): return "GEOGRAPHY" + def visit_GEOMETRY(self, type_, **kw): + return "GEOMETRY" + construct_arguments = [(Table, {"clusterby": None})] diff --git a/src/snowflake/sqlalchemy/custom_types.py b/src/snowflake/sqlalchemy/custom_types.py index 3f42f034..802d1ce1 100644 --- a/src/snowflake/sqlalchemy/custom_types.py +++ b/src/snowflake/sqlalchemy/custom_types.py @@ -61,6 +61,10 @@ class GEOGRAPHY(SnowflakeType): __visit_name__ = "GEOGRAPHY" +class GEOMETRY(SnowflakeType): + __visit_name__ = "GEOMETRY" + + class _CUSTOM_Date(SnowflakeType, sqltypes.Date): def literal_processor(self, dialect): def process(value): diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index dc1dea1b..350027f4 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -51,6 +51,7 @@ _CUSTOM_DECIMAL, ARRAY, GEOGRAPHY, + GEOMETRY, OBJECT, TIMESTAMP_LTZ, TIMESTAMP_NTZ, @@ -105,6 +106,7 @@ "OBJECT": OBJECT, "ARRAY": ARRAY, "GEOGRAPHY": GEOGRAPHY, + "GEOMETRY": GEOMETRY, } diff --git a/tests/test_core.py b/tests/test_core.py index 157889ff..8206e43d 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1612,7 +1612,8 @@ def test_column_type_schema(engine_testaccount): C1 BIGINT, C2 BINARY, C3 BOOLEAN, C4 CHAR, C5 CHARACTER, C6 DATE, C7 DATETIME, C8 DEC, C9 DECIMAL, C10 DOUBLE, C11 FLOAT, C12 INT, C13 INTEGER, C14 NUMBER, C15 REAL, C16 BYTEINT, C17 SMALLINT, C18 STRING, C19 TEXT, C20 TIME, C21 TIMESTAMP, C22 TIMESTAMP_TZ, C23 TIMESTAMP_LTZ, - C24 TIMESTAMP_NTZ, C25 TINYINT, C26 VARBINARY, C27 VARCHAR, C28 VARIANT, C29 OBJECT, C30 ARRAY, C31 GEOGRAPHY + C24 TIMESTAMP_NTZ, C25 TINYINT, C26 VARBINARY, C27 VARCHAR, C28 VARIANT, C29 OBJECT, C30 ARRAY, C31 GEOGRAPHY, + C32 GEOMETRY ) """ ) @@ -1635,7 +1636,8 @@ def test_result_type_and_value(engine_testaccount): C1 BIGINT, C2 BINARY, C3 BOOLEAN, C4 CHAR, C5 CHARACTER, C6 DATE, C7 DATETIME, C8 DEC(12,3), C9 DECIMAL(12,3), C10 DOUBLE, C11 FLOAT, C12 INT, C13 INTEGER, C14 NUMBER, C15 REAL, C16 BYTEINT, C17 SMALLINT, C18 STRING, C19 TEXT, C20 TIME, C21 TIMESTAMP, C22 TIMESTAMP_TZ, C23 TIMESTAMP_LTZ, - C24 TIMESTAMP_NTZ, C25 TINYINT, C26 VARBINARY, C27 VARCHAR, C28 VARIANT, C29 OBJECT, C30 ARRAY, C31 GEOGRAPHY + C24 TIMESTAMP_NTZ, C25 TINYINT, C26 VARBINARY, C27 VARCHAR, C28 VARIANT, C29 OBJECT, C30 ARRAY, C31 GEOGRAPHY, + C32 GEOMETRY ) """ ) @@ -1661,6 +1663,8 @@ def test_result_type_and_value(engine_testaccount): CHAR_VALUE = "A" GEOGRAPHY_VALUE = "POINT(-122.35 37.55)" GEOGRAPHY_RESULT_VALUE = '{"coordinates": [-122.35,37.55],"type": "Point"}' + GEOMETRY_VALUE = "POINT(-94.58473 39.08985)" + GEOMETRY_RESULT_VALUE = '{"coordinates": [-94.58473,39.08985],"type": "Point"}' ins = table_reflected.insert().values( c1=MAX_INT_VALUE, # BIGINT @@ -1694,6 +1698,7 @@ def test_result_type_and_value(engine_testaccount): c29=None, # OBJECT, currently snowflake-sqlalchemy/connector does not support binding variant c30=None, # ARRAY, currently snowflake-sqlalchemy/connector does not support binding variant c31=GEOGRAPHY_VALUE, # GEOGRAPHY + c32=GEOMETRY_VALUE, # GEOMETRY ) conn.execute(ins) @@ -1732,6 +1737,7 @@ def test_result_type_and_value(engine_testaccount): and result[28] is None and result[29] is None and json.loads(result[30]) == json.loads(GEOGRAPHY_RESULT_VALUE) + and json.loads(result[31]) == json.loads(GEOMETRY_RESULT_VALUE) ) sql = f""" diff --git a/tests/test_custom_types.py b/tests/test_custom_types.py index b7962199..a997ffe8 100644 --- a/tests/test_custom_types.py +++ b/tests/test_custom_types.py @@ -15,6 +15,7 @@ def test_string_conversions(): "TIMESTAMP_LTZ", "TIMESTAMP_NTZ", "GEOGRAPHY", + "GEOMETRY", ] sf_types = [ "TEXT", diff --git a/tests/test_geometry.py b/tests/test_geometry.py new file mode 100644 index 00000000..742b518e --- /dev/null +++ b/tests/test_geometry.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from json import loads + +from sqlalchemy import Column, Integer, MetaData, Table +from sqlalchemy.sql import select + +from snowflake.sqlalchemy import GEOMETRY + + +def test_create_table_geometry_datatypes(engine_testaccount): + """ + Create table including geometry data types + """ + metadata = MetaData() + table_name = "test_geometry0" + test_geometry = Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("geom", GEOMETRY), + ) + metadata.create_all(engine_testaccount) + try: + assert test_geometry is not None + finally: + test_geometry.drop(engine_testaccount) + + +def test_inspect_geometry_datatypes(engine_testaccount): + """ + Create table including geometry data types + """ + metadata = MetaData() + table_name = "test_geometry0" + test_geometry = Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("geom1", GEOMETRY), + Column("geom2", GEOMETRY), + ) + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + test_point = "POINT(-94.58473 39.08985)" + test_point1 = '{"coordinates": [-94.58473, 39.08985],"type": "Point"}' + + ins = test_geometry.insert().values( + id=1, geom1=test_point, geom2=test_point1 + ) + + with conn.begin(): + results = conn.execute(ins) + results.close() + + s = select(test_geometry) + results = conn.execute(s) + rows = results.fetchone() + results.close() + assert rows[0] == 1 + assert rows[1] == rows[2] + assert loads(rows[2]) == loads(test_point1) + finally: + test_geometry.drop(engine_testaccount) diff --git a/tests/util.py b/tests/util.py index f53333c4..db0b0c9c 100644 --- a/tests/util.py +++ b/tests/util.py @@ -28,6 +28,7 @@ from snowflake.sqlalchemy.custom_types import ( ARRAY, GEOGRAPHY, + GEOMETRY, OBJECT, TIMESTAMP_LTZ, TIMESTAMP_NTZ, @@ -70,6 +71,7 @@ "OBJECT": OBJECT, "ARRAY": ARRAY, "GEOGRAPHY": GEOGRAPHY, + "GEOMETRY": GEOMETRY, } From 67f4152e5e54b2015046dbcb6c0cb05d617bd621 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Tue, 22 Aug 2023 17:17:17 -0700 Subject: [PATCH 09/74] SNOW-897870: remove connector submodule in sqlalchemy (#442) --- .github/workflows/build_test.yml | 43 +------------------------------- .gitmodules | 3 --- tests/connector_regression | 1 - tox.ini | 10 +------- 4 files changed, 2 insertions(+), 55 deletions(-) delete mode 160000 tests/connector_regression diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index 9c648c73..eae3afe3 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -97,51 +97,10 @@ jobs: .tox/.coverage .tox/coverage.xml - test_connector_regression: - name: Connector Regression Test ${{ matrix.os.download_name }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} - needs: lint - runs-on: ${{ matrix.os.image_name }} - strategy: - fail-fast: false - matrix: - os: - - image_name: ubuntu-latest - download_name: manylinux_x86_64 - python-version: ["3.8"] - cloud-provider: [aws] - steps: - - uses: actions/checkout@v2 - with: - submodules: true - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Display Python version - run: python -c "import sys; print(sys.version)" - - name: Setup parameters file - shell: bash - env: - PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} - run: | - gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ - .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/connector_regression/test/parameters.py - - name: Upgrade setuptools, pip and wheel - run: python -m pip install -U setuptools pip wheel - - name: Install tox - run: python -m pip install tox - - name: List installed packages - run: python -m pip freeze - - name: Run tests - run: python -m tox -e connector_regression --skip-missing-interpreters false - env: - PYTEST_ADDOPTS: -vvv --color=yes --tb=short - TOX_PARALLEL_NO_SPINNER: 1 - combine-coverage: if: ${{ success() || failure() }} name: Combine coverage - needs: [test, test_connector_regression] + needs: [test] runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 diff --git a/.gitmodules b/.gitmodules index 6f8702e6..e69de29b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "tests/connector_regression"] - path = tests/connector_regression - url = git@github.com:snowflakedb/snowflake-connector-python diff --git a/tests/connector_regression b/tests/connector_regression deleted file mode 160000 index 6c365ab5..00000000 --- a/tests/connector_regression +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 6c365ab5e4d11621e78194c723b45528eaa807e2 diff --git a/tox.ini b/tox.ini index 99891d22..0c1cb483 100644 --- a/tox.ini +++ b/tox.ini @@ -3,7 +3,6 @@ min_version = 4.0.0 envlist = fix_lint, py{37,38,39,310,311}{,-pandas}, coverage, - connector_regression skip_missing_interpreters = true [testenv] @@ -52,13 +51,6 @@ commands = pytest \ --run_v20_sqlalchemy \ {posargs:tests} -[testenv:connector_regression] -deps = pendulum -commands = pytest \ - {env:SNOWFLAKE_PYTEST_OPTS:} \ - -m "not gcp and not azure" \ - {posargs:tests/connector_regression/test} - [testenv:.pkg_external] deps = build package_glob = {toxinidir}{/}dist{/}*.whl @@ -94,7 +86,7 @@ commands = pre-commit run --all-files python -c 'import pathlib; print("hint: run \{\} install to add checks as pre-commit hook".format(pathlib.Path(r"{envdir}") / "bin" / "pre-commit"))' [pytest] -addopts = -ra --strict-markers --ignore=tests/sqlalchemy_test_suite --ignore=tests/connector_regression +addopts = -ra --strict-markers --ignore=tests/sqlalchemy_test_suite junit_family = legacy log_level = info markers = From ab1269dbf2fe7907487529e135ad3897f7af9fc0 Mon Sep 17 00:00:00 2001 From: Angel Antonio Avalos Cisneros Date: Wed, 23 Aug 2023 16:34:46 -0700 Subject: [PATCH 10/74] SNOW-897060 - Release 1.5.0 (#443) * Bump up version to 1.4.8 * update version * update release version to 1.5.0 --- DESCRIPTION.md | 4 ++-- src/snowflake/sqlalchemy/version.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 2b2fc44e..c4c755bf 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,10 +9,10 @@ Source code is also available at: # Release Notes -- v1.4.8(TBD) +- v1.5.0(Aug 23, 2023) - Added option to create a temporary stage command. - - Add support for geometry type. + - Added support for geometry type. - Fixed a compatibility issue of regex expression with SQLAlchemy 1.4.49. - v1.4.7(Mar 22, 2023) diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index f2b7d15d..b45b8c09 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -3,4 +3,4 @@ # # Update this for the versions # Don't change the forth version number from None -VERSION = (1, 4, 7, None) +VERSION = (1, 5, 0, None) From 83f69b847ac849a6c8d072606a6235e2c1e260d1 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Tue, 10 Oct 2023 11:26:37 -0700 Subject: [PATCH 11/74] SNOW-876389: compatibility with bcr-1057 on outer lateral table join (#444) --- DESCRIPTION.md | 4 + src/snowflake/sqlalchemy/base.py | 377 ++++++++++++++++++++++++++++++- src/snowflake/sqlalchemy/util.py | 216 +++++++++++++++++- tests/test_compiler.py | 13 ++ tests/test_orm.py | 87 ++++++- 5 files changed, 694 insertions(+), 3 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index c4c755bf..da226b07 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,6 +9,10 @@ Source code is also available at: # Release Notes +- v1.5.1(Unreleased) + + - Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057. + - v1.5.0(Aug 23, 2023) - Added option to create a temporary stage command. diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 525522ea..f229fb93 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -2,18 +2,30 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +import itertools import operator import re +from sqlalchemy import exc as sa_exc +from sqlalchemy import inspect, sql from sqlalchemy import util as sa_util from sqlalchemy.engine import default +from sqlalchemy.orm import context +from sqlalchemy.orm.context import _MapperEntity from sqlalchemy.schema import Sequence, Table from sqlalchemy.sql import compiler, expression +from sqlalchemy.sql.base import CompileState from sqlalchemy.sql.elements import quoted_name +from sqlalchemy.sql.selectable import Lateral, SelectState from sqlalchemy.util.compat import string_types from .custom_commands import AWSBucket, AzureContainer, ExternalStage -from .util import _set_connection_interpolate_empty_sequences +from .util import ( + _find_left_clause_to_join_from, + _set_connection_interpolate_empty_sequences, + _Snowflake_ORMJoin, + _Snowflake_Selectable_Join, +) RESERVED_WORDS = frozenset( [ @@ -88,6 +100,330 @@ ) +""" +Overwrite methods to handle Snowflake BCR change: +https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057 +- _join_determine_implicit_left_side +- _join_left_to_right +""" + + +# handle Snowflake BCR bcr-1057 +@CompileState.plugin_for("default", "select") +class SnowflakeSelectState(SelectState): + def _setup_joins(self, args, raw_columns): + for (right, onclause, left, flags) in args: + isouter = flags["isouter"] + full = flags["full"] + + if left is None: + ( + left, + replace_from_obj_index, + ) = self._join_determine_implicit_left_side( + raw_columns, left, right, onclause + ) + else: + (replace_from_obj_index) = self._join_place_explicit_left_side(left) + + if replace_from_obj_index is not None: + # splice into an existing element in the + # self._from_obj list + left_clause = self.from_clauses[replace_from_obj_index] + + self.from_clauses = ( + self.from_clauses[:replace_from_obj_index] + + ( + _Snowflake_Selectable_Join( # handle Snowflake BCR bcr-1057 + left_clause, + right, + onclause, + isouter=isouter, + full=full, + ), + ) + + self.from_clauses[replace_from_obj_index + 1 :] + ) + else: + self.from_clauses = self.from_clauses + ( + # handle Snowflake BCR bcr-1057 + _Snowflake_Selectable_Join( + left, right, onclause, isouter=isouter, full=full + ), + ) + + @sa_util.preload_module("sqlalchemy.sql.util") + def _join_determine_implicit_left_side(self, raw_columns, left, right, onclause): + """When join conditions don't express the left side explicitly, + determine if an existing FROM or entity in this query + can serve as the left hand side. + + """ + + replace_from_obj_index = None + + from_clauses = self.from_clauses + + if from_clauses: + # handle Snowflake BCR bcr-1057 + indexes = _find_left_clause_to_join_from(from_clauses, right, onclause) + + if len(indexes) == 1: + replace_from_obj_index = indexes[0] + left = from_clauses[replace_from_obj_index] + else: + potential = {} + statement = self.statement + + for from_clause in itertools.chain( + itertools.chain.from_iterable( + [element._from_objects for element in raw_columns] + ), + itertools.chain.from_iterable( + [element._from_objects for element in statement._where_criteria] + ), + ): + + potential[from_clause] = () + + all_clauses = list(potential.keys()) + # handle Snowflake BCR bcr-1057 + indexes = _find_left_clause_to_join_from(all_clauses, right, onclause) + + if len(indexes) == 1: + left = all_clauses[indexes[0]] + + if len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already to " + "help resolve the ambiguity." + ) + elif not indexes: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already to " + "help resolve the ambiguity." % (right,) + ) + return left, replace_from_obj_index + + +# handle Snowflake BCR bcr-1057 +@sql.base.CompileState.plugin_for("orm", "select") +class SnowflakeORMSelectCompileState(context.ORMSelectCompileState): + def _join_determine_implicit_left_side( + self, entities_collection, left, right, onclause + ): + """When join conditions don't express the left side explicitly, + determine if an existing FROM or entity in this query + can serve as the left hand side. + + """ + + # when we are here, it means join() was called without an ORM- + # specific way of telling us what the "left" side is, e.g.: + # + # join(RightEntity) + # + # or + # + # join(RightEntity, RightEntity.foo == LeftEntity.bar) + # + + r_info = inspect(right) + + replace_from_obj_index = use_entity_index = None + + if self.from_clauses: + # we have a list of FROMs already. So by definition this + # join has to connect to one of those FROMs. + + # handle Snowflake BCR bcr-1057 + indexes = _find_left_clause_to_join_from( + self.from_clauses, r_info.selectable, onclause + ) + + if len(indexes) == 1: + replace_from_obj_index = indexes[0] + left = self.from_clauses[replace_from_obj_index] + elif len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." + ) + else: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." % (right,) + ) + + elif entities_collection: + # we have no explicit FROMs, so the implicit left has to + # come from our list of entities. + + potential = {} + for entity_index, ent in enumerate(entities_collection): + entity = ent.entity_zero_or_selectable + if entity is None: + continue + ent_info = inspect(entity) + if ent_info is r_info: # left and right are the same, skip + continue + + # by using a dictionary with the selectables as keys this + # de-duplicates those selectables as occurs when the query is + # against a series of columns from the same selectable + if isinstance(ent, context._MapperEntity): + potential[ent.selectable] = (entity_index, entity) + else: + potential[ent_info.selectable] = (None, entity) + + all_clauses = list(potential.keys()) + # handle Snowflake BCR bcr-1057 + indexes = _find_left_clause_to_join_from( + all_clauses, r_info.selectable, onclause + ) + + if len(indexes) == 1: + use_entity_index, left = potential[all_clauses[indexes[0]]] + elif len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." + ) + else: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." % (right,) + ) + else: + raise sa_exc.InvalidRequestError( + "No entities to join from; please use " + "select_from() to establish the left " + "entity/selectable of this join" + ) + + return left, replace_from_obj_index, use_entity_index + + def _join_left_to_right( + self, + entities_collection, + left, + right, + onclause, + prop, + create_aliases, + aliased_generation, + outerjoin, + full, + ): + """given raw "left", "right", "onclause" parameters consumed from + a particular key within _join(), add a real ORMJoin object to + our _from_obj list (or augment an existing one) + + """ + + if left is None: + # left not given (e.g. no relationship object/name specified) + # figure out the best "left" side based on our existing froms / + # entities + assert prop is None + ( + left, + replace_from_obj_index, + use_entity_index, + ) = self._join_determine_implicit_left_side( + entities_collection, left, right, onclause + ) + else: + # left is given via a relationship/name, or as explicit left side. + # Determine where in our + # "froms" list it should be spliced/appended as well as what + # existing entity it corresponds to. + ( + replace_from_obj_index, + use_entity_index, + ) = self._join_place_explicit_left_side(entities_collection, left) + + if left is right and not create_aliases: + raise sa_exc.InvalidRequestError( + "Can't construct a join from %s to %s, they " + "are the same entity" % (left, right) + ) + + # the right side as given often needs to be adapted. additionally + # a lot of things can be wrong with it. handle all that and + # get back the new effective "right" side + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop, create_aliases, aliased_generation + ) + + if not r_info.is_selectable: + extra_criteria = self._get_extra_criteria(r_info) + else: + extra_criteria = () + + if replace_from_obj_index is not None: + # splice into an existing element in the + # self._from_obj list + left_clause = self.from_clauses[replace_from_obj_index] + + self.from_clauses = ( + self.from_clauses[:replace_from_obj_index] + + [ + _Snowflake_ORMJoin( # handle Snowflake BCR bcr-1057 + left_clause, + right, + onclause, + isouter=outerjoin, + full=full, + _extra_criteria=extra_criteria, + ) + ] + + self.from_clauses[replace_from_obj_index + 1 :] + ) + else: + # add a new element to the self._from_obj list + if use_entity_index is not None: + # make use of _MapperEntity selectable, which is usually + # entity_zero.selectable, but if with_polymorphic() were used + # might be distinct + assert isinstance(entities_collection[use_entity_index], _MapperEntity) + left_clause = entities_collection[use_entity_index].selectable + else: + left_clause = left + + self.from_clauses = self.from_clauses + [ + _Snowflake_ORMJoin( # handle Snowflake BCR bcr-1057 + left_clause, + r_info, + onclause, + isouter=outerjoin, + full=full, + _extra_criteria=extra_criteria, + ) + ] + + class SnowflakeIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = {x.lower() for x in RESERVED_WORDS} @@ -392,6 +728,45 @@ def visit_regexp_replace_op_binary(self, binary, operator, **kw): def visit_not_regexp_match_op_binary(self, binary, operator, **kw): return f"NOT {self.visit_regexp_match_op_binary(binary, operator, **kw)}" + def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + if from_linter: + from_linter.edges.update( + itertools.product(join.left._from_objects, join.right._from_objects) + ) + + if join.full: + join_type = " FULL OUTER JOIN " + elif join.isouter: + join_type = " LEFT OUTER JOIN " + else: + join_type = " JOIN " + + join_statement = ( + join.left._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + + join_type + + join.right._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + ) + + if join.onclause is None and isinstance(join.right, Lateral): + # in snowflake, onclause is not accepted for lateral due to BCR change: + # https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057 + # sqlalchemy only allows join with on condition. + # to adapt to snowflake syntax change, + # we make the change such that when oncaluse is None and the right part is + # Lateral, we do not append the on condition + return join_statement + + return ( + join_statement + + " ON " + # TODO: likely need asfrom=True here? + + join.onclause._compiler_dispatch(self, from_linter=from_linter, **kwargs) + ) + def render_literal_value(self, value, type_): # escape backslash return super().render_literal_value(value, type_).replace("\\", "\\\\") diff --git a/src/snowflake/sqlalchemy/util.py b/src/snowflake/sqlalchemy/util.py index 56b5de5b..631ceaee 100644 --- a/src/snowflake/sqlalchemy/util.py +++ b/src/snowflake/sqlalchemy/util.py @@ -3,10 +3,19 @@ # import re +from itertools import chain from typing import Any from urllib.parse import quote_plus -from sqlalchemy import exc +from sqlalchemy import exc, inspection, sql, util +from sqlalchemy.exc import NoForeignKeysError +from sqlalchemy.orm.interfaces import MapperProperty +from sqlalchemy.orm.util import _ORMJoin as sa_orm_util_ORMJoin +from sqlalchemy.orm.util import attributes +from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql.base import _expand_cloned, _from_objects +from sqlalchemy.sql.elements import _find_columns +from sqlalchemy.sql.selectable import Join, Lateral, coercions, operators, roles from snowflake.connector.compat import IS_STR from snowflake.connector.connection import SnowflakeConnection @@ -104,3 +113,208 @@ def _update_connection_application_name(**conn_kwargs: Any) -> Any: if PARAM_INTERNAL_APPLICATION_VERSION not in conn_kwargs: conn_kwargs[PARAM_INTERNAL_APPLICATION_VERSION] = SNOWFLAKE_SQLALCHEMY_VERSION return conn_kwargs + + +# handle Snowflake BCR bcr-1057 +# the BCR impacts sqlalchemy.orm.context.ORMSelectCompileState and sqlalchemy.sql.selectable.SelectState +# which used the 'sqlalchemy.util.preloaded.sql_util.find_left_clause_to_join_from' method that +# can not handle the BCR change, we implement it in a way that lateral join does not need onclause +def _find_left_clause_to_join_from(clauses, join_to, onclause): + """Given a list of FROM clauses, a selectable, + and optional ON clause, return a list of integer indexes from the + clauses list indicating the clauses that can be joined from. + + The presence of an "onclause" indicates that at least one clause can + definitely be joined from; if the list of clauses is of length one + and the onclause is given, returns that index. If the list of clauses + is more than length one, and the onclause is given, attempts to locate + which clauses contain the same columns. + + """ + idx = [] + selectables = set(_from_objects(join_to)) + + # if we are given more than one target clause to join + # from, use the onclause to provide a more specific answer. + # otherwise, don't try to limit, after all, "ON TRUE" is a valid + # on clause + if len(clauses) > 1 and onclause is not None: + resolve_ambiguity = True + cols_in_onclause = _find_columns(onclause) + else: + resolve_ambiguity = False + cols_in_onclause = None + + for i, f in enumerate(clauses): + for s in selectables.difference([f]): + if resolve_ambiguity: + if set(f.c).union(s.c).issuperset(cols_in_onclause): + idx.append(i) + break + elif onclause is not None or Join._can_join(f, s): + idx.append(i) + break + elif onclause is None and isinstance(s, Lateral): + # in snowflake, onclause is not accepted for lateral due to BCR change: + # https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057 + # sqlalchemy only allows join with on condition. + # to adapt to snowflake syntax change, + # we make the change such that when oncaluse is None and the right part is + # Lateral, we append the index indicating Lateral clause can be joined from with without onclause. + idx.append(i) + break + + if len(idx) > 1: + # this is the same "hide froms" logic from + # Selectable._get_display_froms + toremove = set(chain(*[_expand_cloned(f._hide_froms) for f in clauses])) + idx = [i for i in idx if clauses[i] not in toremove] + + # onclause was given and none of them resolved, so assume + # all indexes can match + if not idx and onclause is not None: + return range(len(clauses)) + else: + return idx + + +class _Snowflake_ORMJoin(sa_orm_util_ORMJoin): + def __init__( + self, + left, + right, + onclause=None, + isouter=False, + full=False, + _left_memo=None, + _right_memo=None, + _extra_criteria=(), + ): + left_info = inspection.inspect(left) + + right_info = inspection.inspect(right) + adapt_to = right_info.selectable + + # used by joined eager loader + self._left_memo = _left_memo + self._right_memo = _right_memo + + # legacy, for string attr name ON clause. if that's removed + # then the "_joined_from_info" concept can go + left_orm_info = getattr(left, "_joined_from_info", left_info) + self._joined_from_info = right_info + if isinstance(onclause, util.string_types): + onclause = getattr(left_orm_info.entity, onclause) + # #### + + if isinstance(onclause, attributes.QueryableAttribute): + on_selectable = onclause.comparator._source_selectable() + prop = onclause.property + _extra_criteria += onclause._extra_criteria + elif isinstance(onclause, MapperProperty): + # used internally by joined eager loader...possibly not ideal + prop = onclause + on_selectable = prop.parent.selectable + else: + prop = None + + if prop: + left_selectable = left_info.selectable + + if sql_util.clause_is_present(on_selectable, left_selectable): + adapt_from = on_selectable + else: + adapt_from = left_selectable + + (pj, sj, source, dest, secondary, target_adapter,) = prop._create_joins( + source_selectable=adapt_from, + dest_selectable=adapt_to, + source_polymorphic=True, + of_type_entity=right_info, + alias_secondary=True, + extra_criteria=_extra_criteria, + ) + + if sj is not None: + if isouter: + # note this is an inner join from secondary->right + right = sql.join(secondary, right, sj) + onclause = pj + else: + left = sql.join(left, secondary, pj, isouter) + onclause = sj + else: + onclause = pj + + self._target_adapter = target_adapter + + # we don't use the normal coercions logic for _ORMJoin + # (probably should), so do some gymnastics to get the entity. + # logic here is for #8721, which was a major bug in 1.4 + # for almost two years, not reported/fixed until 1.4.43 (!) + if left_info.is_selectable: + parententity = left_selectable._annotations.get("parententity", None) + elif left_info.is_mapper or left_info.is_aliased_class: + parententity = left_info + else: + parententity = None + + if parententity is not None: + self._annotations = self._annotations.union( + {"parententity": parententity} + ) + + augment_onclause = onclause is None and _extra_criteria + # handle Snowflake BCR bcr-1057 + _Snowflake_Selectable_Join.__init__(self, left, right, onclause, isouter, full) + + if augment_onclause: + self.onclause &= sql.and_(*_extra_criteria) + + if ( + not prop + and getattr(right_info, "mapper", None) + and right_info.mapper.single + ): + # if single inheritance target and we are using a manual + # or implicit ON clause, augment it the same way we'd augment the + # WHERE. + single_crit = right_info.mapper._single_table_criterion + if single_crit is not None: + if right_info.is_aliased_class: + single_crit = right_info._adapter.traverse(single_crit) + self.onclause = self.onclause & single_crit + + +class _Snowflake_Selectable_Join(Join): + def __init__(self, left, right, onclause=None, isouter=False, full=False): + """Construct a new :class:`_expression.Join`. + + The usual entrypoint here is the :func:`_expression.join` + function or the :meth:`_expression.FromClause.join` method of any + :class:`_expression.FromClause` object. + + """ + self.left = coercions.expect(roles.FromClauseRole, left, deannotate=True) + self.right = coercions.expect( + roles.FromClauseRole, right, deannotate=True + ).self_group() + + if onclause is None: + try: + self.onclause = self._match_primaries(self.left, self.right) + except NoForeignKeysError: + # handle Snowflake BCR bcr-1057 + if isinstance(self.right, Lateral): + self.onclause = None + else: + raise + else: + # note: taken from If91f61527236fd4d7ae3cad1f24c38be921c90ba + # not merged yet + self.onclause = coercions.expect(roles.OnClauseRole, onclause).self_group( + against=operators._asbool + ) + + self.isouter = isouter + self.full = full diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 4098f915..0fd75c38 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -7,6 +7,8 @@ from sqlalchemy.sql import column, quoted_name, table from sqlalchemy.testing import AssertsCompiledSQL +from snowflake.sqlalchemy import snowdialect + table1 = table( "table1", column("id", Integer), column("name", String), column("value", Integer) ) @@ -107,3 +109,14 @@ def test_quoted_name_label(engine_testaccount): sel_from_tbl = select(col).group_by(col).select_from(table("abc")) compiled_result = sel_from_tbl.compile() assert str(compiled_result) == t["output"] + + +def test_outer_lateral_join(): + col = column("colname").label("label") + col2 = column("colname2").label("label2") + lateral_table = func.flatten(func.PARSE_JSON(col2), outer=True).lateral() + stmt = select(col).select_from(table("abc")).join(lateral_table).group_by(col) + assert ( + str(stmt.compile(dialect=snowdialect.dialect())) + == "SELECT colname AS label \nFROM abc JOIN LATERAL flatten(PARSE_JSON(colname2)) AS anon_1 GROUP BY colname" + ) diff --git a/tests/test_orm.py b/tests/test_orm.py index 363da671..e485d737 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -3,9 +3,20 @@ # import enum +import logging import pytest -from sqlalchemy import Column, Enum, ForeignKey, Integer, Sequence, String, text +from sqlalchemy import ( + Column, + Enum, + ForeignKey, + Integer, + Sequence, + String, + func, + select, + text, +) from sqlalchemy.orm import Session, declarative_base, relationship @@ -326,3 +337,77 @@ class User(Base): assert user.fullname == "test_user" finally: Base.metadata.drop_all(con) + + +def test_outer_lateral_join(engine_testaccount, caplog): + Base = declarative_base() + + class Employee(Base): + __tablename__ = "employees" + + employee_id = Column(Integer, primary_key=True) + last_name = Column(String) + + class Department(Base): + __tablename__ = "departments" + + department_id = Column(Integer, primary_key=True) + name = Column(String) + + Base.metadata.create_all(engine_testaccount) + session = Session(bind=engine_testaccount) + e1 = Employee(employee_id=101, last_name="Richards") + d1 = Department(department_id=1, name="Engineering") + session.add_all([e1, d1]) + session.commit() + + sub = select(Department).lateral() + query = ( + select(Employee.employee_id, Department.department_id) + .select_from(Employee) + .outerjoin(sub) + ) + assert ( + str(query.compile(engine_testaccount)).replace("\n", "") + == "SELECT employees.employee_id, departments.department_id " + "FROM departments, employees LEFT OUTER JOIN LATERAL " + "(SELECT departments.department_id AS department_id, departments.name AS name " + "FROM departments) AS anon_1" + ) + with caplog.at_level(logging.DEBUG): + assert [res for res in session.execute(query)] + assert ( + "SELECT employees.employee_id, departments.department_id FROM departments" + in caplog.text + ) + + +def test_lateral_join_without_condition(engine_testaccount, caplog): + Base = declarative_base() + + class Employee(Base): + __tablename__ = "Employee" + + pkey = Column(String, primary_key=True) + uid = Column(Integer) + content = Column(String) + + Base.metadata.create_all(engine_testaccount) + lateral_table = func.flatten( + func.PARSE_JSON(Employee.content), outer=False + ).lateral() + query = ( + select( + Employee.uid, + ) + .select_from(Employee) + .join(lateral_table) + .where(Employee.uid == "123") + ) + session = Session(bind=engine_testaccount) + with caplog.at_level(logging.DEBUG): + session.execute(query) + assert ( + '[SELECT "Employee".uid FROM "Employee" JOIN LATERAL flatten(PARSE_JSON("Employee"' + in caplog.text + ) From 8b701a80da0890dbc19c3a68a1dc920d368690b4 Mon Sep 17 00:00:00 2001 From: Sophie Tan Date: Thu, 19 Oct 2023 10:34:19 -0400 Subject: [PATCH 12/74] [Test PR] SNOW-892284: Fix boolean parameter parsing from URL query (#446) * SNOW-892284: Fix boolean parameter parsing from URL query * Fix cache_column_metadata being parsed as connector argument cache_column_metadata is a snowflake-sqlalchemy argument, not a snowflake-connector-python argument so it should be set and ommitted from the arguments list before the call to the connector is made. * Add link to Github issue for dealing with other URL param types * Add patch note * Handle case when cache_column_metadata isn't set * Load cache_column_metadata from query * Add fallthrough for unknown parameters * Fix lint * Update DESCRIPTION.md * Add more inline documentation --------- Co-authored-by: Saulius Beinorius --- DESCRIPTION.md | 2 ++ src/snowflake/sqlalchemy/snowdialect.py | 42 +++++++++++++++++++++---- src/snowflake/sqlalchemy/util.py | 9 ++++++ tests/test_core.py | 21 +++++++++++++ 4 files changed, 68 insertions(+), 6 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index da226b07..47133034 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -12,6 +12,8 @@ Source code is also available at: - v1.5.1(Unreleased) - Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057. + - Fixed credentials with `externalbrowser` authentication not caching due to incorrect parsing of boolean query parameters. + - This fixes other boolean parameter passing to driver as well. - v1.5.0(Aug 23, 2023) diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 350027f4..4fefa07f 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -11,7 +11,7 @@ from sqlalchemy import event as sa_vnt from sqlalchemy import exc as sa_exc from sqlalchemy import util as sa_util -from sqlalchemy.engine import default, reflection +from sqlalchemy.engine import URL, default, reflection from sqlalchemy.schema import Table from sqlalchemy.sql import text from sqlalchemy.sql.elements import quoted_name @@ -38,6 +38,7 @@ ) from snowflake.connector import errors as sf_errors +from snowflake.connector.connection import DEFAULT_CONFIGURATION from snowflake.connector.constants import UTF8 from .base import ( @@ -62,7 +63,7 @@ _CUSTOM_Float, _CUSTOM_Time, ) -from .util import _update_connection_application_name +from .util import _update_connection_application_name, parse_url_boolean colspecs = { Date: _CUSTOM_Date, @@ -109,7 +110,6 @@ "GEOMETRY": GEOMETRY, } - _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME = True @@ -199,7 +199,7 @@ def dbapi(cls): return connector - def create_connect_args(self, url): + def create_connect_args(self, url: URL): opts = url.translate_connect_args(username="user") if "database" in opts: name_spaces = [unquote_plus(e) for e in opts["database"].split("/")] @@ -226,10 +226,40 @@ def create_connect_args(self, url): opts["host"] = opts["host"] + ".snowflakecomputing.com" opts["port"] = "443" opts["autocommit"] = False # autocommit is disabled by default - opts.update(url.query) + + query = dict(**url.query) # make mutable + cache_column_metadata = query.pop("cache_column_metadata", None) self._cache_column_metadata = ( - opts.get("cache_column_metadata", "false").lower() == "true" + parse_url_boolean(cache_column_metadata) if cache_column_metadata else False ) + + # URL sets the query parameter values as strings, we need to cast to expected types when necessary + for name, value in query.items(): + maybe_type_configuration = DEFAULT_CONFIGURATION.get(name) + if ( + not maybe_type_configuration + ): # if the parameter is not found in the type mapping, pass it through as a string + opts[name] = value + continue + + (_, expected_type) = maybe_type_configuration + if not isinstance(expected_type, tuple): + expected_type = (expected_type,) + + if isinstance( + value, expected_type + ): # if the expected type is str, pass it through as a string + opts[name] = value + + elif ( + bool in expected_type + ): # if the expected type is bool, parse it and pass as a boolean + opts[name] = parse_url_boolean(value) + else: + # TODO: other types like int are stil passed through as string + # https://github.com/snowflakedb/snowflake-sqlalchemy/issues/447 + opts[name] = value + return ([], opts) def has_table(self, connection, table_name, schema=None): diff --git a/src/snowflake/sqlalchemy/util.py b/src/snowflake/sqlalchemy/util.py index 631ceaee..54044349 100644 --- a/src/snowflake/sqlalchemy/util.py +++ b/src/snowflake/sqlalchemy/util.py @@ -115,6 +115,15 @@ def _update_connection_application_name(**conn_kwargs: Any) -> Any: return conn_kwargs +def parse_url_boolean(value: str) -> bool: + if value == "True": + return True + elif value == "False": + return False + else: + raise ValueError(f"Invalid boolean value detected: '{value}'") + + # handle Snowflake BCR bcr-1057 # the BCR impacts sqlalchemy.orm.context.ORMSelectCompileState and sqlalchemy.sql.selectable.SelectState # which used the 'sqlalchemy.util.preloaded.sql_util.find_left_clause_to_join_from' method that diff --git a/tests/test_core.py b/tests/test_core.py index 8206e43d..29c55ae9 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -173,6 +173,27 @@ def test_connect_args(): engine.dispose() +def test_boolean_query_argument_parsing(): + engine = create_engine( + URL( + user=CONNECTION_PARAMETERS["user"], + password=CONNECTION_PARAMETERS["password"], + account=CONNECTION_PARAMETERS["account"], + host=CONNECTION_PARAMETERS["host"], + port=CONNECTION_PARAMETERS["port"], + protocol=CONNECTION_PARAMETERS["protocol"], + validate_default_parameters=True, + ) + ) + try: + verify_engine_connection(engine) + connection = engine.raw_connection() + assert connection.validate_default_parameters is True + finally: + connection.close() + engine.dispose() + + def test_create_dialect(): """ Tests getting only dialect object through create_engine From 9bbd4ccf738d66eae3e2d68bf4a6d13cac7045b0 Mon Sep 17 00:00:00 2001 From: Jiazhen Fan <52474868+sfc-gh-jfan@users.noreply.github.com> Date: Mon, 23 Oct 2023 13:47:46 -0700 Subject: [PATCH 13/74] PRODSEC-3611 fix GHA parsing (#455) --- .github/workflows/snyk-issue.yml | 5 +++++ .github/workflows/snyk-pr.yml | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/.github/workflows/snyk-issue.yml b/.github/workflows/snyk-issue.yml index 7098b01e..c8f5d90b 100644 --- a/.github/workflows/snyk-issue.yml +++ b/.github/workflows/snyk-issue.yml @@ -4,6 +4,11 @@ on: schedule: - cron: '* */12 * * *' +permissions: + contents: read + issues: write + pull-requests: write + concurrency: snyk-issue jobs: diff --git a/.github/workflows/snyk-pr.yml b/.github/workflows/snyk-pr.yml index 51e531f4..b951af65 100644 --- a/.github/workflows/snyk-pr.yml +++ b/.github/workflows/snyk-pr.yml @@ -3,6 +3,12 @@ on: pull_request: branches: - main + +permissions: + contents: read + issues: write + pull-requests: write + jobs: snyk: runs-on: ubuntu-latest From 61975bf607434fe8427aaae926549e8ffd3e9bb8 Mon Sep 17 00:00:00 2001 From: Mark Keller Date: Wed, 1 Nov 2023 23:48:45 +0000 Subject: [PATCH 14/74] SNOW-XXX: Bumped up SQLAlchemy PATCH version from 1.5.0 to 1.5.1 (#460) --- DESCRIPTION.md | 2 +- src/snowflake/sqlalchemy/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 47133034..5751354c 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,7 +9,7 @@ Source code is also available at: # Release Notes -- v1.5.1(Unreleased) +- v1.5.1(November 03, 2023) - Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057. - Fixed credentials with `externalbrowser` authentication not caching due to incorrect parsing of boolean query parameters. diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index b45b8c09..6aea4f54 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -3,4 +3,4 @@ # # Update this for the versions # Don't change the forth version number from None -VERSION = (1, 5, 0, None) +VERSION = (1, 5, 1, None) From 8ac9f3355ef3075d8cb80ee97f07d5f1bc9b8121 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Fri, 16 Feb 2024 10:42:54 +0100 Subject: [PATCH 15/74] SNOW-1056848-missing-test-action-output: remove TOX_PARALLEL_NO_SPINNER env variable (#468) --- .github/workflows/build_test.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index eae3afe3..5284a088 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -86,7 +86,6 @@ jobs: env: PYTHON_VERSION: ${{ matrix.python-version }} PYTEST_ADDOPTS: -vvv --color=yes --tb=short - TOX_PARALLEL_NO_SPINNER: 1 - name: Combine coverages run: python -m tox -e coverage --skip-missing-interpreters false shell: bash From f384432e3a62a0c5a10f7aeee09eadeaa7431e8d Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Mon, 11 Mar 2024 17:30:17 +0100 Subject: [PATCH 16/74] SNOW-1212541-ordered-sequence: add support for creating sequences order (#473) --- .pre-commit-config.yaml | 12 +- DESCRIPTION.md | 4 + pyproject.toml | 5 + src/snowflake/sqlalchemy/base.py | 65 ++++++-- src/snowflake/sqlalchemy/custom_commands.py | 6 +- src/snowflake/sqlalchemy/snowdialect.py | 40 +++-- src/snowflake/sqlalchemy/util.py | 9 +- src/snowflake/sqlalchemy/version.py | 2 +- tests/test_core.py | 79 +++++---- tests/test_sequence.py | 173 +++++++++++++------- 10 files changed, 254 insertions(+), 141 deletions(-) create mode 100644 pyproject.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3be42964..70b75ce8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ exclude: '^(.*egg.info.*|.*/parameters.py).*$' repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.5.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -9,23 +9,23 @@ repos: exclude: .github/repo_meta.yaml - id: debug-statements - repo: https://github.com/PyCQA/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort - repo: https://github.com/asottile/pyupgrade - rev: v2.37.3 + rev: v3.15.1 hooks: - id: pyupgrade args: [--py37-plus] - repo: https://github.com/psf/black - rev: 22.6.0 + rev: 24.2.0 hooks: - id: black args: - --safe language_version: python3 - repo: https://github.com/Lucas-C/pre-commit-hooks.git - rev: v1.3.0 + rev: v1.5.5 hooks: - id: insert-license name: insert-py-license @@ -39,7 +39,7 @@ repos: - --license-filepath - license_header.txt - repo: https://github.com/pycqa/flake8 - rev: 5.0.4 + rev: 7.0.0 hooks: - id: flake8 additional_dependencies: diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 5751354c..422fe807 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,6 +9,10 @@ Source code is also available at: # Release Notes +- 1.5.2 + + - Add support for sequence ordering in tests + - v1.5.1(November 03, 2023) - Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..907176c3 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,5 @@ +[tool.ruff] +line-length = 88 + +[tool.black] +line-length = 88 diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index f229fb93..2a1bb51a 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -112,7 +112,7 @@ @CompileState.plugin_for("default", "select") class SnowflakeSelectState(SelectState): def _setup_joins(self, args, raw_columns): - for (right, onclause, left, flags) in args: + for right, onclause, left, flags in args: isouter = flags["isouter"] full = flags["full"] @@ -579,9 +579,11 @@ def visit_copy_into(self, copy_into, **kw): [ "{} = {}".format( n, - v._compiler_dispatch(self, **kw) - if getattr(v, "compiler_dispatch", False) - else str(v), + ( + v._compiler_dispatch(self, **kw) + if getattr(v, "compiler_dispatch", False) + else str(v) + ), ) for n, v in options_list ] @@ -604,20 +606,24 @@ def visit_copy_formatter(self, formatter, **kw): return f"FILE_FORMAT=(format_name = {formatter.options['format_name']})" return "FILE_FORMAT=(TYPE={}{})".format( formatter.file_format, - " " - + " ".join( - [ - "{}={}".format( - name, - value._compiler_dispatch(self, **kw) - if hasattr(value, "_compiler_dispatch") - else formatter.value_repr(name, value), - ) - for name, value in options_list - ] - ) - if formatter.options - else "", + ( + " " + + " ".join( + [ + "{}={}".format( + name, + ( + value._compiler_dispatch(self, **kw) + if hasattr(value, "_compiler_dispatch") + else formatter.value_repr(name, value) + ), + ) + for name, value in options_list + ] + ) + if formatter.options + else "" + ), ) def visit_aws_bucket(self, aws_bucket, **kw): @@ -967,6 +973,29 @@ def visit_identity_column(self, identity, **kw): text += f"({start},{increment})" return text + def get_identity_options(self, identity_options): + text = [] + if identity_options.increment is not None: + text.append(f"INCREMENT BY {identity_options.increment:d}") + if identity_options.start is not None: + text.append(f"START WITH {identity_options.start:d}") + if identity_options.minvalue is not None: + text.append(f"MINVALUE {identity_options.minvalue:d}") + if identity_options.maxvalue is not None: + text.append(f"MAXVALUE {identity_options.maxvalue:d}") + if identity_options.nominvalue is not None: + text.append("NO MINVALUE") + if identity_options.nomaxvalue is not None: + text.append("NO MAXVALUE") + if identity_options.cache is not None: + text.append(f"CACHE {identity_options.cache:d}") + if identity_options.cycle is not None: + text.append("CYCLE" if identity_options.cycle else "NO CYCLE") + if identity_options.order is not None: + text.append("ORDER" if identity_options.order else "NOORDER") + + return " ".join(text) + class SnowflakeTypeCompiler(compiler.GenericTypeCompiler): def visit_BYTEINT(self, type_, **kw): diff --git a/src/snowflake/sqlalchemy/custom_commands.py b/src/snowflake/sqlalchemy/custom_commands.py index 9bb60916..cec16673 100644 --- a/src/snowflake/sqlalchemy/custom_commands.py +++ b/src/snowflake/sqlalchemy/custom_commands.py @@ -259,7 +259,8 @@ def field_delimiter(self, deli_type): def file_extension(self, ext): """String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is - responsible for specifying a valid file extension that can be read by the desired software or service.""" + responsible for specifying a valid file extension that can be read by the desired software or service. + """ if not isinstance(ext, (NoneType, string_types)): raise TypeError("File extension should be a string") self.options["FILE_EXTENSION"] = ext @@ -386,7 +387,8 @@ def compression(self, comp_type): def file_extension(self, ext): """String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is - responsible for specifying a valid file extension that can be read by the desired software or service.""" + responsible for specifying a valid file extension that can be read by the desired software or service. + """ if not isinstance(ext, (NoneType, string_types)): raise TypeError("File extension should be a string") self.options["FILE_EXTENSION"] = ext diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 4fefa07f..2e40d03c 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -595,11 +595,13 @@ def _get_schema_columns(self, connection, schema, **kw): "autoincrement": is_identity == "YES", "comment": comment, "primary_key": ( - column_name - in schema_primary_keys[table_name]["constrained_columns"] - ) - if current_table_pks - else False, + ( + column_name + in schema_primary_keys[table_name]["constrained_columns"] + ) + if current_table_pks + else False + ), } ) if is_identity == "YES": @@ -688,11 +690,13 @@ def _get_table_columns(self, connection, table_name, schema=None, **kw): "autoincrement": is_identity == "YES", "comment": comment if comment != "" else None, "primary_key": ( - column_name - in schema_primary_keys[table_name]["constrained_columns"] - ) - if current_table_pks - else False, + ( + column_name + in schema_primary_keys[table_name]["constrained_columns"] + ) + if current_table_pks + else False + ), } ) @@ -876,18 +880,22 @@ def get_table_comment(self, connection, table_name, schema=None, **kw): result = self._get_view_comment(connection, table_name, schema) return { - "text": result._mapping["comment"] - if result and result._mapping["comment"] - else None + "text": ( + result._mapping["comment"] + if result and result._mapping["comment"] + else None + ) } def connect(self, *cargs, **cparams): return ( super().connect( *cargs, - **_update_connection_application_name(**cparams) - if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME - else cparams, + **( + _update_connection_application_name(**cparams) + if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME + else cparams + ), ) if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME else super().connect(*cargs, **cparams) diff --git a/src/snowflake/sqlalchemy/util.py b/src/snowflake/sqlalchemy/util.py index 54044349..32e07373 100644 --- a/src/snowflake/sqlalchemy/util.py +++ b/src/snowflake/sqlalchemy/util.py @@ -235,7 +235,14 @@ def __init__( else: adapt_from = left_selectable - (pj, sj, source, dest, secondary, target_adapter,) = prop._create_joins( + ( + pj, + sj, + source, + dest, + secondary, + target_adapter, + ) = prop._create_joins( source_selectable=adapt_from, dest_selectable=adapt_to, source_polymorphic=True, diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index 6aea4f54..d4318b86 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -3,4 +3,4 @@ # # Update this for the versions # Don't change the forth version number from None -VERSION = (1, 5, 1, None) +VERSION = (1, 5, 2, None) diff --git a/tests/test_core.py b/tests/test_core.py index 29c55ae9..60b4fea4 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -29,6 +29,7 @@ Table, UniqueConstraint, dialects, + insert, inspect, text, ) @@ -1424,47 +1425,51 @@ def test_special_schema_character(db_parameters, on_public_ci): def test_autoincrement(engine_testaccount): + """Snowflake does not guarantee generating sequence numbers without gaps. + + The generated numbers are not necessarily contiguous. + https://docs.snowflake.com/en/user-guide/querying-sequences + """ metadata = MetaData() users = Table( "users", metadata, - Column("uid", Integer, Sequence("id_seq"), primary_key=True), + Column("uid", Integer, Sequence("id_seq", order=True), primary_key=True), Column("name", String(39)), ) try: - users.create(engine_testaccount) - - with engine_testaccount.connect() as connection: - with connection.begin(): - connection.execute(users.insert(), [{"name": "sf1"}]) - assert connection.execute(select(users)).fetchall() == [(1, "sf1")] - connection.execute(users.insert(), [{"name": "sf2"}, {"name": "sf3"}]) - assert connection.execute(select(users)).fetchall() == [ - (1, "sf1"), - (2, "sf2"), - (3, "sf3"), - ] - connection.execute(users.insert(), {"name": "sf4"}) - assert connection.execute(select(users)).fetchall() == [ - (1, "sf1"), - (2, "sf2"), - (3, "sf3"), - (4, "sf4"), - ] - - seq = Sequence("id_seq") - nextid = connection.execute(seq) - connection.execute(users.insert(), [{"uid": nextid, "name": "sf5"}]) - assert connection.execute(select(users)).fetchall() == [ - (1, "sf1"), - (2, "sf2"), - (3, "sf3"), - (4, "sf4"), - (5, "sf5"), - ] + metadata.create_all(engine_testaccount) + + with engine_testaccount.begin() as connection: + connection.execute(insert(users), [{"name": "sf1"}]) + assert connection.execute(select(users)).fetchall() == [(1, "sf1")] + connection.execute(insert(users), [{"name": "sf2"}, {"name": "sf3"}]) + assert connection.execute(select(users)).fetchall() == [ + (1, "sf1"), + (2, "sf2"), + (3, "sf3"), + ] + connection.execute(insert(users), {"name": "sf4"}) + assert connection.execute(select(users)).fetchall() == [ + (1, "sf1"), + (2, "sf2"), + (3, "sf3"), + (4, "sf4"), + ] + + seq = Sequence("id_seq") + nextid = connection.execute(seq) + connection.execute(insert(users), [{"uid": nextid, "name": "sf5"}]) + assert connection.execute(select(users)).fetchall() == [ + (1, "sf1"), + (2, "sf2"), + (3, "sf3"), + (4, "sf4"), + (5, "sf5"), + ] finally: - users.drop(engine_testaccount) + metadata.drop_all(engine_testaccount) @pytest.mark.skip( @@ -1869,10 +1874,16 @@ def test_snowflake_sqlalchemy_as_valid_client_type(): ) snowflake.connector.connection.DEFAULT_CONFIGURATION[ "internal_application_name" - ] = ("PythonConnector", (type(None), str)) + ] = ( + "PythonConnector", + (type(None), str), + ) snowflake.connector.connection.DEFAULT_CONFIGURATION[ "internal_application_version" - ] = ("3.0.0", (type(None), str)) + ] = ( + "3.0.0", + (type(None), str), + ) engine = create_engine( URL( user=CONNECTION_PARAMETERS["user"], diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 78658012..e428b9d7 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -2,89 +2,136 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from sqlalchemy import Column, Integer, MetaData, Sequence, String, Table, select +from sqlalchemy import ( + Column, + Integer, + MetaData, + Sequence, + String, + Table, + insert, + select, +) +from sqlalchemy.sql import text def test_table_with_sequence(engine_testaccount, db_parameters): + """Snowflake does not guarantee generating sequence numbers without gaps. + + The generated numbers are not necessarily contiguous. + https://docs.snowflake.com/en/user-guide/querying-sequences + """ # https://github.com/snowflakedb/snowflake-sqlalchemy/issues/124 test_table_name = "sequence" test_sequence_name = f"{test_table_name}_id_seq" + metadata = MetaData() + sequence_table = Table( test_table_name, - MetaData(), - Column("id", Integer, Sequence(test_sequence_name), primary_key=True), + metadata, + Column( + "id", Integer, Sequence(test_sequence_name, order=True), primary_key=True + ), Column("data", String(39)), ) - sequence_table.create(engine_testaccount) - seq = Sequence(test_sequence_name) + + autoload_metadata = MetaData() + try: - with engine_testaccount.connect() as conn: - with conn.begin(): - conn.execute(sequence_table.insert(), [{"data": "test_insert_1"}]) - select_stmt = select(sequence_table).order_by("id") - result = conn.execute(select_stmt).fetchall() - assert result == [(1, "test_insert_1")] - autoload_sequence_table = Table( - test_table_name, MetaData(), autoload_with=engine_testaccount - ) - conn.execute( - autoload_sequence_table.insert(), - [{"data": "multi_insert_1"}, {"data": "multi_insert_2"}], - ) - conn.execute( - autoload_sequence_table.insert(), [{"data": "test_insert_2"}] - ) - nextid = conn.execute(seq) - conn.execute( - autoload_sequence_table.insert(), - [{"id": nextid, "data": "test_insert_seq"}], - ) - result = conn.execute(select_stmt).fetchall() - assert result == [ - (1, "test_insert_1"), - (2, "multi_insert_1"), - (3, "multi_insert_2"), - (4, "test_insert_2"), - (5, "test_insert_seq"), - ] + metadata.create_all(engine_testaccount) + + with engine_testaccount.begin() as conn: + conn.execute(insert(sequence_table), ({"data": "test_insert_1"})) + result = conn.execute(select(sequence_table)).fetchall() + assert result == [(1, "test_insert_1")], result + + autoload_sequence_table = Table( + test_table_name, + autoload_metadata, + autoload_with=engine_testaccount, + ) + seq = Sequence(test_sequence_name, order=True) + + conn.execute( + insert(autoload_sequence_table), + ( + {"data": "multi_insert_1"}, + {"data": "multi_insert_2"}, + ), + ) + conn.execute( + insert(autoload_sequence_table), + ({"data": "test_insert_2"},), + ) + + nextid = conn.execute(seq) + conn.execute( + insert(autoload_sequence_table), + ({"id": nextid, "data": "test_insert_seq"}), + ) + + result = conn.execute(select(sequence_table)).fetchall() + + assert result == [ + (1, "test_insert_1"), + (2, "multi_insert_1"), + (3, "multi_insert_2"), + (4, "test_insert_2"), + (5, "test_insert_seq"), + ], result + finally: - sequence_table.drop(engine_testaccount) - seq.drop(engine_testaccount) + metadata.drop_all(engine_testaccount) -def test_table_with_autoincrement(engine_testaccount, db_parameters): +def test_table_with_autoincrement(engine_testaccount): + """Snowflake does not guarantee generating sequence numbers without gaps. + + The generated numbers are not necessarily contiguous. + https://docs.snowflake.com/en/user-guide/querying-sequences + """ # https://github.com/snowflakedb/snowflake-sqlalchemy/issues/124 test_table_name = "sequence" + metadata = MetaData() autoincrement_table = Table( test_table_name, - MetaData(), + metadata, Column("id", Integer, autoincrement=True, primary_key=True), Column("data", String(39)), ) - autoincrement_table.create(engine_testaccount) + + select_stmt = select(autoincrement_table).order_by("id") + try: - with engine_testaccount.connect() as conn: - with conn.begin(): - conn.execute(autoincrement_table.insert(), [{"data": "test_insert_1"}]) - select_stmt = select(autoincrement_table).order_by("id") - result = conn.execute(select_stmt).fetchall() - assert result == [(1, "test_insert_1")] - autoload_sequence_table = Table( - test_table_name, MetaData(), autoload_with=engine_testaccount - ) - conn.execute( - autoload_sequence_table.insert(), - [{"data": "multi_insert_1"}, {"data": "multi_insert_2"}], - ) - conn.execute( - autoload_sequence_table.insert(), [{"data": "test_insert_2"}] - ) - result = conn.execute(select_stmt).fetchall() - assert result == [ - (1, "test_insert_1"), - (2, "multi_insert_1"), - (3, "multi_insert_2"), - (4, "test_insert_2"), - ] + with engine_testaccount.begin() as conn: + conn.execute(text("ALTER SESSION SET NOORDER_SEQUENCE_AS_DEFAULT = FALSE")) + metadata.create_all(conn) + + conn.execute(insert(autoincrement_table), ({"data": "test_insert_1"})) + result = conn.execute(select_stmt).fetchall() + assert result == [(1, "test_insert_1")] + + autoload_sequence_table = Table( + test_table_name, MetaData(), autoload_with=engine_testaccount + ) + conn.execute( + insert(autoload_sequence_table), + [ + {"data": "multi_insert_1"}, + {"data": "multi_insert_2"}, + ], + ) + conn.execute( + insert(autoload_sequence_table), + [{"data": "test_insert_2"}], + ) + result = conn.execute(select_stmt).fetchall() + assert result == [ + (1, "test_insert_1"), + (2, "multi_insert_1"), + (3, "multi_insert_2"), + (4, "test_insert_2"), + ], result + finally: - autoincrement_table.drop(engine_testaccount) + metadata.drop_all(engine_testaccount) From ba51bc49d5497bd725ddb6dd20eab32b6a8df066 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Mon, 25 Mar 2024 10:19:41 +0100 Subject: [PATCH 17/74] Snow 1065172 gh workflow optimization (#477) * SNOW-1065172-gh-workflow-optimization: workflow optimisation --------- Co-authored-by: Adam Stus Co-authored-by: Tomasz Urbaszek --- .github/workflows/build_test.yml | 293 ++++++++++++++++--------- pyproject.toml | 114 ++++++++++ setup.cfg | 77 ------- setup.py | 17 -- src/snowflake/sqlalchemy/_constants.py | 2 +- src/snowflake/sqlalchemy/version.py | 2 +- tests/conftest.py | 104 +++++---- tests/test_core.py | 43 +--- tests/test_pandas.py | 104 ++------- tests/test_qmark.py | 75 ++----- 10 files changed, 400 insertions(+), 431 deletions(-) delete mode 100644 setup.py diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index 5284a088..f232e669 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -24,122 +24,209 @@ jobs: name: Check linting runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + with: + persist-credentials: false - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: '3.8' - - name: Display Python version - run: python -c "import sys; import os; print(\"\n\".join(os.environ[\"PATH\"].split(os.pathsep))); print(sys.version); print(sys.executable);" - - name: Upgrade setuptools, pip and wheel - run: python -m pip install -U setuptools pip wheel - - name: Install tox - run: python -m pip install tox + - name: Upgrade and install tools + run: | + python -m pip install -U pip + python -m pip install -U hatch + python -m hatch env create default - name: Set PY - run: echo "PY=$(python -VV | sha256sum | cut -d' ' -f1)" >> $GITHUB_ENV - - uses: actions/cache@v1 + run: echo "PY=$(hatch run gh-cache-sum)" >> $GITHUB_ENV + - uses: actions/cache@v4 with: path: ~/.cache/pre-commit key: pre-commit|${{ env.PY }}|${{ hashFiles('.pre-commit-config.yaml') }} - - name: Run fix_lint - run: python -m tox -e fix_lint + - name: Run lint checks + run: hatch run check + + test-dialect: + name: Test dialect ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + needs: lint + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ + ubuntu-latest, + macos-latest, + windows-latest, + ] + python-version: ["3.8"] + cloud-provider: [ + aws, + azure, + gcp, + ] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Upgrade pip and prepare environment + run: | + python -m pip install -U pip + python -m pip install -U hatch + python -m hatch env create default + - name: Setup parameters file + shell: bash + env: + PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} + run: | + gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ + .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py + - name: Run tests + run: hatch run test-dialect + - uses: actions/upload-artifact@v4 + with: + name: coverage.xml_dialect-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + path: | + ./coverage.xml - test: - name: Test ${{ matrix.os.download_name }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} - needs: lint - runs-on: ${{ matrix.os.image_name }} - strategy: - fail-fast: false - matrix: - os: - - image_name: ubuntu-latest - download_name: manylinux_x86_64 - - image_name: macos-latest - download_name: macosx_x86_64 - - image_name: windows-2019 - download_name: win_amd64 - python-version: ["3.8"] - cloud-provider: [aws, azure, gcp] - steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Display Python version - run: python -c "import sys; print(sys.version)" - - name: Setup parameters file - shell: bash - env: - PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} - run: | - gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ - .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py - - name: Upgrade setuptools, pip and wheel - run: python -m pip install -U setuptools pip wheel - - name: Install tox - run: python -m pip install tox - - name: List installed packages - run: python -m pip freeze - - name: Run tests - run: python -m tox -e "py${PYTHON_VERSION/\./}" --skip-missing-interpreters false - env: - PYTHON_VERSION: ${{ matrix.python-version }} - PYTEST_ADDOPTS: -vvv --color=yes --tb=short - - name: Combine coverages - run: python -m tox -e coverage --skip-missing-interpreters false - shell: bash - - uses: actions/upload-artifact@v2 - with: - name: coverage_${{ matrix.os.download_name }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} - path: | - .tox/.coverage - .tox/coverage.xml + test-dialect-compatibility: + name: Test dialect compatibility ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + needs: lint + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ + ubuntu-latest, + macos-latest, + windows-latest, + ] + python-version: ["3.8"] + cloud-provider: [ + aws, + azure, + gcp, + ] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Upgrade pip and install hatch + run: | + python -m pip install -U pip + python -m pip install -U hatch + python -m hatch env create default + - name: Setup parameters file + shell: bash + env: + PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} + run: | + gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ + .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py + - name: Run tests + run: hatch run test-dialect-compatibility + - uses: actions/upload-artifact@v4 + with: + name: coverage.xml_dialect-compatibility-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + path: | + ./coverage.xml + + test-dialect-run-v20: + name: Test dialect run v20 ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + needs: lint + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ + ubuntu-latest, + macos-latest, + windows-latest, + ] + python-version: ["3.8"] + cloud-provider: [ + aws, + azure, + gcp, + ] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Setup parameters file + shell: bash + env: + PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} + run: | + gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ + .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py + - name: Upgrade pip and install hatch + run: | + python -m pip install -U pip + python -m pip install -U hatch + python -m hatch env create default + - name: Run tests + run: hatch run test-run_v20 + - uses: actions/upload-artifact@v4 + with: + name: coverage.xml_dialect-run-20-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + path: | + ./coverage.xml combine-coverage: - if: ${{ success() || failure() }} name: Combine coverage - needs: [test] + if: ${{ success() || failure() }} + needs: [test-dialect, test-dialect-compatibility, test-dialect-run-v20] runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions/download-artifact@v2 - with: - path: artifacts - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: - python-version: '3.8' - - name: Display Python version - run: python -c "import sys; print(sys.version)" - - name: Upgrade setuptools and pip - run: python -m pip install -U setuptools pip wheel - - name: Install tox - run: python -m pip install tox - - name: Collect all coverages to one dir + python-version: "3.8" + - name: Prepare environment + run: | + pip install -U pip + pip install -U hatch + hatch env create default + - uses: actions/checkout@v4 + with: + persist-credentials: false + - uses: actions/download-artifact@v4 + with: + path: artifacts/ + - name: Combine coverage files run: | - python -c ' - from pathlib import Path - import shutil - src_dir = Path("artifacts") - dst_dir = Path(".") / ".tox" - dst_dir.mkdir() - for src_file in src_dir.glob("*/.coverage"): - dst_file = dst_dir / f".coverage.{src_file.parent.name[9:]}" - print(f"{src_file} copy to {dst_file}") - shutil.copy(str(src_file), str(dst_file))' - - name: Combine coverages - run: python -m tox -e coverage - - name: Publish html coverage - uses: actions/upload-artifact@v2 - with: - name: overall_cov_html - path: .tox/htmlcov - - name: Publish xml coverage - uses: actions/upload-artifact@v2 - with: - name: overall_cov_xml - path: .tox/coverage.xml - - uses: codecov/codecov-action@v1 - with: - file: .tox/coverage.xml + hatch run coverage combine -a artifacts/coverage.xml_*/coverage.xml + hatch run coverage report -m + hatch run coverage xml -o combined_coverage.xml + hatch run coverage html -d htmlcov + - name: Store coverage reports + uses: actions/upload-artifact@v4 + with: + name: combined_coverage.xml + path: combined_coverage.xml + - name: Store htmlcov report + uses: actions/upload-artifact@v4 + with: + name: combined_htmlcov + path: htmlcov + - name: Uplaod to codecov + uses: codecov/codecov-action@v4 + with: + file: combined_coverage.xml + env_vars: OS,PYTHON + fail_ci_if_error: false + flags: unittests + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + url: https://snowflake.codecov.io/ diff --git a/pyproject.toml b/pyproject.toml index 907176c3..8707fae3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,119 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "snowflake-sqlalchemy" +dynamic = ["version"] +description = "Snowflake SQLAlchemy Dialect" +readme = "README.md" +license = "Apache-2.0" +requires-python = ">=3.8" +authors = [ + { name = "Snowflake Inc.", email = "triage-snowpark-python-api-dl@snowflake.com" }, +] +keywords = ["Snowflake", "analytics", "cloud", "database", "db", "warehouse"] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Environment :: Other Environment", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Information Technology", + "Intended Audience :: System Administrators", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: SQL", + "Topic :: Database", + "Topic :: Scientific/Engineering :: Information Analysis", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Application Frameworks", + "Topic :: Software Development :: Libraries :: Python Modules", +] +dependencies = ["snowflake-connector-python", "SQLAlchemy"] + +[tool.hatch.version] +path = "src/snowflake/sqlalchemy/version.py" + +[project.optional-dependencies] +development = [ + "pre-commit", + "pytest", + "pytest-cov", + "pytest-timeout", + "pytest-rerunfailures", + "pytz", + "numpy", + "mock", +] +pandas = ["snowflake-connector-python[pandas]"] + +[project.entry-points."sqlalchemy.dialects"] +snowflake = "snowflake.sqlalchemy:dialect" + +[project.urls] +Changelog = "https://github.com/snowflakedb/snowflake-sqlalchemy/blob/main/DESCRIPTION.md" +Documentation = "https://docs.snowflake.com/en/user-guide/sqlalchemy.html" +Homepage = "https://www.snowflake.com/" +Issues = "https://github.com/snowflakedb/snowflake-sqlalchemy/issues" +Source = "https://github.com/snowflakedb/snowflake-sqlalchemy" + +[tool.hatch.build.targets.sdist] +exclude = ["/.github"] + +[tool.hatch.build.targets.wheel] +packages = ["src/snowflake"] + +[tool.hatch.envs.default] +extra-dependencies = ["SQLAlchemy<2.0.0,>=1.4.19"] +features = ["development", "pandas"] +python = "3.8" + +[tool.hatch.envs.sa20] +extra-dependencies = ["SQLAlchemy>=2.0.0"] +python = "3.8" + +[tool.hatch.envs.default.env-vars] +COVERAGE_FILE = "coverage.xml" +SQLACHEMY_WARN_20 = "1" + +[tool.hatch.envs.default.scripts] +check = "pre-commit run --all-files" +test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite" +test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite" +test-run_v20 = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite --run_v20_sqlalchemy" +gh-cache-sum = "python -VV | sha256sum | cut -d' ' -f1" + [tool.ruff] line-length = 88 [tool.black] line-length = 88 + +[tool.pytest.ini_options] +markers = [ + # Optional dependency groups markers + "lambda: AWS lambda tests", + "pandas: tests for pandas integration", + "sso: tests for sso optional dependency integration", + # Cloud provider markers + "aws: tests for Amazon Cloud storage", + "azure: tests for Azure Cloud storage", + "gcp: tests for Google Cloud storage", + # Test type markers + "integ: integration tests", + "unit: unit tests", + "skipolddriver: skip for old driver tests", + # Other markers + "timeout: tests that need a timeout time", + "internal: tests that could but should only run on our internal CI", + "external: tests that could but should only run on our external CI", +] diff --git a/setup.cfg b/setup.cfg index 04011f04..7924cc57 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,80 +1,3 @@ -[bdist_wheel] -universal = 1 - -[metadata] -name = snowflake-sqlalchemy -description = Snowflake SQLAlchemy Dialect -long_description = file: DESCRIPTION.md -long_description_content_type = text/markdown -url = https://www.snowflake.com/ -author = Snowflake, Inc -author_email = triage-snowpark-python-api-dl@snowflake.com -license = Apache-2.0 -license_files = LICENSE.txt -classifiers = - Development Status :: 5 - Production/Stable - Environment :: Console - Environment :: Other Environment - Intended Audience :: Developers - Intended Audience :: Education - Intended Audience :: Information Technology - Intended Audience :: System Administrators - License :: OSI Approved :: Apache Software License - Operating System :: OS Independent - Programming Language :: Python :: 3 - Programming Language :: Python :: 3 :: Only - Programming Language :: Python :: 3.7 - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 - Programming Language :: Python :: 3.11 - Programming Language :: SQL - Topic :: Database - Topic :: Scientific/Engineering :: Information Analysis - Topic :: Software Development - Topic :: Software Development :: Libraries - Topic :: Software Development :: Libraries :: Application Frameworks - Topic :: Software Development :: Libraries :: Python Modules -keywords = Snowflake db database cloud analytics warehouse -project_urls = - Documentation=https://docs.snowflake.com/en/user-guide/sqlalchemy.html - Source=https://github.com/snowflakedb/snowflake-sqlalchemy - Issues=https://github.com/snowflakedb/snowflake-sqlalchemy/issues - Changelog=https://github.com/snowflakedb/snowflake-sqlalchemy/blob/main/DESCRIPTION.md - -[options] -python_requires = >=3.7 -packages = find_namespace: -install_requires = - importlib-metadata;python_version<"3.8" - sqlalchemy<2.0.0,>=1.4.0 -; Keep in sync with extras dependency - snowflake-connector-python<4.0.0 -include_package_data = True -package_dir = - =src -zip_safe = False - -[options.packages.find] -where = src -include = snowflake.* - -[options.entry_points] -sqlalchemy.dialects = - snowflake=snowflake.sqlalchemy:dialect - -[options.extras_require] -development = - pytest - pytest-cov - pytest-rerunfailures - pytest-timeout - mock - pytz - numpy -pandas = - snowflake-connector-python[pandas]<4.0.0 - [sqla_testing] requirement_cls=snowflake.sqlalchemy.requirements:Requirements profile_file=tests/profiles.txt diff --git a/setup.py b/setup.py deleted file mode 100644 index 0ec32717..00000000 --- a/setup.py +++ /dev/null @@ -1,17 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - -import os - -from setuptools import setup - -SQLALCHEMY_SRC_DIR = os.path.join("src", "snowflake", "sqlalchemy") -VERSION = (1, 1, 1, None) # Default -with open(os.path.join(SQLALCHEMY_SRC_DIR, "version.py"), encoding="utf-8") as f: - exec(f.read()) - version = ".".join([str(v) for v in VERSION if v is not None]) - -setup( - version=version, -) diff --git a/src/snowflake/sqlalchemy/_constants.py b/src/snowflake/sqlalchemy/_constants.py index dad5b19b..46af4454 100644 --- a/src/snowflake/sqlalchemy/_constants.py +++ b/src/snowflake/sqlalchemy/_constants.py @@ -9,4 +9,4 @@ PARAM_INTERNAL_APPLICATION_VERSION = "internal_application_version" APPLICATION_NAME = "SnowflakeSQLAlchemy" -SNOWFLAKE_SQLALCHEMY_VERSION = ".".join([str(v) for v in VERSION if v is not None]) +SNOWFLAKE_SQLALCHEMY_VERSION = VERSION diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index d4318b86..24f188c2 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -3,4 +3,4 @@ # # Update this for the versions # Don't change the forth version number from None -VERSION = (1, 5, 2, None) +VERSION = "1.5.2" diff --git a/tests/conftest.py b/tests/conftest.py index e22e4d42..a9c2560a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,16 +1,17 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +from __future__ import annotations import os import sys import time import uuid -from functools import partial from logging import getLogger import pytest from sqlalchemy import create_engine +from sqlalchemy.pool import NullPool import snowflake.connector import snowflake.connector.connection @@ -44,8 +45,6 @@ TEST_SCHEMA = f"sqlalchemy_tests_{str(uuid.uuid4()).replace('-', '_')}" -create_engine_with_future_flag = create_engine - def pytest_addoption(parser): parser.addoption( @@ -57,6 +56,11 @@ def pytest_addoption(parser): ) +@pytest.fixture(scope="session") +def run_v20_sqlalchemy(pytestconfig): + return pytestconfig.option.run_v20_sqlalchemy + + @pytest.fixture(scope="session") def on_travis(): return os.getenv("TRAVIS", "").lower() == "true" @@ -102,10 +106,10 @@ def help(): @pytest.fixture(scope="session") def db_parameters(): - return get_db_parameters() + yield get_db_parameters() -def get_db_parameters(): +def get_db_parameters() -> dict: """ Sets the db connection parameters """ @@ -113,12 +117,9 @@ def get_db_parameters(): os.environ["TZ"] = "UTC" if not IS_WINDOWS: time.tzset() - for k, v in CONNECTION_PARAMETERS.items(): - ret[k] = v - for k, v in DEFAULT_PARAMETERS.items(): - if k not in ret: - ret[k] = v + ret.update(DEFAULT_PARAMETERS) + ret.update(CONNECTION_PARAMETERS) if "account" in ret and ret["account"] == DEFAULT_PARAMETERS["account"]: help() @@ -153,43 +154,49 @@ def get_db_parameters(): return ret -def get_engine(user=None, password=None, account=None, schema=None): - """ - Creates a connection using the parameters defined in JDBC connect string - """ - ret = get_db_parameters() - - if user is not None: - ret["user"] = user - if password is not None: - ret["password"] = password - if account is not None: - ret["account"] = account - - from sqlalchemy.pool import NullPool - - engine = create_engine_with_future_flag( - URL( - user=ret["user"], - password=ret["password"], - host=ret["host"], - port=ret["port"], - database=ret["database"], - schema=TEST_SCHEMA if not schema else schema, - account=ret["account"], - protocol=ret["protocol"], - ), - poolclass=NullPool, - ) +def url_factory(**kwargs) -> URL: + url_params = get_db_parameters() + url_params.update(kwargs) + return URL(**url_params) + - return engine, ret +def get_engine(url: URL, run_v20_sqlalchemy=False, **engine_kwargs): + engine_params = { + "poolclass": NullPool, + "future": run_v20_sqlalchemy, + } + engine_params.update(engine_kwargs) + engine = create_engine(url, **engine_kwargs) + return engine @pytest.fixture() -def engine_testaccount(request): - engine, _ = get_engine() +def engine_testaccount(request, run_v20_sqlalchemy): + url = url_factory() + engine = get_engine(url, run_v20_sqlalchemy=run_v20_sqlalchemy) request.addfinalizer(engine.dispose) - return engine + yield engine + + +@pytest.fixture() +def engine_testaccount_with_numpy(request): + url = url_factory(numpy=True) + engine = get_engine(url, run_v20_sqlalchemy=run_v20_sqlalchemy) + request.addfinalizer(engine.dispose) + yield engine + + +@pytest.fixture() +def engine_testaccount_with_qmark(request, run_v20_sqlalchemy): + snowflake.connector.paramstyle = "qmark" + + url = url_factory() + engine = get_engine(url, run_v20_sqlalchemy=run_v20_sqlalchemy) + request.addfinalizer(engine.dispose) + + yield engine + + snowflake.connector.paramstyle = "pyformat" @pytest.fixture(scope="session", autouse=True) @@ -232,19 +239,6 @@ def sql_compiler(): ).replace("\n", "") -@pytest.fixture(scope="session") -def run_v20_sqlalchemy(pytestconfig): - return pytestconfig.option.run_v20_sqlalchemy - - -def pytest_sessionstart(session): - # patch the create_engine with future flag - global create_engine_with_future_flag - create_engine_with_future_flag = partial( - create_engine, future=session.config.option.run_v20_sqlalchemy - ) - - def running_on_public_ci() -> bool: """Whether or not tests are currently running on one of our public CIs.""" return os.getenv("GITHUB_ACTIONS") == "true" diff --git a/tests/test_core.py b/tests/test_core.py index 60b4fea4..6c8d7416 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -28,13 +28,13 @@ String, Table, UniqueConstraint, + create_engine, dialects, insert, inspect, text, ) from sqlalchemy.exc import DBAPIError, NoSuchTableError -from sqlalchemy.pool import NullPool from sqlalchemy.sql import and_, not_, or_, select import snowflake.connector.errors @@ -47,8 +47,7 @@ ) from snowflake.sqlalchemy.snowdialect import SnowflakeDialect -from .conftest import create_engine_with_future_flag as create_engine -from .conftest import get_engine +from .conftest import get_engine, url_factory from .parameters import CONNECTION_PARAMETERS from .util import ischema_names_baseline, random_string @@ -937,37 +936,6 @@ class Appointment(Base): assert str(t.columns["real_data"].type) == "FLOAT" -def _get_engine_with_columm_metadata_cache( - db_parameters, user=None, password=None, account=None -): - """ - Creates a connection with column metadata cache - """ - if user is not None: - db_parameters["user"] = user - if password is not None: - db_parameters["password"] = password - if account is not None: - db_parameters["account"] = account - - engine = create_engine( - URL( - user=db_parameters["user"], - password=db_parameters["password"], - host=db_parameters["host"], - port=db_parameters["port"], - database=db_parameters["database"], - schema=db_parameters["schema"], - account=db_parameters["account"], - protocol=db_parameters["protocol"], - cache_column_metadata=True, - ), - poolclass=NullPool, - ) - - return engine - - def test_many_table_column_metadta(db_parameters): """ Get dozens of table metadata with column metadata cache. @@ -975,7 +943,9 @@ def test_many_table_column_metadta(db_parameters): cache_column_metadata=True will cache all column metadata for all tables in the schema. """ - engine = _get_engine_with_columm_metadata_cache(db_parameters) + url = url_factory(cache_column_metadata=True) + engine = get_engine(url) + RE_SUFFIX_NUM = re.compile(r".*(\d+)$") metadata = MetaData() total_objects = 10 @@ -1344,7 +1314,8 @@ def test_comment_sqlalchemy(db_parameters, engine_testaccount, on_public_ci): column_comment1 = random_string(10, choices=string.ascii_uppercase) table_comment2 = random_string(10, choices=string.ascii_uppercase) column_comment2 = random_string(10, choices=string.ascii_uppercase) - engine2, _ = get_engine(schema=new_schema) + + engine2 = get_engine(url_factory(schema=new_schema)) con2 = None if not on_public_ci: con2 = engine2.connect() diff --git a/tests/test_pandas.py b/tests/test_pandas.py index aadfb9cf..ef64d65e 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -24,13 +24,9 @@ select, text, ) -from sqlalchemy.pool import NullPool from snowflake.connector import ProgrammingError from snowflake.connector.pandas_tools import make_pd_writer, pd_writer -from snowflake.sqlalchemy import URL - -from .conftest import create_engine_with_future_flag as create_engine def _create_users_addresses_tables(engine_testaccount, metadata): @@ -113,40 +109,8 @@ def test_a_simple_read_sql(engine_testaccount): users.drop(engine_testaccount) -def get_engine_with_numpy(db_parameters, user=None, password=None, account=None): - """ - Creates a connection using the parameters defined in JDBC connect string - """ - from snowflake.sqlalchemy import URL - - if user is not None: - db_parameters["user"] = user - if password is not None: - db_parameters["password"] = password - if account is not None: - db_parameters["account"] = account - - engine = create_engine( - URL( - user=db_parameters["user"], - password=db_parameters["password"], - host=db_parameters["host"], - port=db_parameters["port"], - database=db_parameters["database"], - schema=db_parameters["schema"], - account=db_parameters["account"], - protocol=db_parameters["protocol"], - numpy=True, - ), - poolclass=NullPool, - ) - - return engine - - -def test_numpy_datatypes(db_parameters): - engine = get_engine_with_numpy(db_parameters) - with engine.connect() as conn: +def test_numpy_datatypes(engine_testaccount_with_numpy, db_parameters): + with engine_testaccount_with_numpy.connect() as conn: try: specific_date = np.datetime64("2016-03-04T12:03:05.123456789") with conn.begin(): @@ -163,12 +127,11 @@ def test_numpy_datatypes(db_parameters): assert df.c1.values[0] == specific_date finally: conn.exec_driver_sql(f"DROP TABLE IF EXISTS {db_parameters['name']}") - engine.dispose() + engine_testaccount_with_numpy.dispose() -def test_to_sql(db_parameters): - engine = get_engine_with_numpy(db_parameters) - with engine.connect() as conn: +def test_to_sql(engine_testaccount_with_numpy, db_parameters): + with engine_testaccount_with_numpy.connect() as conn: total_rows = 10000 conn.exec_driver_sql( textwrap.dedent( @@ -182,7 +145,13 @@ def test_to_sql(db_parameters): conn.exec_driver_sql("create or replace table dst(c1 float)") tbl = pd.read_sql_query(text("select * from src"), conn) - tbl.to_sql("dst", engine, if_exists="append", chunksize=1000, index=False) + tbl.to_sql( + "dst", + engine_testaccount_with_numpy, + if_exists="append", + chunksize=1000, + index=False, + ) df = pd.read_sql_query(text("select count(*) as cnt from dst"), conn) assert df.cnt.values[0] == total_rows @@ -202,38 +171,12 @@ def test_no_indexes(engine_testaccount, db_parameters): assert str(exc.value) == "Snowflake does not support indexes" -def test_timezone(db_parameters): +def test_timezone(db_parameters, engine_testaccount, engine_testaccount_with_numpy): test_table_name = "".join([random.choice(string.ascii_letters) for _ in range(5)]) - sa_engine = create_engine( - URL( - account=db_parameters["account"], - password=db_parameters["password"], - database=db_parameters["database"], - port=db_parameters["port"], - user=db_parameters["user"], - host=db_parameters["host"], - protocol=db_parameters["protocol"], - schema=db_parameters["schema"], - numpy=True, - ) - ) - - sa_engine2_raw_conn = create_engine( - URL( - account=db_parameters["account"], - password=db_parameters["password"], - database=db_parameters["database"], - port=db_parameters["port"], - user=db_parameters["user"], - host=db_parameters["host"], - protocol=db_parameters["protocol"], - schema=db_parameters["schema"], - timezone="America/Los_Angeles", - numpy="", - ) - ).raw_connection() + sa_engine = engine_testaccount_with_numpy + sa_engine2_raw_conn = engine_testaccount.raw_connection() with sa_engine.connect() as conn: @@ -316,18 +259,13 @@ def test_pandas_writeback(engine_testaccount, run_v20_sqlalchemy): sf_connector_version_df = pd.DataFrame( sf_connector_version_data, columns=["NAME", "NEWEST_VERSION"] ) - sf_connector_version_df.to_sql(table_name, conn, index=False, method=pd_writer) - - assert ( - ( - pd.read_sql_table(table_name, conn).rename( - columns={"newest_version": "NEWEST_VERSION", "name": "NAME"} - ) - == sf_connector_version_df - ) - .all() - .all() + sf_connector_version_df.to_sql( + table_name, conn, index=False, method=pd_writer, if_exists="replace" + ) + results = pd.read_sql_table(table_name, conn).rename( + columns={"newest_version": "NEWEST_VERSION", "name": "NAME"} ) + assert results.equals(sf_connector_version_df) @pytest.mark.parametrize("chunk_size", [5, 1]) diff --git a/tests/test_qmark.py b/tests/test_qmark.py index fe50fae9..f98fa7d3 100644 --- a/tests/test_qmark.py +++ b/tests/test_qmark.py @@ -5,43 +5,14 @@ import os import sys +import pandas as pd import pytest from sqlalchemy import text -from snowflake.sqlalchemy import URL - -from .conftest import create_engine_with_future_flag as create_engine - THIS_DIR = os.path.dirname(os.path.realpath(__file__)) -def _get_engine_with_qmark(db_parameters, user=None, password=None, account=None): - """ - Creates a connection with column metadata cache - """ - if user is not None: - db_parameters["user"] = user - if password is not None: - db_parameters["password"] = password - if account is not None: - db_parameters["account"] = account - - engine = create_engine( - URL( - user=db_parameters["user"], - password=db_parameters["password"], - host=db_parameters["host"], - port=db_parameters["port"], - database=db_parameters["database"], - schema=db_parameters["schema"], - account=db_parameters["account"], - protocol=db_parameters["protocol"], - ) - ) - return engine - - -def test_qmark_bulk_insert(db_parameters, run_v20_sqlalchemy): +def test_qmark_bulk_insert(run_v20_sqlalchemy, engine_testaccount_with_qmark): """ Bulk insert using qmark paramstyle """ @@ -50,31 +21,19 @@ def test_qmark_bulk_insert(db_parameters, run_v20_sqlalchemy): "In Python 3.7, this test depends on pandas features of which the implementation is incompatible with sqlachemy 2.0, and pandas does not support Python 3.7 anymore." ) - import snowflake.connector - - snowflake.connector.paramstyle = "qmark" - - engine = _get_engine_with_qmark(db_parameters) - import pandas as pd - - with engine.connect() as con: - try: - with con.begin(): - con.exec_driver_sql( - """ - create or replace table src(c1 int, c2 string) as select seq8(), - randstr(100, random()) from table(generator(rowcount=>100000)) - """ + with engine_testaccount_with_qmark.connect() as con: + with con.begin(): + con.exec_driver_sql( + """ + create or replace table src(c1 int, c2 string) as select seq8(), + randstr(100, random()) from table(generator(rowcount=>100000)) + """ + ) + con.exec_driver_sql("create or replace table dst like src") + + for data in pd.read_sql_query( + text("select * from src"), con, chunksize=16000 + ): + data.to_sql( + "dst", con, if_exists="append", index=False, index_label=None ) - con.exec_driver_sql("create or replace table dst like src") - - for data in pd.read_sql_query( - text("select * from src"), con, chunksize=16000 - ): - data.to_sql( - "dst", con, if_exists="append", index=False, index_label=None - ) - - finally: - engine.dispose() - snowflake.connector.paramstyle = "pyformat" From c205e3415c77d81fc97c59d549e28489bfc65bdf Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Mon, 25 Mar 2024 14:17:34 +0100 Subject: [PATCH 18/74] SNOW-1058245-fix-failing-gh-action-syntax: use gh variable for setting repository name (#476) --- .github/workflows/snyk-issue.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/snyk-issue.yml b/.github/workflows/snyk-issue.yml index c8f5d90b..d7a3e0e9 100644 --- a/.github/workflows/snyk-issue.yml +++ b/.github/workflows/snyk-issue.yml @@ -22,7 +22,7 @@ jobs: token: ${{ secrets.whitesource_action_token }} path: whitesource-actions - name: Set Env - run: echo "repo=$(basename $github_repository)" >> $github_env + run: echo "repo=${{ github.event.repository.name }}" >> $GITHUB_ENV - name: Jira Creation uses: ./whitesource-actions/snyk-issue with: From 337e0f87399c1191d7e3b7bf94693eca5c57ab04 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Wed, 10 Apr 2024 15:52:17 +0200 Subject: [PATCH 19/74] mraba/update-description (#483) --- DESCRIPTION.md | 1 + 1 file changed, 1 insertion(+) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 422fe807..cb9782a0 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -11,6 +11,7 @@ Source code is also available at: - 1.5.2 + - Bump min SQLAlchemy to 1.4.19 for outer lateral join - Add support for sequence ordering in tests - v1.5.1(November 03, 2023) From 4db4b9521d42d43a6619bf5ca307e1eaa2c6aa09 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Thu, 11 Apr 2024 10:31:35 +0200 Subject: [PATCH 20/74] mraba/description-file-update: add release date and format version string (#484) --- DESCRIPTION.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index cb9782a0..f41ee797 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,7 +9,7 @@ Source code is also available at: # Release Notes -- 1.5.2 +- v1.5.2(April 11, 2024) - Bump min SQLAlchemy to 1.4.19 for outer lateral join - Add support for sequence ordering in tests From 8ea9ab92663b2df40543f391a2f44d3ace70b626 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Mon, 15 Apr 2024 15:46:25 +0200 Subject: [PATCH 21/74] SNOW-1324105-dependency-pinning: move versions pinning from project env to metadata (#486) * SNOW-1324105-dependency-pinning: move versions pinning from project env to metadata * SNOW-1324105-dependency-pinning: update version and description * SNOW-1324105-dependency-pinning: remove duplicated command in CI job --- .github/workflows/build_test.yml | 31 ++++++++++++++++++++++++++--- DESCRIPTION.md | 4 ++++ pyproject.toml | 7 +------ src/snowflake/sqlalchemy/version.py | 2 +- 4 files changed, 34 insertions(+), 10 deletions(-) diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index f232e669..be19f1f1 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -45,9 +45,34 @@ jobs: - name: Run lint checks run: hatch run check + build-install: + name: Test package build and installation + runs-on: ubuntu-latest + needs: lint + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Setup up Python + uses: actions/setup-python@v5 + with: + python-version: '3.8' + - name: Upgrade and install tools + run: | + python -m pip install -U pip + python -m pip install -U hatch + - name: Build package + run: | + python -m hatch clean + python -m hatch build + - name: Install and check import + run: | + python -m pip install dist/snowflake_sqlalchemy-*.whl + python -c "import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)" + test-dialect: name: Test dialect ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} - needs: lint + needs: [ lint, build-install ] runs-on: ${{ matrix.os }} strategy: fail-fast: false @@ -93,7 +118,7 @@ jobs: test-dialect-compatibility: name: Test dialect compatibility ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} - needs: lint + needs: [ lint, build-install ] runs-on: ${{ matrix.os }} strategy: fail-fast: false @@ -139,7 +164,7 @@ jobs: test-dialect-run-v20: name: Test dialect run v20 ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} - needs: lint + needs: [ lint, build-install ] runs-on: ${{ matrix.os }} strategy: fail-fast: false diff --git a/DESCRIPTION.md b/DESCRIPTION.md index f41ee797..e826b42a 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,6 +9,10 @@ Source code is also available at: # Release Notes +- v1.5.3(Unrelased) + + - Limit SQLAlchemy to < 2.0.0 before releasing version compatible with 2.0 + - v1.5.2(April 11, 2024) - Bump min SQLAlchemy to 1.4.19 for outer lateral join diff --git a/pyproject.toml b/pyproject.toml index 8707fae3..d0c31cb8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Application Frameworks", "Topic :: Software Development :: Libraries :: Python Modules", ] -dependencies = ["snowflake-connector-python", "SQLAlchemy"] +dependencies = ["snowflake-connector-python<4.0.0", "SQLAlchemy>=1.4.19,<2.0.0"] [tool.hatch.version] path = "src/snowflake/sqlalchemy/version.py" @@ -73,14 +73,9 @@ exclude = ["/.github"] packages = ["src/snowflake"] [tool.hatch.envs.default] -extra-dependencies = ["SQLAlchemy<2.0.0,>=1.4.19"] features = ["development", "pandas"] python = "3.8" -[tool.hatch.envs.sa20] -extra-dependencies = ["SQLAlchemy>=2.0.0"] -python = "3.8" - [tool.hatch.envs.default.env-vars] COVERAGE_FILE = "coverage.xml" SQLACHEMY_WARN_20 = "1" diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index 24f188c2..61c9fc41 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -3,4 +3,4 @@ # # Update this for the versions # Don't change the forth version number from None -VERSION = "1.5.2" +VERSION = "1.5.3" From d2b333417ce1faf682465ba9a53f7490d6bfe5a4 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Mon, 22 Apr 2024 10:35:05 +0200 Subject: [PATCH 22/74] mraba/codeowners-update: replace snowpark-pytho-api with snowcli (#490) --- .github/CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index d8130079..836e0136 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -@snowflakedb/snowpark-python-api +* @snowflakedb/snowcli From 8f06fd4cdc735276dd9904f2789da376c9c16e21 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Mon, 22 Apr 2024 15:34:35 +0200 Subject: [PATCH 23/74] mraba/snyk-support-requirements: add requirements.txt file for snyk scanning (#491) * mraba/snyk-support-requirements: add requirements.txt file for snyk scanning * mraba/snyk-support-requirements: add script for generating requirements and filename fix * mraba/snyk-support-requirements: finish last line with new line * mraba/snyk-support-requirements: add pre-commit hook for auto generating requirements.txt --- .pre-commit-config.yaml | 7 +++++++ pyproject.toml | 2 +- snyk/requirements.txt | 2 ++ snyk/update_requirements.py | 17 +++++++++++++++++ 4 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 snyk/requirements.txt create mode 100644 snyk/update_requirements.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 70b75ce8..83172eb8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,3 +44,10 @@ repos: - id: flake8 additional_dependencies: - flake8-bugbear +- repo: local + hooks: + - id: requirements-update + name: "Update dependencies from pyproject.toml to snyk/requirements.txt" + language: system + entry: python snyk/update_requirements.py + files: ^pyproject.toml$ diff --git a/pyproject.toml b/pyproject.toml index d0c31cb8..3f95df46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Application Frameworks", "Topic :: Software Development :: Libraries :: Python Modules", ] -dependencies = ["snowflake-connector-python<4.0.0", "SQLAlchemy>=1.4.19,<2.0.0"] +dependencies = ["SQLAlchemy>=1.4.19,<2.0.0", "snowflake-connector-python<4.0.0"] [tool.hatch.version] path = "src/snowflake/sqlalchemy/version.py" diff --git a/snyk/requirements.txt b/snyk/requirements.txt new file mode 100644 index 00000000..3a77e0f9 --- /dev/null +++ b/snyk/requirements.txt @@ -0,0 +1,2 @@ +SQLAlchemy>=1.4.19,<2.0.0 +snowflake-connector-python<4.0.0 diff --git a/snyk/update_requirements.py b/snyk/update_requirements.py new file mode 100644 index 00000000..e0771fbd --- /dev/null +++ b/snyk/update_requirements.py @@ -0,0 +1,17 @@ +from pathlib import Path + +import tomlkit + + +def sync(): + pyproject = tomlkit.loads(Path("pyproject.toml").read_text()) + snyk_reqiurements = Path("snyk/requirements.txt") + dependencies = pyproject.get("project", {}).get("dependencies", []) + + with snyk_reqiurements.open("w") as fh: + fh.write("\n".join(dependencies)) + fh.write("\n") + + +if __name__ == "__main__": + sync() From 107b0b13a28f19760be670d33304e2e88421d0c2 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Wed, 24 Apr 2024 11:06:32 +0200 Subject: [PATCH 24/74] mraba/gh-actions-update: bump github actions versions to latest (#492) * mraba/gh-actions-update: bump github actions versions to latest * mraba/gh-actions-update: set persist-credentials to false in checkout actions --- .github/workflows/changelog.yml | 3 ++- .github/workflows/create_req_files.yml | 13 ++++++++----- .github/workflows/jira_close.yml | 3 ++- .github/workflows/jira_issue.yml | 3 ++- .github/workflows/python-publish.yml | 6 ++++-- .github/workflows/snyk-issue.yml | 3 ++- .github/workflows/snyk-pr.yml | 6 ++++-- 7 files changed, 24 insertions(+), 13 deletions(-) diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml index 2e197168..252405fd 100644 --- a/.github/workflows/changelog.yml +++ b/.github/workflows/changelog.yml @@ -12,8 +12,9 @@ jobs: if: ${{!contains(github.event.pull_request.labels.*.name, 'NO-CHANGELOG-UPDATES')}} steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: + persist-credentials: false fetch-depth: 0 - name: Ensure DESCRIPTION.md is updated diff --git a/.github/workflows/create_req_files.yml b/.github/workflows/create_req_files.yml index 57f7efb8..618b3024 100644 --- a/.github/workflows/create_req_files.yml +++ b/.github/workflows/create_req_files.yml @@ -11,9 +11,11 @@ jobs: matrix: python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + with: + persist-credentials: false - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Display Python version @@ -37,7 +39,7 @@ jobs: - name: Show created req file shell: bash run: cat ${{ env.requirements_file }} - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v4 with: path: temp_requirement @@ -46,11 +48,12 @@ jobs: name: Commit and push files runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: + persist-credentials: false token: ${{ secrets.SNOWFLAKE_GITHUB_TOKEN }} # stored in GitHub secrets - name: Download requirement files - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v4 with: name: artifact path: tested_requirements diff --git a/.github/workflows/jira_close.yml b/.github/workflows/jira_close.yml index dfcb8bc7..5b170d75 100644 --- a/.github/workflows/jira_close.yml +++ b/.github/workflows/jira_close.yml @@ -9,8 +9,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 with: + persist-credentials: false repository: snowflakedb/gh-actions ref: jira_v1 token: ${{ secrets.SNOWFLAKE_GITHUB_TOKEN }} # stored in GitHub secrets diff --git a/.github/workflows/jira_issue.yml b/.github/workflows/jira_issue.yml index 74e58454..31b93aae 100644 --- a/.github/workflows/jira_issue.yml +++ b/.github/workflows/jira_issue.yml @@ -14,8 +14,9 @@ jobs: if: ((github.event_name == 'issue_comment' && github.event.comment.body == 'recreate jira' && github.event.comment.user.login == 'sfc-gh-mkeller') || (github.event_name == 'issues' && github.event.pull_request.user.login != 'whitesource-for-github-com[bot]')) steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 with: + persist-credentials: false repository: snowflakedb/gh-actions ref: jira_v1 token: ${{ secrets.SNOWFLAKE_GITHUB_TOKEN }} # stored in GitHub secrets diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index dd1e1ba6..ab4be45b 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -21,9 +21,11 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + with: + persist-credentials: false - name: Set up Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: '3.x' - name: Install dependencies diff --git a/.github/workflows/snyk-issue.yml b/.github/workflows/snyk-issue.yml index d7a3e0e9..94dfeb53 100644 --- a/.github/workflows/snyk-issue.yml +++ b/.github/workflows/snyk-issue.yml @@ -16,8 +16,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout Action - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: + persist-credentials: false repository: snowflakedb/whitesource-actions token: ${{ secrets.whitesource_action_token }} path: whitesource-actions diff --git a/.github/workflows/snyk-pr.yml b/.github/workflows/snyk-pr.yml index b951af65..cc5e8644 100644 --- a/.github/workflows/snyk-pr.yml +++ b/.github/workflows/snyk-pr.yml @@ -15,14 +15,16 @@ jobs: if: ${{ github.event.pull_request.user.login == 'sfc-gh-snyk-sca-sa' }} steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: + persist-credentials: false ref: ${{ github.event.pull_request.head.ref }} fetch-depth: 0 - name: Checkout Action - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: + persist-credentials: false repository: snowflakedb/whitesource-actions token: ${{ secrets.whitesource_action_token }} path: whitesource-actions From 9b2f325f76681e3526a6824b4d2710a7edd8d7fe Mon Sep 17 00:00:00 2001 From: Daniel Tatarkin Date: Fri, 10 May 2024 08:16:51 -0400 Subject: [PATCH 25/74] Add ability to set ORDER / NOORDER sequence on columns with IDENTITY (#493) * Update base.py to add ability to set ORDER / NOORDER sequence * Set ORDER / NOORDER only if the argument was set. Add test_table_with_identity unit test. --- src/snowflake/sqlalchemy/base.py | 5 ++++- tests/test_sequence.py | 26 ++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 2a1bb51a..e008c92f 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -966,11 +966,14 @@ def visit_drop_column_comment(self, drop, **kw): ) def visit_identity_column(self, identity, **kw): - text = " IDENTITY" + text = "IDENTITY" if identity.start is not None or identity.increment is not None: start = 1 if identity.start is None else identity.start increment = 1 if identity.increment is None else identity.increment text += f"({start},{increment})" + if identity.order is not None: + order = "ORDER" if identity.order else "NOORDER" + text += f" {order}" return text def get_identity_options(self, identity_options): diff --git a/tests/test_sequence.py b/tests/test_sequence.py index e428b9d7..32fc390e 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -4,6 +4,7 @@ from sqlalchemy import ( Column, + Identity, Integer, MetaData, Sequence, @@ -13,6 +14,7 @@ select, ) from sqlalchemy.sql import text +from sqlalchemy.sql.ddl import CreateTable def test_table_with_sequence(engine_testaccount, db_parameters): @@ -135,3 +137,27 @@ def test_table_with_autoincrement(engine_testaccount): finally: metadata.drop_all(engine_testaccount) + + +def test_table_with_identity(sql_compiler): + test_table_name = "identity" + metadata = MetaData() + identity_autoincrement_table = Table( + test_table_name, + metadata, + Column( + "id", Integer, Identity(start=1, increment=1, order=True), primary_key=True + ), + Column("identity_col_unordered", Integer, Identity(order=False)), + Column("identity_col", Integer, Identity()), + ) + create_table = CreateTable(identity_autoincrement_table) + actual = sql_compiler(create_table) + expected = ( + "CREATE TABLE identity (" + "\tid INTEGER NOT NULL IDENTITY(1,1) ORDER, " + "\tidentity_col_unordered INTEGER NOT NULL IDENTITY NOORDER, " + "\tidentity_col INTEGER NOT NULL IDENTITY, " + "\tPRIMARY KEY (id))" + ) + assert actual == expected From 58fb1bd556cddbd2b7cf2b8ccebf0ead3363e113 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Fri, 10 May 2024 14:55:30 +0200 Subject: [PATCH 26/74] Update DESCRIPTION.md (#501) * Update DESCRIPTION.md * mraba/description_update: drop trailing space --- DESCRIPTION.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index e826b42a..2f228781 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,7 +9,11 @@ Source code is also available at: # Release Notes -- v1.5.3(Unrelased) +- v1.5.4 + + - Add ability to set ORDER / NOORDER sequence on columns with IDENTITY + +- v1.5.3(April 16, 2024) - Limit SQLAlchemy to < 2.0.0 before releasing version compatible with 2.0 From d78f0c07c1701fa9889350b9cee31ae188b7fd71 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Tue, 2 Jul 2024 15:24:48 +0200 Subject: [PATCH 27/74] Snow 1058245 SqlAlchemy 2.0 support (#469) SNOW-1058245-sqlalchemy-20-support: Add support for installation SQLAlchemy 2.0 --- .github/workflows/build_test.yml | 108 +++++++--- .github/workflows/create_req_files.yml | 6 +- .github/workflows/jira_close.yml | 2 +- .github/workflows/jira_comment.yml | 4 +- .github/workflows/jira_issue.yml | 4 +- .github/workflows/python-publish.yml | 2 +- .github/workflows/stale_issue_bot.yml | 2 +- DESCRIPTION.md | 6 +- pyproject.toml | 14 +- snyk/requirements.txt | 2 +- snyk/requiremtnts.txt | 2 + src/snowflake/sqlalchemy/base.py | 44 ++-- src/snowflake/sqlalchemy/compat.py | 36 ++++ src/snowflake/sqlalchemy/custom_commands.py | 3 +- src/snowflake/sqlalchemy/functions.py | 16 ++ src/snowflake/sqlalchemy/requirements.py | 16 ++ src/snowflake/sqlalchemy/snowdialect.py | 67 +++--- src/snowflake/sqlalchemy/util.py | 12 +- src/snowflake/sqlalchemy/version.py | 2 +- tests/conftest.py | 32 +-- tests/sqlalchemy_test_suite/conftest.py | 7 + tests/sqlalchemy_test_suite/test_suite.py | 4 + tests/sqlalchemy_test_suite/test_suite_20.py | 205 +++++++++++++++++++ tests/test_compiler.py | 2 +- tests/test_core.py | 85 +++----- tests/test_custom_functions.py | 25 +++ tests/test_orm.py | 42 ++-- tests/test_pandas.py | 11 +- tests/test_qmark.py | 4 +- tox.ini | 10 +- 30 files changed, 558 insertions(+), 217 deletions(-) create mode 100644 snyk/requiremtnts.txt create mode 100644 src/snowflake/sqlalchemy/compat.py create mode 100644 src/snowflake/sqlalchemy/functions.py create mode 100644 tests/sqlalchemy_test_suite/test_suite_20.py create mode 100644 tests/test_custom_functions.py diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index be19f1f1..3baa6a0d 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -33,8 +33,8 @@ jobs: python-version: '3.8' - name: Upgrade and install tools run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch python -m hatch env create default - name: Set PY run: echo "PY=$(hatch run gh-cache-sum)" >> $GITHUB_ENV @@ -49,6 +49,10 @@ jobs: name: Test package build and installation runs-on: ubuntu-latest needs: lint + strategy: + fail-fast: true + matrix: + hatch-env: [default, sa20] steps: - uses: actions/checkout@v4 with: @@ -59,15 +63,14 @@ jobs: python-version: '3.8' - name: Upgrade and install tools run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch - name: Build package run: | - python -m hatch clean - python -m hatch build + python -m hatch -e ${{ matrix.hatch-env }} build --clean - name: Install and check import run: | - python -m pip install dist/snowflake_sqlalchemy-*.whl + python -m uv pip install dist/snowflake_sqlalchemy-*.whl python -c "import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)" test-dialect: @@ -79,7 +82,7 @@ jobs: matrix: os: [ ubuntu-latest, - macos-latest, + macos-13, windows-latest, ] python-version: ["3.8"] @@ -98,8 +101,8 @@ jobs: python-version: ${{ matrix.python-version }} - name: Upgrade pip and prepare environment run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch python -m hatch env create default - name: Setup parameters file shell: bash @@ -125,7 +128,7 @@ jobs: matrix: os: [ ubuntu-latest, - macos-latest, + macos-13, windows-latest, ] python-version: ["3.8"] @@ -144,8 +147,8 @@ jobs: python-version: ${{ matrix.python-version }} - name: Upgrade pip and install hatch run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch python -m hatch env create default - name: Setup parameters file shell: bash @@ -162,8 +165,8 @@ jobs: path: | ./coverage.xml - test-dialect-run-v20: - name: Test dialect run v20 ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + test-dialect-v20: + name: Test dialect v20 ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} needs: [ lint, build-install ] runs-on: ${{ matrix.os }} strategy: @@ -171,7 +174,7 @@ jobs: matrix: os: [ ubuntu-latest, - macos-latest, + macos-13, windows-latest, ] python-version: ["3.8"] @@ -197,21 +200,67 @@ jobs: .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py - name: Upgrade pip and install hatch run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch python -m hatch env create default - name: Run tests - run: hatch run test-run_v20 + run: hatch run sa20:test-dialect - uses: actions/upload-artifact@v4 with: - name: coverage.xml_dialect-run-20-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + name: coverage.xml_dialect-v20-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + path: | + ./coverage.xml + + test-dialect-compatibility-v20: + name: Test dialect v20 compatibility ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + needs: lint + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ + ubuntu-latest, + macos-13, + windows-latest, + ] + python-version: ["3.8"] + cloud-provider: [ + aws, + azure, + gcp, + ] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Upgrade pip and install hatch + run: | + python -m pip install -U uv + python -m uv pip install -U hatch + python -m hatch env create default + - name: Setup parameters file + shell: bash + env: + PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} + run: | + gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ + .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py + - name: Run tests + run: hatch run sa20:test-dialect-compatibility + - uses: actions/upload-artifact@v4 + with: + name: coverage.xml_dialect-v20-compatibility-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} path: | ./coverage.xml combine-coverage: name: Combine coverage if: ${{ success() || failure() }} - needs: [test-dialect, test-dialect-compatibility, test-dialect-run-v20] + needs: [test-dialect, test-dialect-compatibility, test-dialect-v20, test-dialect-compatibility-v20] runs-on: ubuntu-latest steps: - name: Set up Python @@ -220,8 +269,8 @@ jobs: python-version: "3.8" - name: Prepare environment run: | - pip install -U pip - pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch hatch env create default - uses: actions/checkout@v4 with: @@ -233,22 +282,15 @@ jobs: run: | hatch run coverage combine -a artifacts/coverage.xml_*/coverage.xml hatch run coverage report -m - hatch run coverage xml -o combined_coverage.xml - hatch run coverage html -d htmlcov - name: Store coverage reports uses: actions/upload-artifact@v4 with: - name: combined_coverage.xml - path: combined_coverage.xml - - name: Store htmlcov report - uses: actions/upload-artifact@v4 - with: - name: combined_htmlcov - path: htmlcov + name: coverage.xml + path: coverage.xml - name: Uplaod to codecov uses: codecov/codecov-action@v4 with: - file: combined_coverage.xml + file: coverage.xml env_vars: OS,PYTHON fail_ci_if_error: false flags: unittests diff --git a/.github/workflows/create_req_files.yml b/.github/workflows/create_req_files.yml index 618b3024..2cb7a371 100644 --- a/.github/workflows/create_req_files.yml +++ b/.github/workflows/create_req_files.yml @@ -21,10 +21,10 @@ jobs: - name: Display Python version run: python -c "import sys; print(sys.version)" - name: Upgrade setuptools, pip and wheel - run: python -m pip install -U setuptools pip wheel + run: python -m pip install -U setuptools pip wheel uv - name: Install Snowflake SQLAlchemy shell: bash - run: python -m pip install . + run: python -m uv pip install . - name: Generate reqs file name shell: bash run: echo "requirements_file=temp_requirement/requirements_$(python -c 'from sys import version_info;print(str(version_info.major)+str(version_info.minor))').reqs" >> $GITHUB_ENV @@ -34,7 +34,7 @@ jobs: mkdir temp_requirement echo "# Generated on: $(python --version)" >${{ env.requirements_file }} python -m pip freeze | grep -v snowflake-sqlalchemy 1>>${{ env.requirements_file }} 2>/dev/null - echo "snowflake-sqlalchemy==$(python -m pip show snowflake-sqlalchemy | grep ^Version | cut -d' ' -f2-)" >>${{ env.requirements_file }} + echo "snowflake-sqlalchemy==$(python -m uv pip show snowflake-sqlalchemy | grep ^Version | cut -d' ' -f2-)" >>${{ env.requirements_file }} id: create-reqs-file - name: Show created req file shell: bash diff --git a/.github/workflows/jira_close.yml b/.github/workflows/jira_close.yml index 5b170d75..7862f483 100644 --- a/.github/workflows/jira_close.yml +++ b/.github/workflows/jira_close.yml @@ -17,7 +17,7 @@ jobs: token: ${{ secrets.SNOWFLAKE_GITHUB_TOKEN }} # stored in GitHub secrets path: . - name: Jira login - uses: atlassian/gajira-login@master + uses: atlassian/gajira-login@v3 env: JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} diff --git a/.github/workflows/jira_comment.yml b/.github/workflows/jira_comment.yml index 954929fa..8533c14c 100644 --- a/.github/workflows/jira_comment.yml +++ b/.github/workflows/jira_comment.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Jira login - uses: atlassian/gajira-login@master + uses: atlassian/gajira-login@v3 env: JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} @@ -22,7 +22,7 @@ jobs: jira=$(echo -n $TITLE | awk '{print $1}' | sed -e 's/://') echo ::set-output name=jira::$jira - name: Comment on issue - uses: atlassian/gajira-comment@master + uses: atlassian/gajira-comment@v3 if: startsWith(steps.extract.outputs.jira, 'SNOW-') with: issue: "${{ steps.extract.outputs.jira }}" diff --git a/.github/workflows/jira_issue.yml b/.github/workflows/jira_issue.yml index 31b93aae..85c774ca 100644 --- a/.github/workflows/jira_issue.yml +++ b/.github/workflows/jira_issue.yml @@ -23,7 +23,7 @@ jobs: path: . - name: Login - uses: atlassian/gajira-login@v2.0.0 + uses: atlassian/gajira-login@v3 env: JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} @@ -31,7 +31,7 @@ jobs: - name: Create JIRA Ticket id: create - uses: atlassian/gajira-create@v2.0.1 + uses: atlassian/gajira-create@v3 with: project: SNOW issuetype: Bug diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index ab4be45b..23116e7a 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -35,7 +35,7 @@ jobs: - name: Build package run: python -m build - name: Publish package - uses: pypa/gh-action-pypi-publish@release/v1 + uses: pypa/gh-action-pypi-publish@e53eb8b103ffcb59469888563dc324e3c8ba6f06 with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/stale_issue_bot.yml b/.github/workflows/stale_issue_bot.yml index 6d76e9f4..4ee56ff8 100644 --- a/.github/workflows/stale_issue_bot.yml +++ b/.github/workflows/stale_issue_bot.yml @@ -10,7 +10,7 @@ jobs: stale: runs-on: ubuntu-latest steps: - - uses: actions/stale@v7 + - uses: actions/stale@v9 with: close-issue-message: 'To clean up and re-prioritize bugs and feature requests we are closing all issues older than 6 months as of Apr 1, 2023. If there are any issues or feature requests that you would like us to address, please re-create them. For urgent issues, opening a support case with this link [Snowflake Community](https://community.snowflake.com/s/article/How-To-Submit-a-Support-Case-in-Snowflake-Lodge) is the fastest way to get a response' days-before-issue-stale: ${{ inputs.staleDays }} diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 2f228781..8b4dcd37 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,13 +9,17 @@ Source code is also available at: # Release Notes +- v1.6.0(Not released) + + - support for installing with SQLAlchemy 2.0.x + - v1.5.4 - Add ability to set ORDER / NOORDER sequence on columns with IDENTITY - v1.5.3(April 16, 2024) - - Limit SQLAlchemy to < 2.0.0 before releasing version compatible with 2.0 + - Limit SQLAlchemy to < 2.0.0 before releasing version compatible with 2.0 - v1.5.2(April 11, 2024) diff --git a/pyproject.toml b/pyproject.toml index 3f95df46..d2316a44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Application Frameworks", "Topic :: Software Development :: Libraries :: Python Modules", ] -dependencies = ["SQLAlchemy>=1.4.19,<2.0.0", "snowflake-connector-python<4.0.0"] +dependencies = ["SQLAlchemy>=1.4.19", "snowflake-connector-python<4.0.0"] [tool.hatch.version] path = "src/snowflake/sqlalchemy/version.py" @@ -73,8 +73,14 @@ exclude = ["/.github"] packages = ["src/snowflake"] [tool.hatch.envs.default] +extra-dependencies = ["SQLAlchemy>=1.4.19,<2.0.0"] features = ["development", "pandas"] python = "3.8" +installer = "uv" + +[tool.hatch.envs.sa20] +extra-dependencies = ["SQLAlchemy>=1.4.19,<=2.1.0"] +python = "3.8" [tool.hatch.envs.default.env-vars] COVERAGE_FILE = "coverage.xml" @@ -82,10 +88,10 @@ SQLACHEMY_WARN_20 = "1" [tool.hatch.envs.default.scripts] check = "pre-commit run --all-files" -test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite" -test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite" -test-run_v20 = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite --run_v20_sqlalchemy" +test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" +test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite tests/" gh-cache-sum = "python -VV | sha256sum | cut -d' ' -f1" +check-import = "python -c 'import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)'" [tool.ruff] line-length = 88 diff --git a/snyk/requirements.txt b/snyk/requirements.txt index 3a77e0f9..0166d751 100644 --- a/snyk/requirements.txt +++ b/snyk/requirements.txt @@ -1,2 +1,2 @@ -SQLAlchemy>=1.4.19,<2.0.0 +SQLAlchemy>=1.4.19 snowflake-connector-python<4.0.0 diff --git a/snyk/requiremtnts.txt b/snyk/requiremtnts.txt new file mode 100644 index 00000000..a92c527e --- /dev/null +++ b/snyk/requiremtnts.txt @@ -0,0 +1,2 @@ +snowflake-connector-python<4.0.0 +SQLAlchemy>=1.4.19,<2.1.0 diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index e008c92f..1aaa881e 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -13,13 +13,14 @@ from sqlalchemy.orm import context from sqlalchemy.orm.context import _MapperEntity from sqlalchemy.schema import Sequence, Table -from sqlalchemy.sql import compiler, expression +from sqlalchemy.sql import compiler, expression, functions from sqlalchemy.sql.base import CompileState from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.selectable import Lateral, SelectState -from sqlalchemy.util.compat import string_types +from .compat import IS_VERSION_20, args_reducer, string_types from .custom_commands import AWSBucket, AzureContainer, ExternalStage +from .functions import flatten from .util import ( _find_left_clause_to_join_from, _set_connection_interpolate_empty_sequences, @@ -324,17 +325,9 @@ def _join_determine_implicit_left_side( return left, replace_from_obj_index, use_entity_index + @args_reducer(positions_to_drop=(6, 7)) def _join_left_to_right( - self, - entities_collection, - left, - right, - onclause, - prop, - create_aliases, - aliased_generation, - outerjoin, - full, + self, entities_collection, left, right, onclause, prop, outerjoin, full ): """given raw "left", "right", "onclause" parameters consumed from a particular key within _join(), add a real ORMJoin object to @@ -364,7 +357,7 @@ def _join_left_to_right( use_entity_index, ) = self._join_place_explicit_left_side(entities_collection, left) - if left is right and not create_aliases: + if left is right: raise sa_exc.InvalidRequestError( "Can't construct a join from %s to %s, they " "are the same entity" % (left, right) @@ -373,9 +366,15 @@ def _join_left_to_right( # the right side as given often needs to be adapted. additionally # a lot of things can be wrong with it. handle all that and # get back the new effective "right" side - r_info, right, onclause = self._join_check_and_adapt_right_side( - left, right, onclause, prop, create_aliases, aliased_generation - ) + + if IS_VERSION_20: + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop + ) + else: + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop, False, False + ) if not r_info.is_selectable: extra_criteria = self._get_extra_criteria(r_info) @@ -979,24 +978,23 @@ def visit_identity_column(self, identity, **kw): def get_identity_options(self, identity_options): text = [] if identity_options.increment is not None: - text.append(f"INCREMENT BY {identity_options.increment:d}") + text.append("INCREMENT BY %d" % identity_options.increment) if identity_options.start is not None: - text.append(f"START WITH {identity_options.start:d}") + text.append("START WITH %d" % identity_options.start) if identity_options.minvalue is not None: - text.append(f"MINVALUE {identity_options.minvalue:d}") + text.append("MINVALUE %d" % identity_options.minvalue) if identity_options.maxvalue is not None: - text.append(f"MAXVALUE {identity_options.maxvalue:d}") + text.append("MAXVALUE %d" % identity_options.maxvalue) if identity_options.nominvalue is not None: text.append("NO MINVALUE") if identity_options.nomaxvalue is not None: text.append("NO MAXVALUE") if identity_options.cache is not None: - text.append(f"CACHE {identity_options.cache:d}") + text.append("CACHE %d" % identity_options.cache) if identity_options.cycle is not None: text.append("CYCLE" if identity_options.cycle else "NO CYCLE") if identity_options.order is not None: text.append("ORDER" if identity_options.order else "NOORDER") - return " ".join(text) @@ -1066,3 +1064,5 @@ def visit_GEOMETRY(self, type_, **kw): construct_arguments = [(Table, {"clusterby": None})] + +functions.register_function("flatten", flatten) diff --git a/src/snowflake/sqlalchemy/compat.py b/src/snowflake/sqlalchemy/compat.py new file mode 100644 index 00000000..9e97e574 --- /dev/null +++ b/src/snowflake/sqlalchemy/compat.py @@ -0,0 +1,36 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +from __future__ import annotations + +import functools +from typing import Callable + +from sqlalchemy import __version__ as SA_VERSION +from sqlalchemy import util + +string_types = (str,) +returns_unicode = util.symbol("RETURNS_UNICODE") + +IS_VERSION_20 = tuple(int(v) for v in SA_VERSION.split(".")) >= (2, 0, 0) + + +def args_reducer(positions_to_drop: tuple): + """Removes args at positions provided in tuple positions_to_drop. + + For example tuple (3, 5) will remove items at third and fifth position. + Keep in mind that on class methods first postion is cls or self. + """ + + def fn_wrapper(fn: Callable): + @functools.wraps(fn) + def wrapper(*args): + reduced_args = args + if not IS_VERSION_20: + reduced_args = tuple( + arg for idx, arg in enumerate(args) if idx not in positions_to_drop + ) + fn(*reduced_args) + + return wrapper + + return fn_wrapper diff --git a/src/snowflake/sqlalchemy/custom_commands.py b/src/snowflake/sqlalchemy/custom_commands.py index cec16673..15585bd5 100644 --- a/src/snowflake/sqlalchemy/custom_commands.py +++ b/src/snowflake/sqlalchemy/custom_commands.py @@ -10,7 +10,8 @@ from sqlalchemy.sql.dml import UpdateBase from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.roles import FromClauseRole -from sqlalchemy.util.compat import string_types + +from .compat import string_types NoneType = type(None) diff --git a/src/snowflake/sqlalchemy/functions.py b/src/snowflake/sqlalchemy/functions.py new file mode 100644 index 00000000..c08aa734 --- /dev/null +++ b/src/snowflake/sqlalchemy/functions.py @@ -0,0 +1,16 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import warnings + +from sqlalchemy.sql import functions as sqlfunc + +FLATTEN_WARNING = "For backward compatibility params are not rendered." + + +class flatten(sqlfunc.GenericFunction): + name = "flatten" + + def __init__(self, *args, **kwargs): + warnings.warn(FLATTEN_WARNING, DeprecationWarning, stacklevel=2) + super().__init__(*args, **kwargs) diff --git a/src/snowflake/sqlalchemy/requirements.py b/src/snowflake/sqlalchemy/requirements.py index ea30a823..f2844804 100644 --- a/src/snowflake/sqlalchemy/requirements.py +++ b/src/snowflake/sqlalchemy/requirements.py @@ -289,9 +289,25 @@ def datetime_implicit_bound(self): # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. return exclusions.closed() + @property + def date_implicit_bound(self): + # Supporting this would require behavior breaking change to implicitly convert str to timestamp when binding + # parameters in string forms of timestamp values. + return exclusions.closed() + + @property + def time_implicit_bound(self): + # Supporting this would require behavior breaking change to implicitly convert str to timestamp when binding + # parameters in string forms of timestamp values. + return exclusions.closed() + @property def timestamp_microseconds_implicit_bound(self): # Supporting this would require behavior breaking change to implicitly convert str to timestamp when binding # parameters in string forms of timestamp values. # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. return exclusions.closed() + + @property + def array_type(self): + return exclusions.closed() diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 2e40d03c..04305a00 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -5,6 +5,7 @@ import operator from collections import defaultdict from functools import reduce +from typing import Any from urllib.parse import unquote_plus import sqlalchemy.types as sqltypes @@ -15,7 +16,6 @@ from sqlalchemy.schema import Table from sqlalchemy.sql import text from sqlalchemy.sql.elements import quoted_name -from sqlalchemy.sql.sqltypes import String from sqlalchemy.types import ( BIGINT, BINARY, @@ -40,6 +40,7 @@ from snowflake.connector import errors as sf_errors from snowflake.connector.connection import DEFAULT_CONFIGURATION from snowflake.connector.constants import UTF8 +from snowflake.sqlalchemy.compat import returns_unicode from .base import ( SnowflakeCompiler, @@ -63,7 +64,11 @@ _CUSTOM_Float, _CUSTOM_Time, ) -from .util import _update_connection_application_name, parse_url_boolean +from .util import ( + _update_connection_application_name, + parse_url_boolean, + parse_url_integer, +) colspecs = { Date: _CUSTOM_Date, @@ -134,7 +139,7 @@ class SnowflakeDialect(default.DefaultDialect): # unicode strings supports_unicode_statements = True supports_unicode_binds = True - returns_unicode_strings = String.RETURNS_UNICODE + returns_unicode_strings = returns_unicode description_encoding = None # No lastrowid support. See SNOW-11155 @@ -195,10 +200,34 @@ class SnowflakeDialect(default.DefaultDialect): @classmethod def dbapi(cls): + return cls.import_dbapi() + + @classmethod + def import_dbapi(cls): from snowflake import connector return connector + @staticmethod + def parse_query_param_type(name: str, value: Any) -> Any: + """Cast param value if possible to type defined in connector-python.""" + if not (maybe_type_configuration := DEFAULT_CONFIGURATION.get(name)): + return value + + _, expected_type = maybe_type_configuration + if not isinstance(expected_type, tuple): + expected_type = (expected_type,) + + if isinstance(value, expected_type): + return value + + elif bool in expected_type: + return parse_url_boolean(value) + elif int in expected_type: + return parse_url_integer(value) + else: + return value + def create_connect_args(self, url: URL): opts = url.translate_connect_args(username="user") if "database" in opts: @@ -235,47 +264,25 @@ def create_connect_args(self, url: URL): # URL sets the query parameter values as strings, we need to cast to expected types when necessary for name, value in query.items(): - maybe_type_configuration = DEFAULT_CONFIGURATION.get(name) - if ( - not maybe_type_configuration - ): # if the parameter is not found in the type mapping, pass it through as a string - opts[name] = value - continue - - (_, expected_type) = maybe_type_configuration - if not isinstance(expected_type, tuple): - expected_type = (expected_type,) - - if isinstance( - value, expected_type - ): # if the expected type is str, pass it through as a string - opts[name] = value - - elif ( - bool in expected_type - ): # if the expected type is bool, parse it and pass as a boolean - opts[name] = parse_url_boolean(value) - else: - # TODO: other types like int are stil passed through as string - # https://github.com/snowflakedb/snowflake-sqlalchemy/issues/447 - opts[name] = value + opts[name] = self.parse_query_param_type(name, value) return ([], opts) - def has_table(self, connection, table_name, schema=None): + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): """ Checks if the table exists """ return self._has_object(connection, "TABLE", table_name, schema) - def has_sequence(self, connection, sequence_name, schema=None): + @reflection.cache + def has_sequence(self, connection, sequence_name, schema=None, **kw): """ Checks if the sequence exists """ return self._has_object(connection, "SEQUENCE", sequence_name, schema) def _has_object(self, connection, object_type, object_name, schema=None): - full_name = self._denormalize_quote_join(schema, object_name) try: results = connection.execute( diff --git a/src/snowflake/sqlalchemy/util.py b/src/snowflake/sqlalchemy/util.py index 32e07373..a1aefff9 100644 --- a/src/snowflake/sqlalchemy/util.py +++ b/src/snowflake/sqlalchemy/util.py @@ -7,7 +7,7 @@ from typing import Any from urllib.parse import quote_plus -from sqlalchemy import exc, inspection, sql, util +from sqlalchemy import exc, inspection, sql from sqlalchemy.exc import NoForeignKeysError from sqlalchemy.orm.interfaces import MapperProperty from sqlalchemy.orm.util import _ORMJoin as sa_orm_util_ORMJoin @@ -19,6 +19,7 @@ from snowflake.connector.compat import IS_STR from snowflake.connector.connection import SnowflakeConnection +from snowflake.sqlalchemy import compat from ._constants import ( APPLICATION_NAME, @@ -124,6 +125,13 @@ def parse_url_boolean(value: str) -> bool: raise ValueError(f"Invalid boolean value detected: '{value}'") +def parse_url_integer(value: str) -> int: + try: + return int(value) + except ValueError as e: + raise ValueError(f"Invalid int value detected: '{value}") from e + + # handle Snowflake BCR bcr-1057 # the BCR impacts sqlalchemy.orm.context.ORMSelectCompileState and sqlalchemy.sql.selectable.SelectState # which used the 'sqlalchemy.util.preloaded.sql_util.find_left_clause_to_join_from' method that @@ -212,7 +220,7 @@ def __init__( # then the "_joined_from_info" concept can go left_orm_info = getattr(left, "_joined_from_info", left_info) self._joined_from_info = right_info - if isinstance(onclause, util.string_types): + if isinstance(onclause, compat.string_types): onclause = getattr(left_orm_info.entity, onclause) # #### diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index 61c9fc41..56509b7d 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -3,4 +3,4 @@ # # Update this for the versions # Don't change the forth version number from None -VERSION = "1.5.3" +VERSION = "1.6.0" diff --git a/tests/conftest.py b/tests/conftest.py index a9c2560a..d4dab3d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,21 +46,6 @@ TEST_SCHEMA = f"sqlalchemy_tests_{str(uuid.uuid4()).replace('-', '_')}" -def pytest_addoption(parser): - parser.addoption( - "--run_v20_sqlalchemy", - help="Use only 2.0 SQLAlchemy APIs, any legacy features (< 2.0) will not be supported." - "Turning on this option will set future flag to True on Engine and Session objects according to" - "the migration guide: https://docs.sqlalchemy.org/en/14/changelog/migration_20.html", - action="store_true", - ) - - -@pytest.fixture(scope="session") -def run_v20_sqlalchemy(pytestconfig): - return pytestconfig.option.run_v20_sqlalchemy - - @pytest.fixture(scope="session") def on_travis(): return os.getenv("TRAVIS", "").lower() == "true" @@ -160,20 +145,21 @@ def url_factory(**kwargs) -> URL: return URL(**url_params) -def get_engine(url: URL, run_v20_sqlalchemy=False, **engine_kwargs): +def get_engine(url: URL, **engine_kwargs): engine_params = { "poolclass": NullPool, - "future": run_v20_sqlalchemy, + "future": True, + "echo": True, } engine_params.update(engine_kwargs) - engine = create_engine(url, **engine_kwargs) + engine = create_engine(url, **engine_params) return engine @pytest.fixture() -def engine_testaccount(request, run_v20_sqlalchemy): +def engine_testaccount(request): url = url_factory() - engine = get_engine(url, run_v20_sqlalchemy=run_v20_sqlalchemy) + engine = get_engine(url) request.addfinalizer(engine.dispose) yield engine @@ -181,17 +167,17 @@ def engine_testaccount(request, run_v20_sqlalchemy): @pytest.fixture() def engine_testaccount_with_numpy(request): url = url_factory(numpy=True) - engine = get_engine(url, run_v20_sqlalchemy=run_v20_sqlalchemy) + engine = get_engine(url) request.addfinalizer(engine.dispose) yield engine @pytest.fixture() -def engine_testaccount_with_qmark(request, run_v20_sqlalchemy): +def engine_testaccount_with_qmark(request): snowflake.connector.paramstyle = "qmark" url = url_factory() - engine = get_engine(url, run_v20_sqlalchemy=run_v20_sqlalchemy) + engine = get_engine(url) request.addfinalizer(engine.dispose) yield engine diff --git a/tests/sqlalchemy_test_suite/conftest.py b/tests/sqlalchemy_test_suite/conftest.py index 31cd7c5c..f0464c7d 100644 --- a/tests/sqlalchemy_test_suite/conftest.py +++ b/tests/sqlalchemy_test_suite/conftest.py @@ -15,6 +15,7 @@ import snowflake.connector from snowflake.sqlalchemy import URL +from snowflake.sqlalchemy.compat import IS_VERSION_20 from ..conftest import get_db_parameters from ..util import random_string @@ -25,6 +26,12 @@ TEST_SCHEMA_2 = f"{TEST_SCHEMA}_2" +if IS_VERSION_20: + collect_ignore_glob = ["test_suite.py"] +else: + collect_ignore_glob = ["test_suite_20.py"] + + # patch sqlalchemy.testing.config.Confi.__init__ for schema name randomization # same schema name would result in conflict as we're running tests in parallel in the CI def config_patched__init__(self, db, db_opts, options, file_config): diff --git a/tests/sqlalchemy_test_suite/test_suite.py b/tests/sqlalchemy_test_suite/test_suite.py index d79e511e..643d1559 100644 --- a/tests/sqlalchemy_test_suite/test_suite.py +++ b/tests/sqlalchemy_test_suite/test_suite.py @@ -69,6 +69,10 @@ def test_empty_insert(self, connection): def test_empty_insert_multiple(self, connection): pass + @pytest.mark.skip("Snowflake does not support returning in insert.") + def test_no_results_for_non_returning_insert(self, connection, style, executemany): + pass + # 2. Patched Tests diff --git a/tests/sqlalchemy_test_suite/test_suite_20.py b/tests/sqlalchemy_test_suite/test_suite_20.py new file mode 100644 index 00000000..1f79c4e9 --- /dev/null +++ b/tests/sqlalchemy_test_suite/test_suite_20.py @@ -0,0 +1,205 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Integer, testing +from sqlalchemy.schema import Column, Sequence, Table +from sqlalchemy.testing import config +from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.testing.suite import ( + BizarroCharacterFKResolutionTest as _BizarroCharacterFKResolutionTest, +) +from sqlalchemy.testing.suite import ( + CompositeKeyReflectionTest as _CompositeKeyReflectionTest, +) +from sqlalchemy.testing.suite import DateTimeHistoricTest as _DateTimeHistoricTest +from sqlalchemy.testing.suite import FetchLimitOffsetTest as _FetchLimitOffsetTest +from sqlalchemy.testing.suite import HasSequenceTest as _HasSequenceTest +from sqlalchemy.testing.suite import InsertBehaviorTest as _InsertBehaviorTest +from sqlalchemy.testing.suite import LikeFunctionsTest as _LikeFunctionsTest +from sqlalchemy.testing.suite import LongNameBlowoutTest as _LongNameBlowoutTest +from sqlalchemy.testing.suite import SimpleUpdateDeleteTest as _SimpleUpdateDeleteTest +from sqlalchemy.testing.suite import TimeMicrosecondsTest as _TimeMicrosecondsTest +from sqlalchemy.testing.suite import TrueDivTest as _TrueDivTest +from sqlalchemy.testing.suite import * # noqa + +# 1. Unsupported by snowflake db + +del ComponentReflectionTest # require indexes not supported by snowflake +del HasIndexTest # require indexes not supported by snowflake +del QuotedNameArgumentTest # require indexes not supported by snowflake + + +class LongNameBlowoutTest(_LongNameBlowoutTest): + # The combination ("ix",) is removed due to Snowflake not supporting indexes + def ix(self, metadata, connection): + pytest.skip("ix required index feature not supported by Snowflake") + + +class FetchLimitOffsetTest(_FetchLimitOffsetTest): + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_bound_offset(self, connection): + pass + + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_simple_limit_expr_offset(self, connection): + pass + + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_simple_offset(self, connection): + pass + + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_simple_offset_zero(self, connection): + pass + + +class InsertBehaviorTest(_InsertBehaviorTest): + @pytest.mark.skip( + "Snowflake does not support inserting empty values, the value may be a literal or an expression." + ) + def test_empty_insert(self, connection): + pass + + @pytest.mark.skip( + "Snowflake does not support inserting empty values, The value may be a literal or an expression." + ) + def test_empty_insert_multiple(self, connection): + pass + + @pytest.mark.skip("Snowflake does not support returning in insert.") + def test_no_results_for_non_returning_insert(self, connection, style, executemany): + pass + + +# road to 2.0 +class TrueDivTest(_TrueDivTest): + @pytest.mark.skip("`//` not supported") + def test_floordiv_integer_bound(self, connection): + """Snowflake does not provide `//` arithmetic operator. + + https://docs.snowflake.com/en/sql-reference/operators-arithmetic. + """ + pass + + @pytest.mark.skip("`//` not supported") + def test_floordiv_integer(self, connection, left, right, expected): + """Snowflake does not provide `//` arithmetic operator. + + https://docs.snowflake.com/en/sql-reference/operators-arithmetic. + """ + pass + + +class TimeMicrosecondsTest(_TimeMicrosecondsTest): + def __init__(self): + super().__init__() + + +class DateTimeHistoricTest(_DateTimeHistoricTest): + def __init__(self): + super().__init__() + + +# 2. Patched Tests + + +class HasSequenceTest(_HasSequenceTest): + # Override the define_tables method as snowflake does not support 'nomaxvalue'/'nominvalue' + @classmethod + def define_tables(cls, metadata): + Sequence("user_id_seq", metadata=metadata) + # Replace Sequence("other_seq") creation as in the original test suite, + # the Sequence created with 'nomaxvalue' and 'nominvalue' + # which snowflake does not support: + # Sequence("other_seq", metadata=metadata, nomaxvalue=True, nominvalue=True) + Sequence("other_seq", metadata=metadata) + if testing.requires.schemas.enabled: + Sequence("user_id_seq", schema=config.test_schema, metadata=metadata) + Sequence("schema_seq", schema=config.test_schema, metadata=metadata) + Table( + "user_id_table", + metadata, + Column("id", Integer, primary_key=True), + ) + + +class LikeFunctionsTest(_LikeFunctionsTest): + @testing.requires.regexp_match + @testing.combinations( + ("a.cde.*", {1, 5, 6, 9}), + ("abc.*", {1, 5, 6, 9, 10}), + ("^abc.*", {1, 5, 6, 9, 10}), + (".*9cde.*", {8}), + ("^a.*", set(range(1, 11))), + (".*(b|c).*", set(range(1, 11))), + ("^(b|c).*", set()), + ) + def test_regexp_match(self, text, expected): + super().test_regexp_match(text, expected) + + def test_not_regexp_match(self): + col = self.tables.some_table.c.data + self._test(~col.regexp_match("a.cde.*"), {2, 3, 4, 7, 8, 10}) + + +class SimpleUpdateDeleteTest(_SimpleUpdateDeleteTest): + def test_update(self, connection): + t = self.tables.plain_pk + r = connection.execute(t.update().where(t.c.id == 2), dict(data="d2_new")) + assert not r.is_insert + # snowflake returns a row with numbers of rows updated and number of multi-joined rows updated + assert r.returns_rows + assert r.rowcount == 1 + + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (2, "d2_new"), (3, "d3")], + ) + + def test_delete(self, connection): + t = self.tables.plain_pk + r = connection.execute(t.delete().where(t.c.id == 2)) + assert not r.is_insert + # snowflake returns a row with number of rows deleted + assert r.returns_rows + assert r.rowcount == 1 + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (3, "d3")], + ) + + +class CompositeKeyReflectionTest(_CompositeKeyReflectionTest): + @pytest.mark.xfail(reason="Fixing this would require behavior breaking change.") + def test_fk_column_order(self): + # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. + super().test_fk_column_order() + + @pytest.mark.xfail(reason="Fixing this would require behavior breaking change.") + def test_pk_column_order(self): + # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. + super().test_pk_column_order() + + +class BizarroCharacterFKResolutionTest(_BizarroCharacterFKResolutionTest): + @testing.combinations( + ("id",), ("(3)",), ("col%p",), ("[brack]",), argnames="columnname" + ) + @testing.variation("use_composite", [True, False]) + @testing.combinations( + ("plain",), + ("(2)",), + ("[brackets]",), + argnames="tablename", + ) + def test_fk_ref(self, connection, metadata, use_composite, tablename, columnname): + super().test_fk_ref(connection, metadata, use_composite, tablename, columnname) diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 0fd75c38..40207b41 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -5,7 +5,7 @@ from sqlalchemy import Integer, String, and_, func, select from sqlalchemy.schema import DropColumnComment, DropTableComment from sqlalchemy.sql import column, quoted_name, table -from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing.assertions import AssertsCompiledSQL from snowflake.sqlalchemy import snowdialect diff --git a/tests/test_core.py b/tests/test_core.py index 6c8d7416..179133c8 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -34,7 +34,7 @@ inspect, text, ) -from sqlalchemy.exc import DBAPIError, NoSuchTableError +from sqlalchemy.exc import DBAPIError, NoSuchTableError, OperationalError from sqlalchemy.sql import and_, not_, or_, select import snowflake.connector.errors @@ -406,16 +406,6 @@ def test_insert_tables(engine_testaccount): str(users.join(addresses)) == "users JOIN addresses ON " "users.id = addresses.user_id" ) - assert ( - str( - users.join( - addresses, - addresses.c.email_address.like(users.c.name + "%"), - ) - ) - == "users JOIN addresses " - "ON addresses.email_address LIKE users.name || :name_1" - ) s = select(users.c.fullname).select_from( users.join( @@ -444,7 +434,7 @@ def test_table_does_not_exist(engine_testaccount): """ meta = MetaData() with pytest.raises(NoSuchTableError): - Table("does_not_exist", meta, autoload=True, autoload_with=engine_testaccount) + Table("does_not_exist", meta, autoload_with=engine_testaccount) @pytest.mark.skip( @@ -470,9 +460,7 @@ def test_reflextion(engine_testaccount): ) try: meta = MetaData() - user_reflected = Table( - "user", meta, autoload=True, autoload_with=engine_testaccount - ) + user_reflected = Table("user", meta, autoload_with=engine_testaccount) assert user_reflected.c == ["user.id", "user.name", "user.fullname"] finally: conn.execute("DROP TABLE IF EXISTS user") @@ -1071,28 +1059,15 @@ def harass_inspector(): assert outcome -@pytest.mark.timeout(15) -def test_region(): - engine = create_engine( - URL( - user="testuser", - password="testpassword", - account="testaccount", - region="eu-central-1", - login_timeout=5, - ) - ) - try: - engine.connect() - pytest.fail("should not run") - except Exception as ex: - assert ex.orig.errno == 250001 - assert "Failed to connect to DB" in ex.orig.msg - assert "testaccount.eu-central-1.snowflakecomputing.com" in ex.orig.msg - - -@pytest.mark.timeout(15) -def test_azure(): +@pytest.mark.timeout(10) +@pytest.mark.parametrize( + "region", + ( + pytest.param("eu-central-1", id="region"), + pytest.param("east-us-2.azure", id="azure"), + ), +) +def test_connection_timeout_error(region): engine = create_engine( URL( user="testuser", @@ -1102,13 +1077,13 @@ def test_azure(): login_timeout=5, ) ) - try: + + with pytest.raises(OperationalError) as excinfo: engine.connect() - pytest.fail("should not run") - except Exception as ex: - assert ex.orig.errno == 250001 - assert "Failed to connect to DB" in ex.orig.msg - assert "testaccount.east-us-2.azure.snowflakecomputing.com" in ex.orig.msg + + assert excinfo.value.orig.errno == 250001 + assert "Could not connect to Snowflake backend" in excinfo.value.orig.msg + assert region not in excinfo.value.orig.msg def test_load_dialect(): @@ -1535,11 +1510,16 @@ def test_too_many_columns_detection(engine_testaccount, db_parameters): metadata.create_all(engine_testaccount) inspector = inspect(engine_testaccount) # Do test - original_execute = inspector.bind.execute + connection = inspector.bind.connect() + original_execute = connection.execute + + too_many_columns_was_raised = False def mock_helper(command, *args, **kwargs): - if "_get_schema_columns" in command: + if "_get_schema_columns" in command.text: # Creating exception exactly how SQLAlchemy does + nonlocal too_many_columns_was_raised + too_many_columns_was_raised = True raise DBAPIError.instance( """ SELECT /* sqlalchemy:_get_schema_columns */ @@ -1571,9 +1551,12 @@ def mock_helper(command, *args, **kwargs): else: return original_execute(command, *args, **kwargs) - with patch.object(inspector.bind, "execute", side_effect=mock_helper): - column_metadata = inspector.get_columns("users", db_parameters["schema"]) + with patch.object(engine_testaccount, "connect") as conn: + conn.return_value = connection + with patch.object(connection, "execute", side_effect=mock_helper): + column_metadata = inspector.get_columns("users", db_parameters["schema"]) assert len(column_metadata) == 4 + assert too_many_columns_was_raised # Clean up metadata.drop_all(engine_testaccount) @@ -1615,9 +1598,7 @@ def test_column_type_schema(engine_testaccount): """ ) - table_reflected = Table( - table_name, MetaData(), autoload=True, autoload_with=conn - ) + table_reflected = Table(table_name, MetaData(), autoload_with=conn) columns = table_reflected.columns assert ( len(columns) == len(ischema_names_baseline) - 1 @@ -1638,9 +1619,7 @@ def test_result_type_and_value(engine_testaccount): ) """ ) - table_reflected = Table( - table_name, MetaData(), autoload=True, autoload_with=conn - ) + table_reflected = Table(table_name, MetaData(), autoload_with=conn) current_date = date.today() current_utctime = datetime.utcnow() current_localtime = pytz.utc.localize(current_utctime, is_dst=False).astimezone( diff --git a/tests/test_custom_functions.py b/tests/test_custom_functions.py new file mode 100644 index 00000000..2a1e1cb5 --- /dev/null +++ b/tests/test_custom_functions.py @@ -0,0 +1,25 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import pytest +from sqlalchemy import func + +from snowflake.sqlalchemy import snowdialect + + +def test_flatten_does_not_render_params(): + """This behavior is for backward compatibility. + + In previous version params were not rendered. + In future this behavior will change. + """ + flat = func.flatten("[1, 2]", outer=True) + res = flat.compile(dialect=snowdialect.dialect()) + + assert str(res) == "flatten(%(flatten_1)s)" + + +def test_flatten_emits_warning(): + expected_warning = "For backward compatibility params are not rendered." + with pytest.warns(DeprecationWarning, match=expected_warning): + func.flatten().compile(dialect=snowdialect.dialect()) diff --git a/tests/test_orm.py b/tests/test_orm.py index e485d737..f53cd708 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -20,7 +20,7 @@ from sqlalchemy.orm import Session, declarative_base, relationship -def test_basic_orm(engine_testaccount, run_v20_sqlalchemy): +def test_basic_orm(engine_testaccount): """ Tests declarative """ @@ -46,7 +46,6 @@ def __repr__(self): ed_user = User(name="ed", fullname="Edward Jones") session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(ed_user) our_user = session.query(User).filter_by(name="ed").first() @@ -56,7 +55,7 @@ def __repr__(self): Base.metadata.drop_all(engine_testaccount) -def test_orm_one_to_many_relationship(engine_testaccount, run_v20_sqlalchemy): +def test_orm_one_to_many_relationship(engine_testaccount): """ Tests One to Many relationship """ @@ -97,7 +96,6 @@ def __repr__(self): ] session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(jack) # cascade each Address into the Session as well session.commit() @@ -124,7 +122,7 @@ def __repr__(self): Base.metadata.drop_all(engine_testaccount) -def test_delete_cascade(engine_testaccount, run_v20_sqlalchemy): +def test_delete_cascade(engine_testaccount): """ Test delete cascade """ @@ -169,7 +167,6 @@ def __repr__(self): ] session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(jack) # cascade each Address into the Session as well session.commit() @@ -189,7 +186,7 @@ def __repr__(self): WIP """, ) -def test_orm_query(engine_testaccount, run_v20_sqlalchemy): +def test_orm_query(engine_testaccount): """ Tests ORM query """ @@ -210,7 +207,6 @@ def __repr__(self): # TODO: insert rows session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy # TODO: query.all() for name, fullname in session.query(User.name, User.fullname): @@ -220,7 +216,7 @@ def __repr__(self): # MultipleResultsFound if not one result -def test_schema_including_db(engine_testaccount, db_parameters, run_v20_sqlalchemy): +def test_schema_including_db(engine_testaccount, db_parameters): """ Test schema parameter including database separated by a dot. """ @@ -243,7 +239,6 @@ class User(Base): ed_user = User(name="ed", fullname="Edward Jones") session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(ed_user) ret_user = session.query(User.id, User.name).first() @@ -255,7 +250,7 @@ class User(Base): Base.metadata.drop_all(engine_testaccount) -def test_schema_including_dot(engine_testaccount, db_parameters, run_v20_sqlalchemy): +def test_schema_including_dot(engine_testaccount, db_parameters): """ Tests pseudo schema name including dot. """ @@ -276,7 +271,6 @@ class User(Base): fullname = Column(String) session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy query = session.query(User.id) assert str(query).startswith( 'SELECT {db}."{schema}.{schema}".{db}.users.id'.format( @@ -285,9 +279,7 @@ class User(Base): ) -def test_schema_translate_map( - engine_testaccount, db_parameters, sql_compiler, run_v20_sqlalchemy -): +def test_schema_translate_map(engine_testaccount, db_parameters): """ Test schema translate map execution option works replaces schema correctly """ @@ -310,7 +302,6 @@ class User(Base): schema_translate_map={schema_map: db_parameters["schema"]} ) as con: session = Session(bind=con) - session.future = run_v20_sqlalchemy with con.begin(): Base.metadata.create_all(con) try: @@ -367,18 +358,29 @@ class Department(Base): .select_from(Employee) .outerjoin(sub) ) - assert ( - str(query.compile(engine_testaccount)).replace("\n", "") - == "SELECT employees.employee_id, departments.department_id " + compiled_stmts = ( + # v1.x + "SELECT employees.employee_id, departments.department_id " "FROM departments, employees LEFT OUTER JOIN LATERAL " "(SELECT departments.department_id AS department_id, departments.name AS name " - "FROM departments) AS anon_1" + "FROM departments) AS anon_1", + # v2.x + "SELECT employees.employee_id, departments.department_id " + "FROM employees LEFT OUTER JOIN LATERAL " + "(SELECT departments.department_id AS department_id, departments.name AS name " + "FROM departments) AS anon_1, departments", ) + compiled_stmt = str(query.compile(engine_testaccount)).replace("\n", "") + assert compiled_stmt in compiled_stmts + with caplog.at_level(logging.DEBUG): assert [res for res in session.execute(query)] assert ( "SELECT employees.employee_id, departments.department_id FROM departments" in caplog.text + ) or ( + "SELECT employees.employee_id, departments.department_id FROM employees" + in caplog.text ) diff --git a/tests/test_pandas.py b/tests/test_pandas.py index ef64d65e..63cd6d0e 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -27,6 +27,7 @@ from snowflake.connector import ProgrammingError from snowflake.connector.pandas_tools import make_pd_writer, pd_writer +from snowflake.sqlalchemy.compat import IS_VERSION_20 def _create_users_addresses_tables(engine_testaccount, metadata): @@ -240,8 +241,8 @@ def test_timezone(db_parameters, engine_testaccount, engine_testaccount_with_num conn.exec_driver_sql(f"DROP TABLE {test_table_name};") -def test_pandas_writeback(engine_testaccount, run_v20_sqlalchemy): - if run_v20_sqlalchemy and sys.version_info < (3, 8): +def test_pandas_writeback(engine_testaccount): + if IS_VERSION_20 and sys.version_info < (3, 8): pytest.skip( "In Python 3.7, this test depends on pandas features of which the implementation is incompatible with sqlachemy 2.0, and pandas does not support Python 3.7 anymore." ) @@ -352,8 +353,8 @@ def test_pandas_invalid_make_pd_writer(engine_testaccount): ) -def test_percent_signs(engine_testaccount, run_v20_sqlalchemy): - if run_v20_sqlalchemy and sys.version_info < (3, 8): +def test_percent_signs(engine_testaccount): + if IS_VERSION_20 and sys.version_info < (3, 8): pytest.skip( "In Python 3.7, this test depends on pandas features of which the implementation is incompatible with sqlachemy 2.0, and pandas does not support Python 3.7 anymore." ) @@ -376,7 +377,7 @@ def test_percent_signs(engine_testaccount, run_v20_sqlalchemy): not_like_sql = f"select * from {table_name} where c2 not like '%b%'" like_sql = f"select * from {table_name} where c2 like '%b%'" calculate_sql = "SELECT 1600 % 400 AS a, 1599 % 400 as b" - if run_v20_sqlalchemy: + if IS_VERSION_20: not_like_sql = sqlalchemy.text(not_like_sql) like_sql = sqlalchemy.text(like_sql) calculate_sql = sqlalchemy.text(calculate_sql) diff --git a/tests/test_qmark.py b/tests/test_qmark.py index f98fa7d3..3761181a 100644 --- a/tests/test_qmark.py +++ b/tests/test_qmark.py @@ -12,11 +12,11 @@ THIS_DIR = os.path.dirname(os.path.realpath(__file__)) -def test_qmark_bulk_insert(run_v20_sqlalchemy, engine_testaccount_with_qmark): +def test_qmark_bulk_insert(engine_testaccount_with_qmark): """ Bulk insert using qmark paramstyle """ - if run_v20_sqlalchemy and sys.version_info < (3, 8): + if sys.version_info < (3, 8): pytest.skip( "In Python 3.7, this test depends on pandas features of which the implementation is incompatible with sqlachemy 2.0, and pandas does not support Python 3.7 anymore." ) diff --git a/tox.ini b/tox.ini index 0c1cb483..7f605627 100644 --- a/tox.ini +++ b/tox.ini @@ -34,7 +34,7 @@ passenv = setenv = COVERAGE_FILE = {env:COVERAGE_FILE:{toxworkdir}/.coverage.{envname}} SQLALCHEMY_WARN_20 = 1 - ci: SNOWFLAKE_PYTEST_OPTS = -vvv + ci: SNOWFLAKE_PYTEST_OPTS = -vvv --tb=long commands = pytest \ {env:SNOWFLAKE_PYTEST_OPTS:} \ --cov "snowflake.sqlalchemy" \ @@ -44,12 +44,6 @@ commands = pytest \ --cov "snowflake.sqlalchemy" --cov-append \ --junitxml {toxworkdir}/junit_{envname}.xml \ {posargs:tests/sqlalchemy_test_suite} - pytest \ - {env:SNOWFLAKE_PYTEST_OPTS:} \ - --cov "snowflake.sqlalchemy" --cov-append \ - --junitxml {toxworkdir}/junit_{envname}.xml \ - --run_v20_sqlalchemy \ - {posargs:tests} [testenv:.pkg_external] deps = build @@ -86,7 +80,7 @@ commands = pre-commit run --all-files python -c 'import pathlib; print("hint: run \{\} install to add checks as pre-commit hook".format(pathlib.Path(r"{envdir}") / "bin" / "pre-commit"))' [pytest] -addopts = -ra --strict-markers --ignore=tests/sqlalchemy_test_suite +addopts = -ra --ignore=tests/sqlalchemy_test_suite junit_family = legacy log_level = info markers = From 423b8c13ec23d6d63d39f4019bc0f1caa97909ac Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Wed, 3 Jul 2024 14:46:35 +0200 Subject: [PATCH 28/74] SNOW-1516075: use SQLALchemy 2.0 as default dependency (#511) * SNOW-1516075: use SQLALchemy 2.0 as default dependency --- .github/workflows/build_test.yml | 20 ++++++++++---------- DESCRIPTION.md | 3 ++- README.md | 31 +++++++++++++++++++++---------- ci/build.sh | 13 ++++++++----- ci/test_linux.sh | 18 +++++++++--------- pyproject.toml | 15 ++++++++++++--- 6 files changed, 62 insertions(+), 38 deletions(-) diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index 3baa6a0d..d7f7832b 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -52,7 +52,7 @@ jobs: strategy: fail-fast: true matrix: - hatch-env: [default, sa20] + hatch-env: [default, sa14] steps: - uses: actions/checkout@v4 with: @@ -165,8 +165,8 @@ jobs: path: | ./coverage.xml - test-dialect-v20: - name: Test dialect v20 ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + test-dialect-v14: + name: Test dialect v14 ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} needs: [ lint, build-install ] runs-on: ${{ matrix.os }} strategy: @@ -204,15 +204,15 @@ jobs: python -m uv pip install -U hatch python -m hatch env create default - name: Run tests - run: hatch run sa20:test-dialect + run: hatch run sa14:test-dialect - uses: actions/upload-artifact@v4 with: - name: coverage.xml_dialect-v20-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + name: coverage.xml_dialect-v14-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} path: | ./coverage.xml - test-dialect-compatibility-v20: - name: Test dialect v20 compatibility ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + test-dialect-compatibility-v14: + name: Test dialect v14 compatibility ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} needs: lint runs-on: ${{ matrix.os }} strategy: @@ -250,17 +250,17 @@ jobs: gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py - name: Run tests - run: hatch run sa20:test-dialect-compatibility + run: hatch run sa14:test-dialect-compatibility - uses: actions/upload-artifact@v4 with: - name: coverage.xml_dialect-v20-compatibility-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + name: coverage.xml_dialect-v14-compatibility-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} path: | ./coverage.xml combine-coverage: name: Combine coverage if: ${{ success() || failure() }} - needs: [test-dialect, test-dialect-compatibility, test-dialect-v20, test-dialect-compatibility-v20] + needs: [test-dialect, test-dialect-compatibility, test-dialect-v14, test-dialect-compatibility-v14] runs-on: ubuntu-latest steps: - name: Set up Python diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 8b4dcd37..782c426d 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,9 +9,10 @@ Source code is also available at: # Release Notes -- v1.6.0(Not released) +- v1.6.0(July 4, 2024) - support for installing with SQLAlchemy 2.0.x + - use `hatch` & `uv` for managing project virtual environments - v1.5.4 diff --git a/README.md b/README.md index 0c75854e..c428353f 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ containing special characters need to be URL encoded to be parsed correctly. Thi characters could lead to authentication failure. The encoding for the password can be generated using `urllib.parse`: + ```python import urllib.parse urllib.parse.quote("kx@% jj5/g") @@ -111,6 +112,7 @@ urllib.parse.quote("kx@% jj5/g") To create an engine with the proper encodings, either manually constructing the url string by formatting or taking advantage of the `snowflake.sqlalchemy.URL` helper method: + ```python import urllib.parse from snowflake.sqlalchemy import URL @@ -191,14 +193,23 @@ engine = create_engine(...) engine.execute() engine.dispose() -# Do this. +# Better. engine = create_engine(...) connection = engine.connect() try: - connection.execute() + connection.execute(text()) finally: connection.close() engine.dispose() + +# Best +try: + with engine.connext() as connection: + connection.execute(text()) + # or + connection.exec_driver_sql() +finally: + engine.dispose() ``` ### Auto-increment Behavior @@ -242,14 +253,14 @@ engine = create_engine(URL( specific_date = np.datetime64('2016-03-04T12:03:05.123456789Z') -connection = engine.connect() -connection.execute( - "CREATE OR REPLACE TABLE ts_tbl(c1 TIMESTAMP_NTZ)") -connection.execute( - "INSERT INTO ts_tbl(c1) values(%s)", (specific_date,) -) -df = pd.read_sql_query("SELECT * FROM ts_tbl", engine) -assert df.c1.values[0] == specific_date +with engine.connect() as connection: + connection.exec_driver_sql( + "CREATE OR REPLACE TABLE ts_tbl(c1 TIMESTAMP_NTZ)") + connection.exec_driver_sql( + "INSERT INTO ts_tbl(c1) values(%s)", (specific_date,) + ) + df = pd.read_sql_query("SELECT * FROM ts_tbl", connection) + assert df.c1.values[0] == specific_date ``` The following `NumPy` data types are supported: diff --git a/ci/build.sh b/ci/build.sh index 4229506d..85d67df7 100755 --- a/ci/build.sh +++ b/ci/build.sh @@ -3,7 +3,7 @@ # Build snowflake-sqlalchemy set -o pipefail -PYTHON="python3.7" +PYTHON="python3.8" THIS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" SQLALCHEMY_DIR="$(dirname "${THIS_DIR}")" DIST_DIR="${SQLALCHEMY_DIR}/dist" @@ -11,8 +11,8 @@ DIST_DIR="${SQLALCHEMY_DIR}/dist" cd "$SQLALCHEMY_DIR" # Clean up previously built DIST_DIR if [ -d "${DIST_DIR}" ]; then - echo "[WARN] ${DIST_DIR} already existing, deleting it..." - rm -rf "${DIST_DIR}" + echo "[WARN] ${DIST_DIR} already existing, deleting it..." + rm -rf "${DIST_DIR}" fi # Constants and setup @@ -20,5 +20,8 @@ fi echo "[Info] Building snowflake-sqlalchemy with $PYTHON" # Clean up possible build artifacts rm -rf build generated_version.py -${PYTHON} -m pip install --upgrade pip setuptools wheel build -${PYTHON} -m build --outdir ${DIST_DIR} . +# ${PYTHON} -m pip install --upgrade pip setuptools wheel build +# ${PYTHON} -m build --outdir ${DIST_DIR} . +export UV_NO_CACHE=true +${PYTHON} -m pip install uv hatch +${PYTHON} -m hatch build diff --git a/ci/test_linux.sh b/ci/test_linux.sh index 695251e6..f5afc4fb 100755 --- a/ci/test_linux.sh +++ b/ci/test_linux.sh @@ -6,9 +6,9 @@ # - This script assumes that ../dist/repaired_wheels has the wheel(s) built for all versions to be tested # - This is the script that test_docker.sh runs inside of the docker container -PYTHON_VERSIONS="${1:-3.7 3.8 3.9 3.10 3.11}" -THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -SQLALCHEMY_DIR="$( dirname "${THIS_DIR}")" +PYTHON_VERSIONS="${1:-3.8 3.9 3.10 3.11}" +THIS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SQLALCHEMY_DIR="$(dirname "${THIS_DIR}")" # Install one copy of tox python3 -m pip install -U tox @@ -16,10 +16,10 @@ python3 -m pip install -U tox # Run tests cd $SQLALCHEMY_DIR for PYTHON_VERSION in ${PYTHON_VERSIONS}; do - echo "[Info] Testing with ${PYTHON_VERSION}" - SHORT_VERSION=$(python3 -c "print('${PYTHON_VERSION}'.replace('.', ''))") - SQLALCHEMY_WHL=$(ls $SQLALCHEMY_DIR/dist/snowflake_sqlalchemy-*-py2.py3-none-any.whl | sort -r | head -n 1) - TEST_ENVLIST=fix_lint,py${SHORT_VERSION}-ci,py${SHORT_VERSION}-coverage - echo "[Info] Running tox for ${TEST_ENVLIST}" - python3 -m tox -e ${TEST_ENVLIST} --installpkg ${SQLALCHEMY_WHL} + echo "[Info] Testing with ${PYTHON_VERSION}" + SHORT_VERSION=$(python3 -c "print('${PYTHON_VERSION}'.replace('.', ''))") + SQLALCHEMY_WHL=$(ls $SQLALCHEMY_DIR/dist/snowflake_sqlalchemy-*-py3-none-any.whl | sort -r | head -n 1) + TEST_ENVLIST=fix_lint,py${SHORT_VERSION}-ci,py${SHORT_VERSION}-coverage + echo "[Info] Running tox for ${TEST_ENVLIST}" + python3 -m tox -e ${TEST_ENVLIST} --installpkg ${SQLALCHEMY_WHL} done diff --git a/pyproject.toml b/pyproject.toml index d2316a44..58544017 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,13 +73,14 @@ exclude = ["/.github"] packages = ["src/snowflake"] [tool.hatch.envs.default] -extra-dependencies = ["SQLAlchemy>=1.4.19,<2.0.0"] +extra-dependencies = ["SQLAlchemy>=1.4.19,<2.1.0"] features = ["development", "pandas"] python = "3.8" installer = "uv" -[tool.hatch.envs.sa20] -extra-dependencies = ["SQLAlchemy>=1.4.19,<=2.1.0"] +[tool.hatch.envs.sa14] +extra-dependencies = ["SQLAlchemy>=1.4.19,<2.0.0"] +features = ["development", "pandas"] python = "3.8" [tool.hatch.envs.default.env-vars] @@ -93,6 +94,14 @@ test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalch gh-cache-sum = "python -VV | sha256sum | cut -d' ' -f1" check-import = "python -c 'import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)'" +[[tool.hatch.envs.release.matrix]] +python = ["3.8", "3.9", "3.10", "3.11", "3.12"] +features = ["development", "pandas"] + +[tool.hatch.envs.release.scripts] +test-dialect = "pytest -ra -vvv --tb=short --ignore=tests/sqlalchemy_test_suite tests/" +test-compatibility = "pytest -ra -vvv --tb=short tests/sqlalchemy_test_suite tests/" + [tool.ruff] line-length = 88 From 71308ce07465e106222923a31435e29dd022f2f5 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Thu, 4 Jul 2024 10:31:43 +0200 Subject: [PATCH 29/74] SNOW-1516075: set release date to July 8th (#514) --- DESCRIPTION.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 782c426d..79971c53 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,7 +9,7 @@ Source code is also available at: # Release Notes -- v1.6.0(July 4, 2024) +- v1.6.0(July 8, 2024) - support for installing with SQLAlchemy 2.0.x - use `hatch` & `uv` for managing project virtual environments From bde2372c3a79ac799a318369fec34c48112758ec Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Thu, 4 Jul 2024 12:29:28 +0200 Subject: [PATCH 30/74] SNOW-1519492: add export PATH in build.sh script (#516) --- ci/build.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ci/build.sh b/ci/build.sh index 85d67df7..b63c8e01 100755 --- a/ci/build.sh +++ b/ci/build.sh @@ -16,12 +16,11 @@ if [ -d "${DIST_DIR}" ]; then fi # Constants and setup +export PATH=$PATH:$HOME/.local/bin echo "[Info] Building snowflake-sqlalchemy with $PYTHON" # Clean up possible build artifacts rm -rf build generated_version.py -# ${PYTHON} -m pip install --upgrade pip setuptools wheel build -# ${PYTHON} -m build --outdir ${DIST_DIR} . export UV_NO_CACHE=true ${PYTHON} -m pip install uv hatch ${PYTHON} -m hatch build From 411ae559f5d6402e4d01d2b07a5e1011153292ed Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Thu, 4 Jul 2024 15:59:00 +0200 Subject: [PATCH 31/74] SNOW-1519635: skip dialect tests in snowflake tests (#517) --- pyproject.toml | 2 +- tox.ini | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 58544017..9cdd9fb4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,7 @@ SQLACHEMY_WARN_20 = "1" [tool.hatch.envs.default.scripts] check = "pre-commit run --all-files" test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" -test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite tests/" +test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite" gh-cache-sum = "python -VV | sha256sum | cut -d' ' -f1" check-import = "python -c 'import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)'" diff --git a/tox.ini b/tox.ini index 7f605627..2f7360a6 100644 --- a/tox.ini +++ b/tox.ini @@ -39,6 +39,7 @@ commands = pytest \ {env:SNOWFLAKE_PYTEST_OPTS:} \ --cov "snowflake.sqlalchemy" \ --junitxml {toxworkdir}/junit_{envname}.xml \ + --ignore=tests/sqlalchemy_test_suite \ {posargs:tests} pytest {env:SNOWFLAKE_PYTEST_OPTS:} \ --cov "snowflake.sqlalchemy" --cov-append \ @@ -74,6 +75,7 @@ passenv = PROGRAMDATA deps = {[testenv]deps} + tomlkit >= 1.12.0 pre-commit >= 2.9.0 skip_install = True commands = pre-commit run --all-files From 64fafbb7e94c5c256f51c918182d4e70412d4195 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Thu, 4 Jul 2024 16:58:50 +0200 Subject: [PATCH 32/74] SNOW-1519766: drop tomlkit version for fix_lint (#518) --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 2f7360a6..102e2273 100644 --- a/tox.ini +++ b/tox.ini @@ -75,7 +75,7 @@ passenv = PROGRAMDATA deps = {[testenv]deps} - tomlkit >= 1.12.0 + tomlkit pre-commit >= 2.9.0 skip_install = True commands = pre-commit run --all-files From 305d2980cec33d37dbb9418684adcef72a3eaf76 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Tue, 9 Jul 2024 12:38:23 +0200 Subject: [PATCH 33/74] mraba/update-python_publish-workflow (#520) SNOW-1519875: update publish branch workflow for v1.6.1 --- .github/workflows/python-publish.yml | 8 ++++---- DESCRIPTION.md | 4 ++++ src/snowflake/sqlalchemy/version.py | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 23116e7a..0a9f22bd 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -30,12 +30,12 @@ jobs: python-version: '3.x' - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install build + python -m pip install -U uv + python -m uv pip install -U hatch - name: Build package - run: python -m build + run: python -m hatch build --clean - name: Publish package - uses: pypa/gh-action-pypi-publish@e53eb8b103ffcb59469888563dc324e3c8ba6f06 + uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 79971c53..38cd70f7 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,6 +9,10 @@ Source code is also available at: # Release Notes +- v1.6.1(July 9, 2024) + + - Update internal project workflow with pypi publishing + - v1.6.0(July 8, 2024) - support for installing with SQLAlchemy 2.0.x diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index 56509b7d..d90f706b 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -3,4 +3,4 @@ # # Update this for the versions # Don't change the forth version number from None -VERSION = "1.6.0" +VERSION = "1.6.1" From dd7fc8aca7460fc669c7bb6667e45c83f615865e Mon Sep 17 00:00:00 2001 From: Angel Antonio Avalos Cisneros Date: Thu, 8 Aug 2024 14:53:13 -0700 Subject: [PATCH 34/74] sign artifacts before publish (#522) --- .github/workflows/python-publish.yml | 43 +++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 0a9f22bd..a1eb1a0c 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -13,7 +13,8 @@ on: types: [published] permissions: - contents: read + contents: write + id-token: write jobs: deploy: @@ -34,6 +35,46 @@ jobs: python -m uv pip install -U hatch - name: Build package run: python -m hatch build --clean + - name: List artifacts + run: ls ./dist + - name: Install sigstore + run: python -m pip install sigstore + - name: Signing + run: | + for dist in dist/*; do + dist_base="$(basename "${dist}")" + echo "dist: ${dist}" + echo "dist_base: ${dist_base}" + python -m \ + sigstore sign "${dist}" \ + --output-signature "${dist_base}.sig" \ + --output-certificate "${dist_base}.crt" \ + --bundle "${dist_base}.sigstore" + + # Verify using `.sig` `.crt` pair; + python -m \ + sigstore verify identity "${dist}" \ + --signature "${dist_base}.sig" \ + --cert "${dist_base}.crt" \ + --cert-oidc-issuer https://token.actions.githubusercontent.com \ + --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/build_and_sign_demand.yml@${GITHUB_REF} + + # Verify using `.sigstore` bundle; + python -m \ + sigstore verify identity "${dist}" \ + --bundle "${dist_base}.sigstore" \ + --cert-oidc-issuer https://token.actions.githubusercontent.com \ + --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/build_and_sign_demand.yml@${GITHUB_REF} + done + - name: List artifacts after sign + run: ls ./dist + - name: Copy files to release + run: | + gh release upload ${{ github.event.release.tag_name }} *.sigstore + gh release upload ${{ github.event.release.tag_name }} *.sig + gh release upload ${{ github.event.release.tag_name }} *.crt + env: + GITHUB_TOKEN: ${{ github.TOKEN }} - name: Publish package uses: pypa/gh-action-pypi-publish@release/v1 with: From fd8c29a08696feab256d51fc2c42773cfd74c590 Mon Sep 17 00:00:00 2001 From: Angel Antonio Avalos Cisneros Date: Mon, 19 Aug 2024 13:08:35 -0700 Subject: [PATCH 35/74] Update python-publish.yml (#526) --- .github/workflows/python-publish.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index a1eb1a0c..52f43106 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -57,14 +57,14 @@ jobs: --signature "${dist_base}.sig" \ --cert "${dist_base}.crt" \ --cert-oidc-issuer https://token.actions.githubusercontent.com \ - --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/build_and_sign_demand.yml@${GITHUB_REF} + --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/python-publish.yml@${GITHUB_REF} # Verify using `.sigstore` bundle; python -m \ sigstore verify identity "${dist}" \ --bundle "${dist_base}.sigstore" \ --cert-oidc-issuer https://token.actions.githubusercontent.com \ - --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/build_and_sign_demand.yml@${GITHUB_REF} + --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/python-publish.yml@${GITHUB_REF} done - name: List artifacts after sign run: ls ./dist From 957a4699cf21151070071969cd4996735e33001b Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Mon, 9 Sep 2024 03:38:57 -0600 Subject: [PATCH 36/74] Add tests to try max lob size in memory feature (#529) * Add test with large object --- pyproject.toml | 2 ++ tests/test_custom_types.py | 33 ++++++++++++++++++++++++++++++++- tests/test_orm.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9cdd9fb4..99aacbee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,6 +109,7 @@ line-length = 88 line-length = 88 [tool.pytest.ini_options] +addopts = "-m 'not feature_max_lob_size'" markers = [ # Optional dependency groups markers "lambda: AWS lambda tests", @@ -126,4 +127,5 @@ markers = [ "timeout: tests that need a timeout time", "internal: tests that could but should only run on our internal CI", "external: tests that could but should only run on our external CI", + "feature_max_lob_size: tests that could but should only run on our external CI", ] diff --git a/tests/test_custom_types.py b/tests/test_custom_types.py index a997ffe8..3961a5d3 100644 --- a/tests/test_custom_types.py +++ b/tests/test_custom_types.py @@ -2,7 +2,10 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from snowflake.sqlalchemy import custom_types +import pytest +from sqlalchemy import Column, Integer, MetaData, Table, text + +from snowflake.sqlalchemy import TEXT, custom_types def test_string_conversions(): @@ -34,3 +37,31 @@ def test_string_conversions(): sample = getattr(custom_types, type_)() if type_ in sf_custom_types: assert type_ == str(sample) + + +@pytest.mark.feature_max_lob_size +def test_create_table_with_text_type(engine_testaccount): + metadata = MetaData() + table_name = "test_max_lob_size_0" + test_max_lob_size = Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("full_name", TEXT(), server_default=text("id::varchar")), + ) + + metadata.create_all(engine_testaccount) + try: + assert test_max_lob_size is not None + + with engine_testaccount.connect() as conn: + with conn.begin(): + query = text(f"SELECT GET_DDL('TABLE', '{table_name}')") + result = conn.execute(query) + row = str(result.mappings().fetchone()) + assert ( + "VARCHAR(134217728)" in row + ), f"Expected VARCHAR(134217728) in {row}" + + finally: + test_max_lob_size.drop(engine_testaccount) diff --git a/tests/test_orm.py b/tests/test_orm.py index f53cd708..f51c9a90 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -7,6 +7,7 @@ import pytest from sqlalchemy import ( + TEXT, Column, Enum, ForeignKey, @@ -413,3 +414,34 @@ class Employee(Base): '[SELECT "Employee".uid FROM "Employee" JOIN LATERAL flatten(PARSE_JSON("Employee"' in caplog.text ) + + +@pytest.mark.feature_max_lob_size +def test_basic_table_with_large_lob_size_in_memory(engine_testaccount, sql_compiler): + Base = declarative_base() + + class User(Base): + __tablename__ = "user" + + id = Column(Integer, primary_key=True) + full_name = Column(TEXT(), server_default=text("id::varchar")) + + def __repr__(self): + return f"" + + Base.metadata.create_all(engine_testaccount) + + try: + assert User.__table__ is not None + + with engine_testaccount.connect() as conn: + with conn.begin(): + query = text(f"SELECT GET_DDL('TABLE', '{User.__tablename__}')") + result = conn.execute(query) + row = str(result.mappings().fetchone()) + assert ( + "VARCHAR(134217728)" in row + ), f"Expected VARCHAR(134217728) in {row}" + + finally: + Base.metadata.drop_all(engine_testaccount) From 16ad10fbb90d2fc98d3ab7218fe41e2ac708db33 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Wed, 2 Oct 2024 10:49:48 +0200 Subject: [PATCH 37/74] SNOW-1655751: register overwritten functions under `snowflake` namespace (#532) * SNOW-1655751: register overwritten functions under `snowflake` namespace --- DESCRIPTION.md | 7 +++++-- src/snowflake/sqlalchemy/base.py | 3 +-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 38cd70f7..67b50ab0 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,6 +9,9 @@ Source code is also available at: # Release Notes +- 1.6.2 + - Fixed SAWarning when registering functions with existing name in default namespace + - v1.6.1(July 9, 2024) - Update internal project workflow with pypi publishing @@ -24,7 +27,7 @@ Source code is also available at: - v1.5.3(April 16, 2024) - - Limit SQLAlchemy to < 2.0.0 before releasing version compatible with 2.0 + - Limit SQLAlchemy to < 2.0.0 before releasing version compatible with 2.0 - v1.5.2(April 11, 2024) @@ -33,7 +36,7 @@ Source code is also available at: - v1.5.1(November 03, 2023) - - Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057. + - Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check . - Fixed credentials with `externalbrowser` authentication not caching due to incorrect parsing of boolean query parameters. - This fixes other boolean parameter passing to driver as well. diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 1aaa881e..3e504f7b 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -184,7 +184,6 @@ def _join_determine_implicit_left_side(self, raw_columns, left, right, onclause) [element._from_objects for element in statement._where_criteria] ), ): - potential[from_clause] = () all_clauses = list(potential.keys()) @@ -1065,4 +1064,4 @@ def visit_GEOMETRY(self, type_, **kw): construct_arguments = [(Table, {"clusterby": None})] -functions.register_function("flatten", flatten) +functions.register_function("flatten", flatten, "snowflake") From b5af4e31611b4ac9e4467eee4ad6235a4b6d8d57 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Wed, 2 Oct 2024 09:49:44 -0600 Subject: [PATCH 38/74] Adding support for snowflake dynamic tables to SqlAlchemy Core (#531) * Add support for dynamic tables * Update DESCRIPTION.md * Remove unnesary code to support dynamic tables in sqlalchemy 1.4 * Fix bug to support sqlalchemy v1.4 * Add syrupy * Remove non necessary parameter * Add snapshots --- DESCRIPTION.md | 5 +- pyproject.toml | 1 + src/snowflake/sqlalchemy/__init__.py | 7 + src/snowflake/sqlalchemy/_constants.py | 1 + src/snowflake/sqlalchemy/base.py | 30 ++- src/snowflake/sqlalchemy/snowdialect.py | 3 +- src/snowflake/sqlalchemy/sql/__init__.py | 3 + .../sqlalchemy/sql/custom_schema/__init__.py | 6 + .../sql/custom_schema/custom_table_base.py | 51 +++++ .../sql/custom_schema/dynamic_table.py | 86 +++++++++ .../sql/custom_schema/options/__init__.py | 9 + .../sql/custom_schema/options/as_query.py | 62 ++++++ .../sql/custom_schema/options/table_option.py | 26 +++ .../options/table_option_base.py | 30 +++ .../sql/custom_schema/options/target_lag.py | 60 ++++++ .../sql/custom_schema/options/warehouse.py | 51 +++++ .../sql/custom_schema/table_from_query.py | 60 ++++++ .../test_compile_dynamic_table.ambr | 13 ++ .../test_reflect_dynamic_table.ambr | 4 + tests/test_compile_dynamic_table.py | 177 ++++++++++++++++++ tests/test_create_dynamic_table.py | 93 +++++++++ tests/test_reflect_dynamic_table.py | 88 +++++++++ 22 files changed, 860 insertions(+), 6 deletions(-) create mode 100644 src/snowflake/sqlalchemy/sql/__init__.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/__init__.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py create mode 100644 tests/__snapshots__/test_compile_dynamic_table.ambr create mode 100644 tests/__snapshots__/test_reflect_dynamic_table.ambr create mode 100644 tests/test_compile_dynamic_table.py create mode 100644 tests/test_create_dynamic_table.py create mode 100644 tests/test_reflect_dynamic_table.py diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 67b50ab0..205685f1 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,7 +9,10 @@ Source code is also available at: # Release Notes -- 1.6.2 +- (Unreleased) + + - Add support for dynamic tables and required options + - Fixed SAWarning when registering functions with existing name in default namespace - Fixed SAWarning when registering functions with existing name in default namespace - v1.6.1(July 9, 2024) diff --git a/pyproject.toml b/pyproject.toml index 99aacbee..4fe06a9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ development = [ "pytz", "numpy", "mock", + "syrupy==4.6.1", ] pandas = ["snowflake-connector-python[pandas]"] diff --git a/src/snowflake/sqlalchemy/__init__.py b/src/snowflake/sqlalchemy/__init__.py index 9df6aaa2..30cd140c 100644 --- a/src/snowflake/sqlalchemy/__init__.py +++ b/src/snowflake/sqlalchemy/__init__.py @@ -61,6 +61,8 @@ VARBINARY, VARIANT, ) +from .sql.custom_schema import DynamicTable +from .sql.custom_schema.options import AsQuery, TargetLag, TimeUnit, Warehouse from .util import _url as URL base.dialect = dialect = snowdialect.dialect @@ -113,4 +115,9 @@ "ExternalStage", "CreateStage", "CreateFileFormat", + "DynamicTable", + "AsQuery", + "TargetLag", + "TimeUnit", + "Warehouse", ) diff --git a/src/snowflake/sqlalchemy/_constants.py b/src/snowflake/sqlalchemy/_constants.py index 46af4454..839745ee 100644 --- a/src/snowflake/sqlalchemy/_constants.py +++ b/src/snowflake/sqlalchemy/_constants.py @@ -10,3 +10,4 @@ APPLICATION_NAME = "SnowflakeSQLAlchemy" SNOWFLAKE_SQLALCHEMY_VERSION = VERSION +DIALECT_NAME = "snowflake" diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 3e504f7b..56631728 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -18,9 +18,16 @@ from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.selectable import Lateral, SelectState -from .compat import IS_VERSION_20, args_reducer, string_types -from .custom_commands import AWSBucket, AzureContainer, ExternalStage +from snowflake.sqlalchemy._constants import DIALECT_NAME +from snowflake.sqlalchemy.compat import IS_VERSION_20, args_reducer, string_types +from snowflake.sqlalchemy.custom_commands import ( + AWSBucket, + AzureContainer, + ExternalStage, +) + from .functions import flatten +from .sql.custom_schema.options.table_option_base import TableOptionBase from .util import ( _find_left_clause_to_join_from, _set_connection_interpolate_empty_sequences, @@ -878,7 +885,7 @@ def get_column_specification(self, column, **kwargs): return " ".join(colspec) - def post_create_table(self, table): + def handle_cluster_by(self, table): """ Handles snowflake-specific ``CREATE TABLE ... CLUSTER BY`` syntax. @@ -908,7 +915,7 @@ def post_create_table(self, table): """ text = "" - info = table.dialect_options["snowflake"] + info = table.dialect_options[DIALECT_NAME] cluster = info.get("clusterby") if cluster: text += " CLUSTER BY ({})".format( @@ -916,6 +923,21 @@ def post_create_table(self, table): ) return text + def post_create_table(self, table): + text = self.handle_cluster_by(table) + options = [ + option + for _, option in table.dialect_options[DIALECT_NAME].items() + if isinstance(option, TableOptionBase) + ] + options.sort( + key=lambda x: (x.__priority__.value, x.__option_name__), reverse=True + ) + for option in options: + text += "\t" + option.render_option(self) + + return text + def visit_create_stage(self, create_stage, **kw): """ This visitor will create the SQL representation for a CREATE STAGE command. diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 04305a00..b0472eb6 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -42,6 +42,7 @@ from snowflake.connector.constants import UTF8 from snowflake.sqlalchemy.compat import returns_unicode +from ._constants import DIALECT_NAME from .base import ( SnowflakeCompiler, SnowflakeDDLCompiler, @@ -119,7 +120,7 @@ class SnowflakeDialect(default.DefaultDialect): - name = "snowflake" + name = DIALECT_NAME driver = "snowflake" max_identifier_length = 255 cte_follows_insert = True diff --git a/src/snowflake/sqlalchemy/sql/__init__.py b/src/snowflake/sqlalchemy/sql/__init__.py new file mode 100644 index 00000000..ef416f64 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py new file mode 100644 index 00000000..4bbac246 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py @@ -0,0 +1,6 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from .dynamic_table import DynamicTable + +__all__ = ["DynamicTable"] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py new file mode 100644 index 00000000..0c04f33f --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import typing +from typing import Any + +from sqlalchemy.exc import ArgumentError +from sqlalchemy.sql.schema import MetaData, SchemaItem, Table + +from ..._constants import DIALECT_NAME +from ...compat import IS_VERSION_20 +from ...custom_commands import NoneType +from .options.table_option import TableOption + + +class CustomTableBase(Table): + __table_prefix__ = "" + _support_primary_and_foreign_keys = True + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if self.__table_prefix__ != "": + prefixes = kw.get("prefixes", []) + [self.__table_prefix__] + kw.update(prefixes=prefixes) + if not IS_VERSION_20 and hasattr(super(), "_init"): + super()._init(name, metadata, *args, **kw) + else: + super().__init__(name, metadata, *args, **kw) + + if not kw.get("autoload_with", False): + self._validate_table() + + def _validate_table(self): + if not self._support_primary_and_foreign_keys and ( + self.primary_key or self.foreign_keys + ): + raise ArgumentError( + f"Primary key and foreign keys are not supported in {self.__table_prefix__} TABLE." + ) + + return True + + def _get_dialect_option(self, option_name: str) -> typing.Optional[TableOption]: + if option_name in self.dialect_options[DIALECT_NAME]: + return self.dialect_options[DIALECT_NAME][option_name] + return NoneType diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py new file mode 100644 index 00000000..7d0a02e6 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py @@ -0,0 +1,86 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import typing +from typing import Any + +from sqlalchemy.exc import ArgumentError +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .options.target_lag import TargetLag +from .options.warehouse import Warehouse +from .table_from_query import TableFromQueryBase + + +class DynamicTable(TableFromQueryBase): + """ + A class representing a dynamic table with configurable options and settings. + + The `DynamicTable` class allows for the creation and querying of tables with + specific options, such as `Warehouse` and `TargetLag`. + + While it does not support reflection at this time, it provides a flexible + interface for creating dynamic tables and management. + + """ + + __table_prefix__ = "DYNAMIC" + + _support_primary_and_foreign_keys = False + + @property + def warehouse(self) -> typing.Optional[Warehouse]: + return self._get_dialect_option(Warehouse.__option_name__) + + @property + def target_lag(self) -> typing.Optional[TargetLag]: + return self._get_dialect_option(TargetLag.__option_name__) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + super().__init__(name, metadata, *args, **kw) + + def _validate_table(self): + missing_attributes = [] + if self.target_lag is NoneType: + missing_attributes.append("TargetLag") + if self.warehouse is NoneType: + missing_attributes.append("Warehouse") + if self.as_query is NoneType: + missing_attributes.append("AsQuery") + if missing_attributes: + raise ArgumentError( + "DYNAMIC TABLE must have the following arguments: %s" + % ", ".join(missing_attributes) + ) + super()._validate_table() + + def __repr__(self) -> str: + return "DynamicTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [repr(self.target_lag)] + + [repr(self.warehouse)] + + [repr(self.as_query)] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py new file mode 100644 index 00000000..052e2d96 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py @@ -0,0 +1,9 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from .as_query import AsQuery +from .target_lag import TargetLag, TimeUnit +from .warehouse import Warehouse + +__all__ = ["Warehouse", "AsQuery", "TargetLag", "TimeUnit"] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py new file mode 100644 index 00000000..68076af9 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py @@ -0,0 +1,62 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Union + +from sqlalchemy.sql import Selectable + +from .table_option import TableOption +from .table_option_base import Priority + + +class AsQuery(TableOption): + """Class to represent an AS clause in tables. + This configuration option is used to specify the query from which the table is created. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-table#create-table-as-select-also-referred-to-as-ctas + + + AsQuery example usage using an input string: + DynamicTable( + "sometable", metadata, + Column("name", String(50)), + Column("address", String(100)), + AsQuery('select name, address from existing_table where name = "test"') + ) + + AsQuery example usage using a selectable statement: + DynamicTable( + "sometable", + Base.metadata, + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery(select(test_table_1).where(test_table_1.c.id == 23)) + ) + + """ + + __option_name__ = "as_query" + __priority__ = Priority.LOWEST + + def __init__(self, query: Union[str, Selectable]) -> None: + r"""Construct an as_query object. + + :param \*expressions: + AS + + """ + self.query = query + + @staticmethod + def template() -> str: + return "AS %s" + + def get_expression(self): + if isinstance(self.query, Selectable): + return self.query.compile(compile_kwargs={"literal_binds": True}) + return self.query + + def render_option(self, compiler) -> str: + return AsQuery.template() % (self.get_expression()) + + def __repr__(self) -> str: + return "Query(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py new file mode 100644 index 00000000..7ac27575 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py @@ -0,0 +1,26 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Any + +from sqlalchemy import exc +from sqlalchemy.sql.base import SchemaEventTarget +from sqlalchemy.sql.schema import SchemaItem, Table + +from snowflake.sqlalchemy._constants import DIALECT_NAME + +from .table_option_base import TableOptionBase + + +class TableOption(TableOptionBase, SchemaItem): + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + if self.__option_name__ == "default": + raise exc.SQLAlchemyError(f"{self.__class__.__name__} does not has a name") + if not isinstance(parent, Table): + raise exc.SQLAlchemyError( + f"{self.__class__.__name__} option can only be applied to Table" + ) + parent.dialect_options[DIALECT_NAME][self.__option_name__] = self + + def _set_table_option_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + pass diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py new file mode 100644 index 00000000..54008ec8 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py @@ -0,0 +1,30 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from enum import Enum + + +class Priority(Enum): + LOWEST = 0 + VERY_LOW = 1 + LOW = 2 + MEDIUM = 4 + HIGH = 6 + VERY_HIGH = 7 + HIGHEST = 8 + + +class TableOptionBase: + __option_name__ = "default" + __visit_name__ = __option_name__ + __priority__ = Priority.MEDIUM + + @staticmethod + def template() -> str: + raise NotImplementedError + + def get_expression(self): + raise NotImplementedError + + def render_option(self, compiler) -> str: + raise NotImplementedError diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py new file mode 100644 index 00000000..4331a4cb --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py @@ -0,0 +1,60 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# from enum import Enum +from enum import Enum +from typing import Optional + +from .table_option import TableOption +from .table_option_base import Priority + + +class TimeUnit(Enum): + SECONDS = "seconds" + MINUTES = "minutes" + HOURS = "hour" + DAYS = "days" + + +class TargetLag(TableOption): + """Class to represent the target lag clause. + This configuration option is used to specify the target lag time for the dynamic table. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + + Target lag example usage: + DynamicTable("sometable", metadata, + Column("name", String(50)), + Column("address", String(100)), + TargetLag(20, TimeUnit.MINUTES), + ) + """ + + __option_name__ = "target_lag" + __priority__ = Priority.HIGH + + def __init__( + self, + time: Optional[int] = 0, + unit: Optional[TimeUnit] = TimeUnit.MINUTES, + down_stream: Optional[bool] = False, + ) -> None: + self.time = time + self.unit = unit + self.down_stream = down_stream + + @staticmethod + def template() -> str: + return "TARGET_LAG = %s" + + def get_expression(self): + return ( + f"'{str(self.time)} {str(self.unit.value)}'" + if not self.down_stream + else "DOWNSTREAM" + ) + + def render_option(self, compiler) -> str: + return TargetLag.template() % (self.get_expression()) + + def __repr__(self) -> str: + return "TargetLag(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py new file mode 100644 index 00000000..a5b8cce0 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional + +from .table_option import TableOption +from .table_option_base import Priority + + +class Warehouse(TableOption): + """Class to represent the warehouse clause. + This configuration option is used to specify the warehouse for the dynamic table. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + + Warehouse example usage: + DynamicTable("sometable", metadata, + Column("name", String(50)), + Column("address", String(100)), + Warehouse('my_warehouse_name') + ) + """ + + __option_name__ = "warehouse" + __priority__ = Priority.HIGH + + def __init__( + self, + name: Optional[str], + ) -> None: + r"""Construct a Warehouse object. + + :param \*expressions: + Dynamic table warehouse option. + WAREHOUSE = + + """ + self.name = name + + @staticmethod + def template() -> str: + return "WAREHOUSE = %s" + + def get_expression(self): + return self.name + + def render_option(self, compiler) -> str: + return Warehouse.template() % (self.get_expression()) + + def __repr__(self) -> str: + return "Warehouse(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py new file mode 100644 index 00000000..60e8995f --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py @@ -0,0 +1,60 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import typing +from typing import Any, Optional + +from sqlalchemy.sql import Selectable +from sqlalchemy.sql.schema import Column, MetaData, SchemaItem +from sqlalchemy.util import NoneType + +from .custom_table_base import CustomTableBase +from .options.as_query import AsQuery + + +class TableFromQueryBase(CustomTableBase): + + @property + def as_query(self): + return self._get_dialect_option(AsQuery.__option_name__) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + items = [item for item in args] + as_query: AsQuery = self.__get_as_query_from_items(items) + if ( + as_query is not NoneType + and isinstance(as_query.query, Selectable) + and not self.__has_defined_columns(items) + ): + columns = self.__create_columns_from_selectable(as_query.query) + args = items + columns + super().__init__(name, metadata, *args, **kw) + + def __get_as_query_from_items( + self, items: typing.List[SchemaItem] + ) -> Optional[AsQuery]: + for item in items: + if isinstance(item, AsQuery): + return item + return NoneType + + def __has_defined_columns(self, items: typing.List[SchemaItem]) -> bool: + for item in items: + if isinstance(item, Column): + return True + + def __create_columns_from_selectable( + self, selectable: Selectable + ) -> Optional[typing.List[Column]]: + if not isinstance(selectable, Selectable): + return + columns: typing.List[Column] = [] + for _, c in selectable.exported_columns.items(): + columns += [Column(c.name, c.type)] + return columns diff --git a/tests/__snapshots__/test_compile_dynamic_table.ambr b/tests/__snapshots__/test_compile_dynamic_table.ambr new file mode 100644 index 00000000..81c7f90f --- /dev/null +++ b/tests/__snapshots__/test_compile_dynamic_table.ambr @@ -0,0 +1,13 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm + "CREATE DYNAMIC TABLE test_dynamic_table_orm (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm_with_str_keys + "CREATE DYNAMIC TABLE test_dynamic_table_orm_2 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_selectable + "CREATE DYNAMIC TABLE dynamic_test_table_1 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT test_table_1.id, test_table_1.name FROM test_table_1 WHERE test_table_1.id = 23" +# --- diff --git a/tests/__snapshots__/test_reflect_dynamic_table.ambr b/tests/__snapshots__/test_reflect_dynamic_table.ambr new file mode 100644 index 00000000..d4cc22b5 --- /dev/null +++ b/tests/__snapshots__/test_reflect_dynamic_table.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- diff --git a/tests/test_compile_dynamic_table.py b/tests/test_compile_dynamic_table.py new file mode 100644 index 00000000..73ce54aa --- /dev/null +++ b/tests/test_compile_dynamic_table.py @@ -0,0 +1,177 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import ( + Column, + ForeignKeyConstraint, + Integer, + MetaData, + String, + Table, + exc, + select, +) +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import GEOMETRY, DynamicTable +from snowflake.sqlalchemy.sql.custom_schema.options.as_query import AsQuery +from snowflake.sqlalchemy.sql.custom_schema.options.target_lag import ( + TargetLag, + TimeUnit, +) +from snowflake.sqlalchemy.sql.custom_schema.options.warehouse import Warehouse + + +def test_compile_dynamic_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + test_geometry = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_without_required_args(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="DYNAMIC TABLE must have the following arguments: TargetLag, " + "Warehouse, AsQuery", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer, primary_key=True), + Column("geom", GEOMETRY), + ) + + +def test_compile_dynamic_table_with_primary_key(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="Primary key and foreign keys are not supported in DYNAMIC TABLE.", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer, primary_key=True), + Column("geom", GEOMETRY), + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ) + + +def test_compile_dynamic_table_with_foreign_key(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="Primary key and foreign keys are not supported in DYNAMIC TABLE.", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer), + Column("geom", GEOMETRY), + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ForeignKeyConstraint(["id"], ["table.id"]), + ) + + +def test_compile_dynamic_table_orm(sql_compiler, snapshot): + Base = declarative_base() + metadata = MetaData() + table_name = "test_dynamic_table_orm" + test_dynamic_table_orm = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("name", String), + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ) + + class TestDynamicTableOrm(Base): + __table__ = test_dynamic_table_orm + __mapper_args__ = { + "primary_key": [test_dynamic_table_orm.c.id, test_dynamic_table_orm.c.name] + } + + def __repr__(self): + return f"" + + value = CreateTable(TestDynamicTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_orm_with_str_keys(sql_compiler, snapshot): + Base = declarative_base() + + class TestDynamicTableOrm(Base): + __tablename__ = "test_dynamic_table_orm_2" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return DynamicTable(name, metadata, *arg, **kw) + + __table_args__ = ( + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ) + + id = Column(Integer) + name = Column(String) + + __mapper_args__ = {"primary_key": [id, name]} + + def __repr__(self): + return f"" + + value = CreateTable(TestDynamicTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_with_selectable(sql_compiler, snapshot): + Base = declarative_base() + + test_table_1 = Table( + "test_table_1", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + + dynamic_test_table = DynamicTable( + "dynamic_test_table_1", + Base.metadata, + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery(select(test_table_1).where(test_table_1.c.id == 23)), + ) + + value = CreateTable(dynamic_test_table) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/test_create_dynamic_table.py b/tests/test_create_dynamic_table.py new file mode 100644 index 00000000..4e6c48ca --- /dev/null +++ b/tests/test_create_dynamic_table.py @@ -0,0 +1,93 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from sqlalchemy import Column, Integer, MetaData, String, Table, select + +from snowflake.sqlalchemy import DynamicTable +from snowflake.sqlalchemy.sql.custom_schema.options.as_query import AsQuery +from snowflake.sqlalchemy.sql.custom_schema.options.target_lag import ( + TargetLag, + TimeUnit, +) +from snowflake.sqlalchemy.sql.custom_schema.options.warehouse import Warehouse + + +def test_create_dynamic_table(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + + dynamic_test_table_1 = DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + TargetLag(1, TimeUnit.HOURS), + Warehouse(warehouse), + AsQuery("SELECT id, name from test_table_1;"), + ) + + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table_1) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) + + +def test_create_dynamic_table_without_dynamictable_class( + engine_testaccount, db_parameters +): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + + dynamic_test_table_1 = Table( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + TargetLag(1, TimeUnit.HOURS), + Warehouse(warehouse), + AsQuery("SELECT id, name from test_table_1;"), + prefixes=["DYNAMIC"], + ) + + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table_1) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/test_reflect_dynamic_table.py b/tests/test_reflect_dynamic_table.py new file mode 100644 index 00000000..8a4a8445 --- /dev/null +++ b/tests/test_reflect_dynamic_table.py @@ -0,0 +1,88 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from sqlalchemy import Column, Integer, MetaData, String, Table, select + +from snowflake.sqlalchemy import DynamicTable +from snowflake.sqlalchemy.custom_commands import NoneType + + +def test_simple_reflection_dynamic_table_as_table(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + create_table_sql = f""" + CREATE DYNAMIC TABLE dynamic_test_table (id INT, name VARCHAR) + TARGET_LAG = '20 minutes' + WAREHOUSE = {warehouse} + AS SELECT id, name from test_table_1; + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + dynamic_test_table = Table( + "dynamic_test_table", metadata, autoload_with=engine_testaccount + ) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) + + +def test_simple_reflection_without_options_loading(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + create_table_sql = f""" + CREATE DYNAMIC TABLE dynamic_test_table (id INT, name VARCHAR) + TARGET_LAG = '20 minutes' + WAREHOUSE = {warehouse} + AS SELECT id, name from test_table_1; + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + dynamic_test_table = DynamicTable( + "dynamic_test_table", metadata, autoload_with=engine_testaccount + ) + + # TODO: Add support for loading options when table is reflected + assert dynamic_test_table.warehouse is NoneType + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) From 43c6b563e462884faf7b7063bbf7fe10de7a5f60 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Tue, 8 Oct 2024 07:34:47 -0600 Subject: [PATCH 39/74] Add support for hybrid tables and indexes (#533) * Add support for hybrid tables * Update DESCRIPTION.md and add support for indexes --- .github/workflows/build_test.yml | 6 + DESCRIPTION.md | 2 +- pyproject.toml | 3 +- src/snowflake/sqlalchemy/__init__.py | 3 +- src/snowflake/sqlalchemy/snowdialect.py | 138 +++++++++++++++- .../sqlalchemy/sql/custom_schema/__init__.py | 3 +- .../sql/custom_schema/custom_table_base.py | 23 ++- .../sql/custom_schema/custom_table_prefix.py | 13 ++ .../sql/custom_schema/dynamic_table.py | 3 +- .../sql/custom_schema/hybrid_table.py | 67 ++++++++ tests/__snapshots__/test_orm.ambr | 4 + tests/custom_tables/__init__.py | 2 + .../test_compile_dynamic_table.ambr | 13 ++ .../test_compile_hybrid_table.ambr | 7 + .../test_create_hybrid_table.ambr | 7 + .../test_reflect_hybrid_table.ambr | 4 + .../test_compile_dynamic_table.py | 4 +- .../test_compile_hybrid_table.py | 52 ++++++ .../test_create_dynamic_table.py | 0 .../custom_tables/test_create_hybrid_table.py | 95 +++++++++++ .../test_reflect_dynamic_table.py | 0 .../test_reflect_hybrid_table.py | 65 ++++++++ tests/test_core.py | 7 +- tests/test_index_reflection.py | 34 ++++ tests/test_orm.py | 155 +++++++++++++++++- tests/test_pandas.py | 2 +- 26 files changed, 679 insertions(+), 33 deletions(-) create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py create mode 100644 tests/__snapshots__/test_orm.ambr create mode 100644 tests/custom_tables/__init__.py create mode 100644 tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr create mode 100644 tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr create mode 100644 tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr create mode 100644 tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr rename tests/{ => custom_tables}/test_compile_dynamic_table.py (96%) create mode 100644 tests/custom_tables/test_compile_hybrid_table.py rename tests/{ => custom_tables}/test_create_dynamic_table.py (100%) create mode 100644 tests/custom_tables/test_create_hybrid_table.py rename tests/{ => custom_tables}/test_reflect_dynamic_table.py (100%) create mode 100644 tests/custom_tables/test_reflect_hybrid_table.py create mode 100644 tests/test_index_reflection.py diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index d7f7832b..5e9823f2 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -111,6 +111,9 @@ jobs: run: | gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py + - name: Run test for AWS + run: hatch run test-dialect-aws + if: matrix.cloud-provider == 'aws' - name: Run tests run: hatch run test-dialect - uses: actions/upload-artifact@v4 @@ -203,6 +206,9 @@ jobs: python -m pip install -U uv python -m uv pip install -U hatch python -m hatch env create default + - name: Run test for AWS + run: hatch run sa14:test-dialect-aws + if: matrix.cloud-provider == 'aws' - name: Run tests run: hatch run sa14:test-dialect - uses: actions/upload-artifact@v4 diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 205685f1..58c2dfe2 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -12,7 +12,7 @@ Source code is also available at: - (Unreleased) - Add support for dynamic tables and required options - - Fixed SAWarning when registering functions with existing name in default namespace + - Add support for hybrid tables - Fixed SAWarning when registering functions with existing name in default namespace - v1.6.1(July 9, 2024) diff --git a/pyproject.toml b/pyproject.toml index 4fe06a9b..6c72f683 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,6 +92,7 @@ SQLACHEMY_WARN_20 = "1" check = "pre-commit run --all-files" test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite" +test-dialect-aws = "pytest -m \"aws\" -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" gh-cache-sum = "python -VV | sha256sum | cut -d' ' -f1" check-import = "python -c 'import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)'" @@ -110,7 +111,7 @@ line-length = 88 line-length = 88 [tool.pytest.ini_options] -addopts = "-m 'not feature_max_lob_size'" +addopts = "-m 'not feature_max_lob_size and not aws'" markers = [ # Optional dependency groups markers "lambda: AWS lambda tests", diff --git a/src/snowflake/sqlalchemy/__init__.py b/src/snowflake/sqlalchemy/__init__.py index 30cd140c..0afd44a5 100644 --- a/src/snowflake/sqlalchemy/__init__.py +++ b/src/snowflake/sqlalchemy/__init__.py @@ -61,7 +61,7 @@ VARBINARY, VARIANT, ) -from .sql.custom_schema import DynamicTable +from .sql.custom_schema import DynamicTable, HybridTable from .sql.custom_schema.options import AsQuery, TargetLag, TimeUnit, Warehouse from .util import _url as URL @@ -120,4 +120,5 @@ "TargetLag", "TimeUnit", "Warehouse", + "HybridTable", ) diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index b0472eb6..f2fb9b18 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -65,6 +65,7 @@ _CUSTOM_Float, _CUSTOM_Time, ) +from .sql.custom_schema.custom_table_prefix import CustomTablePrefix from .util import ( _update_connection_application_name, parse_url_boolean, @@ -352,14 +353,6 @@ def _map_name_to_idx(result): name_to_idx[col[0]] = idx return name_to_idx - @reflection.cache - def get_indexes(self, connection, table_name, schema=None, **kw): - """ - Gets all indexes - """ - # no index is supported by Snowflake - return [] - @reflection.cache def get_check_constraints(self, connection, table_name, schema, **kw): # check constraints are not supported by Snowflake @@ -895,6 +888,129 @@ def get_table_comment(self, connection, table_name, schema=None, **kw): ) } + def get_multi_indexes( + self, + connection, + *, + schema, + filter_names, + **kw, + ): + """ + Gets the indexes definition + """ + + table_prefixes = self.get_multi_prefixes( + connection, schema, filter_prefix=CustomTablePrefix.HYBRID.name + ) + if len(table_prefixes) == 0: + return [] + schema = schema or self.default_schema_name + if not schema: + result = connection.execute( + text("SHOW /* sqlalchemy:get_multi_indexes */ INDEXES") + ) + else: + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_indexes */ INDEXES IN SCHEMA {self._denormalize_quote_join(schema)}" + ) + ) + + n2i = self.__class__._map_name_to_idx(result) + indexes = {} + + for row in result.cursor.fetchall(): + table = self.normalize_name(str(row[n2i["table"]])) + if ( + row[n2i["name"]] == f'SYS_INDEX_{row[n2i["table"]]}_PRIMARY' + or table not in filter_names + or (schema, table) not in table_prefixes + or ( + (schema, table) in table_prefixes + and CustomTablePrefix.HYBRID.name + not in table_prefixes[(schema, table)] + ) + ): + continue + index = { + "name": row[n2i["name"]], + "unique": row[n2i["is_unique"]] == "Y", + "column_names": row[n2i["columns"]], + "include_columns": row[n2i["included_columns"]], + "dialect_options": {}, + } + if (schema, table) in indexes: + indexes[(schema, table)] = indexes[(schema, table)].append(index) + else: + indexes[(schema, table)] = [index] + + return list(indexes.items()) + + def _value_or_default(self, data, table, schema): + table = self.normalize_name(str(table)) + dic_data = dict(data) + if (schema, table) in dic_data: + return dic_data[(schema, table)] + else: + return [] + + def get_prefixes_from_data(self, n2i, row, **kw): + prefixes_found = [] + for valid_prefix in CustomTablePrefix: + key = f"is_{valid_prefix.name.lower()}" + if key in n2i and row[n2i[key]] == "Y": + prefixes_found.append(valid_prefix.name) + return prefixes_found + + @reflection.cache + def get_multi_prefixes( + self, connection, schema, table_name=None, filter_prefix=None, **kw + ): + """ + Gets all table prefixes + """ + schema = schema or self.default_schema_name + filter = f"LIKE '{table_name}'" if table_name else "" + if schema: + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES IN SCHEMA {schema}" + ) + ) + else: + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES LIKE '{table_name}'" + ) + ) + + n2i = self.__class__._map_name_to_idx(result) + tables_prefixes = {} + for row in result.cursor.fetchall(): + table = self.normalize_name(str(row[n2i["name"]])) + table_prefixes = self.get_prefixes_from_data(n2i, row) + if filter_prefix and filter_prefix not in table_prefixes: + continue + if (schema, table) in tables_prefixes: + tables_prefixes[(schema, table)].append(table_prefixes) + else: + tables_prefixes[(schema, table)] = table_prefixes + + return tables_prefixes + + @reflection.cache + def get_indexes(self, connection, tablename, schema, **kw): + """ + Gets the indexes definition + """ + table_name = self.normalize_name(str(tablename)) + data = self.get_multi_indexes( + connection=connection, schema=schema, filter_names=[table_name], **kw + ) + + return self._value_or_default(data, table_name, schema) + def connect(self, *cargs, **cparams): return ( super().connect( @@ -912,8 +1028,12 @@ def connect(self, *cargs, **cparams): @sa_vnt.listens_for(Table, "before_create") def check_table(table, connection, _ddl_runner, **kw): + from .sql.custom_schema.hybrid_table import HybridTable + + if HybridTable.is_equal_type(table): # noqa + return True if isinstance(_ddl_runner.dialect, SnowflakeDialect) and table.indexes: - raise NotImplementedError("Snowflake does not support indexes") + raise NotImplementedError("Only Snowflake Hybrid Tables supports indexes") dialect = SnowflakeDialect diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py index 4bbac246..66b9270f 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py @@ -2,5 +2,6 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # from .dynamic_table import DynamicTable +from .hybrid_table import HybridTable -__all__ = ["DynamicTable"] +__all__ = ["DynamicTable", "HybridTable"] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py index 0c04f33f..b61c270d 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py @@ -10,12 +10,17 @@ from ..._constants import DIALECT_NAME from ...compat import IS_VERSION_20 from ...custom_commands import NoneType +from .custom_table_prefix import CustomTablePrefix from .options.table_option import TableOption class CustomTableBase(Table): - __table_prefix__ = "" - _support_primary_and_foreign_keys = True + __table_prefixes__: typing.List[CustomTablePrefix] = [] + _support_primary_and_foreign_keys: bool = True + + @property + def table_prefixes(self) -> typing.List[str]: + return [prefix.name for prefix in self.__table_prefixes__] def __init__( self, @@ -24,8 +29,8 @@ def __init__( *args: SchemaItem, **kw: Any, ) -> None: - if self.__table_prefix__ != "": - prefixes = kw.get("prefixes", []) + [self.__table_prefix__] + if len(self.__table_prefixes__) > 0: + prefixes = kw.get("prefixes", []) + self.table_prefixes kw.update(prefixes=prefixes) if not IS_VERSION_20 and hasattr(super(), "_init"): super()._init(name, metadata, *args, **kw) @@ -40,7 +45,7 @@ def _validate_table(self): self.primary_key or self.foreign_keys ): raise ArgumentError( - f"Primary key and foreign keys are not supported in {self.__table_prefix__} TABLE." + f"Primary key and foreign keys are not supported in {' '.join(self.table_prefixes)} TABLE." ) return True @@ -49,3 +54,11 @@ def _get_dialect_option(self, option_name: str) -> typing.Optional[TableOption]: if option_name in self.dialect_options[DIALECT_NAME]: return self.dialect_options[DIALECT_NAME][option_name] return NoneType + + @classmethod + def is_equal_type(cls, table: Table) -> bool: + for prefix in cls.__table_prefixes__: + if prefix.name not in table._prefixes: + return False + + return True diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py new file mode 100644 index 00000000..de7835d1 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from enum import Enum + + +class CustomTablePrefix(Enum): + DEFAULT = 0 + EXTERNAL = 1 + EVENT = 2 + HYBRID = 3 + ICEBERG = 4 + DYNAMIC = 5 diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py index 7d0a02e6..1a2248fc 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py @@ -10,6 +10,7 @@ from snowflake.sqlalchemy.custom_commands import NoneType +from .custom_table_prefix import CustomTablePrefix from .options.target_lag import TargetLag from .options.warehouse import Warehouse from .table_from_query import TableFromQueryBase @@ -27,7 +28,7 @@ class DynamicTable(TableFromQueryBase): """ - __table_prefix__ = "DYNAMIC" + __table_prefixes__ = [CustomTablePrefix.DYNAMIC] _support_primary_and_foreign_keys = False diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py new file mode 100644 index 00000000..bd49a420 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from typing import Any + +from sqlalchemy.exc import ArgumentError +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .custom_table_base import CustomTableBase +from .custom_table_prefix import CustomTablePrefix + + +class HybridTable(CustomTableBase): + """ + A class representing a hybrid table with configurable options and settings. + + The `HybridTable` class allows for the creation and querying of OLTP Snowflake Tables . + + While it does not support reflection at this time, it provides a flexible + interface for creating dynamic tables and management. + """ + + __table_prefixes__ = [CustomTablePrefix.HYBRID] + + _support_primary_and_foreign_keys = True + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + super().__init__(name, metadata, *args, **kw) + + def _validate_table(self): + missing_attributes = [] + if self.key is NoneType: + missing_attributes.append("Primary Key") + if missing_attributes: + raise ArgumentError( + "HYBRID TABLE must have the following arguments: %s" + % ", ".join(missing_attributes) + ) + super()._validate_table() + + def __repr__(self) -> str: + return "HybridTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/tests/__snapshots__/test_orm.ambr b/tests/__snapshots__/test_orm.ambr new file mode 100644 index 00000000..2116e9e9 --- /dev/null +++ b/tests/__snapshots__/test_orm.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_orm_one_to_many_relationship_with_hybrid_table + ProgrammingError('(snowflake.connector.errors.ProgrammingError) 200009 (22000): Foreign key constraint "SYS_INDEX_HB_TBL_ADDRESS_FOREIGN_KEY_USER_ID_HB_TBL_USER_ID" was violated.') +# --- diff --git a/tests/custom_tables/__init__.py b/tests/custom_tables/__init__.py new file mode 100644 index 00000000..d43f066c --- /dev/null +++ b/tests/custom_tables/__init__.py @@ -0,0 +1,2 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. diff --git a/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr b/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr new file mode 100644 index 00000000..81c7f90f --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr @@ -0,0 +1,13 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm + "CREATE DYNAMIC TABLE test_dynamic_table_orm (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm_with_str_keys + "CREATE DYNAMIC TABLE test_dynamic_table_orm_2 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_selectable + "CREATE DYNAMIC TABLE dynamic_test_table_1 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT test_table_1.id, test_table_1.name FROM test_table_1 WHERE test_table_1.id = 23" +# --- diff --git a/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr b/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr new file mode 100644 index 00000000..9412fb45 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_compile_hybrid_table + 'CREATE HYBRID TABLE test_hybrid_table (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tgeom GEOMETRY, \tPRIMARY KEY (id))' +# --- +# name: test_compile_hybrid_table_orm + 'CREATE HYBRID TABLE test_hybrid_table_orm (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY (id))' +# --- diff --git a/tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr b/tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr new file mode 100644 index 00000000..696ff9c8 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_create_hybrid_table + "[(1, 'test')]" +# --- +# name: test_create_hybrid_table_with_multiple_index + ProgrammingError("(snowflake.connector.errors.ProgrammingError) 391480 (0A000): Another index is being built on table 'TEST_HYBRID_TABLE_WITH_MULTIPLE_INDEX'. Only one index can be built at a time. Either cancel the other index creation or wait until it is complete.") +# --- diff --git a/tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr b/tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr new file mode 100644 index 00000000..6f6cd395 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_simple_reflection_hybrid_table_as_table + 'CREATE TABLE test_hybrid_table_reflection (\tid DECIMAL(38, 0) NOT NULL, \tname VARCHAR(16777216), \tCONSTRAINT demo_name PRIMARY KEY (id))' +# --- diff --git a/tests/test_compile_dynamic_table.py b/tests/custom_tables/test_compile_dynamic_table.py similarity index 96% rename from tests/test_compile_dynamic_table.py rename to tests/custom_tables/test_compile_dynamic_table.py index 73ce54aa..16a039e7 100644 --- a/tests/test_compile_dynamic_table.py +++ b/tests/custom_tables/test_compile_dynamic_table.py @@ -121,11 +121,13 @@ def __repr__(self): assert actual == snapshot -def test_compile_dynamic_table_orm_with_str_keys(sql_compiler, snapshot): +def test_compile_dynamic_table_orm_with_str_keys(sql_compiler, db_parameters, snapshot): Base = declarative_base() + schema = db_parameters["schema"] class TestDynamicTableOrm(Base): __tablename__ = "test_dynamic_table_orm_2" + __table_args__ = {"schema": schema} @classmethod def __table_cls__(cls, name, metadata, *arg, **kw): diff --git a/tests/custom_tables/test_compile_hybrid_table.py b/tests/custom_tables/test_compile_hybrid_table.py new file mode 100644 index 00000000..f1af6dc2 --- /dev/null +++ b/tests/custom_tables/test_compile_hybrid_table.py @@ -0,0 +1,52 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Column, Integer, MetaData, String +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import GEOMETRY, HybridTable + + +@pytest.mark.aws +def test_compile_hybrid_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_hybrid_table" + test_geometry = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + Column("geom", GEOMETRY), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +@pytest.mark.aws +def test_compile_hybrid_table_orm(sql_compiler, snapshot): + Base = declarative_base() + + class TestHybridTableOrm(Base): + __tablename__ = "test_hybrid_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"" + + value = CreateTable(TestHybridTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/test_create_dynamic_table.py b/tests/custom_tables/test_create_dynamic_table.py similarity index 100% rename from tests/test_create_dynamic_table.py rename to tests/custom_tables/test_create_dynamic_table.py diff --git a/tests/custom_tables/test_create_hybrid_table.py b/tests/custom_tables/test_create_hybrid_table.py new file mode 100644 index 00000000..43ae3ab6 --- /dev/null +++ b/tests/custom_tables/test_create_hybrid_table.py @@ -0,0 +1,95 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +import sqlalchemy.exc +from sqlalchemy import Column, Index, Integer, MetaData, String, select +from sqlalchemy.orm import Session, declarative_base + +from snowflake.sqlalchemy import HybridTable + + +@pytest.mark.aws +def test_create_hybrid_table(engine_testaccount, db_parameters, snapshot): + metadata = MetaData() + table_name = "test_create_hybrid_table" + + dynamic_test_table_1 = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = dynamic_test_table_1.insert().values(id=1, name="test") + conn.execute(ins) + conn.commit() + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table_1) + results_hybrid_table = conn.execute(s).fetchall() + assert str(results_hybrid_table) == snapshot + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_create_hybrid_table_with_multiple_index( + engine_testaccount, db_parameters, snapshot, sql_compiler +): + metadata = MetaData() + table_name = "test_hybrid_table_with_multiple_index" + + hybrid_test_table_1 = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String, index=True), + Column("name2", String), + Column("name3", String), + ) + + metadata.create_all(engine_testaccount) + + index = Index("idx_col34", hybrid_test_table_1.c.name2, hybrid_test_table_1.c.name3) + + with pytest.raises(sqlalchemy.exc.ProgrammingError) as exc_info: + index.create(engine_testaccount) + try: + assert exc_info.value == snapshot + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_create_hybrid_table_with_orm(sql_compiler, engine_testaccount): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestHybridTableOrm(Base): + __tablename__ = "test_hybrid_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + instance = TestHybridTableOrm(id=0, name="name_example") + session.add(instance) + session.commit() + data = session.query(TestHybridTableOrm).all() + assert str(data) == "[(0, 'name_example')]" + finally: + Base.metadata.drop_all(engine_testaccount) diff --git a/tests/test_reflect_dynamic_table.py b/tests/custom_tables/test_reflect_dynamic_table.py similarity index 100% rename from tests/test_reflect_dynamic_table.py rename to tests/custom_tables/test_reflect_dynamic_table.py diff --git a/tests/custom_tables/test_reflect_hybrid_table.py b/tests/custom_tables/test_reflect_hybrid_table.py new file mode 100644 index 00000000..4a777bf0 --- /dev/null +++ b/tests/custom_tables/test_reflect_hybrid_table.py @@ -0,0 +1,65 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import MetaData, Table +from sqlalchemy.sql.ddl import CreateTable + + +@pytest.mark.aws +def test_simple_reflection_hybrid_table_as_table( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_hybrid_table_reflection" + + create_table_sql = f""" + CREATE HYBRID TABLE {table_name} (id INT primary key, name VARCHAR, INDEX index_name (name)); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + hybrid_test_table = Table(table_name, metadata, autoload_with=engine_testaccount) + + constraint = hybrid_test_table.constraints.pop() + constraint.name = "demo_name" + hybrid_test_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(hybrid_test_table) + + actual = sql_compiler(value) + + # Prefixes reflection not supported, example: "HYBRID, DYNAMIC" + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_reflect_hybrid_table_with_index( + engine_testaccount, db_parameters, sql_compiler +): + metadata = MetaData() + schema = db_parameters["schema"] + + table_name = "test_hybrid_table_2" + index_name = "INDEX_NAME_2" + + create_table_sql = f""" + CREATE HYBRID TABLE {table_name} (id INT primary key, name VARCHAR, INDEX {index_name} (name)); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + table = Table(table_name, metadata, schema=schema, autoload_with=engine_testaccount) + + try: + assert len(table.indexes) == 1 and table.indexes.pop().name == index_name + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/test_core.py b/tests/test_core.py index 179133c8..15840838 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -502,19 +502,20 @@ def test_inspect_column(engine_testaccount): users.drop(engine_testaccount) -def test_get_indexes(engine_testaccount): +def test_get_indexes(engine_testaccount, db_parameters): """ Tests get indexes - NOTE: Snowflake doesn't support indexes + NOTE: Only Snowflake Hybrid Tables support indexes """ + schema = db_parameters["schema"] metadata = MetaData() users, addresses = _create_users_addresses_tables_without_sequence( engine_testaccount, metadata ) try: inspector = inspect(engine_testaccount) - assert inspector.get_indexes("users") == [] + assert inspector.get_indexes("users", schema) == [] finally: addresses.drop(engine_testaccount) diff --git a/tests/test_index_reflection.py b/tests/test_index_reflection.py new file mode 100644 index 00000000..09f5cfe7 --- /dev/null +++ b/tests/test_index_reflection.py @@ -0,0 +1,34 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import MetaData +from sqlalchemy.engine import reflection + + +@pytest.mark.aws +def test_indexes_reflection(engine_testaccount, db_parameters, sql_compiler): + metadata = MetaData() + + table_name = "test_hybrid_table_2" + index_name = "INDEX_NAME_2" + schema = db_parameters["schema"] + + create_table_sql = f""" + CREATE HYBRID TABLE {table_name} (id INT primary key, name VARCHAR, INDEX {index_name} (name)); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + insp = reflection.Inspector.from_engine(engine_testaccount) + + try: + with engine_testaccount.connect(): + # Prefixes reflection not supported, example: "HYBRID, DYNAMIC" + indexes = insp.get_indexes(table_name, schema) + assert len(indexes) == 1 + assert indexes[0].get("name") == index_name + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/test_orm.py b/tests/test_orm.py index f51c9a90..cb3a7768 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -14,12 +14,15 @@ Integer, Sequence, String, + exc, func, select, text, ) from sqlalchemy.orm import Session, declarative_base, relationship +from snowflake.sqlalchemy import HybridTable + def test_basic_orm(engine_testaccount): """ @@ -56,14 +59,15 @@ def __repr__(self): Base.metadata.drop_all(engine_testaccount) -def test_orm_one_to_many_relationship(engine_testaccount): +def test_orm_one_to_many_relationship(engine_testaccount, db_parameters): """ Tests One to Many relationship """ Base = declarative_base() + prefix = "tbl_" class User(Base): - __tablename__ = "user" + __tablename__ = prefix + "user" id = Column(Integer, Sequence("user_id_seq"), primary_key=True) name = Column(String) @@ -73,13 +77,13 @@ def __repr__(self): return f"" class Address(Base): - __tablename__ = "address" + __tablename__ = prefix + "address" id = Column(Integer, Sequence("address_id_seq"), primary_key=True) email_address = Column(String, nullable=False) - user_id = Column(Integer, ForeignKey("user.id")) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) - user = relationship("User", backref="addresses") + user = relationship(User, backref="addresses") def __repr__(self): return f"" @@ -123,14 +127,79 @@ def __repr__(self): Base.metadata.drop_all(engine_testaccount) +@pytest.mark.aws +def test_orm_one_to_many_relationship_with_hybrid_table(engine_testaccount, snapshot): + """ + Tests One to Many relationship + """ + Base = declarative_base() + + class User(Base): + __tablename__ = "hb_tbl_user" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + name = Column(String) + fullname = Column(String) + + def __repr__(self): + return f"" + + class Address(Base): + __tablename__ = "hb_tbl_address" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, Sequence("address_id_seq"), primary_key=True) + email_address = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) + + user = relationship(User, backref="addresses") + + def __repr__(self): + return f"" + + Base.metadata.create_all(engine_testaccount) + + try: + jack = User(name="jack", fullname="Jack Bean") + assert jack.addresses == [], "one to many record is empty list" + + jack.addresses = [ + Address(email_address="jack@gmail.com"), + Address(email_address="j25@yahoo.com"), + Address(email_address="jack@hotmail.com"), + ] + + session = Session(bind=engine_testaccount) + session.add(jack) # cascade each Address into the Session as well + session.commit() + + session.delete(jack) + + with pytest.raises(exc.ProgrammingError) as exc_info: + session.query(Address).all() + + assert exc_info.value == snapshot, "Iceberg Table enforce FK constraint" + + finally: + Base.metadata.drop_all(engine_testaccount) + + def test_delete_cascade(engine_testaccount): """ Test delete cascade """ Base = declarative_base() + prefix = "tbl_" class User(Base): - __tablename__ = "user" + __tablename__ = prefix + "user" id = Column(Integer, Sequence("user_id_seq"), primary_key=True) name = Column(String) @@ -144,13 +213,81 @@ def __repr__(self): return f"" class Address(Base): - __tablename__ = "address" + __tablename__ = prefix + "address" + + id = Column(Integer, Sequence("address_id_seq"), primary_key=True) + email_address = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) + + user = relationship(User, back_populates="addresses") + + def __repr__(self): + return f"" + + Base.metadata.create_all(engine_testaccount) + + try: + jack = User(name="jack", fullname="Jack Bean") + assert jack.addresses == [], "one to many record is empty list" + + jack.addresses = [ + Address(email_address="jack@gmail.com"), + Address(email_address="j25@yahoo.com"), + Address(email_address="jack@hotmail.com"), + ] + + session = Session(bind=engine_testaccount) + session.add(jack) # cascade each Address into the Session as well + session.commit() + + got_jack = session.query(User).first() + assert got_jack == jack + + session.delete(jack) + got_addresses = session.query(Address).all() + assert len(got_addresses) == 0, "no address record" + finally: + Base.metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_delete_cascade_hybrid_table(engine_testaccount): + """ + Test delete cascade + """ + Base = declarative_base() + prefix = "hb_tbl_" + + class User(Base): + __tablename__ = prefix + "user" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + name = Column(String) + fullname = Column(String) + + addresses = relationship( + "Address", back_populates="user", cascade="all, delete, delete-orphan" + ) + + def __repr__(self): + return f"" + + class Address(Base): + __tablename__ = prefix + "address" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) id = Column(Integer, Sequence("address_id_seq"), primary_key=True) email_address = Column(String, nullable=False) - user_id = Column(Integer, ForeignKey("user.id")) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) - user = relationship("User", back_populates="addresses") + user = relationship(User, back_populates="addresses") def __repr__(self): return f"" diff --git a/tests/test_pandas.py b/tests/test_pandas.py index 63cd6d0e..2a6b9f1b 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -169,7 +169,7 @@ def test_no_indexes(engine_testaccount, db_parameters): con=conn, if_exists="replace", ) - assert str(exc.value) == "Snowflake does not support indexes" + assert str(exc.value) == "Only Snowflake Hybrid Tables supports indexes" def test_timezone(db_parameters, engine_testaccount, engine_testaccount_with_numpy): From e78319725d4b96ea205ef1264b744c65eb37853d Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Tue, 22 Oct 2024 09:18:29 -0600 Subject: [PATCH 40/74] Add generic options in order to support Iceberg Table in following PR (#537) * Add generic options and remove schema options --- .pre-commit-config.yaml | 1 + DESCRIPTION.md | 2 + src/snowflake/sqlalchemy/__init__.py | 28 +++- src/snowflake/sqlalchemy/base.py | 36 ++-- src/snowflake/sqlalchemy/exc.py | 74 +++++++++ .../sql/custom_schema/custom_table_base.py | 67 ++++++-- .../sql/custom_schema/dynamic_table.py | 84 +++++++--- .../sql/custom_schema/hybrid_table.py | 29 ++-- .../sql/custom_schema/options/__init__.py | 29 +++- .../sql/custom_schema/options/as_query.py | 62 ------- .../custom_schema/options/as_query_option.py | 63 +++++++ .../options/identifier_option.py | 63 +++++++ .../options/invalid_table_option.py | 25 +++ .../custom_schema/options/keyword_option.py | 65 ++++++++ .../sql/custom_schema/options/keywords.py | 14 ++ .../custom_schema/options/literal_option.py | 67 ++++++++ .../sql/custom_schema/options/table_option.py | 91 +++++++++-- .../options/table_option_base.py | 30 ---- .../sql/custom_schema/options/target_lag.py | 60 ------- .../options/target_lag_option.py | 94 +++++++++++ .../sql/custom_schema/options/warehouse.py | 51 ------ .../sql/custom_schema/table_from_query.py | 22 +-- .../test_compile_dynamic_table.ambr | 29 +++- .../test_create_dynamic_table.ambr | 7 + .../__snapshots__/test_generic_options.ambr | 13 ++ .../test_compile_dynamic_table.py | 154 ++++++++++++++---- .../test_create_dynamic_table.py | 75 ++++++--- tests/custom_tables/test_generic_options.py | 83 ++++++++++ .../test_reflect_dynamic_table.py | 2 +- tests/test_core.py | 1 + 30 files changed, 1058 insertions(+), 363 deletions(-) create mode 100644 src/snowflake/sqlalchemy/exc.py delete mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py delete mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py delete mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py delete mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py create mode 100644 tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr create mode 100644 tests/custom_tables/__snapshots__/test_generic_options.ambr create mode 100644 tests/custom_tables/test_generic_options.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 83172eb8..b7370b74 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,7 @@ repos: rev: v4.5.0 hooks: - id: trailing-whitespace + exclude: '\.ambr$' - id: end-of-file-fixer - id: check-yaml exclude: .github/repo_meta.yaml diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 58c2dfe2..909d52cf 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -14,6 +14,8 @@ Source code is also available at: - Add support for dynamic tables and required options - Add support for hybrid tables - Fixed SAWarning when registering functions with existing name in default namespace + - Update options to be defined in key arguments instead of arguments. + - Add support for refresh_mode option in DynamicTable - v1.6.1(July 9, 2024) diff --git a/src/snowflake/sqlalchemy/__init__.py b/src/snowflake/sqlalchemy/__init__.py index 0afd44a5..e53f9b74 100644 --- a/src/snowflake/sqlalchemy/__init__.py +++ b/src/snowflake/sqlalchemy/__init__.py @@ -62,7 +62,16 @@ VARIANT, ) from .sql.custom_schema import DynamicTable, HybridTable -from .sql.custom_schema.options import AsQuery, TargetLag, TimeUnit, Warehouse +from .sql.custom_schema.options import ( + AsQueryOption, + IdentifierOption, + KeywordOption, + LiteralOption, + SnowflakeKeyword, + TableOptionKey, + TargetLagOption, + TimeUnit, +) from .util import _url as URL base.dialect = dialect = snowdialect.dialect @@ -70,6 +79,7 @@ __version__ = importlib_metadata.version("snowflake-sqlalchemy") __all__ = ( + # Custom Types "BIGINT", "BINARY", "BOOLEAN", @@ -104,6 +114,7 @@ "TINYINT", "VARBINARY", "VARIANT", + # Custom Commands "MergeInto", "CSVFormatter", "JSONFormatter", @@ -115,10 +126,17 @@ "ExternalStage", "CreateStage", "CreateFileFormat", + # Custom Tables + "HybridTable", "DynamicTable", - "AsQuery", - "TargetLag", + # Custom Table Options + "AsQueryOption", + "TargetLagOption", + "LiteralOption", + "IdentifierOption", + "KeywordOption", + # Enums "TimeUnit", - "Warehouse", - "HybridTable", + "TableOptionKey", + "SnowflakeKeyword", ) diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 56631728..023f7afb 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -5,6 +5,7 @@ import itertools import operator import re +from typing import List from sqlalchemy import exc as sa_exc from sqlalchemy import inspect, sql @@ -26,8 +27,13 @@ ExternalStage, ) +from .exc import ( + CustomOptionsAreOnlySupportedOnSnowflakeTables, + UnexpectedOptionTypeError, +) from .functions import flatten -from .sql.custom_schema.options.table_option_base import TableOptionBase +from .sql.custom_schema.custom_table_base import CustomTableBase +from .sql.custom_schema.options.table_option import TableOption from .util import ( _find_left_clause_to_join_from, _set_connection_interpolate_empty_sequences, @@ -925,16 +931,24 @@ def handle_cluster_by(self, table): def post_create_table(self, table): text = self.handle_cluster_by(table) - options = [ - option - for _, option in table.dialect_options[DIALECT_NAME].items() - if isinstance(option, TableOptionBase) - ] - options.sort( - key=lambda x: (x.__priority__.value, x.__option_name__), reverse=True - ) - for option in options: - text += "\t" + option.render_option(self) + options = [] + invalid_options: List[str] = [] + + for key, option in table.dialect_options[DIALECT_NAME].items(): + if isinstance(option, TableOption): + options.append(option) + elif key not in ["clusterby", "*"]: + invalid_options.append(key) + + if len(invalid_options) > 0: + raise UnexpectedOptionTypeError(sorted(invalid_options)) + + if isinstance(table, CustomTableBase): + options.sort(key=lambda x: (x.priority.value, x.option_name), reverse=True) + for option in options: + text += "\t" + option.render_option(self) + elif len(options) > 0: + raise CustomOptionsAreOnlySupportedOnSnowflakeTables() return text diff --git a/src/snowflake/sqlalchemy/exc.py b/src/snowflake/sqlalchemy/exc.py new file mode 100644 index 00000000..898de279 --- /dev/null +++ b/src/snowflake/sqlalchemy/exc.py @@ -0,0 +1,74 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from typing import List + +from sqlalchemy.exc import ArgumentError + + +class NoPrimaryKeyError(ArgumentError): + def __init__(self, target: str): + super().__init__(f"Table {target} required primary key.") + + +class UnsupportedPrimaryKeysAndForeignKeysError(ArgumentError): + def __init__(self, target: str): + super().__init__(f"Primary key and foreign keys are not supported in {target}.") + + +class RequiredParametersNotProvidedError(ArgumentError): + def __init__(self, target: str, parameters: List[str]): + super().__init__( + f"{target} requires the following parameters: %s." % ", ".join(parameters) + ) + + +class UnexpectedTableOptionKeyError(ArgumentError): + def __init__(self, expected: str, actual: str): + super().__init__(f"Expected table option {expected} but got {actual}.") + + +class OptionKeyNotProvidedError(ArgumentError): + def __init__(self, target: str): + super().__init__( + f"Expected option key in {target} option but got NoneType instead." + ) + + +class UnexpectedOptionParameterTypeError(ArgumentError): + def __init__(self, parameter_name: str, target: str, types: List[str]): + super().__init__( + f"Parameter {parameter_name} of {target} requires to be one" + f" of following types: {', '.join(types)}." + ) + + +class CustomOptionsAreOnlySupportedOnSnowflakeTables(ArgumentError): + def __init__(self): + super().__init__( + "Identifier, Literal, TargetLag and other custom options are only supported on Snowflake tables." + ) + + +class UnexpectedOptionTypeError(ArgumentError): + def __init__(self, options: List[str]): + super().__init__( + f"The following options are either unsupported or should be defined using a Snowflake table: {', '.join(options)}." + ) + + +class InvalidTableParameterTypeError(ArgumentError): + def __init__(self, name: str, input_type: str, expected_types: List[str]): + expected_types_str = "', '".join(expected_types) + super().__init__( + f"Invalid parameter type '{input_type}' provided for '{name}'. " + f"Expected one of the following types: '{expected_types_str}'.\n" + ) + + +class MultipleErrors(ArgumentError): + def __init__(self, errors): + self.errors = errors + + def __str__(self): + return "".join(str(e) for e in self.errors) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py index b61c270d..671c6957 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py @@ -2,21 +2,29 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import typing -from typing import Any +from typing import Any, List -from sqlalchemy.exc import ArgumentError from sqlalchemy.sql.schema import MetaData, SchemaItem, Table from ..._constants import DIALECT_NAME from ...compat import IS_VERSION_20 from ...custom_commands import NoneType +from ...exc import ( + MultipleErrors, + NoPrimaryKeyError, + RequiredParametersNotProvidedError, + UnsupportedPrimaryKeysAndForeignKeysError, +) from .custom_table_prefix import CustomTablePrefix -from .options.table_option import TableOption +from .options.invalid_table_option import InvalidTableOption +from .options.table_option import TableOption, TableOptionKey class CustomTableBase(Table): __table_prefixes__: typing.List[CustomTablePrefix] = [] _support_primary_and_foreign_keys: bool = True + _enforce_primary_keys: bool = False + _required_parameters: List[TableOptionKey] = [] @property def table_prefixes(self) -> typing.List[str]: @@ -32,7 +40,9 @@ def __init__( if len(self.__table_prefixes__) > 0: prefixes = kw.get("prefixes", []) + self.table_prefixes kw.update(prefixes=prefixes) + if not IS_VERSION_20 and hasattr(super(), "_init"): + kw.pop("_no_init", True) super()._init(name, metadata, *args, **kw) else: super().__init__(name, metadata, *args, **kw) @@ -41,19 +51,56 @@ def __init__( self._validate_table() def _validate_table(self): + exceptions: List[Exception] = [] + + for _, option in self.dialect_options[DIALECT_NAME].items(): + if isinstance(option, InvalidTableOption): + exceptions.append(option.exception) + + if isinstance(self.key, NoneType) and self._enforce_primary_keys: + exceptions.append(NoPrimaryKeyError(self.__class__.__name__)) + missing_parameters: List[str] = [] + + for required_parameter in self._required_parameters: + if isinstance(self._get_dialect_option(required_parameter), NoneType): + missing_parameters.append(required_parameter.name.lower()) + if missing_parameters: + exceptions.append( + RequiredParametersNotProvidedError( + self.__class__.__name__, missing_parameters + ) + ) + if not self._support_primary_and_foreign_keys and ( self.primary_key or self.foreign_keys ): - raise ArgumentError( - f"Primary key and foreign keys are not supported in {' '.join(self.table_prefixes)} TABLE." + exceptions.append( + UnsupportedPrimaryKeysAndForeignKeysError(self.__class__.__name__) ) - return True + if len(exceptions) > 1: + exceptions.sort(key=lambda e: str(e)) + raise MultipleErrors(exceptions) + elif len(exceptions) == 1: + raise exceptions[0] + + def _get_dialect_option( + self, option_name: TableOptionKey + ) -> typing.Optional[TableOption]: + if option_name.value in self.dialect_options[DIALECT_NAME]: + return self.dialect_options[DIALECT_NAME][option_name.value] + return None - def _get_dialect_option(self, option_name: str) -> typing.Optional[TableOption]: - if option_name in self.dialect_options[DIALECT_NAME]: - return self.dialect_options[DIALECT_NAME][option_name] - return NoneType + def _as_dialect_options( + self, table_options: List[TableOption] + ) -> typing.Dict[str, TableOption]: + result = {} + for table_option in table_options: + if isinstance(table_option, TableOption) and isinstance( + table_option.option_name, str + ): + result[DIALECT_NAME + "_" + table_option.option_name] = table_option + return result @classmethod def is_equal_type(cls, table: Table) -> bool: diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py index 1a2248fc..6db4312d 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py @@ -3,16 +3,21 @@ # import typing -from typing import Any +from typing import Any, Union -from sqlalchemy.exc import ArgumentError from sqlalchemy.sql.schema import MetaData, SchemaItem -from snowflake.sqlalchemy.custom_commands import NoneType - from .custom_table_prefix import CustomTablePrefix -from .options.target_lag import TargetLag -from .options.warehouse import Warehouse +from .options import ( + IdentifierOption, + IdentifierOptionType, + KeywordOptionType, + LiteralOption, + TableOptionKey, + TargetLagOption, + TargetLagOptionType, +) +from .options.keyword_option import KeywordOption from .table_from_query import TableFromQueryBase @@ -26,29 +31,69 @@ class DynamicTable(TableFromQueryBase): While it does not support reflection at this time, it provides a flexible interface for creating dynamic tables and management. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + Example using option values: + DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + target_lag=(1, TimeUnit.HOURS), + warehouse='warehouse_name', + refresh_mode=SnowflakeKeyword.AUTO + as_query="SELECT id, name from test_table_1;" + ) + + Example using full options: + DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + target_lag=TargetLag(1, TimeUnit.HOURS), + warehouse=Identifier('warehouse_name'), + refresh_mode=KeywordOption(SnowflakeKeyword.AUTO) + as_query=AsQuery("SELECT id, name from test_table_1;") + ) """ __table_prefixes__ = [CustomTablePrefix.DYNAMIC] - _support_primary_and_foreign_keys = False + _required_parameters = [ + TableOptionKey.WAREHOUSE, + TableOptionKey.AS_QUERY, + TableOptionKey.TARGET_LAG, + ] @property - def warehouse(self) -> typing.Optional[Warehouse]: - return self._get_dialect_option(Warehouse.__option_name__) + def warehouse(self) -> typing.Optional[LiteralOption]: + return self._get_dialect_option(TableOptionKey.WAREHOUSE) @property - def target_lag(self) -> typing.Optional[TargetLag]: - return self._get_dialect_option(TargetLag.__option_name__) + def target_lag(self) -> typing.Optional[TargetLagOption]: + return self._get_dialect_option(TableOptionKey.TARGET_LAG) def __init__( self, name: str, metadata: MetaData, *args: SchemaItem, + warehouse: IdentifierOptionType = None, + target_lag: Union[TargetLagOptionType, KeywordOptionType] = None, + refresh_mode: KeywordOptionType = None, **kw: Any, ) -> None: if kw.get("_no_init", True): return + + options = [ + IdentifierOption.create(TableOptionKey.WAREHOUSE, warehouse), + TargetLagOption.create(target_lag), + KeywordOption.create(TableOptionKey.REFRESH_MODE, refresh_mode), + ] + + kw.update(self._as_dialect_options(options)) super().__init__(name, metadata, *args, **kw) def _init( @@ -58,22 +103,7 @@ def _init( *args: SchemaItem, **kw: Any, ) -> None: - super().__init__(name, metadata, *args, **kw) - - def _validate_table(self): - missing_attributes = [] - if self.target_lag is NoneType: - missing_attributes.append("TargetLag") - if self.warehouse is NoneType: - missing_attributes.append("Warehouse") - if self.as_query is NoneType: - missing_attributes.append("AsQuery") - if missing_attributes: - raise ArgumentError( - "DYNAMIC TABLE must have the following arguments: %s" - % ", ".join(missing_attributes) - ) - super()._validate_table() + self.__init__(name, metadata, *args, _no_init=False, **kw) def __repr__(self) -> str: return "DynamicTable(%s)" % ", ".join( diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py index bd49a420..b7c29e78 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py @@ -4,11 +4,8 @@ from typing import Any -from sqlalchemy.exc import ArgumentError from sqlalchemy.sql.schema import MetaData, SchemaItem -from snowflake.sqlalchemy.custom_commands import NoneType - from .custom_table_base import CustomTableBase from .custom_table_prefix import CustomTablePrefix @@ -21,11 +18,20 @@ class HybridTable(CustomTableBase): While it does not support reflection at this time, it provides a flexible interface for creating dynamic tables and management. + + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-hybrid-table + + Example usage: + HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String) + ) """ __table_prefixes__ = [CustomTablePrefix.HYBRID] - - _support_primary_and_foreign_keys = True + _enforce_primary_keys: bool = True def __init__( self, @@ -45,18 +51,7 @@ def _init( *args: SchemaItem, **kw: Any, ) -> None: - super().__init__(name, metadata, *args, **kw) - - def _validate_table(self): - missing_attributes = [] - if self.key is NoneType: - missing_attributes.append("Primary Key") - if missing_attributes: - raise ArgumentError( - "HYBRID TABLE must have the following arguments: %s" - % ", ".join(missing_attributes) - ) - super()._validate_table() + self.__init__(name, metadata, *args, _no_init=False, **kw) def __repr__(self) -> str: return "HybridTable(%s)" % ", ".join( diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py index 052e2d96..11b54c1a 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py @@ -2,8 +2,29 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from .as_query import AsQuery -from .target_lag import TargetLag, TimeUnit -from .warehouse import Warehouse +from .as_query_option import AsQueryOption, AsQueryOptionType +from .identifier_option import IdentifierOption, IdentifierOptionType +from .keyword_option import KeywordOption, KeywordOptionType +from .keywords import SnowflakeKeyword +from .literal_option import LiteralOption, LiteralOptionType +from .table_option import TableOptionKey +from .target_lag_option import TargetLagOption, TargetLagOptionType, TimeUnit -__all__ = ["Warehouse", "AsQuery", "TargetLag", "TimeUnit"] +__all__ = [ + # Options + "IdentifierOption", + "LiteralOption", + "KeywordOption", + "AsQueryOption", + "TargetLagOption", + # Enums + "TimeUnit", + "SnowflakeKeyword", + "TableOptionKey", + # Types + "IdentifierOptionType", + "LiteralOptionType", + "AsQueryOptionType", + "TargetLagOptionType", + "KeywordOptionType", +] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py deleted file mode 100644 index 68076af9..00000000 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py +++ /dev/null @@ -1,62 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# -from typing import Union - -from sqlalchemy.sql import Selectable - -from .table_option import TableOption -from .table_option_base import Priority - - -class AsQuery(TableOption): - """Class to represent an AS clause in tables. - This configuration option is used to specify the query from which the table is created. - For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-table#create-table-as-select-also-referred-to-as-ctas - - - AsQuery example usage using an input string: - DynamicTable( - "sometable", metadata, - Column("name", String(50)), - Column("address", String(100)), - AsQuery('select name, address from existing_table where name = "test"') - ) - - AsQuery example usage using a selectable statement: - DynamicTable( - "sometable", - Base.metadata, - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery(select(test_table_1).where(test_table_1.c.id == 23)) - ) - - """ - - __option_name__ = "as_query" - __priority__ = Priority.LOWEST - - def __init__(self, query: Union[str, Selectable]) -> None: - r"""Construct an as_query object. - - :param \*expressions: - AS - - """ - self.query = query - - @staticmethod - def template() -> str: - return "AS %s" - - def get_expression(self): - if isinstance(self.query, Selectable): - return self.query.compile(compile_kwargs={"literal_binds": True}) - return self.query - - def render_option(self, compiler) -> str: - return AsQuery.template() % (self.get_expression()) - - def __repr__(self) -> str: - return "Query(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py new file mode 100644 index 00000000..93994abc --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py @@ -0,0 +1,63 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional, Union + +from sqlalchemy.sql import Selectable + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class AsQueryOption(TableOption): + """Class to represent an AS clause in tables. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-table#create-table-as-select-also-referred-to-as-ctas + + Example: + as_query=AsQueryOption('select name, address from existing_table where name = "test"') + + is equivalent to: + + as select name, address from existing_table where name = "test" + """ + + def __init__(self, query: Union[str, Selectable]) -> None: + super().__init__() + self._name: TableOptionKey = TableOptionKey.AS_QUERY + self.query = query + + @staticmethod + def create( + value: Optional[Union["AsQueryOption", str, Selectable]] + ) -> "TableOption": + if isinstance(value, (NoneType, AsQueryOption)): + return value + if isinstance(value, (str, Selectable)): + return AsQueryOption(value) + return TableOption._get_invalid_table_option( + TableOptionKey.AS_QUERY, + str(type(value).__name__), + [AsQueryOption.__name__, str.__name__, Selectable.__name__], + ) + + def template(self) -> str: + return "AS %s" + + @property + def priority(self) -> Priority: + return Priority.LOWEST + + def __get_expression(self): + if isinstance(self.query, Selectable): + return self.query.compile(compile_kwargs={"literal_binds": True}) + return self.query + + def _render(self, compiler) -> str: + return self.template() % (self.__get_expression()) + + def __repr__(self) -> str: + return "AsQueryOption(%s)" % self.__get_expression() + + +AsQueryOptionType = Union[AsQueryOption, str, Selectable] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py new file mode 100644 index 00000000..b296898b --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py @@ -0,0 +1,63 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class IdentifierOption(TableOption): + """Class to represent an identifier option in Snowflake Tables. + + Example: + warehouse = IdentifierOption('my_warehouse') + + is equivalent to: + + WAREHOUSE = my_warehouse + """ + + def __init__(self, value: Union[str]) -> None: + super().__init__() + self.value: str = value + + @property + def priority(self): + return Priority.HIGH + + @staticmethod + def create( + name: TableOptionKey, value: Optional[Union[str, "IdentifierOption"]] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return None + + if isinstance(value, str): + value = IdentifierOption(value) + + if isinstance(value, IdentifierOption): + value._set_option_name(name) + return value + + return TableOption._get_invalid_table_option( + name, str(type(value).__name__), [IdentifierOption.__name__, str.__name__] + ) + + def template(self) -> str: + return f"{self.option_name.upper()} = %s" + + def _render(self, compiler) -> str: + return self.template() % self.value + + def __repr__(self) -> str: + option_name = ( + f", table_option_key={self.option_name}" + if not isinstance(self.option_name, NoneType) + else "" + ) + return f"IdentifierOption(value='{self.value}'{option_name})" + + +IdentifierOptionType = Union[IdentifierOption, str, int] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py new file mode 100644 index 00000000..2bdc9dd3 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py @@ -0,0 +1,25 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional + +from .table_option import TableOption, TableOptionKey + + +class InvalidTableOption(TableOption): + """Class to store errors and raise them after table initialization in order to avoid recursion error.""" + + def __init__(self, name: TableOptionKey, value: Exception) -> None: + super().__init__() + self.exception: Exception = value + self._name = name + + @staticmethod + def create(name: TableOptionKey, value: Exception) -> Optional[TableOption]: + return InvalidTableOption(name, value) + + def _render(self, compiler) -> str: + raise self.exception + + def __repr__(self) -> str: + return f"ErrorOption(value='{self.exception}')" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py new file mode 100644 index 00000000..ff6b444d --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py @@ -0,0 +1,65 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .keywords import SnowflakeKeyword +from .table_option import Priority, TableOption, TableOptionKey + + +class KeywordOption(TableOption): + """Class to represent a keyword option in Snowflake Tables. + + Example: + target_lag = KeywordOption(SnowflakeKeyword.DOWNSTREAM) + + is equivalent to: + + TARGET_LAG = DOWNSTREAM + """ + + def __init__(self, value: Union[SnowflakeKeyword]) -> None: + super().__init__() + self.value: str = value.value + + @property + def priority(self): + return Priority.HIGH + + def template(self) -> str: + return f"{self.option_name.upper()} = %s" + + def _render(self, compiler) -> str: + return self.template() % self.value.upper() + + @staticmethod + def create( + name: TableOptionKey, value: Optional[Union[SnowflakeKeyword, "KeywordOption"]] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return value + if isinstance(value, SnowflakeKeyword): + value = KeywordOption(value) + + if isinstance(value, KeywordOption): + value._set_option_name(name) + return value + + return TableOption._get_invalid_table_option( + name, + str(type(value).__name__), + [KeywordOption.__name__, SnowflakeKeyword.__name__], + ) + + def __repr__(self) -> str: + option_name = ( + f", table_option_key={self.option_name}" + if isinstance(self.option_name, NoneType) + else "" + ) + return f"KeywordOption(value='{self.value}'{option_name})" + + +KeywordOptionType = Union[KeywordOption, SnowflakeKeyword] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py new file mode 100644 index 00000000..f4ba87ea --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py @@ -0,0 +1,14 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from enum import Enum + + +class SnowflakeKeyword(Enum): + # TARGET_LAG + DOWNSTREAM = "DOWNSTREAM" + + # REFRESH_MODE + AUTO = "AUTO" + FULL = "FULL" + INCREMENTAL = "INCREMENTAL" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py new file mode 100644 index 00000000..55dd7675 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Any, Optional, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class LiteralOption(TableOption): + """Class to represent a literal option in Snowflake Table. + + Example: + warehouse = LiteralOption('my_warehouse') + + is equivalent to: + + WAREHOUSE = 'my_warehouse' + """ + + def __init__(self, value: Union[int, str]) -> None: + super().__init__() + self.value: Any = value + + @property + def priority(self): + return Priority.HIGH + + @staticmethod + def create( + name: TableOptionKey, value: Optional[Union[str, int, "LiteralOption"]] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return None + if isinstance(value, (str, int)): + value = LiteralOption(value) + + if isinstance(value, LiteralOption): + value._set_option_name(name) + return value + + return TableOption._get_invalid_table_option( + name, + str(type(value).__name__), + [LiteralOption.__name__, str.__name__, int.__name__], + ) + + def template(self) -> str: + if isinstance(self.value, int): + return f"{self.option_name.upper()} = %d" + else: + return f"{self.option_name.upper()} = '%s'" + + def _render(self, compiler) -> str: + return self.template() % self.value + + def __repr__(self) -> str: + option_name = ( + f", table_option_key={self.option_name}" + if not isinstance(self.option_name, NoneType) + else "" + ) + return f"LiteralOption(value='{self.value}'{option_name})" + + +LiteralOptionType = Union[LiteralOption, str, int] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py index 7ac27575..14b91f2e 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py @@ -1,26 +1,83 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from typing import Any +from enum import Enum +from typing import List, Optional -from sqlalchemy import exc -from sqlalchemy.sql.base import SchemaEventTarget -from sqlalchemy.sql.schema import SchemaItem, Table +from snowflake.sqlalchemy import exc +from snowflake.sqlalchemy.custom_commands import NoneType -from snowflake.sqlalchemy._constants import DIALECT_NAME -from .table_option_base import TableOptionBase +class Priority(Enum): + LOWEST = 0 + VERY_LOW = 1 + LOW = 2 + MEDIUM = 4 + HIGH = 6 + VERY_HIGH = 7 + HIGHEST = 8 -class TableOption(TableOptionBase, SchemaItem): - def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: - if self.__option_name__ == "default": - raise exc.SQLAlchemyError(f"{self.__class__.__name__} does not has a name") - if not isinstance(parent, Table): - raise exc.SQLAlchemyError( - f"{self.__class__.__name__} option can only be applied to Table" - ) - parent.dialect_options[DIALECT_NAME][self.__option_name__] = self +class TableOption: - def _set_table_option_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: - pass + def __init__(self) -> None: + self._name: Optional[TableOptionKey] = None + + @property + def option_name(self) -> str: + if isinstance(self._name, NoneType): + return None + return str(self._name.value) + + def _set_option_name(self, name: Optional["TableOptionKey"]): + self._name = name + + @property + def priority(self) -> Priority: + return Priority.MEDIUM + + @staticmethod + def create(**kwargs) -> "TableOption": + raise NotImplementedError + + @staticmethod + def _get_invalid_table_option( + parameter_name: "TableOptionKey", input_type: str, expected_types: List[str] + ) -> "TableOption": + from .invalid_table_option import InvalidTableOption + + return InvalidTableOption( + parameter_name, + exc.InvalidTableParameterTypeError( + parameter_name.value, input_type, expected_types + ), + ) + + def _validate_option(self): + if isinstance(self.option_name, NoneType): + raise exc.OptionKeyNotProvidedError(self.__class__.__name__) + + def template(self) -> str: + return f"{self.option_name.upper()} = %s" + + def render_option(self, compiler) -> str: + self._validate_option() + return self._render(compiler) + + def _render(self, compiler) -> str: + raise NotImplementedError + + +class TableOptionKey(Enum): + AS_QUERY = "as_query" + BASE_LOCATION = "base_location" + CATALOG = "catalog" + CATALOG_SYNC = "catalog_sync" + DATA_RETENTION_TIME_IN_DAYS = "data_retention_time_in_days" + DEFAULT_DDL_COLLATION = "default_ddl_collation" + EXTERNAL_VOLUME = "external_volume" + MAX_DATA_EXTENSION_TIME_IN_DAYS = "max_data_extension_time_in_days" + REFRESH_MODE = "refresh_mode" + STORAGE_SERIALIZATION_POLICY = "storage_serialization_policy" + TARGET_LAG = "target_lag" + WAREHOUSE = "warehouse" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py deleted file mode 100644 index 54008ec8..00000000 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py +++ /dev/null @@ -1,30 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# -from enum import Enum - - -class Priority(Enum): - LOWEST = 0 - VERY_LOW = 1 - LOW = 2 - MEDIUM = 4 - HIGH = 6 - VERY_HIGH = 7 - HIGHEST = 8 - - -class TableOptionBase: - __option_name__ = "default" - __visit_name__ = __option_name__ - __priority__ = Priority.MEDIUM - - @staticmethod - def template() -> str: - raise NotImplementedError - - def get_expression(self): - raise NotImplementedError - - def render_option(self, compiler) -> str: - raise NotImplementedError diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py deleted file mode 100644 index 4331a4cb..00000000 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py +++ /dev/null @@ -1,60 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# from enum import Enum -from enum import Enum -from typing import Optional - -from .table_option import TableOption -from .table_option_base import Priority - - -class TimeUnit(Enum): - SECONDS = "seconds" - MINUTES = "minutes" - HOURS = "hour" - DAYS = "days" - - -class TargetLag(TableOption): - """Class to represent the target lag clause. - This configuration option is used to specify the target lag time for the dynamic table. - For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table - - - Target lag example usage: - DynamicTable("sometable", metadata, - Column("name", String(50)), - Column("address", String(100)), - TargetLag(20, TimeUnit.MINUTES), - ) - """ - - __option_name__ = "target_lag" - __priority__ = Priority.HIGH - - def __init__( - self, - time: Optional[int] = 0, - unit: Optional[TimeUnit] = TimeUnit.MINUTES, - down_stream: Optional[bool] = False, - ) -> None: - self.time = time - self.unit = unit - self.down_stream = down_stream - - @staticmethod - def template() -> str: - return "TARGET_LAG = %s" - - def get_expression(self): - return ( - f"'{str(self.time)} {str(self.unit.value)}'" - if not self.down_stream - else "DOWNSTREAM" - ) - - def render_option(self, compiler) -> str: - return TargetLag.template() % (self.get_expression()) - - def __repr__(self) -> str: - return "TargetLag(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py new file mode 100644 index 00000000..7c1c0825 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py @@ -0,0 +1,94 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# from enum import Enum +from enum import Enum +from typing import Optional, Tuple, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .keyword_option import KeywordOption, KeywordOptionType +from .keywords import SnowflakeKeyword +from .table_option import Priority, TableOption, TableOptionKey + + +class TimeUnit(Enum): + SECONDS = "seconds" + MINUTES = "minutes" + HOURS = "hours" + DAYS = "days" + + +class TargetLagOption(TableOption): + """Class to represent the target lag clause in Dynamic Tables. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + Example using the time and unit parameters: + + target_lag = TargetLagOption(10, TimeUnit.SECONDS) + + is equivalent to: + + TARGET_LAG = '10 SECONDS' + + Example using keyword parameter: + + target_lag = KeywordOption(SnowflakeKeyword.DOWNSTREAM) + + is equivalent to: + + TARGET_LAG = DOWNSTREAM + + """ + + def __init__( + self, + time: Optional[int] = 0, + unit: Optional[TimeUnit] = TimeUnit.MINUTES, + ) -> None: + super().__init__() + self.time = time + self.unit = unit + self._name: TableOptionKey = TableOptionKey.TARGET_LAG + + @staticmethod + def create( + value: Union["TargetLagOption", Tuple[int, TimeUnit], KeywordOptionType] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return value + + if isinstance(value, Tuple): + time, unit = value + value = TargetLagOption(time, unit) + + if isinstance(value, TargetLagOption): + return value + + if isinstance(value, (KeywordOption, SnowflakeKeyword)): + return KeywordOption.create(TableOptionKey.TARGET_LAG, value) + + return TableOption._get_invalid_table_option( + TableOptionKey.TARGET_LAG, + str(type(value).__name__), + [ + TargetLagOption.__name__, + f"Tuple[int, {TimeUnit.__name__}])", + SnowflakeKeyword.__name__, + ], + ) + + def __get_expression(self): + return f"'{str(self.time)} {str(self.unit.value)}'" + + @property + def priority(self) -> Priority: + return Priority.HIGH + + def _render(self, compiler) -> str: + return self.template() % (self.__get_expression()) + + def __repr__(self) -> str: + return "TargetLagOption(%s)" % self.__get_expression() + + +TargetLagOptionType = Union[TargetLagOption, Tuple[int, TimeUnit]] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py deleted file mode 100644 index a5b8cce0..00000000 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py +++ /dev/null @@ -1,51 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# -from typing import Optional - -from .table_option import TableOption -from .table_option_base import Priority - - -class Warehouse(TableOption): - """Class to represent the warehouse clause. - This configuration option is used to specify the warehouse for the dynamic table. - For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table - - - Warehouse example usage: - DynamicTable("sometable", metadata, - Column("name", String(50)), - Column("address", String(100)), - Warehouse('my_warehouse_name') - ) - """ - - __option_name__ = "warehouse" - __priority__ = Priority.HIGH - - def __init__( - self, - name: Optional[str], - ) -> None: - r"""Construct a Warehouse object. - - :param \*expressions: - Dynamic table warehouse option. - WAREHOUSE = - - """ - self.name = name - - @staticmethod - def template() -> str: - return "WAREHOUSE = %s" - - def get_expression(self): - return self.name - - def render_option(self, compiler) -> str: - return Warehouse.template() % (self.get_expression()) - - def __repr__(self) -> str: - return "Warehouse(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py index 60e8995f..fccc7a0b 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py @@ -6,29 +6,31 @@ from sqlalchemy.sql import Selectable from sqlalchemy.sql.schema import Column, MetaData, SchemaItem -from sqlalchemy.util import NoneType from .custom_table_base import CustomTableBase -from .options.as_query import AsQuery +from .options.as_query_option import AsQueryOption, AsQueryOptionType +from .options.table_option import TableOptionKey class TableFromQueryBase(CustomTableBase): @property - def as_query(self): - return self._get_dialect_option(AsQuery.__option_name__) + def as_query(self) -> Optional[AsQueryOption]: + return self._get_dialect_option(TableOptionKey.AS_QUERY) def __init__( self, name: str, metadata: MetaData, *args: SchemaItem, + as_query: AsQueryOptionType = None, **kw: Any, ) -> None: items = [item for item in args] - as_query: AsQuery = self.__get_as_query_from_items(items) + as_query = AsQueryOption.create(as_query) # noqa + kw.update(self._as_dialect_options([as_query])) if ( - as_query is not NoneType + isinstance(as_query, AsQueryOption) and isinstance(as_query.query, Selectable) and not self.__has_defined_columns(items) ): @@ -36,14 +38,6 @@ def __init__( args = items + columns super().__init__(name, metadata, *args, **kw) - def __get_as_query_from_items( - self, items: typing.List[SchemaItem] - ) -> Optional[AsQuery]: - for item in items: - if isinstance(item, AsQuery): - return item - return NoneType - def __has_defined_columns(self, items: typing.List[SchemaItem]) -> bool: for item in items: if isinstance(item, Column): diff --git a/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr b/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr index 81c7f90f..66c8f98e 100644 --- a/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr +++ b/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr @@ -6,7 +6,34 @@ "CREATE DYNAMIC TABLE test_dynamic_table_orm (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" # --- # name: test_compile_dynamic_table_orm_with_str_keys - "CREATE DYNAMIC TABLE test_dynamic_table_orm_2 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" + 'CREATE DYNAMIC TABLE "SCHEMA_DB".test_dynamic_table_orm_2 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = \'10 seconds\'\tAS SELECT * FROM table' +# --- +# name: test_compile_dynamic_table_with_multiple_wrong_option_types + ''' + Invalid parameter type 'IdentifierOption' provided for 'refresh_mode'. Expected one of the following types: 'KeywordOption', 'SnowflakeKeyword'. + Invalid parameter type 'IdentifierOption' provided for 'target_lag'. Expected one of the following types: 'TargetLagOption', 'Tuple[int, TimeUnit])', 'SnowflakeKeyword'. + Invalid parameter type 'KeywordOption' provided for 'as_query'. Expected one of the following types: 'AsQueryOption', 'str', 'Selectable'. + Invalid parameter type 'KeywordOption' provided for 'warehouse'. Expected one of the following types: 'IdentifierOption', 'str'. + + ''' +# --- +# name: test_compile_dynamic_table_with_one_wrong_option_types + ''' + Invalid parameter type 'LiteralOption' provided for 'warehouse'. Expected one of the following types: 'IdentifierOption', 'str'. + + ''' +# --- +# name: test_compile_dynamic_table_with_options_objects + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = AUTO\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_refresh_mode[SnowflakeKeyword.AUTO] + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = AUTO\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_refresh_mode[SnowflakeKeyword.FULL] + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = FULL\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_refresh_mode[SnowflakeKeyword.INCREMENTAL] + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = INCREMENTAL\tAS SELECT * FROM table" # --- # name: test_compile_dynamic_table_with_selectable "CREATE DYNAMIC TABLE dynamic_test_table_1 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT test_table_1.id, test_table_1.name FROM test_table_1 WHERE test_table_1.id = 23" diff --git a/tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr b/tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr new file mode 100644 index 00000000..80201495 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_create_dynamic_table_without_dynamictable_and_defined_options + CustomOptionsAreOnlySupportedOnSnowflakeTables('Identifier, Literal, TargetLag and other custom options are only supported on Snowflake tables.') +# --- +# name: test_create_dynamic_table_without_dynamictable_class + UnexpectedOptionTypeError('The following options are either unsupported or should be defined using a Snowflake table: as_query, warehouse.') +# --- diff --git a/tests/custom_tables/__snapshots__/test_generic_options.ambr b/tests/custom_tables/__snapshots__/test_generic_options.ambr new file mode 100644 index 00000000..eef5e6fd --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_generic_options.ambr @@ -0,0 +1,13 @@ +# serializer version: 1 +# name: test_identifier_option_with_wrong_type + InvalidTableParameterTypeError("Invalid parameter type 'int' provided for 'warehouse'. Expected one of the following types: 'IdentifierOption', 'str'.\n") +# --- +# name: test_identifier_option_without_name + OptionKeyNotProvidedError('Expected option key in IdentifierOption option but got NoneType instead.') +# --- +# name: test_invalid_as_query_option + InvalidTableParameterTypeError("Invalid parameter type 'int' provided for 'as_query'. Expected one of the following types: 'AsQueryOption', 'str', 'Selectable'.\n") +# --- +# name: test_literal_option_with_wrong_type + InvalidTableParameterTypeError("Invalid parameter type 'SnowflakeKeyword' provided for 'warehouse'. Expected one of the following types: 'LiteralOption', 'str', 'int'.\n") +# --- diff --git a/tests/custom_tables/test_compile_dynamic_table.py b/tests/custom_tables/test_compile_dynamic_table.py index 16a039e7..935c61cd 100644 --- a/tests/custom_tables/test_compile_dynamic_table.py +++ b/tests/custom_tables/test_compile_dynamic_table.py @@ -12,16 +12,21 @@ exc, select, ) +from sqlalchemy.exc import ArgumentError from sqlalchemy.orm import declarative_base from sqlalchemy.sql.ddl import CreateTable from snowflake.sqlalchemy import GEOMETRY, DynamicTable -from snowflake.sqlalchemy.sql.custom_schema.options.as_query import AsQuery -from snowflake.sqlalchemy.sql.custom_schema.options.target_lag import ( - TargetLag, +from snowflake.sqlalchemy.exc import MultipleErrors +from snowflake.sqlalchemy.sql.custom_schema.options import ( + AsQueryOption, + IdentifierOption, + KeywordOption, + LiteralOption, + TargetLagOption, TimeUnit, ) -from snowflake.sqlalchemy.sql.custom_schema.options.warehouse import Warehouse +from snowflake.sqlalchemy.sql.custom_schema.options.keywords import SnowflakeKeyword def test_compile_dynamic_table(sql_compiler, snapshot): @@ -32,9 +37,9 @@ def test_compile_dynamic_table(sql_compiler, snapshot): metadata, Column("id", Integer), Column("geom", GEOMETRY), - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery("SELECT * FROM table"), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", ) value = CreateTable(test_geometry) @@ -44,11 +49,99 @@ def test_compile_dynamic_table(sql_compiler, snapshot): assert actual == snapshot +@pytest.mark.parametrize( + "refresh_mode_keyword", + [ + SnowflakeKeyword.AUTO, + SnowflakeKeyword.FULL, + SnowflakeKeyword.INCREMENTAL, + ], +) +def test_compile_dynamic_table_with_refresh_mode( + sql_compiler, snapshot, refresh_mode_keyword +): + metadata = MetaData() + table_name = "test_dynamic_table" + test_geometry = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", + refresh_mode=refresh_mode_keyword, + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_with_options_objects(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + test_geometry = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=TargetLagOption(10, TimeUnit.SECONDS), + warehouse=IdentifierOption("warehouse"), + as_query=AsQueryOption("SELECT * FROM table"), + refresh_mode=KeywordOption(SnowflakeKeyword.AUTO), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_with_one_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + with pytest.raises(ArgumentError) as argument_error: + DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=TargetLagOption(10, TimeUnit.SECONDS), + warehouse=LiteralOption("warehouse"), + as_query=AsQueryOption("SELECT * FROM table"), + refresh_mode=KeywordOption(SnowflakeKeyword.AUTO), + ) + + assert str(argument_error.value) == snapshot + + +def test_compile_dynamic_table_with_multiple_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + with pytest.raises(MultipleErrors) as argument_error: + DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=IdentifierOption(SnowflakeKeyword.AUTO), + warehouse=KeywordOption(SnowflakeKeyword.AUTO), + as_query=KeywordOption(SnowflakeKeyword.AUTO), + refresh_mode=IdentifierOption(SnowflakeKeyword.AUTO), + ) + + assert str(argument_error.value) == snapshot + + def test_compile_dynamic_table_without_required_args(sql_compiler): with pytest.raises( exc.ArgumentError, - match="DYNAMIC TABLE must have the following arguments: TargetLag, " - "Warehouse, AsQuery", + match="DynamicTable requires the following parameters: warehouse, " + "as_query, target_lag.", ): DynamicTable( "test_dynamic_table", @@ -61,33 +154,33 @@ def test_compile_dynamic_table_without_required_args(sql_compiler): def test_compile_dynamic_table_with_primary_key(sql_compiler): with pytest.raises( exc.ArgumentError, - match="Primary key and foreign keys are not supported in DYNAMIC TABLE.", + match="Primary key and foreign keys are not supported in DynamicTable.", ): DynamicTable( "test_dynamic_table", MetaData(), Column("id", Integer, primary_key=True), Column("geom", GEOMETRY), - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery("SELECT * FROM table"), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", ) def test_compile_dynamic_table_with_foreign_key(sql_compiler): with pytest.raises( exc.ArgumentError, - match="Primary key and foreign keys are not supported in DYNAMIC TABLE.", + match="Primary key and foreign keys are not supported in DynamicTable.", ): DynamicTable( "test_dynamic_table", MetaData(), Column("id", Integer), Column("geom", GEOMETRY), - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery("SELECT * FROM table"), ForeignKeyConstraint(["id"], ["table.id"]), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", ) @@ -100,9 +193,9 @@ def test_compile_dynamic_table_orm(sql_compiler, snapshot): metadata, Column("id", Integer), Column("name", String), - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery("SELECT * FROM table"), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", ) class TestDynamicTableOrm(Base): @@ -121,23 +214,22 @@ def __repr__(self): assert actual == snapshot -def test_compile_dynamic_table_orm_with_str_keys(sql_compiler, db_parameters, snapshot): +def test_compile_dynamic_table_orm_with_str_keys(sql_compiler, snapshot): Base = declarative_base() - schema = db_parameters["schema"] class TestDynamicTableOrm(Base): __tablename__ = "test_dynamic_table_orm_2" - __table_args__ = {"schema": schema} @classmethod def __table_cls__(cls, name, metadata, *arg, **kw): return DynamicTable(name, metadata, *arg, **kw) - __table_args__ = ( - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery("SELECT * FROM table"), - ) + __table_args__ = { + "schema": "SCHEMA_DB", + "target_lag": (10, TimeUnit.SECONDS), + "warehouse": "warehouse", + "as_query": "SELECT * FROM table", + } id = Column(Integer) name = Column(String) @@ -167,9 +259,9 @@ def test_compile_dynamic_table_with_selectable(sql_compiler, snapshot): dynamic_test_table = DynamicTable( "dynamic_test_table_1", Base.metadata, - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery(select(test_table_1).where(test_table_1.c.id == 23)), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query=select(test_table_1).where(test_table_1.c.id == 23), ) value = CreateTable(dynamic_test_table) diff --git a/tests/custom_tables/test_create_dynamic_table.py b/tests/custom_tables/test_create_dynamic_table.py index 4e6c48ca..b583faad 100644 --- a/tests/custom_tables/test_create_dynamic_table.py +++ b/tests/custom_tables/test_create_dynamic_table.py @@ -1,15 +1,20 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +import pytest from sqlalchemy import Column, Integer, MetaData, String, Table, select -from snowflake.sqlalchemy import DynamicTable -from snowflake.sqlalchemy.sql.custom_schema.options.as_query import AsQuery -from snowflake.sqlalchemy.sql.custom_schema.options.target_lag import ( - TargetLag, +from snowflake.sqlalchemy import DynamicTable, exc +from snowflake.sqlalchemy.sql.custom_schema.options.as_query_option import AsQueryOption +from snowflake.sqlalchemy.sql.custom_schema.options.identifier_option import ( + IdentifierOption, +) +from snowflake.sqlalchemy.sql.custom_schema.options.keywords import SnowflakeKeyword +from snowflake.sqlalchemy.sql.custom_schema.options.table_option import TableOptionKey +from snowflake.sqlalchemy.sql.custom_schema.options.target_lag_option import ( + TargetLagOption, TimeUnit, ) -from snowflake.sqlalchemy.sql.custom_schema.options.warehouse import Warehouse def test_create_dynamic_table(engine_testaccount, db_parameters): @@ -32,9 +37,10 @@ def test_create_dynamic_table(engine_testaccount, db_parameters): metadata, Column("id", Integer), Column("name", String), - TargetLag(1, TimeUnit.HOURS), - Warehouse(warehouse), - AsQuery("SELECT id, name from test_table_1;"), + target_lag=(1, TimeUnit.HOURS), + warehouse=warehouse, + as_query="SELECT id, name from test_table_1;", + refresh_mode=SnowflakeKeyword.FULL, ) metadata.create_all(engine_testaccount) @@ -52,7 +58,7 @@ def test_create_dynamic_table(engine_testaccount, db_parameters): def test_create_dynamic_table_without_dynamictable_class( - engine_testaccount, db_parameters + engine_testaccount, db_parameters, snapshot ): warehouse = db_parameters.get("warehouse", "default") metadata = MetaData() @@ -68,26 +74,51 @@ def test_create_dynamic_table_without_dynamictable_class( conn.execute(ins) conn.commit() - dynamic_test_table_1 = Table( + Table( "dynamic_test_table_1", metadata, Column("id", Integer), Column("name", String), - TargetLag(1, TimeUnit.HOURS), - Warehouse(warehouse), - AsQuery("SELECT id, name from test_table_1;"), + snowflake_warehouse=warehouse, + snowflake_as_query="SELECT id, name from test_table_1;", prefixes=["DYNAMIC"], ) + with pytest.raises(exc.UnexpectedOptionTypeError) as exc_info: + metadata.create_all(engine_testaccount) + assert exc_info.value == snapshot + + +def test_create_dynamic_table_without_dynamictable_and_defined_options( + engine_testaccount, db_parameters, snapshot +): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + metadata.create_all(engine_testaccount) - try: - with engine_testaccount.connect() as conn: - s = select(dynamic_test_table_1) - results_dynamic_table = conn.execute(s).fetchall() - s = select(test_table_1) - results_table = conn.execute(s).fetchall() - assert results_dynamic_table == results_table + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") - finally: - metadata.drop_all(engine_testaccount) + conn.execute(ins) + conn.commit() + + Table( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + snowflake_target_lag=TargetLagOption.create((1, TimeUnit.HOURS)), + snowflake_warehouse=IdentifierOption.create( + TableOptionKey.WAREHOUSE, warehouse + ), + snowflake_as_query=AsQueryOption.create("SELECT id, name from test_table_1;"), + prefixes=["DYNAMIC"], + ) + + with pytest.raises(exc.CustomOptionsAreOnlySupportedOnSnowflakeTables) as exc_info: + metadata.create_all(engine_testaccount) + assert exc_info.value == snapshot diff --git a/tests/custom_tables/test_generic_options.py b/tests/custom_tables/test_generic_options.py new file mode 100644 index 00000000..916b94c6 --- /dev/null +++ b/tests/custom_tables/test_generic_options.py @@ -0,0 +1,83 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import pytest + +from snowflake.sqlalchemy import ( + AsQueryOption, + IdentifierOption, + KeywordOption, + LiteralOption, + SnowflakeKeyword, + TableOptionKey, + TargetLagOption, + exc, +) +from snowflake.sqlalchemy.sql.custom_schema.options.invalid_table_option import ( + InvalidTableOption, +) + + +def test_identifier_option(): + identifier = IdentifierOption.create(TableOptionKey.WAREHOUSE, "xsmall") + assert identifier.render_option(None) == "WAREHOUSE = xsmall" + + +def test_literal_option(): + literal = LiteralOption.create(TableOptionKey.WAREHOUSE, "xsmall") + assert literal.render_option(None) == "WAREHOUSE = 'xsmall'" + + +def test_identifier_option_without_name(snapshot): + identifier = IdentifierOption("xsmall") + with pytest.raises(exc.OptionKeyNotProvidedError) as exc_info: + identifier.render_option(None) + assert exc_info.value == snapshot + + +def test_identifier_option_with_wrong_type(snapshot): + identifier = IdentifierOption.create(TableOptionKey.WAREHOUSE, 23) + with pytest.raises(exc.InvalidTableParameterTypeError) as exc_info: + identifier.render_option(None) + assert exc_info.value == snapshot + + +def test_literal_option_with_wrong_type(snapshot): + literal = LiteralOption.create( + TableOptionKey.WAREHOUSE, SnowflakeKeyword.DOWNSTREAM + ) + with pytest.raises(exc.InvalidTableParameterTypeError) as exc_info: + literal.render_option(None) + assert exc_info.value == snapshot + + +def test_invalid_as_query_option(snapshot): + as_query = AsQueryOption.create(23) + with pytest.raises(exc.InvalidTableParameterTypeError) as exc_info: + as_query.render_option(None) + assert exc_info.value == snapshot + + +@pytest.mark.parametrize( + "table_option", + [ + IdentifierOption, + LiteralOption, + KeywordOption, + ], +) +def test_generic_option_with_wrong_type(table_option): + literal = table_option.create(TableOptionKey.WAREHOUSE, 0.32) + assert isinstance(literal, InvalidTableOption), "Expected InvalidTableOption" + + +@pytest.mark.parametrize( + "table_option", + [ + TargetLagOption, + AsQueryOption, + ], +) +def test_non_generic_option_with_wrong_type(table_option): + literal = table_option.create(0.32) + assert isinstance(literal, InvalidTableOption), "Expected InvalidTableOption" diff --git a/tests/custom_tables/test_reflect_dynamic_table.py b/tests/custom_tables/test_reflect_dynamic_table.py index 8a4a8445..52eb4457 100644 --- a/tests/custom_tables/test_reflect_dynamic_table.py +++ b/tests/custom_tables/test_reflect_dynamic_table.py @@ -74,7 +74,7 @@ def test_simple_reflection_without_options_loading(engine_testaccount, db_parame ) # TODO: Add support for loading options when table is reflected - assert dynamic_test_table.warehouse is NoneType + assert isinstance(dynamic_test_table.warehouse, NoneType) try: with engine_testaccount.connect() as conn: diff --git a/tests/test_core.py b/tests/test_core.py index 15840838..980db1d2 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1060,6 +1060,7 @@ def harass_inspector(): assert outcome +@pytest.mark.skip(reason="Testaccount is not available, it returns 404 error.") @pytest.mark.timeout(10) @pytest.mark.parametrize( "region", From 14be28216fb477d10815fd15e8290c866de4e260 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Tue, 29 Oct 2024 07:03:43 -0600 Subject: [PATCH 41/74] Add support for iceberg table with snowflake catalog (#539) * Add support for Iceberg Table with Snowflake Catalog * Add support for Snowflake Table * Update DESCRIPTION.md --- DESCRIPTION.md | 2 + README.md | 2 +- src/snowflake/sqlalchemy/__init__.py | 48 +++-- src/snowflake/sqlalchemy/base.py | 13 +- .../sqlalchemy/sql/custom_schema/__init__.py | 4 +- .../sql/custom_schema/clustered_table.py | 37 ++++ .../sql/custom_schema/dynamic_table.py | 6 +- .../sql/custom_schema/hybrid_table.py | 2 +- .../sql/custom_schema/iceberg_table.py | 101 ++++++++++ .../sql/custom_schema/options/__init__.py | 3 + .../options/cluster_by_option.py | 58 ++++++ .../sql/custom_schema/options/table_option.py | 1 + .../sql/custom_schema/snowflake_table.py | 70 +++++++ .../sql/custom_schema/table_from_query.py | 4 +- tests/__snapshots__/test_core.ambr | 4 + .../test_compile_iceberg_table.ambr | 19 ++ .../test_compile_snowflake_table.ambr | 35 ++++ .../test_create_iceberg_table.ambr | 14 ++ .../test_create_snowflake_table.ambr | 4 + .../test_reflect_snowflake_table.ambr | 7 + .../test_compile_iceberg_table.py | 116 +++++++++++ .../test_compile_snowflake_table.py | 180 ++++++++++++++++++ .../test_create_iceberg_table.py | 43 +++++ .../test_create_snowflake_table.py | 66 +++++++ .../test_reflect_snowflake_table.py | 69 +++++++ tests/test_core.py | 34 ++++ 26 files changed, 916 insertions(+), 26 deletions(-) create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/clustered_table.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/cluster_by_option.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/snowflake_table.py create mode 100644 tests/__snapshots__/test_core.ambr create mode 100644 tests/custom_tables/__snapshots__/test_compile_iceberg_table.ambr create mode 100644 tests/custom_tables/__snapshots__/test_compile_snowflake_table.ambr create mode 100644 tests/custom_tables/__snapshots__/test_create_iceberg_table.ambr create mode 100644 tests/custom_tables/__snapshots__/test_create_snowflake_table.ambr create mode 100644 tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr create mode 100644 tests/custom_tables/test_compile_iceberg_table.py create mode 100644 tests/custom_tables/test_compile_snowflake_table.py create mode 100644 tests/custom_tables/test_create_iceberg_table.py create mode 100644 tests/custom_tables/test_create_snowflake_table.py create mode 100644 tests/custom_tables/test_reflect_snowflake_table.py diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 909d52cf..47697d30 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -16,6 +16,8 @@ Source code is also available at: - Fixed SAWarning when registering functions with existing name in default namespace - Update options to be defined in key arguments instead of arguments. - Add support for refresh_mode option in DynamicTable + - Add support for iceberg table with Snowflake Catalog + - Fix cluster by option to support explicit expressions - v1.6.1(July 9, 2024) diff --git a/README.md b/README.md index c428353f..c6c13349 100644 --- a/README.md +++ b/README.md @@ -340,7 +340,7 @@ This example shows how to create a table with two columns, `id` and `name`, as t t = Table('myuser', metadata, Column('id', Integer, primary_key=True), Column('name', String), - snowflake_clusterby=['id', 'name'], ... + snowflake_clusterby=['id', 'name', text('id > 5')], ... ) metadata.create_all(engine) ``` diff --git a/src/snowflake/sqlalchemy/__init__.py b/src/snowflake/sqlalchemy/__init__.py index e53f9b74..f6c97f0d 100644 --- a/src/snowflake/sqlalchemy/__init__.py +++ b/src/snowflake/sqlalchemy/__init__.py @@ -9,7 +9,7 @@ else: import importlib.metadata as importlib_metadata -from sqlalchemy.types import ( +from sqlalchemy.types import ( # noqa BIGINT, BINARY, BOOLEAN, @@ -27,8 +27,8 @@ VARCHAR, ) -from . import base, snowdialect -from .custom_commands import ( +from . import base, snowdialect # noqa +from .custom_commands import ( # noqa AWSBucket, AzureContainer, CopyFormatter, @@ -41,7 +41,7 @@ MergeInto, PARQUETFormatter, ) -from .custom_types import ( +from .custom_types import ( # noqa ARRAY, BYTEINT, CHARACTER, @@ -61,9 +61,15 @@ VARBINARY, VARIANT, ) -from .sql.custom_schema import DynamicTable, HybridTable -from .sql.custom_schema.options import ( +from .sql.custom_schema import ( # noqa + DynamicTable, + HybridTable, + IcebergTable, + SnowflakeTable, +) +from .sql.custom_schema.options import ( # noqa AsQueryOption, + ClusterByOption, IdentifierOption, KeywordOption, LiteralOption, @@ -72,14 +78,13 @@ TargetLagOption, TimeUnit, ) -from .util import _url as URL +from .util import _url as URL # noqa base.dialect = dialect = snowdialect.dialect __version__ = importlib_metadata.version("snowflake-sqlalchemy") -__all__ = ( - # Custom Types +_custom_types = ( "BIGINT", "BINARY", "BOOLEAN", @@ -114,7 +119,9 @@ "TINYINT", "VARBINARY", "VARIANT", - # Custom Commands +) + +_custom_commands = ( "MergeInto", "CSVFormatter", "JSONFormatter", @@ -126,17 +133,28 @@ "ExternalStage", "CreateStage", "CreateFileFormat", - # Custom Tables - "HybridTable", - "DynamicTable", - # Custom Table Options +) + +_custom_tables = ("HybridTable", "DynamicTable", "IcebergTable", "SnowflakeTable") + +_custom_table_options = ( "AsQueryOption", "TargetLagOption", "LiteralOption", "IdentifierOption", "KeywordOption", - # Enums + "ClusterByOption", +) + +_enums = ( "TimeUnit", "TableOptionKey", "SnowflakeKeyword", ) +__all__ = ( + *_custom_types, + *_custom_commands, + *_custom_tables, + *_custom_table_options, + *_enums, +) diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 023f7afb..4e36c4ad 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -908,7 +908,7 @@ def handle_cluster_by(self, table): ... metadata, ... sa.Column('id', sa.Integer, primary_key=True), ... sa.Column('name', sa.String), - ... snowflake_clusterby=['id', 'name'] + ... snowflake_clusterby=['id', 'name', text("id > 5")] ... ) >>> print(CreateTable(user).compile(engine)) @@ -916,7 +916,7 @@ def handle_cluster_by(self, table): id INTEGER NOT NULL AUTOINCREMENT, name VARCHAR, PRIMARY KEY (id) - ) CLUSTER BY (id, name) + ) CLUSTER BY (id, name, id > 5) """ @@ -925,7 +925,14 @@ def handle_cluster_by(self, table): cluster = info.get("clusterby") if cluster: text += " CLUSTER BY ({})".format( - ", ".join(self.denormalize_column_name(key) for key in cluster) + ", ".join( + ( + self.denormalize_column_name(key) + if isinstance(key, str) + else str(key) + ) + for key in cluster + ) ) return text diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py index 66b9270f..cbc75ebc 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py @@ -3,5 +3,7 @@ # from .dynamic_table import DynamicTable from .hybrid_table import HybridTable +from .iceberg_table import IcebergTable +from .snowflake_table import SnowflakeTable -__all__ = ["DynamicTable", "HybridTable"] +__all__ = ["DynamicTable", "HybridTable", "IcebergTable", "SnowflakeTable"] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/clustered_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/clustered_table.py new file mode 100644 index 00000000..6c0904a8 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/clustered_table.py @@ -0,0 +1,37 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from typing import Any, Optional + +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from .custom_table_base import CustomTableBase +from .options.as_query_option import AsQueryOption +from .options.cluster_by_option import ClusterByOption, ClusterByOptionType +from .options.table_option import TableOptionKey + + +class ClusteredTableBase(CustomTableBase): + + @property + def cluster_by(self) -> Optional[AsQueryOption]: + return self._get_dialect_option(TableOptionKey.CLUSTER_BY) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + cluster_by: ClusterByOptionType = None, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + + options = [ + ClusterByOption.create(cluster_by), + ] + + kw.update(self._as_dialect_options(options)) + super().__init__(name, metadata, *args, **kw) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py index 6db4312d..91c379f0 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py @@ -12,7 +12,6 @@ IdentifierOption, IdentifierOptionType, KeywordOptionType, - LiteralOption, TableOptionKey, TargetLagOption, TargetLagOptionType, @@ -45,7 +44,7 @@ class DynamicTable(TableFromQueryBase): as_query="SELECT id, name from test_table_1;" ) - Example using full options: + Example using explicit options: DynamicTable( "dynamic_test_table_1", metadata, @@ -67,7 +66,7 @@ class DynamicTable(TableFromQueryBase): ] @property - def warehouse(self) -> typing.Optional[LiteralOption]: + def warehouse(self) -> typing.Optional[IdentifierOption]: return self._get_dialect_option(TableOptionKey.WAREHOUSE) @property @@ -112,6 +111,7 @@ def __repr__(self) -> str: + [repr(x) for x in self.columns] + [repr(self.target_lag)] + [repr(self.warehouse)] + + [repr(self.cluster_by)] + [repr(self.as_query)] + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py index b7c29e78..16a58d47 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py @@ -17,7 +17,7 @@ class HybridTable(CustomTableBase): The `HybridTable` class allows for the creation and querying of OLTP Snowflake Tables . While it does not support reflection at this time, it provides a flexible - interface for creating dynamic tables and management. + interface for creating hybrid tables and management. For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-hybrid-table diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py new file mode 100644 index 00000000..5c9c53d9 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py @@ -0,0 +1,101 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import typing +from typing import Any + +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from .custom_table_prefix import CustomTablePrefix +from .options import LiteralOption, LiteralOptionType, TableOptionKey +from .table_from_query import TableFromQueryBase + + +class IcebergTable(TableFromQueryBase): + """ + A class representing an iceberg table with configurable options and settings. + + While it does not support reflection at this time, it provides a flexible + interface for creating iceberg tables and management. + + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-iceberg-table + + Example using option values: + + IcebergTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + external_volume='my_external_volume', + base_location='my_iceberg_table'" + ) + + Example using explicit options: + DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + external_volume=LiteralOption('my_external_volume') + base_location=LiteralOption('my_iceberg_table') + ) + """ + + __table_prefixes__ = [CustomTablePrefix.ICEBERG] + + @property + def external_volume(self) -> typing.Optional[LiteralOption]: + return self._get_dialect_option(TableOptionKey.EXTERNAL_VOLUME) + + @property + def base_location(self) -> typing.Optional[LiteralOption]: + return self._get_dialect_option(TableOptionKey.BASE_LOCATION) + + @property + def catalog(self) -> typing.Optional[LiteralOption]: + return self._get_dialect_option(TableOptionKey.CATALOG) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + external_volume: LiteralOptionType = None, + base_location: LiteralOptionType = None, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + + options = [ + LiteralOption.create(TableOptionKey.EXTERNAL_VOLUME, external_volume), + LiteralOption.create(TableOptionKey.BASE_LOCATION, base_location), + LiteralOption.create(TableOptionKey.CATALOG, "SNOWFLAKE"), + ] + + kw.update(self._as_dialect_options(options)) + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + self.__init__(name, metadata, *args, _no_init=False, **kw) + + def __repr__(self) -> str: + return "IcebergTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [repr(self.external_volume)] + + [repr(self.base_location)] + + [repr(self.catalog)] + + [repr(self.cluster_by)] + + [repr(self.as_query)] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py index 11b54c1a..e94ea46b 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py @@ -3,6 +3,7 @@ # from .as_query_option import AsQueryOption, AsQueryOptionType +from .cluster_by_option import ClusterByOption, ClusterByOptionType from .identifier_option import IdentifierOption, IdentifierOptionType from .keyword_option import KeywordOption, KeywordOptionType from .keywords import SnowflakeKeyword @@ -17,6 +18,7 @@ "KeywordOption", "AsQueryOption", "TargetLagOption", + "ClusterByOption", # Enums "TimeUnit", "SnowflakeKeyword", @@ -27,4 +29,5 @@ "AsQueryOptionType", "TargetLagOptionType", "KeywordOptionType", + "ClusterByOptionType", ] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/cluster_by_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/cluster_by_option.py new file mode 100644 index 00000000..b92029bb --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/cluster_by_option.py @@ -0,0 +1,58 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import List, Union + +from sqlalchemy.sql.expression import TextClause + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class ClusterByOption(TableOption): + """Class to represent the cluster by clause in tables. + For further information on this clause, please refer to: https://docs.snowflake.com/en/user-guide/tables-clustering-keys + Example: + cluster_by=ClusterByOption('name', text('id > 0')) + + is equivalent to: + + cluster by (name, id > 0) + """ + + def __init__(self, *expressions: Union[str, TextClause]) -> None: + super().__init__() + self._name: TableOptionKey = TableOptionKey.CLUSTER_BY + self.expressions = expressions + + @staticmethod + def create(value: "ClusterByOptionType") -> "TableOption": + if isinstance(value, (NoneType, ClusterByOption)): + return value + if isinstance(value, List): + return ClusterByOption(*value) + return TableOption._get_invalid_table_option( + TableOptionKey.CLUSTER_BY, + str(type(value).__name__), + [ClusterByOption.__name__, list.__name__], + ) + + def template(self) -> str: + return f"{self.option_name.upper()} (%s)" + + @property + def priority(self) -> Priority: + return Priority.HIGH + + def __get_expression(self): + return ", ".join([str(expression) for expression in self.expressions]) + + def _render(self, compiler) -> str: + return self.template() % (self.__get_expression()) + + def __repr__(self) -> str: + return "ClusterByOption(%s)" % self.__get_expression() + + +ClusterByOptionType = Union[ClusterByOption, List[Union[str, TextClause]]] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py index 14b91f2e..5ebb4817 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py @@ -73,6 +73,7 @@ class TableOptionKey(Enum): BASE_LOCATION = "base_location" CATALOG = "catalog" CATALOG_SYNC = "catalog_sync" + CLUSTER_BY = "cluster by" DATA_RETENTION_TIME_IN_DAYS = "data_retention_time_in_days" DEFAULT_DDL_COLLATION = "default_ddl_collation" EXTERNAL_VOLUME = "external_volume" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/snowflake_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/snowflake_table.py new file mode 100644 index 00000000..56a14c83 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/snowflake_table.py @@ -0,0 +1,70 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from typing import Any + +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from .table_from_query import TableFromQueryBase + + +class SnowflakeTable(TableFromQueryBase): + """ + A class representing a table in Snowflake with configurable options and settings. + + While it does not support reflection at this time, it provides a flexible + interface for creating tables and management. + + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-table + Example usage: + + SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by = ["id", text("name > 5")] + ) + + Example using explict options: + + SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by = ClusterByOption("id", text("name > 5")) + ) + + """ + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + self.__init__(name, metadata, *args, _no_init=False, **kw) + + def __repr__(self) -> str: + return "SnowflakeTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [repr(self.cluster_by)] + + [repr(self.as_query)] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py index fccc7a0b..cbd65de3 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py @@ -7,12 +7,12 @@ from sqlalchemy.sql import Selectable from sqlalchemy.sql.schema import Column, MetaData, SchemaItem -from .custom_table_base import CustomTableBase +from .clustered_table import ClusteredTableBase from .options.as_query_option import AsQueryOption, AsQueryOptionType from .options.table_option import TableOptionKey -class TableFromQueryBase(CustomTableBase): +class TableFromQueryBase(ClusteredTableBase): @property def as_query(self) -> Optional[AsQueryOption]: diff --git a/tests/__snapshots__/test_core.ambr b/tests/__snapshots__/test_core.ambr new file mode 100644 index 00000000..7a4e0f99 --- /dev/null +++ b/tests/__snapshots__/test_core.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_compile_table_with_cluster_by_with_expression + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY ("Id")) CLUSTER BY ("Id", name, "Id" > 5)' +# --- diff --git a/tests/custom_tables/__snapshots__/test_compile_iceberg_table.ambr b/tests/custom_tables/__snapshots__/test_compile_iceberg_table.ambr new file mode 100644 index 00000000..b243cc09 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_iceberg_table.ambr @@ -0,0 +1,19 @@ +# serializer version: 1 +# name: test_compile_dynamic_table_orm_with_as_query + "CREATE ICEBERG TABLE test_iceberg_table_orm_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY (id))\tEXTERNAL_VOLUME = 'my_external_volume'\tCATALOG = 'SNOWFLAKE'\tBASE_LOCATION = 'my_iceberg_table'\tAS SELECT * FROM table" +# --- +# name: test_compile_icberg_table_with_primary_key + "CREATE ICEBERG TABLE test_iceberg_table_with_options (\tid INTEGER NOT NULL AUTOINCREMENT, \tgeom VARCHAR, \tPRIMARY KEY (id))\tEXTERNAL_VOLUME = 'my_external_volume'\tCATALOG = 'SNOWFLAKE'\tBASE_LOCATION = 'my_iceberg_table'" +# --- +# name: test_compile_iceberg_table + "CREATE ICEBERG TABLE test_iceberg_table (\tid INTEGER, \tgeom VARCHAR)\tEXTERNAL_VOLUME = 'my_external_volume'\tCATALOG = 'SNOWFLAKE'\tBASE_LOCATION = 'my_iceberg_table'" +# --- +# name: test_compile_iceberg_table_with_one_wrong_option_types + ''' + Invalid parameter type 'IdentifierOption' provided for 'external_volume'. Expected one of the following types: 'LiteralOption', 'str', 'int'. + + ''' +# --- +# name: test_compile_iceberg_table_with_options_objects + "CREATE ICEBERG TABLE test_iceberg_table_with_options (\tid INTEGER, \tgeom VARCHAR)\tEXTERNAL_VOLUME = 'my_external_volume'\tCATALOG = 'SNOWFLAKE'\tBASE_LOCATION = 'my_iceberg_table'" +# --- diff --git a/tests/custom_tables/__snapshots__/test_compile_snowflake_table.ambr b/tests/custom_tables/__snapshots__/test_compile_snowflake_table.ambr new file mode 100644 index 00000000..5ea64c12 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_snowflake_table.ambr @@ -0,0 +1,35 @@ +# serializer version: 1 +# name: test_compile_dynamic_table_orm_with_str_keys + 'CREATE TABLE "SCHEMA_DB".test_snowflake_table_orm_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_dynamic_table_with_foreign_key + 'CREATE TABLE test_table_2 (\tid INTEGER NOT NULL, \tgeom VARCHAR, \tPRIMARY KEY (id), \tFOREIGN KEY(id) REFERENCES "table" (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_dynamic_table_with_primary_key + 'CREATE TABLE test_table_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tgeom VARCHAR, \tPRIMARY KEY (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table + 'CREATE TABLE test_table_1 (\tid INTEGER, \tgeom VARCHAR)\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_orm_with_str_keys + 'CREATE TABLE "SCHEMA_DB".test_snowflake_table_orm_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_with_explicit_options + 'CREATE TABLE test_table_2 (\tid INTEGER, \tgeom VARCHAR)\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_with_foreign_key + 'CREATE TABLE test_table_2 (\tid INTEGER NOT NULL, \tgeom VARCHAR, \tPRIMARY KEY (id), \tFOREIGN KEY(id) REFERENCES "table" (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_with_primary_key + 'CREATE TABLE test_table_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tgeom VARCHAR, \tPRIMARY KEY (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_with_selectable + 'CREATE TABLE snowflake_test_table_1 (\tid INTEGER, \tgeom VARCHAR)\tAS SELECT test_table_1.id, test_table_1.geom FROM test_table_1 WHERE test_table_1.id = 23' +# --- +# name: test_compile_snowflake_table_with_wrong_option_types + ''' + Invalid parameter type 'AsQueryOption' provided for 'cluster by'. Expected one of the following types: 'ClusterByOption', 'list'. + Invalid parameter type 'ClusterByOption' provided for 'as_query'. Expected one of the following types: 'AsQueryOption', 'str', 'Selectable'. + + ''' +# --- diff --git a/tests/custom_tables/__snapshots__/test_create_iceberg_table.ambr b/tests/custom_tables/__snapshots__/test_create_iceberg_table.ambr new file mode 100644 index 00000000..908a4c60 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_iceberg_table.ambr @@ -0,0 +1,14 @@ +# serializer version: 1 +# name: test_create_iceberg_table + ''' + (snowflake.connector.errors.ProgrammingError) 091017 (22000): S3 bucket 'my_example_bucket' does not exist or not authorized. + [SQL: + CREATE ICEBERG TABLE "Iceberg_Table_1" ( + id INTEGER NOT NULL AUTOINCREMENT, + geom VARCHAR, + PRIMARY KEY (id) + ) EXTERNAL_VOLUME = 'exvol' CATALOG = 'SNOWFLAKE' BASE_LOCATION = 'my_iceberg_table' + + ] + ''' +# --- diff --git a/tests/custom_tables/__snapshots__/test_create_snowflake_table.ambr b/tests/custom_tables/__snapshots__/test_create_snowflake_table.ambr new file mode 100644 index 00000000..98d3137f --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_snowflake_table.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_create_snowflake_table_with_cluster_by + "[(1, 'test')]" +# --- diff --git a/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr b/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr new file mode 100644 index 00000000..6ef09ff7 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_simple_reflection_of_table_as_snowflake_table + 'CREATE TABLE test_snowflake_table_reflection (\tid DECIMAL(38, 0) NOT NULL, \tname VARCHAR(16777216), \tCONSTRAINT demo_name PRIMARY KEY (id))' +# --- +# name: test_simple_reflection_of_table_as_sqlalchemy_table + 'CREATE TABLE test_snowflake_table_reflection (\tid DECIMAL(38, 0) NOT NULL, \tname VARCHAR(16777216), \tCONSTRAINT demo_name PRIMARY KEY (id))' +# --- diff --git a/tests/custom_tables/test_compile_iceberg_table.py b/tests/custom_tables/test_compile_iceberg_table.py new file mode 100644 index 00000000..173e7b0a --- /dev/null +++ b/tests/custom_tables/test_compile_iceberg_table.py @@ -0,0 +1,116 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Column, Integer, MetaData, String +from sqlalchemy.exc import ArgumentError +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import IcebergTable +from snowflake.sqlalchemy.sql.custom_schema.options import ( + IdentifierOption, + LiteralOption, +) + + +def test_compile_iceberg_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_iceberg_table" + test_table = IcebergTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + external_volume="my_external_volume", + base_location="my_iceberg_table", + ) + + value = CreateTable(test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_iceberg_table_with_options_objects(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_iceberg_table_with_options" + test_table = IcebergTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + external_volume=LiteralOption("my_external_volume"), + base_location=LiteralOption("my_iceberg_table"), + ) + + value = CreateTable(test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_iceberg_table_with_one_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_wrong_iceberg_table" + with pytest.raises(ArgumentError) as argument_error: + IcebergTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + external_volume=IdentifierOption("my_external_volume"), + base_location=LiteralOption("my_iceberg_table"), + ) + + assert str(argument_error.value) == snapshot + + +def test_compile_icberg_table_with_primary_key(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_iceberg_table_with_options" + test_table = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + external_volume=LiteralOption("my_external_volume"), + base_location=LiteralOption("my_iceberg_table"), + ) + + value = CreateTable(test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_orm_with_as_query(sql_compiler, snapshot): + Base = declarative_base() + + class TestDynamicTableOrm(Base): + __tablename__ = "test_iceberg_table_orm_2" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return IcebergTable(name, metadata, *arg, **kw) + + __table_args__ = { + "external_volume": "my_external_volume", + "base_location": "my_iceberg_table", + "as_query": "SELECT * FROM table", + } + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"" + + value = CreateTable(TestDynamicTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/custom_tables/test_compile_snowflake_table.py b/tests/custom_tables/test_compile_snowflake_table.py new file mode 100644 index 00000000..be9383eb --- /dev/null +++ b/tests/custom_tables/test_compile_snowflake_table.py @@ -0,0 +1,180 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import ( + Column, + ForeignKeyConstraint, + Integer, + MetaData, + String, + select, + text, +) +from sqlalchemy.exc import ArgumentError +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import SnowflakeTable +from snowflake.sqlalchemy.sql.custom_schema.options import ( + AsQueryOption, + ClusterByOption, +) + + +def test_compile_snowflake_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_table_1" + test_geometry = SnowflakeTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + cluster_by=["id", text("id > 100")], + as_query="SELECT * FROM table", + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_with_explicit_options(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_table_2" + test_geometry = SnowflakeTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + cluster_by=ClusterByOption("id", text("id > 100")), + as_query=AsQueryOption("SELECT * FROM table"), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_with_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_snowflake_table" + with pytest.raises(ArgumentError) as argument_error: + SnowflakeTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + as_query=ClusterByOption("id", text("id > 100")), + cluster_by=AsQueryOption("SELECT * FROM table"), + ) + + assert str(argument_error.value) == snapshot + + +def test_compile_snowflake_table_with_primary_key(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_table_2" + test_geometry = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + cluster_by=ClusterByOption("id", text("id > 100")), + as_query=AsQueryOption("SELECT * FROM table"), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_with_foreign_key(sql_compiler, snapshot): + metadata = MetaData() + + SnowflakeTable( + "table", + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + ForeignKeyConstraint(["id"], ["table.id"]), + cluster_by=ClusterByOption("id", text("id > 100")), + as_query=AsQueryOption("SELECT * FROM table"), + ) + + table_name = "test_table_2" + test_geometry = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + ForeignKeyConstraint(["id"], ["table.id"]), + cluster_by=ClusterByOption("id", text("id > 100")), + as_query=AsQueryOption("SELECT * FROM table"), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_orm_with_str_keys(sql_compiler, snapshot): + Base = declarative_base() + + class TestSnowflakeTableOrm(Base): + __tablename__ = "test_snowflake_table_orm_2" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return SnowflakeTable(name, metadata, *arg, **kw) + + __table_args__ = { + "schema": "SCHEMA_DB", + "cluster_by": ["id", text("id > 100")], + "as_query": "SELECT * FROM table", + } + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"" + + value = CreateTable(TestSnowflakeTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_with_selectable(sql_compiler, snapshot): + Base = declarative_base() + + test_table_1 = SnowflakeTable( + "test_table_1", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + ForeignKeyConstraint(["id"], ["table.id"]), + cluster_by=ClusterByOption("id", text("id > 100")), + ) + + test_table_2 = SnowflakeTable( + "snowflake_test_table_1", + Base.metadata, + as_query=select(test_table_1).where(test_table_1.c.id == 23), + ) + + value = CreateTable(test_table_2) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/custom_tables/test_create_iceberg_table.py b/tests/custom_tables/test_create_iceberg_table.py new file mode 100644 index 00000000..3ecd703b --- /dev/null +++ b/tests/custom_tables/test_create_iceberg_table.py @@ -0,0 +1,43 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Column, Integer, MetaData, String +from sqlalchemy.exc import ProgrammingError + +from snowflake.sqlalchemy import IcebergTable + + +@pytest.mark.aws +def test_create_iceberg_table(engine_testaccount, snapshot): + metadata = MetaData() + external_volume_name = "exvol" + create_external_volume = f""" + CREATE OR REPLACE EXTERNAL VOLUME {external_volume_name} + STORAGE_LOCATIONS = + ( + ( + NAME = 'my-s3-us-west-2' + STORAGE_PROVIDER = 'S3' + STORAGE_BASE_URL = 's3://MY_EXAMPLE_BUCKET/' + STORAGE_AWS_ROLE_ARN = 'arn:aws:iam::123456789012:role/myrole' + ENCRYPTION=(TYPE='AWS_SSE_KMS' KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab') + ) + ); + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_external_volume) + IcebergTable( + "Iceberg_Table_1", + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + external_volume=external_volume_name, + base_location="my_iceberg_table", + ) + + with pytest.raises(ProgrammingError) as argument_error: + metadata.create_all(engine_testaccount) + + error_str = str(argument_error.value) + assert error_str[: error_str.rfind("\n")] == snapshot diff --git a/tests/custom_tables/test_create_snowflake_table.py b/tests/custom_tables/test_create_snowflake_table.py new file mode 100644 index 00000000..09140fb8 --- /dev/null +++ b/tests/custom_tables/test_create_snowflake_table.py @@ -0,0 +1,66 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from sqlalchemy import Column, Integer, MetaData, String, select, text +from sqlalchemy.orm import Session, declarative_base + +from snowflake.sqlalchemy import SnowflakeTable + + +def test_create_snowflake_table_with_cluster_by( + engine_testaccount, db_parameters, snapshot +): + metadata = MetaData() + table_name = "test_create_snowflake_table" + + test_table_1 = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by=["id", text("id > 5")], + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + conn.execute(ins) + conn.commit() + + try: + with engine_testaccount.connect() as conn: + s = select(test_table_1) + results_hybrid_table = conn.execute(s).fetchall() + assert str(results_hybrid_table) == snapshot + finally: + metadata.drop_all(engine_testaccount) + + +def test_create_snowflake_table_with_orm(sql_compiler, engine_testaccount): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestHybridTableOrm(Base): + __tablename__ = "test_snowflake_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return SnowflakeTable(name, metadata, *arg, **kw) + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + instance = TestHybridTableOrm(id=0, name="name_example") + session.add(instance) + session.commit() + data = session.query(TestHybridTableOrm).all() + assert str(data) == "[(0, 'name_example')]" + finally: + Base.metadata.drop_all(engine_testaccount) diff --git a/tests/custom_tables/test_reflect_snowflake_table.py b/tests/custom_tables/test_reflect_snowflake_table.py new file mode 100644 index 00000000..ef84622b --- /dev/null +++ b/tests/custom_tables/test_reflect_snowflake_table.py @@ -0,0 +1,69 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from sqlalchemy import MetaData, Table +from sqlalchemy.sql.ddl import CreateTable + +from src.snowflake.sqlalchemy import SnowflakeTable + + +def test_simple_reflection_of_table_as_sqlalchemy_table( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_snowflake_table_reflection" + + create_table_sql = f""" + CREATE TABLE {table_name} (id INT primary key, name VARCHAR); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + snowflake_test_table = Table(table_name, metadata, autoload_with=engine_testaccount) + constraint = snowflake_test_table.constraints.pop() + constraint.name = "demo_name" + snowflake_test_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(snowflake_test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) + + +def test_simple_reflection_of_table_as_snowflake_table( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_snowflake_table_reflection" + + create_table_sql = f""" + CREATE TABLE {table_name} (id INT primary key, name VARCHAR); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + snowflake_test_table = SnowflakeTable( + table_name, metadata, autoload_with=engine_testaccount + ) + constraint = snowflake_test_table.constraints.pop() + constraint.name = "demo_name" + snowflake_test_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(snowflake_test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/test_core.py b/tests/test_core.py index 980db1d2..9342ad58 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -36,6 +36,7 @@ ) from sqlalchemy.exc import DBAPIError, NoSuchTableError, OperationalError from sqlalchemy.sql import and_, not_, or_, select +from sqlalchemy.sql.ddl import CreateTable import snowflake.connector.errors import snowflake.sqlalchemy.snowdialect @@ -699,6 +700,39 @@ def test_create_table_with_cluster_by(engine_testaccount): user.drop(engine_testaccount) +def test_create_table_with_cluster_by_with_expression(engine_testaccount): + metadata = MetaData() + Table( + "clustered_user", + metadata, + Column("Id", Integer, primary_key=True), + Column("name", String), + snowflake_clusterby=["Id", "name", text('"Id" > 5')], + ) + metadata.create_all(engine_testaccount) + try: + inspector = inspect(engine_testaccount) + columns_in_table = inspector.get_columns("clustered_user") + assert columns_in_table[0]["name"] == "Id", "name" + finally: + metadata.drop_all(engine_testaccount) + + +def test_compile_table_with_cluster_by_with_expression(sql_compiler, snapshot): + metadata = MetaData() + user = Table( + "clustered_user", + metadata, + Column("Id", Integer, primary_key=True), + Column("name", String), + snowflake_clusterby=["Id", "name", text('"Id" > 5')], + ) + + create_table = CreateTable(user) + + assert sql_compiler(create_table) == snapshot + + def test_view_names(engine_testaccount): """ Tests all views From 31d0da643f0bef64b1fbd49e49e840ff3d29eb87 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Tue, 19 Nov 2024 10:15:48 -0600 Subject: [PATCH 42/74] Add support for map datatype (#541) Add support for map datatype --- DESCRIPTION.md | 3 +- pyproject.toml | 3 +- src/snowflake/sqlalchemy/__init__.py | 2 + src/snowflake/sqlalchemy/_constants.py | 1 + src/snowflake/sqlalchemy/base.py | 7 + src/snowflake/sqlalchemy/custom_types.py | 20 ++ src/snowflake/sqlalchemy/exc.py | 8 + .../sqlalchemy/parser/custom_type_parser.py | 190 ++++++++++++ src/snowflake/sqlalchemy/snowdialect.py | 213 ++++++-------- .../sql/custom_schema/custom_table_base.py | 16 ++ .../sql/custom_schema/iceberg_table.py | 1 + src/snowflake/sqlalchemy/version.py | 2 +- .../test_structured_datatypes.ambr | 90 ++++++ .../test_unit_structured_types.ambr | 4 + tests/conftest.py | 30 ++ .../test_reflect_snowflake_table.ambr | 22 ++ .../test_reflect_snowflake_table.py | 27 +- tests/test_core.py | 136 ++++----- tests/test_structured_datatypes.py | 271 ++++++++++++++++++ tests/test_unit_structured_types.py | 73 +++++ tests/util.py | 2 + 21 files changed, 898 insertions(+), 223 deletions(-) create mode 100644 src/snowflake/sqlalchemy/parser/custom_type_parser.py create mode 100644 tests/__snapshots__/test_structured_datatypes.ambr create mode 100644 tests/__snapshots__/test_unit_structured_types.ambr create mode 100644 tests/test_structured_datatypes.py create mode 100644 tests/test_unit_structured_types.py diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 47697d30..33775996 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,7 +9,7 @@ Source code is also available at: # Release Notes -- (Unreleased) +- v1.7.0(November 12, 2024) - Add support for dynamic tables and required options - Add support for hybrid tables @@ -18,6 +18,7 @@ Source code is also available at: - Add support for refresh_mode option in DynamicTable - Add support for iceberg table with Snowflake Catalog - Fix cluster by option to support explicit expressions + - Add support for MAP datatype - v1.6.1(July 9, 2024) diff --git a/pyproject.toml b/pyproject.toml index 6c72f683..84e64faf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,7 +111,7 @@ line-length = 88 line-length = 88 [tool.pytest.ini_options] -addopts = "-m 'not feature_max_lob_size and not aws'" +addopts = "-m 'not feature_max_lob_size and not aws and not requires_external_volume'" markers = [ # Optional dependency groups markers "lambda: AWS lambda tests", @@ -128,6 +128,7 @@ markers = [ # Other markers "timeout: tests that need a timeout time", "internal: tests that could but should only run on our internal CI", + "requires_external_volume: tests that needs a external volume to be executed", "external: tests that could but should only run on our external CI", "feature_max_lob_size: tests that could but should only run on our external CI", ] diff --git a/src/snowflake/sqlalchemy/__init__.py b/src/snowflake/sqlalchemy/__init__.py index f6c97f0d..7d795b2a 100644 --- a/src/snowflake/sqlalchemy/__init__.py +++ b/src/snowflake/sqlalchemy/__init__.py @@ -50,6 +50,7 @@ FIXED, GEOGRAPHY, GEOMETRY, + MAP, NUMBER, OBJECT, STRING, @@ -119,6 +120,7 @@ "TINYINT", "VARBINARY", "VARIANT", + "MAP", ) _custom_commands = ( diff --git a/src/snowflake/sqlalchemy/_constants.py b/src/snowflake/sqlalchemy/_constants.py index 839745ee..205ad5d9 100644 --- a/src/snowflake/sqlalchemy/_constants.py +++ b/src/snowflake/sqlalchemy/_constants.py @@ -11,3 +11,4 @@ APPLICATION_NAME = "SnowflakeSQLAlchemy" SNOWFLAKE_SQLALCHEMY_VERSION = VERSION DIALECT_NAME = "snowflake" +NOT_NULL = "NOT NULL" diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 4e36c4ad..a1e16062 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -27,6 +27,7 @@ ExternalStage, ) +from ._constants import NOT_NULL from .exc import ( CustomOptionsAreOnlySupportedOnSnowflakeTables, UnexpectedOptionTypeError, @@ -1071,6 +1072,12 @@ def visit_TINYINT(self, type_, **kw): def visit_VARIANT(self, type_, **kw): return "VARIANT" + def visit_MAP(self, type_, **kw): + not_null = f" {NOT_NULL}" if type_.not_null else "" + return ( + f"MAP({type_.key_type.compile()}, {type_.value_type.compile()}{not_null})" + ) + def visit_ARRAY(self, type_, **kw): return "ARRAY" diff --git a/src/snowflake/sqlalchemy/custom_types.py b/src/snowflake/sqlalchemy/custom_types.py index 802d1ce1..f2c950dd 100644 --- a/src/snowflake/sqlalchemy/custom_types.py +++ b/src/snowflake/sqlalchemy/custom_types.py @@ -37,6 +37,26 @@ class VARIANT(SnowflakeType): __visit_name__ = "VARIANT" +class StructuredType(SnowflakeType): + def __init__(self): + super().__init__() + + +class MAP(StructuredType): + __visit_name__ = "MAP" + + def __init__( + self, + key_type: sqltypes.TypeEngine, + value_type: sqltypes.TypeEngine, + not_null: bool = False, + ): + self.key_type = key_type + self.value_type = value_type + self.not_null = not_null + super().__init__() + + class OBJECT(SnowflakeType): __visit_name__ = "OBJECT" diff --git a/src/snowflake/sqlalchemy/exc.py b/src/snowflake/sqlalchemy/exc.py index 898de279..399e94b6 100644 --- a/src/snowflake/sqlalchemy/exc.py +++ b/src/snowflake/sqlalchemy/exc.py @@ -72,3 +72,11 @@ def __init__(self, errors): def __str__(self): return "".join(str(e) for e in self.errors) + + +class StructuredTypeNotSupportedInTableColumnsError(ArgumentError): + def __init__(self, table_type: str, table_name: str, column_name: str): + super().__init__( + f"Column '{column_name}' is of a structured type, which is only supported on Iceberg tables. " + f"The table '{table_name}' is of type '{table_type}', not Iceberg." + ) diff --git a/src/snowflake/sqlalchemy/parser/custom_type_parser.py b/src/snowflake/sqlalchemy/parser/custom_type_parser.py new file mode 100644 index 00000000..cf69c594 --- /dev/null +++ b/src/snowflake/sqlalchemy/parser/custom_type_parser.py @@ -0,0 +1,190 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import sqlalchemy.types as sqltypes +from sqlalchemy.sql.type_api import TypeEngine +from sqlalchemy.types import ( + BIGINT, + BINARY, + BOOLEAN, + CHAR, + DATE, + DATETIME, + DECIMAL, + FLOAT, + INTEGER, + REAL, + SMALLINT, + TIME, + TIMESTAMP, + VARCHAR, + NullType, +) + +from ..custom_types import ( + _CUSTOM_DECIMAL, + ARRAY, + DOUBLE, + GEOGRAPHY, + GEOMETRY, + MAP, + OBJECT, + TIMESTAMP_LTZ, + TIMESTAMP_NTZ, + TIMESTAMP_TZ, + VARIANT, +) + +ischema_names = { + "BIGINT": BIGINT, + "BINARY": BINARY, + # 'BIT': BIT, + "BOOLEAN": BOOLEAN, + "CHAR": CHAR, + "CHARACTER": CHAR, + "DATE": DATE, + "DATETIME": DATETIME, + "DEC": DECIMAL, + "DECIMAL": DECIMAL, + "DOUBLE": DOUBLE, + "FIXED": DECIMAL, + "FLOAT": FLOAT, # Snowflake FLOAT datatype doesn't has parameters + "INT": INTEGER, + "INTEGER": INTEGER, + "NUMBER": _CUSTOM_DECIMAL, + # 'OBJECT': ? + "REAL": REAL, + "BYTEINT": SMALLINT, + "SMALLINT": SMALLINT, + "STRING": VARCHAR, + "TEXT": VARCHAR, + "TIME": TIME, + "TIMESTAMP": TIMESTAMP, + "TIMESTAMP_TZ": TIMESTAMP_TZ, + "TIMESTAMP_LTZ": TIMESTAMP_LTZ, + "TIMESTAMP_NTZ": TIMESTAMP_NTZ, + "TINYINT": SMALLINT, + "VARBINARY": BINARY, + "VARCHAR": VARCHAR, + "VARIANT": VARIANT, + "MAP": MAP, + "OBJECT": OBJECT, + "ARRAY": ARRAY, + "GEOGRAPHY": GEOGRAPHY, + "GEOMETRY": GEOMETRY, +} + + +def extract_parameters(text: str) -> list: + """ + Extracts parameters from a comma-separated string, handling parentheses. + + :param text: A string with comma-separated parameters, which may include parentheses. + + :return: A list of parameters as strings. + + :example: + For input `"a, (b, c), d"`, the output is `['a', '(b, c)', 'd']`. + """ + + output_parameters = [] + parameter = "" + open_parenthesis = 0 + for c in text: + + if c == "(": + open_parenthesis += 1 + elif c == ")": + open_parenthesis -= 1 + + if open_parenthesis > 0 or c != ",": + parameter += c + elif c == ",": + output_parameters.append(parameter.strip(" ")) + parameter = "" + if parameter != "": + output_parameters.append(parameter.strip(" ")) + return output_parameters + + +def parse_type(type_text: str) -> TypeEngine: + """ + Parses a type definition string and returns the corresponding SQLAlchemy type. + + The function handles types with or without parameters, such as `VARCHAR(255)` or `INTEGER`. + + :param type_text: A string representing a SQLAlchemy type, which may include parameters + in parentheses (e.g., "VARCHAR(255)" or "DECIMAL(10, 2)"). + :return: An instance of the corresponding SQLAlchemy type class (e.g., `String`, `Integer`), + or `NullType` if the type is not recognized. + + :example: + parse_type("VARCHAR(255)") + String(length=255) + """ + index = type_text.find("(") + type_name = type_text[:index] if index != -1 else type_text + parameters = ( + extract_parameters(type_text[index + 1 : -1]) if type_name != type_text else [] + ) + + col_type_class = ischema_names.get(type_name, None) + col_type_kw = {} + if col_type_class is None: + col_type_class = NullType + else: + if issubclass(col_type_class, sqltypes.Numeric): + col_type_kw = __parse_numeric_type_parameters(parameters) + elif issubclass(col_type_class, (sqltypes.String, sqltypes.BINARY)): + col_type_kw = __parse_type_with_length_parameters(parameters) + elif issubclass(col_type_class, MAP): + col_type_kw = __parse_map_type_parameters(parameters) + if col_type_kw is None: + col_type_class = NullType + col_type_kw = {} + + return col_type_class(**col_type_kw) + + +def __parse_map_type_parameters(parameters): + if len(parameters) != 2: + return None + + key_type_str = parameters[0] + value_type_str = parameters[1] + not_null_str = "NOT NULL" + not_null = False + if ( + len(value_type_str) >= len(not_null_str) + and value_type_str[-len(not_null_str) :] == not_null_str + ): + not_null = True + value_type_str = value_type_str[: -len(not_null_str) - 1] + + key_type: TypeEngine = parse_type(key_type_str) + value_type: TypeEngine = parse_type(value_type_str) + if isinstance(key_type, NullType) or isinstance(value_type, NullType): + return None + + return { + "key_type": key_type, + "value_type": value_type, + "not_null": not_null, + } + + +def __parse_type_with_length_parameters(parameters): + return ( + {"length": int(parameters[0])} + if len(parameters) == 1 and str.isdigit(parameters[0]) + else {} + ) + + +def __parse_numeric_type_parameters(parameters): + result = {} + if len(parameters) >= 1 and str.isdigit(parameters[0]): + result["precision"] = int(parameters[0]) + if len(parameters) == 2 and str.isdigit(parameters[1]): + result["scale"] = int(parameters[1]) + return result diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index f2fb9b18..f9e2e4c8 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -3,6 +3,7 @@ # import operator +import re from collections import defaultdict from functools import reduce from typing import Any @@ -16,26 +17,7 @@ from sqlalchemy.schema import Table from sqlalchemy.sql import text from sqlalchemy.sql.elements import quoted_name -from sqlalchemy.types import ( - BIGINT, - BINARY, - BOOLEAN, - CHAR, - DATE, - DATETIME, - DECIMAL, - FLOAT, - INTEGER, - REAL, - SMALLINT, - TIME, - TIMESTAMP, - VARCHAR, - Date, - DateTime, - Float, - Time, -) +from sqlalchemy.types import FLOAT, Date, DateTime, Float, NullType, Time from snowflake.connector import errors as sf_errors from snowflake.connector.connection import DEFAULT_CONFIGURATION @@ -51,20 +33,13 @@ SnowflakeTypeCompiler, ) from .custom_types import ( - _CUSTOM_DECIMAL, - ARRAY, - GEOGRAPHY, - GEOMETRY, - OBJECT, - TIMESTAMP_LTZ, - TIMESTAMP_NTZ, - TIMESTAMP_TZ, - VARIANT, + MAP, _CUSTOM_Date, _CUSTOM_DateTime, _CUSTOM_Float, _CUSTOM_Time, ) +from .parser.custom_type_parser import ischema_names, parse_type from .sql.custom_schema.custom_table_prefix import CustomTablePrefix from .util import ( _update_connection_application_name, @@ -79,44 +54,6 @@ Float: _CUSTOM_Float, } -ischema_names = { - "BIGINT": BIGINT, - "BINARY": BINARY, - # 'BIT': BIT, - "BOOLEAN": BOOLEAN, - "CHAR": CHAR, - "CHARACTER": CHAR, - "DATE": DATE, - "DATETIME": DATETIME, - "DEC": DECIMAL, - "DECIMAL": DECIMAL, - "DOUBLE": FLOAT, - "FIXED": DECIMAL, - "FLOAT": FLOAT, - "INT": INTEGER, - "INTEGER": INTEGER, - "NUMBER": _CUSTOM_DECIMAL, - # 'OBJECT': ? - "REAL": REAL, - "BYTEINT": SMALLINT, - "SMALLINT": SMALLINT, - "STRING": VARCHAR, - "TEXT": VARCHAR, - "TIME": TIME, - "TIMESTAMP": TIMESTAMP, - "TIMESTAMP_TZ": TIMESTAMP_TZ, - "TIMESTAMP_LTZ": TIMESTAMP_LTZ, - "TIMESTAMP_NTZ": TIMESTAMP_NTZ, - "TINYINT": SMALLINT, - "VARBINARY": BINARY, - "VARCHAR": VARCHAR, - "VARIANT": VARIANT, - "OBJECT": OBJECT, - "ARRAY": ARRAY, - "GEOGRAPHY": GEOGRAPHY, - "GEOMETRY": GEOMETRY, -} - _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME = True @@ -333,8 +270,8 @@ def _denormalize_quote_join(self, *idents): @reflection.cache def _current_database_schema(self, connection, **kw): - res = connection.exec_driver_sql( - "select current_database(), current_schema();" + res = connection.execute( + text("select current_database(), current_schema();") ).fetchone() return ( self.normalize_name(res[0]), @@ -508,6 +445,12 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kw): ) return foreign_key_map.get(table_name, []) + def table_columns_as_dict(self, columns): + result = {} + for column in columns: + result[column["name"]] = column + return result + @reflection.cache def _get_schema_columns(self, connection, schema, **kw): """Get all columns in the schema, if we hit 'Information schema query returned too much data' problem return @@ -515,10 +458,12 @@ def _get_schema_columns(self, connection, schema, **kw): ans = {} current_database, _ = self._current_database_schema(connection, **kw) full_schema_name = self._denormalize_quote_join(current_database, schema) + full_columns_descriptions = {} try: schema_primary_keys = self._get_schema_primary_keys( connection, full_schema_name, **kw ) + schema_name = self.denormalize_name(schema) result = connection.execute( text( """ @@ -539,7 +484,7 @@ def _get_schema_columns(self, connection, schema, **kw): WHERE ic.table_schema=:table_schema ORDER BY ic.ordinal_position""" ), - {"table_schema": self.denormalize_name(schema)}, + {"table_schema": schema_name}, ) except sa_exc.ProgrammingError as pe: if pe.orig.errno == 90030: @@ -569,10 +514,7 @@ def _get_schema_columns(self, connection, schema, **kw): col_type = self.ischema_names.get(coltype, None) col_type_kw = {} if col_type is None: - sa_util.warn( - f"Did not recognize type '{coltype}' of column '{column_name}'" - ) - col_type = sqltypes.NULLTYPE + col_type = NullType else: if issubclass(col_type, FLOAT): col_type_kw["precision"] = numeric_precision @@ -582,6 +524,33 @@ def _get_schema_columns(self, connection, schema, **kw): col_type_kw["scale"] = numeric_scale elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)): col_type_kw["length"] = character_maximum_length + elif issubclass(col_type, MAP): + if (schema_name, table_name) not in full_columns_descriptions: + full_columns_descriptions[(schema_name, table_name)] = ( + self.table_columns_as_dict( + self._get_table_columns( + connection, table_name, schema_name + ) + ) + ) + + if ( + (schema_name, table_name) in full_columns_descriptions + and column_name + in full_columns_descriptions[(schema_name, table_name)] + ): + ans[table_name].append( + full_columns_descriptions[(schema_name, table_name)][ + column_name + ] + ) + continue + else: + col_type = NullType + if col_type == NullType: + sa_util.warn( + f"Did not recognize type '{coltype}' of column '{column_name}'" + ) type_instance = col_type(**col_type_kw) @@ -616,91 +585,71 @@ def _get_schema_columns(self, connection, schema, **kw): def _get_table_columns(self, connection, table_name, schema=None, **kw): """Get all columns in a table in a schema""" ans = [] - current_database, _ = self._current_database_schema(connection, **kw) - full_schema_name = self._denormalize_quote_join(current_database, schema) - schema_primary_keys = self._get_schema_primary_keys( - connection, full_schema_name, **kw + current_database, default_schema = self._current_database_schema( + connection, **kw ) + schema = schema if schema else default_schema + table_schema = self.denormalize_name(schema) + table_name = self.denormalize_name(table_name) result = connection.execute( text( - """ - SELECT /* sqlalchemy:get_table_columns */ - ic.table_name, - ic.column_name, - ic.data_type, - ic.character_maximum_length, - ic.numeric_precision, - ic.numeric_scale, - ic.is_nullable, - ic.column_default, - ic.is_identity, - ic.comment - FROM information_schema.columns ic - WHERE ic.table_schema=:table_schema - AND ic.table_name=:table_name - ORDER BY ic.ordinal_position""" - ), - { - "table_schema": self.denormalize_name(schema), - "table_name": self.denormalize_name(table_name), - }, + "DESC /* sqlalchemy:_get_schema_columns */" + f" TABLE {table_schema}.{table_name} TYPE = COLUMNS" + ) ) for ( - table_name, column_name, coltype, - character_maximum_length, - numeric_precision, - numeric_scale, + _kind, is_nullable, column_default, - is_identity, + primary_key, + _unique_key, + _check, + _expression, comment, + _policy_name, + _privacy_domain, + _name_mapping, ) in result: - table_name = self.normalize_name(table_name) + column_name = self.normalize_name(column_name) if column_name.startswith("sys_clustering_column"): continue # ignoring clustering column - col_type = self.ischema_names.get(coltype, None) - col_type_kw = {} - if col_type is None: + type_instance = parse_type(coltype) + if isinstance(type_instance, NullType): sa_util.warn( f"Did not recognize type '{coltype}' of column '{column_name}'" ) - col_type = sqltypes.NULLTYPE - else: - if issubclass(col_type, FLOAT): - col_type_kw["precision"] = numeric_precision - col_type_kw["decimal_return_scale"] = numeric_scale - elif issubclass(col_type, sqltypes.Numeric): - col_type_kw["precision"] = numeric_precision - col_type_kw["scale"] = numeric_scale - elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)): - col_type_kw["length"] = character_maximum_length - - type_instance = col_type(**col_type_kw) - current_table_pks = schema_primary_keys.get(table_name) + identity = None + match = re.match( + r"IDENTITY START (?P\d+) INCREMENT (?P\d+) (?PORDER|NOORDER)", + column_default if column_default else "", + ) + if match: + identity = { + "start": int(match.group("start")), + "increment": int(match.group("increment")), + "order_type": match.group("order_type"), + } + is_identity = identity is not None ans.append( { "name": column_name, "type": type_instance, - "nullable": is_nullable == "YES", - "default": column_default, - "autoincrement": is_identity == "YES", + "nullable": is_nullable == "Y", + "default": None if is_identity else column_default, + "autoincrement": is_identity, "comment": comment if comment != "" else None, - "primary_key": ( - ( - column_name - in schema_primary_keys[table_name]["constrained_columns"] - ) - if current_table_pks - else False - ), + "primary_key": primary_key == "Y", } ) + if is_identity: + ans[-1]["identity"] = identity + # If we didn't find any columns for the table, the table doesn't exist. if len(ans) == 0: raise sa_exc.NoSuchTableError() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py index 671c6957..6f7ee0c5 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py @@ -9,10 +9,12 @@ from ..._constants import DIALECT_NAME from ...compat import IS_VERSION_20 from ...custom_commands import NoneType +from ...custom_types import StructuredType from ...exc import ( MultipleErrors, NoPrimaryKeyError, RequiredParametersNotProvidedError, + StructuredTypeNotSupportedInTableColumnsError, UnsupportedPrimaryKeysAndForeignKeysError, ) from .custom_table_prefix import CustomTablePrefix @@ -25,6 +27,7 @@ class CustomTableBase(Table): _support_primary_and_foreign_keys: bool = True _enforce_primary_keys: bool = False _required_parameters: List[TableOptionKey] = [] + _support_structured_types: bool = False @property def table_prefixes(self) -> typing.List[str]: @@ -53,6 +56,10 @@ def __init__( def _validate_table(self): exceptions: List[Exception] = [] + columns_validation = self.__validate_columns() + if columns_validation is not None: + exceptions.append(columns_validation) + for _, option in self.dialect_options[DIALECT_NAME].items(): if isinstance(option, InvalidTableOption): exceptions.append(option.exception) @@ -84,6 +91,15 @@ def _validate_table(self): elif len(exceptions) == 1: raise exceptions[0] + def __validate_columns(self): + for column in self.columns: + if not self._support_structured_types and isinstance( + column.type, StructuredType + ): + return StructuredTypeNotSupportedInTableColumnsError( + self.__class__.__name__, self.name, column.name + ) + def _get_dialect_option( self, option_name: TableOptionKey ) -> typing.Optional[TableOption]: diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py index 5c9c53d9..4f62d4f2 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py @@ -44,6 +44,7 @@ class IcebergTable(TableFromQueryBase): """ __table_prefixes__ = [CustomTablePrefix.ICEBERG] + _support_structured_types = True @property def external_volume(self) -> typing.Optional[LiteralOption]: diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index d90f706b..b80a9096 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -3,4 +3,4 @@ # # Update this for the versions # Don't change the forth version number from None -VERSION = "1.6.1" +VERSION = "1.7.0" diff --git a/tests/__snapshots__/test_structured_datatypes.ambr b/tests/__snapshots__/test_structured_datatypes.ambr new file mode 100644 index 00000000..0325a946 --- /dev/null +++ b/tests/__snapshots__/test_structured_datatypes.ambr @@ -0,0 +1,90 @@ +# serializer version: 1 +# name: test_compile_table_with_cluster_by_with_expression + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL, VARCHAR), \tPRIMARY KEY ("Id"))' +# --- +# name: test_compile_table_with_double_map + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL, MAP(DECIMAL, VARCHAR)), \tPRIMARY KEY ("Id"))' +# --- +# name: test_insert_map + list([ + (1, '{\n "100": "item1",\n "200": "item2"\n}'), + ]) +# --- +# name: test_insert_map_orm + ''' + 002014 (22000): SQL compilation error: + Invalid expression [CAST(OBJECT_CONSTRUCT('100', 'item1', '200', 'item2') AS MAP(NUMBER(10,0), VARCHAR(16777216)))] in VALUES clause + ''' +# --- +# name: test_inspect_structured_data_types[structured_type0] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'map_id', + 'nullable': True, + 'primary_key': False, + 'type': MAP(_CUSTOM_DECIMAL(precision=10, scale=0), VARCHAR(length=16777216)), + }), + ]) +# --- +# name: test_inspect_structured_data_types[structured_type1] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'map_id', + 'nullable': True, + 'primary_key': False, + 'type': MAP(_CUSTOM_DECIMAL(precision=10, scale=0), MAP(_CUSTOM_DECIMAL(precision=10, scale=0), VARCHAR(length=16777216))), + }), + ]) +# --- +# name: test_reflect_structured_data_types[MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))] + "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tmap_id MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" +# --- +# name: test_reflect_structured_data_types[MAP(NUMBER(10, 0), VARCHAR)] + "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tmap_id MAP(DECIMAL(10, 0), VARCHAR(16777216)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" +# --- +# name: test_select_map_orm + list([ + (1, '{\n "100": "item1",\n "200": "item2"\n}'), + (2, '{\n "100": "item1",\n "200": "item2"\n}'), + ]) +# --- +# name: test_select_map_orm.1 + list([ + ]) +# --- +# name: test_select_map_orm.2 + list([ + ]) +# --- diff --git a/tests/__snapshots__/test_unit_structured_types.ambr b/tests/__snapshots__/test_unit_structured_types.ambr new file mode 100644 index 00000000..ff861351 --- /dev/null +++ b/tests/__snapshots__/test_unit_structured_types.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_compile_map_with_not_null + 'MAP(DECIMAL(10, 0), VARCHAR NOT NULL)' +# --- diff --git a/tests/conftest.py b/tests/conftest.py index d4dab3d1..a91521b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -94,6 +94,36 @@ def db_parameters(): yield get_db_parameters() +@pytest.fixture(scope="session") +def external_volume(): + db_parameters = get_db_parameters() + if "external_volume" in db_parameters: + yield db_parameters["external_volume"] + else: + raise ValueError("External_volume is not set") + + +@pytest.fixture(scope="session") +def external_stage(): + db_parameters = get_db_parameters() + if "external_stage" in db_parameters: + yield db_parameters["external_stage"] + else: + raise ValueError("External_stage is not set") + + +@pytest.fixture(scope="function") +def base_location(external_stage, engine_testaccount): + unique_id = str(uuid.uuid4()) + base_location = "L" + unique_id.replace("-", "_") + yield base_location + remove_base_location = f""" + REMOVE @{external_stage} pattern='.*{base_location}.*'; + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(remove_base_location) + + def get_db_parameters() -> dict: """ Sets the db connection parameters diff --git a/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr b/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr index 6ef09ff7..7e85841a 100644 --- a/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr +++ b/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr @@ -1,4 +1,26 @@ # serializer version: 1 +# name: test_inspect_snowflake_table + list([ + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=38, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'name', + 'nullable': True, + 'primary_key': False, + 'type': VARCHAR(length=16777216), + }), + ]) +# --- # name: test_simple_reflection_of_table_as_snowflake_table 'CREATE TABLE test_snowflake_table_reflection (\tid DECIMAL(38, 0) NOT NULL, \tname VARCHAR(16777216), \tCONSTRAINT demo_name PRIMARY KEY (id))' # --- diff --git a/tests/custom_tables/test_reflect_snowflake_table.py b/tests/custom_tables/test_reflect_snowflake_table.py index ef84622b..603b6187 100644 --- a/tests/custom_tables/test_reflect_snowflake_table.py +++ b/tests/custom_tables/test_reflect_snowflake_table.py @@ -1,10 +1,10 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from sqlalchemy import MetaData, Table +from sqlalchemy import MetaData, Table, inspect from sqlalchemy.sql.ddl import CreateTable -from src.snowflake.sqlalchemy import SnowflakeTable +from snowflake.sqlalchemy import SnowflakeTable def test_simple_reflection_of_table_as_sqlalchemy_table( @@ -67,3 +67,26 @@ def test_simple_reflection_of_table_as_snowflake_table( finally: metadata.drop_all(engine_testaccount) + + +def test_inspect_snowflake_table( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_snowflake_table_inspect" + + create_table_sql = f""" + CREATE TABLE {table_name} (id INT primary key, name VARCHAR); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + try: + with engine_testaccount.connect() as conn: + insp = inspect(conn) + table = insp.get_columns(table_name) + assert table == snapshot + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/test_core.py b/tests/test_core.py index 9342ad58..63f097db 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -30,6 +30,7 @@ UniqueConstraint, create_engine, dialects, + exc, insert, inspect, text, @@ -124,14 +125,26 @@ def test_connect_args(): Snowflake connect string supports account name as a replacement of host:port """ + server = "" + if "host" in CONNECTION_PARAMETERS and "port" in CONNECTION_PARAMETERS: + server = "{host}:{port}".format( + host=CONNECTION_PARAMETERS["host"], port=CONNECTION_PARAMETERS["port"] + ) + elif "account" in CONNECTION_PARAMETERS and "region" in CONNECTION_PARAMETERS: + server = "{account}.{region}".format( + account=CONNECTION_PARAMETERS["account"], + region=CONNECTION_PARAMETERS["region"], + ) + elif "account" in CONNECTION_PARAMETERS: + server = CONNECTION_PARAMETERS["account"] + engine = create_engine( - "snowflake://{user}:{password}@{host}:{port}/{database}/{schema}" + "snowflake://{user}:{password}@{server}/{database}/{schema}" "?account={account}&protocol={protocol}".format( user=CONNECTION_PARAMETERS["user"], account=CONNECTION_PARAMETERS["account"], password=CONNECTION_PARAMETERS["password"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], + server=server, database=CONNECTION_PARAMETERS["database"], schema=CONNECTION_PARAMETERS["schema"], protocol=CONNECTION_PARAMETERS["protocol"], @@ -142,32 +155,14 @@ def test_connect_args(): finally: engine.dispose() - engine = create_engine( - URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - ) - ) + engine = create_engine(URL(**CONNECTION_PARAMETERS)) try: verify_engine_connection(engine) finally: engine.dispose() - - engine = create_engine( - URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - warehouse="testwh", - ) - ) + parameters = {**CONNECTION_PARAMETERS} + parameters["warehouse"] = "testwh" + engine = create_engine(URL(**parameters)) try: verify_engine_connection(engine) finally: @@ -175,14 +170,10 @@ def test_connect_args(): def test_boolean_query_argument_parsing(): + engine = create_engine( URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], + **CONNECTION_PARAMETERS, validate_default_parameters=True, ) ) @@ -1549,15 +1540,8 @@ def test_too_many_columns_detection(engine_testaccount, db_parameters): connection = inspector.bind.connect() original_execute = connection.execute - too_many_columns_was_raised = False - - def mock_helper(command, *args, **kwargs): - if "_get_schema_columns" in command.text: - # Creating exception exactly how SQLAlchemy does - nonlocal too_many_columns_was_raised - too_many_columns_was_raised = True - raise DBAPIError.instance( - """ + exception_instance = DBAPIError.instance( + """ SELECT /* sqlalchemy:_get_schema_columns */ ic.table_name, ic.column_name, @@ -1572,27 +1556,32 @@ def mock_helper(command, *args, **kwargs): FROM information_schema.columns ic WHERE ic.table_schema='schema_name' ORDER BY ic.ordinal_position""", - {"table_schema": "TESTSCHEMA"}, - ProgrammingError( - "Information schema query returned too much data. Please repeat query with more " - "selective predicates.", - 90030, - ), - Error, - hide_parameters=False, - connection_invalidated=False, - dialect=SnowflakeDialect(), - ismulti=None, - ) + {"table_schema": "TESTSCHEMA"}, + ProgrammingError( + "Information schema query returned too much data. Please repeat query with more " + "selective predicates.", + 90030, + ), + Error, + hide_parameters=False, + connection_invalidated=False, + dialect=SnowflakeDialect(), + ismulti=None, + ) + + def mock_helper(command, *args, **kwargs): + if "_get_schema_columns" in command.text: + # Creating exception exactly how SQLAlchemy does + raise exception_instance else: return original_execute(command, *args, **kwargs) with patch.object(engine_testaccount, "connect") as conn: conn.return_value = connection with patch.object(connection, "execute", side_effect=mock_helper): - column_metadata = inspector.get_columns("users", db_parameters["schema"]) - assert len(column_metadata) == 4 - assert too_many_columns_was_raised + with pytest.raises(exc.ProgrammingError) as exception: + inspector.get_columns("users", db_parameters["schema"]) + assert exception.value.orig == exception_instance.orig # Clean up metadata.drop_all(engine_testaccount) @@ -1636,9 +1625,9 @@ def test_column_type_schema(engine_testaccount): table_reflected = Table(table_name, MetaData(), autoload_with=conn) columns = table_reflected.columns - assert ( - len(columns) == len(ischema_names_baseline) - 1 - ) # -1 because FIXED is not supported + assert len(columns) == ( + len(ischema_names_baseline) - 2 + ) # -2 because FIXED and MAP is not supported def test_result_type_and_value(engine_testaccount): @@ -1816,30 +1805,14 @@ def test_normalize_and_denormalize_empty_string_column_name(engine_testaccount): def test_snowflake_sqlalchemy_as_valid_client_type(): engine = create_engine( - URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - ), + URL(**CONNECTION_PARAMETERS), connect_args={"internal_application_name": "UnknownClient"}, ) with engine.connect() as conn: with pytest.raises(snowflake.connector.errors.NotSupportedError): conn.exec_driver_sql("select 1").cursor.fetch_pandas_all() - engine = create_engine( - URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - ) - ) + engine = create_engine(URL(**CONNECTION_PARAMETERS)) with engine.connect() as conn: conn.exec_driver_sql("select 1").cursor.fetch_pandas_all() @@ -1870,16 +1843,7 @@ def test_snowflake_sqlalchemy_as_valid_client_type(): "3.0.0", (type(None), str), ) - engine = create_engine( - URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - ) - ) + engine = create_engine(URL(**CONNECTION_PARAMETERS)) with engine.connect() as conn: conn.exec_driver_sql("select 1").cursor.fetch_pandas_all() assert ( diff --git a/tests/test_structured_datatypes.py b/tests/test_structured_datatypes.py new file mode 100644 index 00000000..4ea0892b --- /dev/null +++ b/tests/test_structured_datatypes.py @@ -0,0 +1,271 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import pytest +from sqlalchemy import ( + Column, + Integer, + MetaData, + Sequence, + Table, + cast, + exc, + inspect, + text, +) +from sqlalchemy.orm import Session, declarative_base +from sqlalchemy.sql import select +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import NUMBER, IcebergTable, SnowflakeTable +from snowflake.sqlalchemy.custom_types import MAP, TEXT +from snowflake.sqlalchemy.exc import StructuredTypeNotSupportedInTableColumnsError + + +def test_compile_table_with_cluster_by_with_expression(sql_compiler, snapshot): + metadata = MetaData() + user_table = Table( + "clustered_user", + metadata, + Column("Id", Integer, primary_key=True), + Column("name", MAP(NUMBER(), TEXT())), + ) + + create_table = CreateTable(user_table) + + assert sql_compiler(create_table) == snapshot + + +@pytest.mark.requires_external_volume +def test_create_table_structured_datatypes( + engine_testaccount, external_volume, base_location +): + metadata = MetaData() + table_name = "test_map0" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_id", MAP(NUMBER(10, 0), TEXT())), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + try: + assert test_map is not None + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_insert_map(engine_testaccount, external_volume, base_location, snapshot): + metadata = MetaData() + table_name = "test_insert_map" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_id", MAP(NUMBER(10, 0), TEXT())), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + slt = select( + 1, + cast( + text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT()) + ), + ) + ins = test_map.insert().from_select(["id", "map_id"], slt) + conn.execute(ins) + + results = conn.execute(test_map.select()) + data = results.fetchmany() + results.close() + snapshot.assert_match(data) + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +@pytest.mark.parametrize( + "structured_type", + [ + MAP(NUMBER(10, 0), TEXT()), + MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), TEXT())), + ], +) +def test_inspect_structured_data_types( + engine_testaccount, external_volume, base_location, snapshot, structured_type +): + metadata = MetaData() + table_name = "test_st_types" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_id", structured_type), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + try: + inspecter = inspect(engine_testaccount) + columns = inspecter.get_columns(table_name) + + assert isinstance(columns[0]["type"], NUMBER) + assert isinstance(columns[1]["type"], MAP) + assert columns == snapshot + + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +@pytest.mark.parametrize( + "structured_type", + [ + "MAP(NUMBER(10, 0), VARCHAR)", + "MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))", + ], +) +def test_reflect_structured_data_types( + engine_testaccount, + external_volume, + base_location, + snapshot, + structured_type, + sql_compiler, +): + metadata = MetaData() + table_name = "test_reflect_st_types" + create_table_sql = f""" +CREATE OR REPLACE ICEBERG TABLE {table_name} ( + id number(38,0) primary key, + map_id {structured_type}) +CATALOG = 'SNOWFLAKE' +EXTERNAL_VOLUME = '{external_volume}' +BASE_LOCATION = '{base_location}'; + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + iceberg_table = IcebergTable(table_name, metadata, autoload_with=engine_testaccount) + constraint = iceberg_table.constraints.pop() + constraint.name = "constraint_name" + iceberg_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(iceberg_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_insert_map_orm( + sql_compiler, external_volume, base_location, engine_testaccount, snapshot +): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __tablename__ = "test_iceberg_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return IcebergTable(name, metadata, *arg, **kw) + + __table_args__ = { + "external_volume": external_volume, + "base_location": base_location, + } + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + map_id = Column(MAP(NUMBER(10, 0), TEXT())) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + cast_expr = cast( + text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT()) + ) + instance = TestIcebergTableOrm(id=0, map_id=cast_expr) + session.add(instance) + with pytest.raises(exc.ProgrammingError) as programming_error: + session.commit() + # TODO: Support variant in insert statement + assert str(programming_error.value.orig) == snapshot + finally: + Base.metadata.drop_all(engine_testaccount) + + +def test_snowflake_tables_with_structured_types(sql_compiler): + metadata = MetaData() + with pytest.raises( + StructuredTypeNotSupportedInTableColumnsError + ) as programming_error: + SnowflakeTable( + "clustered_user", + metadata, + Column("Id", Integer, primary_key=True), + Column("name", MAP(NUMBER(10, 0), TEXT())), + ) + assert programming_error is not None + + +@pytest.mark.requires_external_volume +def test_select_map_orm(engine_testaccount, external_volume, base_location, snapshot): + metadata = MetaData() + table_name = "test_select_map_orm" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_id", MAP(NUMBER(10, 0), TEXT())), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + slt1 = select( + 2, + cast(text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT())), + ) + slt2 = select( + 1, + cast(text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT())), + ).union_all(slt1) + ins = test_map.insert().from_select(["id", "map_id"], slt2) + conn.execute(ins) + conn.commit() + + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __table__ = test_map + + def __repr__(self): + return f"({self.id!r}, {self.map_id!r})" + + try: + data = session.query(TestIcebergTableOrm).all() + snapshot.assert_match(data) + finally: + test_map.drop(engine_testaccount) diff --git a/tests/test_unit_structured_types.py b/tests/test_unit_structured_types.py new file mode 100644 index 00000000..c7bcd6ef --- /dev/null +++ b/tests/test_unit_structured_types.py @@ -0,0 +1,73 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest + +from snowflake.sqlalchemy import NUMBER +from snowflake.sqlalchemy.custom_types import MAP, TEXT +from src.snowflake.sqlalchemy.parser.custom_type_parser import ( + extract_parameters, + parse_type, +) + + +def test_compile_map_with_not_null(snapshot): + user_table = MAP(NUMBER(10, 0), TEXT(), not_null=True) + assert user_table.compile() == snapshot + + +def test_extract_parameters(): + example = "a, b(c, d, f), d" + assert extract_parameters(example) == ["a", "b(c, d, f)", "d"] + + +@pytest.mark.parametrize( + "input_type, expected_type", + [ + ("BIGINT", "BIGINT"), + ("BINARY(16)", "BINARY(16)"), + ("BOOLEAN", "BOOLEAN"), + ("CHAR(5)", "CHAR(5)"), + ("CHARACTER(5)", "CHAR(5)"), + ("DATE", "DATE"), + ("DATETIME(3)", "DATETIME"), + ("DECIMAL(10, 2)", "DECIMAL(10, 2)"), + ("DEC(10, 2)", "DECIMAL(10, 2)"), + ("DOUBLE", "FLOAT"), + ("FLOAT", "FLOAT"), + ("FIXED(10, 2)", "DECIMAL(10, 2)"), + ("INT", "INTEGER"), + ("INTEGER", "INTEGER"), + ("NUMBER(12, 4)", "DECIMAL(12, 4)"), + ("REAL", "REAL"), + ("BYTEINT", "SMALLINT"), + ("SMALLINT", "SMALLINT"), + ("STRING(255)", "VARCHAR(255)"), + ("TEXT(255)", "VARCHAR(255)"), + ("VARCHAR(255)", "VARCHAR(255)"), + ("TIME(6)", "TIME"), + ("TIMESTAMP(3)", "TIMESTAMP"), + ("TIMESTAMP_TZ(3)", "TIMESTAMP_TZ"), + ("TIMESTAMP_LTZ(3)", "TIMESTAMP_LTZ"), + ("TIMESTAMP_NTZ(3)", "TIMESTAMP_NTZ"), + ("TINYINT", "SMALLINT"), + ("VARBINARY(16)", "BINARY(16)"), + ("VARCHAR(255)", "VARCHAR(255)"), + ("VARIANT", "VARIANT"), + ( + "MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR NOT NULL))", + "MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR NOT NULL))", + ), + ( + "MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR))", + "MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR))", + ), + ("MAP(DECIMAL(10, 0), VARIANT)", "MAP(DECIMAL(10, 0), VARIANT)"), + ("OBJECT", "OBJECT"), + ("ARRAY", "ARRAY"), + ("GEOGRAPHY", "GEOGRAPHY"), + ("GEOMETRY", "GEOMETRY"), + ], +) +def test_snowflake_data_types(input_type, expected_type): + assert parse_type(input_type).compile() == expected_type diff --git a/tests/util.py b/tests/util.py index db0b0c9c..264478ff 100644 --- a/tests/util.py +++ b/tests/util.py @@ -29,6 +29,7 @@ ARRAY, GEOGRAPHY, GEOMETRY, + MAP, OBJECT, TIMESTAMP_LTZ, TIMESTAMP_NTZ, @@ -72,6 +73,7 @@ "ARRAY": ARRAY, "GEOGRAPHY": GEOGRAPHY, "GEOMETRY": GEOMETRY, + "MAP": MAP, } From 0d0e6864d9f9f5d32eaf74d2077a28cf907b6298 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Wed, 20 Nov 2024 14:45:21 -0600 Subject: [PATCH 43/74] Update release notes date (#547) * Update release notes date november 22 --- DESCRIPTION.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 33775996..e39984b7 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,7 +9,7 @@ Source code is also available at: # Release Notes -- v1.7.0(November 12, 2024) +- v1.7.0(November 22, 2024) - Add support for dynamic tables and required options - Add support for hybrid tables From 3f633e28a1bd37ffc862e0215116f3a87a684cdc Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Thu, 21 Nov 2024 10:10:32 -0600 Subject: [PATCH 44/74] Update CODEOWNERS (#540) Update CODEOWNERS --- .github/CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 836e0136..b2168af7 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -* @snowflakedb/snowcli +* @snowflakedb/ORM From 65754a4ab2524d9de2c8b9d56d1fb07f819248d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luis=20Calder=C3=B3n=20Ach=C3=ADo?= Date: Thu, 21 Nov 2024 16:06:50 -0600 Subject: [PATCH 45/74] SNOW-878116 Add support for PARTITION BY to COPY INTO location (#542) * add PARTITION BY option for CopyInto --------- Co-authored-by: azban --- DESCRIPTION.md | 3 + src/snowflake/sqlalchemy/base.py | 26 +++++++-- src/snowflake/sqlalchemy/custom_commands.py | 9 ++- tests/test_copy.py | 65 +++++++++++++++------ 4 files changed, 78 insertions(+), 25 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index e39984b7..82ddebc9 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,6 +9,9 @@ Source code is also available at: # Release Notes +- (Unreleased) + - Add support for partition by to copy into + - v1.7.0(November 22, 2024) - Add support for dynamic tables and required options diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index a1e16062..02e4f741 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -16,7 +16,8 @@ from sqlalchemy.schema import Sequence, Table from sqlalchemy.sql import compiler, expression, functions from sqlalchemy.sql.base import CompileState -from sqlalchemy.sql.elements import quoted_name +from sqlalchemy.sql.elements import BindParameter, quoted_name +from sqlalchemy.sql.expression import Executable from sqlalchemy.sql.selectable import Lateral, SelectState from snowflake.sqlalchemy._constants import DIALECT_NAME @@ -563,9 +564,8 @@ def visit_copy_into(self, copy_into, **kw): if isinstance(copy_into.into, Table) else copy_into.into._compiler_dispatch(self, **kw) ) - from_ = None if isinstance(copy_into.from_, Table): - from_ = copy_into.from_ + from_ = copy_into.from_.name # this is intended to catch AWSBucket and AzureContainer elif ( isinstance(copy_into.from_, AWSBucket) @@ -576,6 +576,21 @@ def visit_copy_into(self, copy_into, **kw): # everything else (selects, etc.) else: from_ = f"({copy_into.from_._compiler_dispatch(self, **kw)})" + + partition_by_value = None + if isinstance(copy_into.partition_by, (BindParameter, Executable)): + partition_by_value = copy_into.partition_by.compile( + compile_kwargs={"literal_binds": True} + ) + elif copy_into.partition_by is not None: + partition_by_value = copy_into.partition_by + + partition_by = ( + f"PARTITION BY {partition_by_value}" + if partition_by_value is not None and partition_by_value != "" + else "" + ) + credentials, encryption = "", "" if isinstance(into, tuple): into, credentials, encryption = into @@ -586,8 +601,7 @@ def visit_copy_into(self, copy_into, **kw): options_list.sort(key=operator.itemgetter(0)) options = ( ( - " " - + " ".join( + " ".join( [ "{} = {}".format( n, @@ -608,7 +622,7 @@ def visit_copy_into(self, copy_into, **kw): options += f" {credentials}" if encryption: options += f" {encryption}" - return f"COPY INTO {into} FROM {from_} {formatter}{options}" + return f"COPY INTO {into} FROM {' '.join([from_, partition_by, formatter, options])}" def visit_copy_formatter(self, formatter, **kw): options_list = list(formatter.options.items()) diff --git a/src/snowflake/sqlalchemy/custom_commands.py b/src/snowflake/sqlalchemy/custom_commands.py index 15585bd5..1b9260fe 100644 --- a/src/snowflake/sqlalchemy/custom_commands.py +++ b/src/snowflake/sqlalchemy/custom_commands.py @@ -115,18 +115,23 @@ class CopyInto(UpdateBase): __visit_name__ = "copy_into" _bind = None - def __init__(self, from_, into, formatter=None): + def __init__(self, from_, into, partition_by=None, formatter=None): self.from_ = from_ self.into = into self.formatter = formatter self.copy_options = {} + self.partition_by = partition_by def __repr__(self): """ repr for debugging / logging purposes only. For compilation logic, see the corresponding visitor in base.py """ - return f"COPY INTO {self.into} FROM {repr(self.from_)} {repr(self.formatter)} ({self.copy_options})" + val = f"COPY INTO {self.into} FROM {repr(self.from_)}" + if self.partition_by is not None: + val += f" PARTITION BY {self.partition_by}" + + return val + f" {repr(self.formatter)} ({self.copy_options})" def bind(self): return None diff --git a/tests/test_copy.py b/tests/test_copy.py index e0752d4f..8dfcf286 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -4,7 +4,7 @@ import pytest from sqlalchemy import Column, Integer, MetaData, Sequence, String, Table -from sqlalchemy.sql import select, text +from sqlalchemy.sql import functions, select, text from snowflake.sqlalchemy import ( AWSBucket, @@ -58,8 +58,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_1) - == "COPY INTO 's3://backup' FROM python_tests_foods FILE_FORMAT=(TYPE=csv " - "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" + == "COPY INTO 's3://backup' FROM python_tests_foods FILE_FORMAT=(TYPE=csv " + "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" "(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')" ) copy_stmt_2 = CopyIntoStorage( @@ -73,8 +73,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): sql_compiler(copy_stmt_2) == "COPY INTO 's3://backup' FROM (SELECT python_tests_foods.id, " "python_tests_foods.name, python_tests_foods.quantity FROM python_tests_foods " - "WHERE python_tests_foods.id = 1) FILE_FORMAT=(TYPE=json COMPRESSION='zstd' " - "FILE_EXTENSION='json') CREDENTIALS=(AWS_ROLE='some_iam_role') " + "WHERE python_tests_foods.id = 1) FILE_FORMAT=(TYPE=json COMPRESSION='zstd' " + "FILE_EXTENSION='json') CREDENTIALS=(AWS_ROLE='some_iam_role') " "ENCRYPTION=(TYPE='AWS_SSE_S3')" ) copy_stmt_3 = CopyIntoStorage( @@ -87,7 +87,7 @@ def test_copy_into_location(engine_testaccount, sql_compiler): assert ( sql_compiler(copy_stmt_3) == "COPY INTO 'azure://snowflake.blob.core.windows.net/snowpile/backup' " - "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " + "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " "CREDENTIALS=(AZURE_SAS_TOKEN='token')" ) @@ -95,7 +95,7 @@ def test_copy_into_location(engine_testaccount, sql_compiler): assert ( sql_compiler(copy_stmt_3) == "COPY INTO 'azure://snowflake.blob.core.windows.net/snowpile/backup' " - "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " + "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " "MAX_FILE_SIZE = 50000000 " "CREDENTIALS=(AZURE_SAS_TOKEN='token')" ) @@ -112,8 +112,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_4) - == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " - "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" + == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " + "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" "(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')" ) @@ -126,8 +126,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_5) - == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " - "FIELD_DELIMITER=',') ENCRYPTION=" + == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " + "FIELD_DELIMITER=',') ENCRYPTION=" "(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')" ) @@ -138,7 +138,7 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_6) - == "COPY INTO @stage_name FROM python_tests_foods FILE_FORMAT=(TYPE=csv)" + == "COPY INTO @stage_name FROM python_tests_foods FILE_FORMAT=(TYPE=csv) " ) copy_stmt_7 = CopyIntoStorage( @@ -148,7 +148,38 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_7) - == "COPY INTO @name.stage_name/prefix/file FROM python_tests_foods FILE_FORMAT=(TYPE=csv)" + == "COPY INTO @name.stage_name/prefix/file FROM python_tests_foods FILE_FORMAT=(TYPE=csv) " + ) + + copy_stmt_8 = CopyIntoStorage( + from_=food_items, + into=ExternalStage(name="stage_name"), + partition_by=text("('YEAR=' || year)"), + ) + assert ( + sql_compiler(copy_stmt_8) + == "COPY INTO @stage_name FROM python_tests_foods PARTITION BY ('YEAR=' || year) " + ) + + copy_stmt_9 = CopyIntoStorage( + from_=food_items, + into=ExternalStage(name="stage_name"), + partition_by=functions.concat( + text("'YEAR='"), text(food_items.columns["name"].name) + ), + ) + assert ( + sql_compiler(copy_stmt_9) + == "COPY INTO @stage_name FROM python_tests_foods PARTITION BY concat('YEAR=', name) " + ) + + copy_stmt_10 = CopyIntoStorage( + from_=food_items, + into=ExternalStage(name="stage_name"), + partition_by="", + ) + assert ( + sql_compiler(copy_stmt_10) == "COPY INTO @stage_name FROM python_tests_foods " ) # NOTE Other than expect known compiled text, submit it to RegressionTests environment and expect them to fail, but @@ -231,7 +262,7 @@ def test_copy_into_storage_csv_extended(sql_compiler): result = sql_compiler(copy_into) expected = ( r"COPY INTO TEST_IMPORT " - r"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata " + r"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata " r"FILE_FORMAT=(TYPE=csv COMPRESSION='auto' DATE_FORMAT='AUTO' " r"ERROR_ON_COLUMN_COUNT_MISMATCH=True ESCAPE=None " r"ESCAPE_UNENCLOSED_FIELD='\134' FIELD_DELIMITER=',' " @@ -288,7 +319,7 @@ def test_copy_into_storage_parquet_named_format(sql_compiler): expected = ( "COPY INTO TEST_IMPORT " "FROM (SELECT $1:COL1::number, $1:COL2::varchar " - "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet) " + "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet) " "FILE_FORMAT=(format_name = parquet_file_format) force = TRUE" ) assert result == expected @@ -350,7 +381,7 @@ def test_copy_into_storage_parquet_files(sql_compiler): "COPY INTO TEST_IMPORT " "FROM (SELECT $1:COL1::number, $1:COL2::varchar " "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet " - "(file_format => parquet_file_format)) FILES = ('foo.txt','bar.txt') " + "(file_format => parquet_file_format)) FILES = ('foo.txt','bar.txt') " "FORCE = true" ) assert result == expected @@ -412,6 +443,6 @@ def test_copy_into_storage_parquet_pattern(sql_compiler): "COPY INTO TEST_IMPORT " "FROM (SELECT $1:COL1::number, $1:COL2::varchar " "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet " - "(file_format => parquet_file_format)) FORCE = true PATTERN = '.*csv'" + "(file_format => parquet_file_format)) FORCE = true PATTERN = '.*csv'" ) assert result == expected From 9157932f02725f4b356fb0109875eac5511bdb14 Mon Sep 17 00:00:00 2001 From: David Szmolka <69192509+sfc-gh-dszmolka@users.noreply.github.com> Date: Fri, 22 Nov 2024 01:25:00 +0100 Subject: [PATCH 46/74] Amend README for urgent support (#544) --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index c6c13349..dac87fe8 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,11 @@ Snowflake SQLAlchemy runs on the top of the Snowflake Connector for Python as a [dialect](http://docs.sqlalchemy.org/en/latest/dialects/) to bridge a Snowflake database and SQLAlchemy applications. + +| :exclamation: | For production-affecting or urgent issues related to the connector, please [create a case with Snowflake Support](https://community.snowflake.com/s/article/How-To-Submit-a-Support-Case-in-Snowflake-Lodge). | +|---------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| + + ## Prerequisites ### Snowflake Connector for Python From 9b2c6d15c3da64990da4fd49c036d7094cd36e23 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Mon, 25 Nov 2024 07:13:35 -0600 Subject: [PATCH 47/74] Fix readme typos (#548) * Fix typo in README.md --------- Co-authored-by: Norman Rosner Co-authored-by: Anthony Holten --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index dac87fe8..2dbf6632 100644 --- a/README.md +++ b/README.md @@ -209,7 +209,7 @@ finally: # Best try: - with engine.connext() as connection: + with engine.connect() as connection: connection.execute(text()) # or connection.exec_driver_sql() @@ -230,7 +230,7 @@ t = Table('mytable', metadata, ### Object Name Case Handling -Snowflake stores all case-insensitive object names in uppercase text. In contrast, SQLAlchemy considers all lowercase object names to be case-insensitive. Snowflake SQLAlchemy converts the object name case during schema-level communication, i.e. during table and index reflection. If you use uppercase object names, SQLAlchemy assumes they are case-sensitive and encloses the names with quotes. This behavior will cause mismatches agaisnt data dictionary data received from Snowflake, so unless identifier names have been truly created as case sensitive using quotes, e.g., `"TestDb"`, all lowercase names should be used on the SQLAlchemy side. +Snowflake stores all case-insensitive object names in uppercase text. In contrast, SQLAlchemy considers all lowercase object names to be case-insensitive. Snowflake SQLAlchemy converts the object name case during schema-level communication, i.e. during table and index reflection. If you use uppercase object names, SQLAlchemy assumes they are case-sensitive and encloses the names with quotes. This behavior will cause mismatches against data dictionary data received from Snowflake, so unless identifier names have been truly created as case sensitive using quotes, e.g., `"TestDb"`, all lowercase names should be used on the SQLAlchemy side. ### Index Support From 140fec81e74d26c9e036147790f2ef451865875b Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Mon, 25 Nov 2024 09:47:16 -0600 Subject: [PATCH 48/74] Fix BOOLEAN not found in snowdialect (#551) * Fix import BOOLEAN error * Update imports * Add test for explicit imports --- DESCRIPTION.md | 1 + src/snowflake/sqlalchemy/snowdialect.py | 2 + tests/test_imports.py | 64 +++++++++++++++++++++++++ 3 files changed, 67 insertions(+) create mode 100644 tests/test_imports.py diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 82ddebc9..bbb33fb8 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -11,6 +11,7 @@ Source code is also available at: - (Unreleased) - Add support for partition by to copy into + - Fix BOOLEAN type not found in snowdialect - v1.7.0(November 22, 2024) diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index f9e2e4c8..e6baadf7 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -39,6 +39,8 @@ _CUSTOM_Float, _CUSTOM_Time, ) +from .parser.custom_type_parser import * # noqa +from .parser.custom_type_parser import _CUSTOM_DECIMAL # noqa from .parser.custom_type_parser import ischema_names, parse_type from .sql.custom_schema.custom_table_prefix import CustomTablePrefix from .util import ( diff --git a/tests/test_imports.py b/tests/test_imports.py new file mode 100644 index 00000000..0cfe5931 --- /dev/null +++ b/tests/test_imports.py @@ -0,0 +1,64 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import importlib +import inspect + +import pytest + + +def get_classes_from_module(module_name): + """Returns a set of class names from a given module.""" + try: + module = importlib.import_module(module_name) + members = inspect.getmembers(module) + return {name for name, obj in members if inspect.isclass(obj)} + + except ImportError: + print(f"Module '{module_name}' could not be imported.") + return set() + + +def test_types_in_snowdialect(): + classes_a = get_classes_from_module( + "snowflake.sqlalchemy.parser.custom_type_parser" + ) + classes_b = get_classes_from_module("snowflake.sqlalchemy.snowdialect") + assert classes_a.issubset(classes_b), str(classes_a - classes_b) + + +@pytest.mark.parametrize( + "type_class_name", + [ + "BIGINT", + "BINARY", + "BOOLEAN", + "CHAR", + "DATE", + "DATETIME", + "DECIMAL", + "FLOAT", + "INTEGER", + "REAL", + "SMALLINT", + "TIME", + "TIMESTAMP", + "VARCHAR", + "NullType", + "_CUSTOM_DECIMAL", + "ARRAY", + "DOUBLE", + "GEOGRAPHY", + "GEOMETRY", + "MAP", + "OBJECT", + "TIMESTAMP_LTZ", + "TIMESTAMP_NTZ", + "TIMESTAMP_TZ", + "VARIANT", + ], +) +def test_snowflake_data_types_instance(type_class_name): + classes_b = get_classes_from_module("snowflake.sqlalchemy.snowdialect") + assert type_class_name in classes_b, type_class_name From 49f91deacc3ba5c6e036a2300469bc6809b9a085 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Mon, 25 Nov 2024 10:32:12 -0600 Subject: [PATCH 49/74] Update README.md with Custom Tables documentation (#552) --- README.md | 145 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 144 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 2dbf6632..3985a9d6 100644 --- a/README.md +++ b/README.md @@ -234,7 +234,39 @@ Snowflake stores all case-insensitive object names in uppercase text. In contras ### Index Support -Snowflake does not utilize indexes, so neither does Snowflake SQLAlchemy. +Indexes are supported only for Hybrid Tables in Snowflake SQLAlchemy. For more details on limitations and use cases, refer to the [Create Index documentation](https://docs.snowflake.com/en/sql-reference/constraints-indexes.html). You can create an index using the following methods: + +#### Single Column Index + +You can create a single column index by setting the `index=True` parameter on the column or by explicitly defining an `Index` object. + +```python +hybrid_test_table_1 = HybridTable( + "table_name", + metadata, + Column("column1", Integer, primary_key=True), + Column("column2", String, index=True), + Index("index_1", "column1", "column2") +) + +metadata.create_all(engine_testaccount) +``` + +#### Multi-Column Index + +For multi-column indexes, you define the `Index` object specifying the columns that should be indexed. + +```python +hybrid_test_table_1 = HybridTable( + "table_name", + metadata, + Column("column1", Integer, primary_key=True), + Column("column2", String), + Index("index_1", "column1", "column2") +) + +metadata.create_all(engine_testaccount) +``` ### Numpy Data Type Support @@ -461,6 +493,117 @@ copy_into = CopyIntoStorage(from_=users, connection.execute(copy_into) ``` +### Iceberg Table with Snowflake Catalog support + +Snowflake SQLAlchemy supports Iceberg Tables with the Snowflake Catalog, along with various related parameters. For detailed information about Iceberg Tables, refer to the Snowflake [CREATE ICEBERG](https://docs.snowflake.com/en/sql-reference/sql/create-iceberg-table-snowflake) documentation. + +To create an Iceberg Table using Snowflake SQLAlchemy, you can define the table using the SQLAlchemy Core syntax as follows: + +```python +table = IcebergTable( + "myuser", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + external_volume=external_volume_name, + base_location="my_iceberg_table", + as_query="SELECT * FROM table" +) +``` + +Alternatively, you can define the table using a declarative approach: + +```python +class MyUser(Base): + __tablename__ = "myuser" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return IcebergTable(name, metadata, *arg, **kw) + + __table_args__ = { + "external_volume": "my_external_volume", + "base_location": "my_iceberg_table", + "as_query": "SELECT * FROM table", + } + + id = Column(Integer, primary_key=True) + name = Column(String) +``` + +### Hybrid Table support + +Snowflake SQLAlchemy supports Hybrid Tables with indexes. For detailed information, refer to the Snowflake [CREATE HYBRID TABLE](https://docs.snowflake.com/en/sql-reference/sql/create-hybrid-table) documentation. + +To create a Hybrid Table and add an index, you can use the SQLAlchemy Core syntax as follows: + +```python +table = HybridTable( + "myuser", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + Index("idx_name", "name") +) +``` + +Alternatively, you can define the table using the declarative approach: + +```python +class MyUser(Base): + __tablename__ = "myuser" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + __table_args__ = ( + Index("idx_name", "name"), + ) + + id = Column(Integer, primary_key=True) + name = Column(String) +``` + +### Dynamic Tables support + +Snowflake SQLAlchemy supports Dynamic Tables. For detailed information, refer to the Snowflake [CREATE DYNAMIC TABLE](https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table) documentation. + +To create a Dynamic Table, you can use the SQLAlchemy Core syntax as follows: + +```python +dynamic_test_table_1 = DynamicTable( + "dynamic_MyUser", + metadata, + Column("id", Integer), + Column("name", String), + target_lag=(1, TimeUnit.HOURS), # Additionally, you can use SnowflakeKeyword.DOWNSTREAM + warehouse='test_wh', + refresh_mode=SnowflakeKeyword.FULL, + as_query="SELECT id, name from MyUser;" +) +``` + +Alternatively, you can define a table without columns using the SQLAlchemy `select()` construct: + +```python +dynamic_test_table_1 = DynamicTable( + "dynamic_MyUser", + metadata, + target_lag=(1, TimeUnit.HOURS), + warehouse='test_wh', + refresh_mode=SnowflakeKeyword.FULL, + as_query=select(MyUser.id, MyUser.name) +) +``` + +### Notes + +- Defining a primary key in a Dynamic Table is not supported, meaning declarative tables don’t support Dynamic Tables. +- When using the `as_query` parameter with a string, you must explicitly define the columns. However, if you use the SQLAlchemy `select()` construct, you don’t need to explicitly define the columns. +- Direct data insertion into Dynamic Tables is not supported. + + ## Support Feel free to file an issue or submit a PR here for general cases. For official support, contact Snowflake support at: From 62bab2f82da8d7d7e0f846eafa57aee397973d8f Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Fri, 29 Nov 2024 14:34:00 -0600 Subject: [PATCH 50/74] Add release version (#553) --- DESCRIPTION.md | 4 ++-- src/snowflake/sqlalchemy/version.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index bbb33fb8..100b5c56 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,11 +9,11 @@ Source code is also available at: # Release Notes -- (Unreleased) +- v1.7.1(December 02, 2024) - Add support for partition by to copy into - Fix BOOLEAN type not found in snowdialect -- v1.7.0(November 22, 2024) +- v1.7.0(November 21, 2024) - Add support for dynamic tables and required options - Add support for hybrid tables diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index b80a9096..f942f2bd 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -3,4 +3,4 @@ # # Update this for the versions # Don't change the forth version number from None -VERSION = "1.7.0" +VERSION = "1.7.1" From 695c0a98f217f9b10d0abe9ed0bfc7dea1184290 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Thu, 5 Dec 2024 09:26:25 +0100 Subject: [PATCH 51/74] mraba/underscore_column_id: use `_` as column identifier (#538) --- DESCRIPTION.md | 1 + src/snowflake/sqlalchemy/base.py | 5 +++- tests/test_compiler.py | 17 ++++++++++++- tests/test_quote.py | 23 +++++++++++++++++ tests/test_quote_identifiers.py | 43 ++++++++++++++++++++++++++++++++ 5 files changed, 87 insertions(+), 2 deletions(-) create mode 100644 tests/test_quote_identifiers.py diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 100b5c56..2da23e45 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -15,6 +15,7 @@ Source code is also available at: - v1.7.0(November 21, 2024) + - Fixed quoting of `_` as column name - Add support for dynamic tables and required options - Add support for hybrid tables - Fixed SAWarning when registering functions with existing name in default namespace diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 02e4f741..4c632e7a 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -5,6 +5,7 @@ import itertools import operator import re +import string from typing import List from sqlalchemy import exc as sa_exc @@ -114,7 +115,8 @@ AUTOCOMMIT_REGEXP = re.compile( r"\s*(?:UPDATE|INSERT|DELETE|MERGE|COPY)", re.I | re.UNICODE ) - +# used for quoting identifiers ie. table names, column names, etc. +ILLEGAL_INITIAL_CHARACTERS = frozenset({d for d in string.digits}.union({"_", "$"})) """ Overwrite methods to handle Snowflake BCR change: @@ -439,6 +441,7 @@ def _join_left_to_right( class SnowflakeIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = {x.lower() for x in RESERVED_WORDS} + illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS def __init__(self, dialect, **kw): quote = '"' diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 40207b41..55451c2f 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -2,7 +2,7 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from sqlalchemy import Integer, String, and_, func, select +from sqlalchemy import Integer, String, and_, func, insert, select from sqlalchemy.schema import DropColumnComment, DropTableComment from sqlalchemy.sql import column, quoted_name, table from sqlalchemy.testing.assertions import AssertsCompiledSQL @@ -33,6 +33,21 @@ def test_now_func(self): dialect="snowflake", ) + def test_underscore_as_valid_identifier(self): + _table = table( + "table_1745924", + column("ca", Integer), + column("cb", String), + column("_", String), + ) + + stmt = insert(_table).values(ca=1, cb="test", _="test_") + self.assert_compile( + stmt, + 'INSERT INTO table_1745924 (ca, cb, "_") VALUES (%(ca)s, %(cb)s, %(_)s)', + dialect="snowflake", + ) + def test_multi_table_delete(self): statement = table1.delete().where(table1.c.id == table2.c.id) self.assert_compile( diff --git a/tests/test_quote.py b/tests/test_quote.py index ca6f36dd..0dd69059 100644 --- a/tests/test_quote.py +++ b/tests/test_quote.py @@ -38,3 +38,26 @@ def test_table_name_with_reserved_words(engine_testaccount, db_parameters): finally: insert_table.drop(engine_testaccount) return insert_table + + +def test_table_column_as_underscore(engine_testaccount): + metadata = MetaData() + test_table_name = "table_1745924" + insert_table = Table( + test_table_name, + metadata, + Column("ca", Integer), + Column("cb", String), + Column("_", String), + ) + metadata.create_all(engine_testaccount) + try: + inspector = inspect(engine_testaccount) + columns_in_insert = inspector.get_columns(test_table_name) + assert len(columns_in_insert) == 3 + assert columns_in_insert[0]["name"] == "ca" + assert columns_in_insert[1]["name"] == "cb" + assert columns_in_insert[2]["name"] == "_" + finally: + insert_table.drop(engine_testaccount) + return insert_table diff --git a/tests/test_quote_identifiers.py b/tests/test_quote_identifiers.py new file mode 100644 index 00000000..58575fad --- /dev/null +++ b/tests/test_quote_identifiers.py @@ -0,0 +1,43 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +import pytest +from sqlalchemy import Column, Integer, MetaData, String, Table, insert, select + + +@pytest.mark.parametrize( + "identifier", + ( + pytest.param("_", id="underscore"), + pytest.param(".", id="dot"), + ), +) +def test_insert_with_identifier_as_column_name(identifier: str, engine_testaccount): + expected_identifier = f"test: {identifier}" + metadata = MetaData() + table = Table( + "table_1745924", + metadata, + Column("ca", Integer), + Column("cb", String), + Column(identifier, String), + ) + + try: + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as connection: + connection.execute( + insert(table).values( + { + "ca": 1, + "cb": "test", + identifier: f"test: {identifier}", + } + ) + ) + result = connection.execute(select(table)).fetchall() + assert result == [(1, "test", expected_identifier)] + finally: + metadata.drop_all(engine_testaccount) From 716683fedfb708e3155d5027fff111faa94355e4 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Thu, 12 Dec 2024 09:23:54 -0600 Subject: [PATCH 52/74] Improve index reflection (#556) Improve index reflection --- DESCRIPTION.md | 7 +- .../sqlalchemy/parser/custom_type_parser.py | 16 ++ src/snowflake/sqlalchemy/snowdialect.py | 184 ++++++++---------- tests/conftest.py | 27 +++ tests/test_index_reflection.py | 42 +++- 5 files changed, 170 insertions(+), 106 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 2da23e45..bed7670b 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,13 +9,16 @@ Source code is also available at: # Release Notes +- (Unreleased) + - Fix quoting of `_` as column name + - Fix index columns was not being reflected + - Fix index reflection cache not working + - v1.7.1(December 02, 2024) - Add support for partition by to copy into - Fix BOOLEAN type not found in snowdialect - v1.7.0(November 21, 2024) - - - Fixed quoting of `_` as column name - Add support for dynamic tables and required options - Add support for hybrid tables - Fixed SAWarning when registering functions with existing name in default namespace diff --git a/src/snowflake/sqlalchemy/parser/custom_type_parser.py b/src/snowflake/sqlalchemy/parser/custom_type_parser.py index cf69c594..dada612d 100644 --- a/src/snowflake/sqlalchemy/parser/custom_type_parser.py +++ b/src/snowflake/sqlalchemy/parser/custom_type_parser.py @@ -1,5 +1,6 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +from typing import List import sqlalchemy.types as sqltypes from sqlalchemy.sql.type_api import TypeEngine @@ -107,6 +108,21 @@ def extract_parameters(text: str) -> list: return output_parameters +def parse_index_columns(columns: str) -> List[str]: + """ + Parses a string with a list of columns for an index. + + :param columns: A string with a list of columns for an index, which may include parentheses. + :param compiler: A SQLAlchemy compiler. + + :return: A list of columns as strings. + + :example: + For input `"[A, B, C]"`, the output is `['A', 'B', 'C']`. + """ + return [column.strip() for column in columns.strip("[]").split(",")] + + def parse_type(type_text: str) -> TypeEngine: """ Parses a type definition string and returns the corresponding SQLAlchemy type. diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index e6baadf7..935794d9 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -6,7 +6,7 @@ import re from collections import defaultdict from functools import reduce -from typing import Any +from typing import Any, Collection, Optional from urllib.parse import unquote_plus import sqlalchemy.types as sqltypes @@ -41,7 +41,7 @@ ) from .parser.custom_type_parser import * # noqa from .parser.custom_type_parser import _CUSTOM_DECIMAL # noqa -from .parser.custom_type_parser import ischema_names, parse_type +from .parser.custom_type_parser import ischema_names, parse_index_columns, parse_type from .sql.custom_schema.custom_table_prefix import CustomTablePrefix from .util import ( _update_connection_application_name, @@ -674,27 +674,43 @@ def get_columns(self, connection, table_name, schema=None, **kw): raise sa_exc.NoSuchTableError() return schema_columns[normalized_table_name] + def get_prefixes_from_data(self, name_to_index_map, row, **kw): + prefixes_found = [] + for valid_prefix in CustomTablePrefix: + key = f"is_{valid_prefix.name.lower()}" + if key in name_to_index_map and row[name_to_index_map[key]] == "Y": + prefixes_found.append(valid_prefix.name) + return prefixes_found + @reflection.cache - def get_table_names(self, connection, schema=None, **kw): + def _get_schema_tables_info(self, connection, schema=None, **kw): """ - Gets all table names. + Retrieves information about all tables in the specified schema. """ + schema = schema or self.default_schema_name - current_schema = schema - if schema: - cursor = connection.execute( - text( - f"SHOW /* sqlalchemy:get_table_names */ TABLES IN {self._denormalize_quote_join(schema)}" - ) - ) - else: - cursor = connection.execute( - text("SHOW /* sqlalchemy:get_table_names */ TABLES") + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_schema_tables_info */ TABLES IN SCHEMA {self._denormalize_quote_join(schema)}" ) - _, current_schema = self._current_database_schema(connection) + ) - ret = [self.normalize_name(row[1]) for row in cursor] + name_to_index_map = self._map_name_to_idx(result) + tables = {} + for row in result.cursor.fetchall(): + table_name = self.normalize_name(str(row[name_to_index_map["name"]])) + table_prefixes = self.get_prefixes_from_data(name_to_index_map, row) + tables[table_name] = {"prefixes": table_prefixes} + return tables + + def get_table_names(self, connection, schema=None, **kw): + """ + Gets all table names. + """ + ret = self._get_schema_tables_info( + connection, schema, info_cache=kw.get("info_cache", None) + ).keys() return ret @reflection.cache @@ -748,17 +764,12 @@ def get_view_definition(self, connection, view_name, schema=None, **kw): def get_temp_table_names(self, connection, schema=None, **kw): schema = schema or self.default_schema_name - if schema: - cursor = connection.execute( - text( - f"SHOW /* sqlalchemy:get_temp_table_names */ TABLES \ - IN {self._denormalize_quote_join(schema)}" - ) - ) - else: - cursor = connection.execute( - text("SHOW /* sqlalchemy:get_temp_table_names */ TABLES") + cursor = connection.execute( + text( + f"SHOW /* sqlalchemy:get_temp_table_names */ TABLES \ + IN SCHEMA {self._denormalize_quote_join(schema)}" ) + ) ret = [] n2i = self.__class__._map_name_to_idx(cursor) @@ -839,62 +850,79 @@ def get_table_comment(self, connection, table_name, schema=None, **kw): ) } - def get_multi_indexes( + def get_table_names_with_prefix( self, connection, *, schema, - filter_names, + prefix, + **kw, + ): + tables_data = self._get_schema_tables_info(connection, schema, **kw) + table_names = [] + for table_name, tables_data_value in tables_data.items(): + if prefix in tables_data_value["prefixes"]: + table_names.append(table_name) + return table_names + + def get_multi_indexes( + self, + connection, + *, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, **kw, ): """ Gets the indexes definition """ - - table_prefixes = self.get_multi_prefixes( - connection, schema, filter_prefix=CustomTablePrefix.HYBRID.name + schema = schema or self.default_schema_name + hybrid_table_names = self.get_table_names_with_prefix( + connection, + schema=schema, + prefix=CustomTablePrefix.HYBRID.name, + info_cache=kw.get("info_cache", None), ) - if len(table_prefixes) == 0: + if len(hybrid_table_names) == 0: return [] - schema = schema or self.default_schema_name - if not schema: - result = connection.execute( - text("SHOW /* sqlalchemy:get_multi_indexes */ INDEXES") - ) - else: - result = connection.execute( - text( - f"SHOW /* sqlalchemy:get_multi_indexes */ INDEXES IN SCHEMA {self._denormalize_quote_join(schema)}" - ) + + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_indexes */ INDEXES IN SCHEMA {self._denormalize_quote_join(schema)}" ) + ) - n2i = self.__class__._map_name_to_idx(result) + n2i = self._map_name_to_idx(result) indexes = {} for row in result.cursor.fetchall(): - table = self.normalize_name(str(row[n2i["table"]])) + table_name = self.normalize_name(str(row[n2i["table"]])) if ( row[n2i["name"]] == f'SYS_INDEX_{row[n2i["table"]]}_PRIMARY' - or table not in filter_names - or (schema, table) not in table_prefixes - or ( - (schema, table) in table_prefixes - and CustomTablePrefix.HYBRID.name - not in table_prefixes[(schema, table)] - ) + or table_name not in filter_names + or table_name not in hybrid_table_names ): continue index = { "name": row[n2i["name"]], "unique": row[n2i["is_unique"]] == "Y", - "column_names": row[n2i["columns"]], - "include_columns": row[n2i["included_columns"]], + "column_names": [ + self.normalize_name(column) + for column in parse_index_columns(row[n2i["columns"]]) + ], + "include_columns": [ + self.normalize_name(column) + for column in parse_index_columns(row[n2i["included_columns"]]) + ], "dialect_options": {}, } - if (schema, table) in indexes: - indexes[(schema, table)] = indexes[(schema, table)].append(index) + + if (schema, table_name) in indexes: + indexes[(schema, table_name)] = indexes[(schema, table_name)].append( + index + ) else: - indexes[(schema, table)] = [index] + indexes[(schema, table_name)] = [index] return list(indexes.items()) @@ -906,50 +934,6 @@ def _value_or_default(self, data, table, schema): else: return [] - def get_prefixes_from_data(self, n2i, row, **kw): - prefixes_found = [] - for valid_prefix in CustomTablePrefix: - key = f"is_{valid_prefix.name.lower()}" - if key in n2i and row[n2i[key]] == "Y": - prefixes_found.append(valid_prefix.name) - return prefixes_found - - @reflection.cache - def get_multi_prefixes( - self, connection, schema, table_name=None, filter_prefix=None, **kw - ): - """ - Gets all table prefixes - """ - schema = schema or self.default_schema_name - filter = f"LIKE '{table_name}'" if table_name else "" - if schema: - result = connection.execute( - text( - f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES IN SCHEMA {schema}" - ) - ) - else: - result = connection.execute( - text( - f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES LIKE '{table_name}'" - ) - ) - - n2i = self.__class__._map_name_to_idx(result) - tables_prefixes = {} - for row in result.cursor.fetchall(): - table = self.normalize_name(str(row[n2i["name"]])) - table_prefixes = self.get_prefixes_from_data(n2i, row) - if filter_prefix and filter_prefix not in table_prefixes: - continue - if (schema, table) in tables_prefixes: - tables_prefixes[(schema, table)].append(table_prefixes) - else: - tables_prefixes[(schema, table)] = table_prefixes - - return tables_prefixes - @reflection.cache def get_indexes(self, connection, tablename, schema, **kw): """ diff --git a/tests/conftest.py b/tests/conftest.py index a91521b9..f2045121 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ # from __future__ import annotations +import logging.handlers import os import sys import time @@ -194,6 +195,32 @@ def engine_testaccount(request): yield engine +@pytest.fixture() +def assert_text_in_buf(): + buf = logging.handlers.BufferingHandler(100) + for log in [ + logging.getLogger("sqlalchemy.engine"), + ]: + log.addHandler(buf) + + def go(expected, occurrences=1): + assert buf.buffer + buflines = [rec.getMessage() for rec in buf.buffer] + + ocurrences_found = buflines.count(expected) + assert occurrences == ocurrences_found, ( + f"Expected {occurrences} of {expected}, got {ocurrences_found} " + f"occurrences in {buflines}." + ) + buf.flush() + + yield go + for log in [ + logging.getLogger("sqlalchemy.engine"), + ]: + log.removeHandler(buf) + + @pytest.fixture() def engine_testaccount_with_numpy(request): url = url_factory(numpy=True) diff --git a/tests/test_index_reflection.py b/tests/test_index_reflection.py index 09f5cfe7..a808703b 100644 --- a/tests/test_index_reflection.py +++ b/tests/test_index_reflection.py @@ -2,8 +2,8 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import pytest -from sqlalchemy import MetaData -from sqlalchemy.engine import reflection +from sqlalchemy import MetaData, inspect +from sqlalchemy.sql.ddl import CreateSchema, DropSchema @pytest.mark.aws @@ -13,15 +13,21 @@ def test_indexes_reflection(engine_testaccount, db_parameters, sql_compiler): table_name = "test_hybrid_table_2" index_name = "INDEX_NAME_2" schema = db_parameters["schema"] + index_columns = ["name", "name2"] create_table_sql = f""" - CREATE HYBRID TABLE {table_name} (id INT primary key, name VARCHAR, INDEX {index_name} (name)); + CREATE HYBRID TABLE {table_name} ( + id INT primary key, + name VARCHAR, + name2 VARCHAR, + INDEX {index_name} ({', '.join(index_columns)}) + ); """ with engine_testaccount.connect() as connection: connection.exec_driver_sql(create_table_sql) - insp = reflection.Inspector.from_engine(engine_testaccount) + insp = inspect(engine_testaccount) try: with engine_testaccount.connect(): @@ -29,6 +35,34 @@ def test_indexes_reflection(engine_testaccount, db_parameters, sql_compiler): indexes = insp.get_indexes(table_name, schema) assert len(indexes) == 1 assert indexes[0].get("name") == index_name + assert indexes[0].get("column_names") == index_columns finally: metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_simple_reflection_hybrid_table_as_table( + engine_testaccount, assert_text_in_buf, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_simple_reflection_hybrid_table_as_table" + schema = db_parameters["schema"] + "_reflections" + with engine_testaccount.connect() as connection: + try: + connection.execute(CreateSchema(schema)) + + create_table_sql = f""" + CREATE HYBRID TABLE {schema}.{table_name} (id INT primary key, new_column VARCHAR, INDEX index_name (new_column)); + """ + connection.exec_driver_sql(create_table_sql) + + metadata.reflect(engine_testaccount, schema=schema) + + assert_text_in_buf( + f"SHOW /* sqlalchemy:get_schema_tables_info */ TABLES IN SCHEMA {schema}", + occurrences=1, + ) + + finally: + connection.execute(DropSchema(schema, cascade=True)) From 6c43ada2fd7a9f0429472daeca366fa911a32679 Mon Sep 17 00:00:00 2001 From: David Szmolka <69192509+sfc-gh-dszmolka@users.noreply.github.com> Date: Fri, 13 Dec 2024 13:12:17 +0100 Subject: [PATCH 53/74] NO-SNOW change jira assignee, labels, priority (#558) --- .github/workflows/jira_issue.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/jira_issue.yml b/.github/workflows/jira_issue.yml index 85c774ca..80a31d31 100644 --- a/.github/workflows/jira_issue.yml +++ b/.github/workflows/jira_issue.yml @@ -38,7 +38,7 @@ jobs: summary: '${{ github.event.issue.title }}' description: | ${{ github.event.issue.body }} \\ \\ _Created from GitHub Action_ for ${{ github.event.issue.html_url }} - fields: '{"customfield_11401":{"id":"14586"},"assignee":{"id":"61027a237ab143006ecfb9a2"},"components":[{"id":"16161"},{"id":"16403"}]}' + fields: '{"customfield_11401":{"id":"14723"},"assignee":{"id":"712020:e1f41916-da57-4fe8-b317-116d5229aa51"},"components":[{"id":"16161"},{"id":"16403"}], "labels": ["oss"], "priority": {"id": "10001"} }' - name: Update GitHub Issue uses: ./jira/gajira-issue-update From af5457afd267052803a905c7f57add5d3fbf828f Mon Sep 17 00:00:00 2001 From: Gabriel Venegas Castro Date: Mon, 16 Dec 2024 08:51:09 -0600 Subject: [PATCH 54/74] SNOW-1776332 Add support for OBJECT (#559) * SNOW-1776332 Add support for OBJECT * Updated description.md * Add missing @pytest.mark.requires_external_volume * Tuple validation in OBJECT class --- DESCRIPTION.md | 1 + pyproject.toml | 3 + src/snowflake/sqlalchemy/base.py | 13 +- src/snowflake/sqlalchemy/custom_types.py | 22 +- .../sqlalchemy/parser/custom_type_parser.py | 37 +- src/snowflake/sqlalchemy/snowdialect.py | 21 +- .../test_structured_datatypes.ambr | 111 ++++- tests/test_structured_datatypes.py | 392 +++++++++++++----- tests/test_unit_structured_types.py | 8 +- 9 files changed, 489 insertions(+), 119 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index bed7670b..d5872bb9 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -13,6 +13,7 @@ Source code is also available at: - Fix quoting of `_` as column name - Fix index columns was not being reflected - Fix index reflection cache not working + - Add support for structured OBJECT datatype - v1.7.1(December 02, 2024) - Add support for partition by to copy into diff --git a/pyproject.toml b/pyproject.toml index 84e64faf..b0ae04c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ path = "src/snowflake/sqlalchemy/version.py" development = [ "pre-commit", "pytest", + "setuptools", "pytest-cov", "pytest-timeout", "pytest-rerunfailures", @@ -74,6 +75,8 @@ exclude = ["/.github"] packages = ["src/snowflake"] [tool.hatch.envs.default] +path = ".venv" +type = "virtual" extra-dependencies = ["SQLAlchemy>=1.4.19,<2.1.0"] features = ["development", "pandas"] python = "3.8" diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 4c632e7a..3fef7709 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -1099,7 +1099,18 @@ def visit_ARRAY(self, type_, **kw): return "ARRAY" def visit_OBJECT(self, type_, **kw): - return "OBJECT" + if type_.is_semi_structured: + return "OBJECT" + else: + contents = [] + for key in type_.items_types: + + row_text = f"{key} {type_.items_types[key][0].compile()}" + # Type and not null is specified + if len(type_.items_types[key]) > 1: + row_text += f"{' NOT NULL' if type_.items_types[key][1] else ''}" + contents.append(row_text) + return "OBJECT" if contents == [] else f"OBJECT({', '.join(contents)})" def visit_BLOB(self, type_, **kw): return "BINARY" diff --git a/src/snowflake/sqlalchemy/custom_types.py b/src/snowflake/sqlalchemy/custom_types.py index f2c950dd..ce7ad592 100644 --- a/src/snowflake/sqlalchemy/custom_types.py +++ b/src/snowflake/sqlalchemy/custom_types.py @@ -1,9 +1,11 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +from typing import Tuple, Union import sqlalchemy.types as sqltypes import sqlalchemy.util as util +from sqlalchemy.types import TypeEngine TEXT = sqltypes.VARCHAR CHARACTER = sqltypes.CHAR @@ -57,9 +59,27 @@ def __init__( super().__init__() -class OBJECT(SnowflakeType): +class OBJECT(StructuredType): __visit_name__ = "OBJECT" + def __init__(self, **items_types: Union[TypeEngine, Tuple[TypeEngine, bool]]): + for key, value in items_types.items(): + if not isinstance(value, tuple): + items_types[key] = (value, False) + + self.items_types = items_types + self.is_semi_structured = len(items_types) == 0 + super().__init__() + + def __repr__(self): + quote_char = "'" + return "OBJECT(%s)" % ", ".join( + [ + f"{repr(key).strip(quote_char)}={repr(value)}" + for key, value in self.items_types.items() + ] + ) + class ARRAY(SnowflakeType): __visit_name__ = "ARRAY" diff --git a/src/snowflake/sqlalchemy/parser/custom_type_parser.py b/src/snowflake/sqlalchemy/parser/custom_type_parser.py index dada612d..1e99ba56 100644 --- a/src/snowflake/sqlalchemy/parser/custom_type_parser.py +++ b/src/snowflake/sqlalchemy/parser/custom_type_parser.py @@ -49,11 +49,10 @@ "DECIMAL": DECIMAL, "DOUBLE": DOUBLE, "FIXED": DECIMAL, - "FLOAT": FLOAT, # Snowflake FLOAT datatype doesn't has parameters + "FLOAT": FLOAT, # Snowflake FLOAT datatype doesn't have parameters "INT": INTEGER, "INTEGER": INTEGER, "NUMBER": _CUSTOM_DECIMAL, - # 'OBJECT': ? "REAL": REAL, "BYTEINT": SMALLINT, "SMALLINT": SMALLINT, @@ -76,18 +75,19 @@ } -def extract_parameters(text: str) -> list: +def tokenize_parameters(text: str, character_for_strip=",") -> list: """ Extracts parameters from a comma-separated string, handling parentheses. :param text: A string with comma-separated parameters, which may include parentheses. + :param character_for_strip: A character to strip the text. + :return: A list of parameters as strings. :example: For input `"a, (b, c), d"`, the output is `['a', '(b, c)', 'd']`. """ - output_parameters = [] parameter = "" open_parenthesis = 0 @@ -98,9 +98,9 @@ def extract_parameters(text: str) -> list: elif c == ")": open_parenthesis -= 1 - if open_parenthesis > 0 or c != ",": + if open_parenthesis > 0 or c != character_for_strip: parameter += c - elif c == ",": + elif c == character_for_strip: output_parameters.append(parameter.strip(" ")) parameter = "" if parameter != "": @@ -138,14 +138,17 @@ def parse_type(type_text: str) -> TypeEngine: parse_type("VARCHAR(255)") String(length=255) """ + index = type_text.find("(") type_name = type_text[:index] if index != -1 else type_text + parameters = ( - extract_parameters(type_text[index + 1 : -1]) if type_name != type_text else [] + tokenize_parameters(type_text[index + 1 : -1]) if type_name != type_text else [] ) col_type_class = ischema_names.get(type_name, None) col_type_kw = {} + if col_type_class is None: col_type_class = NullType else: @@ -155,6 +158,8 @@ def parse_type(type_text: str) -> TypeEngine: col_type_kw = __parse_type_with_length_parameters(parameters) elif issubclass(col_type_class, MAP): col_type_kw = __parse_map_type_parameters(parameters) + elif issubclass(col_type_class, OBJECT): + col_type_kw = __parse_object_type_parameters(parameters) if col_type_kw is None: col_type_class = NullType col_type_kw = {} @@ -162,6 +167,24 @@ def parse_type(type_text: str) -> TypeEngine: return col_type_class(**col_type_kw) +def __parse_object_type_parameters(parameters): + object_rows = {} + for parameter in parameters: + parameter_parts = tokenize_parameters(parameter, " ") + if len(parameter_parts) >= 2: + key = parameter_parts[0] + value_type = parse_type(parameter_parts[1]) + if isinstance(value_type, NullType): + return None + not_null = ( + len(parameter_parts) == 4 + and parameter_parts[2] == "NOT" + and parameter_parts[3] == "NULL" + ) + object_rows[key] = (value_type, not_null) + return object_rows + + def __parse_map_type_parameters(parameters): if len(parameters) != 2: return None diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 935794d9..dd5e4375 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -1,7 +1,6 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # - import operator import re from collections import defaultdict @@ -9,7 +8,7 @@ from typing import Any, Collection, Optional from urllib.parse import unquote_plus -import sqlalchemy.types as sqltypes +import sqlalchemy.sql.sqltypes as sqltypes from sqlalchemy import event as sa_vnt from sqlalchemy import exc as sa_exc from sqlalchemy import util as sa_util @@ -17,7 +16,8 @@ from sqlalchemy.schema import Table from sqlalchemy.sql import text from sqlalchemy.sql.elements import quoted_name -from sqlalchemy.types import FLOAT, Date, DateTime, Float, NullType, Time +from sqlalchemy.sql.sqltypes import NullType +from sqlalchemy.types import FLOAT, Date, DateTime, Float, Time from snowflake.connector import errors as sf_errors from snowflake.connector.connection import DEFAULT_CONFIGURATION @@ -33,7 +33,7 @@ SnowflakeTypeCompiler, ) from .custom_types import ( - MAP, + StructuredType, _CUSTOM_Date, _CUSTOM_DateTime, _CUSTOM_Float, @@ -466,6 +466,14 @@ def _get_schema_columns(self, connection, schema, **kw): connection, full_schema_name, **kw ) schema_name = self.denormalize_name(schema) + + iceberg_table_names = self.get_table_names_with_prefix( + connection, + schema=schema_name, + prefix=CustomTablePrefix.ICEBERG.name, + info_cache=kw.get("info_cache", None), + ) + result = connection.execute( text( """ @@ -526,7 +534,10 @@ def _get_schema_columns(self, connection, schema, **kw): col_type_kw["scale"] = numeric_scale elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)): col_type_kw["length"] = character_maximum_length - elif issubclass(col_type, MAP): + elif ( + issubclass(col_type, StructuredType) + and table_name in iceberg_table_names + ): if (schema_name, table_name) not in full_columns_descriptions: full_columns_descriptions[(schema_name, table_name)] = ( self.table_columns_as_dict( diff --git a/tests/__snapshots__/test_structured_datatypes.ambr b/tests/__snapshots__/test_structured_datatypes.ambr index 0325a946..714f5d57 100644 --- a/tests/__snapshots__/test_structured_datatypes.ambr +++ b/tests/__snapshots__/test_structured_datatypes.ambr @@ -5,6 +5,15 @@ # name: test_compile_table_with_double_map 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL, MAP(DECIMAL, VARCHAR)), \tPRIMARY KEY ("Id"))' # --- +# name: test_compile_table_with_structured_data_type[structured_type0] + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tPRIMARY KEY ("Id"))' +# --- +# name: test_compile_table_with_structured_data_type[structured_type1] + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tPRIMARY KEY ("Id"))' +# --- +# name: test_compile_table_with_structured_data_type[structured_type2] + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tPRIMARY KEY ("Id"))' +# --- # name: test_insert_map list([ (1, '{\n "100": "item1",\n "200": "item2"\n}'), @@ -16,6 +25,43 @@ Invalid expression [CAST(OBJECT_CONSTRUCT('100', 'item1', '200', 'item2') AS MAP(NUMBER(10,0), VARCHAR(16777216)))] in VALUES clause ''' # --- +# name: test_insert_structured_object + list([ + (1, '{\n "key1": "item1",\n "key2": 15\n}'), + ]) +# --- +# name: test_insert_structured_object_orm + ''' + 002014 (22000): SQL compilation error: + Invalid expression [CAST(OBJECT_CONSTRUCT('key1', 1, 'key2', 'item1') AS OBJECT(key1 NUMBER(10,0), key2 VARCHAR(16777216)))] in VALUES clause + ''' +# --- +# name: test_inspect_structured_data_types[structured_type0-MAP] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'structured_type_col', + 'nullable': True, + 'primary_key': False, + 'type': MAP(_CUSTOM_DECIMAL(precision=10, scale=0), VARCHAR(length=16777216)), + }), + ]) +# --- # name: test_inspect_structured_data_types[structured_type0] list([ dict({ @@ -42,6 +88,32 @@ }), ]) # --- +# name: test_inspect_structured_data_types[structured_type1-MAP] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'structured_type_col', + 'nullable': True, + 'primary_key': False, + 'type': MAP(_CUSTOM_DECIMAL(precision=10, scale=0), MAP(_CUSTOM_DECIMAL(precision=10, scale=0), VARCHAR(length=16777216))), + }), + ]) +# --- # name: test_inspect_structured_data_types[structured_type1] list([ dict({ @@ -68,11 +140,40 @@ }), ]) # --- +# name: test_inspect_structured_data_types[structured_type2-OBJECT] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'structured_type_col', + 'nullable': True, + 'primary_key': False, + 'type': OBJECT(key1=(VARCHAR(length=16777216), False), key2=(_CUSTOM_DECIMAL(precision=10, scale=0), False)), + }), + ]) +# --- # name: test_reflect_structured_data_types[MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))] - "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tmap_id MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" + "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" # --- # name: test_reflect_structured_data_types[MAP(NUMBER(10, 0), VARCHAR)] - "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tmap_id MAP(DECIMAL(10, 0), VARCHAR(16777216)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" + "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col MAP(DECIMAL(10, 0), VARCHAR(16777216)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" +# --- +# name: test_reflect_structured_data_types[OBJECT(key1 VARCHAR, key2 NUMBER(10, 0))] + "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" # --- # name: test_select_map_orm list([ @@ -88,3 +189,9 @@ list([ ]) # --- +# name: test_select_structured_object_orm + list([ + (1, '{\n "key1": "value2",\n "key2": 2\n}'), + (2, '{\n "key1": "value1",\n "key2": 1\n}'), + ]) +# --- diff --git a/tests/test_structured_datatypes.py b/tests/test_structured_datatypes.py index 4ea0892b..d6beb3e9 100644 --- a/tests/test_structured_datatypes.py +++ b/tests/test_structured_datatypes.py @@ -1,7 +1,6 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # - import pytest from sqlalchemy import ( Column, @@ -19,17 +18,27 @@ from sqlalchemy.sql.ddl import CreateTable from snowflake.sqlalchemy import NUMBER, IcebergTable, SnowflakeTable -from snowflake.sqlalchemy.custom_types import MAP, TEXT +from snowflake.sqlalchemy.custom_types import MAP, OBJECT, TEXT from snowflake.sqlalchemy.exc import StructuredTypeNotSupportedInTableColumnsError -def test_compile_table_with_cluster_by_with_expression(sql_compiler, snapshot): +@pytest.mark.parametrize( + "structured_type", + [ + MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), TEXT(16777216))), + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + OBJECT(key1=TEXT(16777216), key2=NUMBER(10, 0)), + ], +) +def test_compile_table_with_structured_data_type( + sql_compiler, snapshot, structured_type +): metadata = MetaData() user_table = Table( "clustered_user", metadata, Column("Id", Integer, primary_key=True), - Column("name", MAP(NUMBER(), TEXT())), + Column("name", structured_type), ) create_table = CreateTable(user_table) @@ -38,35 +47,152 @@ def test_compile_table_with_cluster_by_with_expression(sql_compiler, snapshot): @pytest.mark.requires_external_volume -def test_create_table_structured_datatypes( - engine_testaccount, external_volume, base_location +def test_insert_map(engine_testaccount, external_volume, base_location, snapshot): + metadata = MetaData() + table_name = "test_insert_map" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_id", MAP(NUMBER(10, 0), TEXT(16777216))), + external_volume=external_volume, + base_location=base_location, + ) + """ + Test inserting data into a table with a MAP column type. + + Args: + engine_testaccount: The SQLAlchemy engine connected to the test account. + external_volume: The external volume to use for the table. + base_location: The base location for the table. + snapshot: The snapshot object for assertion. + """ + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + slt = select( + 1, + cast( + text("{'100':'item1', '200':'item2'}"), + MAP(NUMBER(10, 0), TEXT(16777216)), + ), + ) + ins = test_map.insert().from_select(["id", "map_id"], slt) + conn.execute(ins) + + results = conn.execute(test_map.select()) + data = results.fetchmany() + results.close() + snapshot.assert_match(data) + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_insert_map_orm( + sql_compiler, external_volume, base_location, engine_testaccount, snapshot ): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __tablename__ = "test_iceberg_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return IcebergTable(name, metadata, *arg, **kw) + + __table_args__ = { + "external_volume": external_volume, + "base_location": base_location, + } + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + map_id = Column(MAP(NUMBER(10, 0), TEXT(16777216))) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + cast_expr = cast( + text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT(16777216)) + ) + instance = TestIcebergTableOrm(id=0, map_id=cast_expr) + session.add(instance) + with pytest.raises(exc.ProgrammingError) as programming_error: + session.commit() + # TODO: Support variant in insert statement + assert str(programming_error.value.orig) == snapshot + finally: + Base.metadata.drop_all(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_select_map_orm(engine_testaccount, external_volume, base_location, snapshot): metadata = MetaData() - table_name = "test_map0" + table_name = "test_select_map_orm" test_map = IcebergTable( table_name, metadata, Column("id", Integer, primary_key=True), - Column("map_id", MAP(NUMBER(10, 0), TEXT())), + Column("map_id", MAP(NUMBER(10, 0), TEXT(16777216))), external_volume=external_volume, base_location=base_location, ) metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + slt1 = select( + 2, + cast( + text("{'100':'item1', '200':'item2'}"), + MAP(NUMBER(10, 0), TEXT(16777216)), + ), + ) + slt2 = select( + 1, + cast( + text("{'100':'item1', '200':'item2'}"), + MAP(NUMBER(10, 0), TEXT(16777216)), + ), + ).union_all(slt1) + ins = test_map.insert().from_select(["id", "map_id"], slt2) + conn.execute(ins) + conn.commit() + + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __table__ = test_map + + def __repr__(self): + return f"({self.id!r}, {self.map_id!r})" + try: - assert test_map is not None + data = session.query(TestIcebergTableOrm).all() + snapshot.assert_match(data) finally: test_map.drop(engine_testaccount) @pytest.mark.requires_external_volume -def test_insert_map(engine_testaccount, external_volume, base_location, snapshot): +def test_insert_structured_object( + engine_testaccount, external_volume, base_location, snapshot +): metadata = MetaData() - table_name = "test_insert_map" + table_name = "test_insert_structured_object" test_map = IcebergTable( table_name, metadata, Column("id", Integer, primary_key=True), - Column("map_id", MAP(NUMBER(10, 0), TEXT())), + Column( + "object_col", + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + ), external_volume=external_volume, base_location=base_location, ) @@ -77,10 +203,11 @@ def test_insert_map(engine_testaccount, external_volume, base_location, snapshot slt = select( 1, cast( - text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT()) + text("{'key1':'item1', 'key2': 15}"), + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), ), ) - ins = test_map.insert().from_select(["id", "map_id"], slt) + ins = test_map.insert().from_select(["id", "object_col"], slt) conn.execute(ins) results = conn.execute(test_map.select()) @@ -91,16 +218,125 @@ def test_insert_map(engine_testaccount, external_volume, base_location, snapshot test_map.drop(engine_testaccount) +@pytest.mark.requires_external_volume +def test_insert_structured_object_orm( + sql_compiler, external_volume, base_location, engine_testaccount, snapshot +): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __tablename__ = "test_iceberg_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return IcebergTable(name, metadata, *arg, **kw) + + __table_args__ = { + "external_volume": external_volume, + "base_location": base_location, + } + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + object_col = Column( + OBJECT(key1=(NUMBER(10, 0), False), key2=(TEXT(16777216), False)) + ) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + cast_expr = cast( + text("{ 'key1' : 1, 'key2' : 'item1' }"), + OBJECT(key1=(NUMBER(10, 0), False), key2=(TEXT(16777216), False)), + ) + instance = TestIcebergTableOrm(id=0, object_col=cast_expr) + session.add(instance) + with pytest.raises(exc.ProgrammingError) as programming_error: + session.commit() + # TODO: Support variant in insert statement + assert str(programming_error.value.orig) == snapshot + finally: + Base.metadata.drop_all(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_select_structured_object_orm( + engine_testaccount, external_volume, base_location, snapshot +): + metadata = MetaData() + table_name = "test_select_structured_object_orm" + iceberg_table = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column( + "structured_obj_col", + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + ), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + first_select = select( + 2, + cast( + text("{'key1': 'value1', 'key2': 1}"), + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + ), + ) + second_select = select( + 1, + cast( + text("{'key1': 'value2', 'key2': 2}"), + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + ), + ).union_all(first_select) + insert_statement = iceberg_table.insert().from_select( + ["id", "structured_obj_col"], second_select + ) + conn.execute(insert_statement) + conn.commit() + + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __table__ = iceberg_table + + def __repr__(self): + return f"({self.id!r}, {self.structured_obj_col!r})" + + try: + data = session.query(TestIcebergTableOrm).all() + snapshot.assert_match(data) + finally: + iceberg_table.drop(engine_testaccount) + + @pytest.mark.requires_external_volume @pytest.mark.parametrize( - "structured_type", + "structured_type, expected_type", [ - MAP(NUMBER(10, 0), TEXT()), - MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), TEXT())), + (MAP(NUMBER(10, 0), TEXT(16777216)), MAP), + (MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), TEXT(16777216))), MAP), + ( + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + OBJECT, + ), ], ) def test_inspect_structured_data_types( - engine_testaccount, external_volume, base_location, snapshot, structured_type + engine_testaccount, + external_volume, + base_location, + snapshot, + structured_type, + expected_type, ): metadata = MetaData() table_name = "test_st_types" @@ -108,7 +344,7 @@ def test_inspect_structured_data_types( table_name, metadata, Column("id", Integer, primary_key=True), - Column("map_id", structured_type), + Column("structured_type_col", structured_type), external_volume=external_volume, base_location=base_location, ) @@ -119,7 +355,7 @@ def test_inspect_structured_data_types( columns = inspecter.get_columns(table_name) assert isinstance(columns[0]["type"], NUMBER) - assert isinstance(columns[1]["type"], MAP) + assert isinstance(columns[1]["type"], expected_type) assert columns == snapshot finally: @@ -132,6 +368,7 @@ def test_inspect_structured_data_types( [ "MAP(NUMBER(10, 0), VARCHAR)", "MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))", + "OBJECT(key1 VARCHAR, key2 NUMBER(10, 0))", ], ) def test_reflect_structured_data_types( @@ -147,7 +384,7 @@ def test_reflect_structured_data_types( create_table_sql = f""" CREATE OR REPLACE ICEBERG TABLE {table_name} ( id number(38,0) primary key, - map_id {structured_type}) + structured_type_col {structured_type}) CATALOG = 'SNOWFLAKE' EXTERNAL_VOLUME = '{external_volume}' BASE_LOCATION = '{base_location}'; @@ -174,47 +411,43 @@ def test_reflect_structured_data_types( @pytest.mark.requires_external_volume -def test_insert_map_orm( - sql_compiler, external_volume, base_location, engine_testaccount, snapshot +def test_create_table_structured_datatypes( + engine_testaccount, external_volume, base_location ): - Base = declarative_base() - session = Session(bind=engine_testaccount) - - class TestIcebergTableOrm(Base): - __tablename__ = "test_iceberg_table_orm" - - @classmethod - def __table_cls__(cls, name, metadata, *arg, **kw): - return IcebergTable(name, metadata, *arg, **kw) - - __table_args__ = { - "external_volume": external_volume, - "base_location": base_location, - } - - id = Column(Integer, Sequence("user_id_seq"), primary_key=True) - map_id = Column(MAP(NUMBER(10, 0), TEXT())) - - def __repr__(self): - return f"({self.id!r}, {self.name!r})" - - Base.metadata.create_all(engine_testaccount) - + metadata = MetaData() + table_name = "test_structured0" + test_structured_dt = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_id", MAP(NUMBER(10, 0), TEXT(16777216))), + Column( + "object_col", + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + ), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) try: - cast_expr = cast( - text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT()) - ) - instance = TestIcebergTableOrm(id=0, map_id=cast_expr) - session.add(instance) - with pytest.raises(exc.ProgrammingError) as programming_error: - session.commit() - # TODO: Support variant in insert statement - assert str(programming_error.value.orig) == snapshot + assert test_structured_dt is not None finally: - Base.metadata.drop_all(engine_testaccount) + test_structured_dt.drop(engine_testaccount) -def test_snowflake_tables_with_structured_types(sql_compiler): +@pytest.mark.parametrize( + "structured_type_col", + [ + Column("name", MAP(NUMBER(10, 0), TEXT(16777216))), + Column( + "object_col", + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + ), + ], +) +def test_structured_type_not_supported_in_table_columns_error( + sql_compiler, structured_type_col +): metadata = MetaData() with pytest.raises( StructuredTypeNotSupportedInTableColumnsError @@ -223,49 +456,6 @@ def test_snowflake_tables_with_structured_types(sql_compiler): "clustered_user", metadata, Column("Id", Integer, primary_key=True), - Column("name", MAP(NUMBER(10, 0), TEXT())), + structured_type_col, ) assert programming_error is not None - - -@pytest.mark.requires_external_volume -def test_select_map_orm(engine_testaccount, external_volume, base_location, snapshot): - metadata = MetaData() - table_name = "test_select_map_orm" - test_map = IcebergTable( - table_name, - metadata, - Column("id", Integer, primary_key=True), - Column("map_id", MAP(NUMBER(10, 0), TEXT())), - external_volume=external_volume, - base_location=base_location, - ) - metadata.create_all(engine_testaccount) - - with engine_testaccount.connect() as conn: - slt1 = select( - 2, - cast(text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT())), - ) - slt2 = select( - 1, - cast(text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT())), - ).union_all(slt1) - ins = test_map.insert().from_select(["id", "map_id"], slt2) - conn.execute(ins) - conn.commit() - - Base = declarative_base() - session = Session(bind=engine_testaccount) - - class TestIcebergTableOrm(Base): - __table__ = test_map - - def __repr__(self): - return f"({self.id!r}, {self.map_id!r})" - - try: - data = session.query(TestIcebergTableOrm).all() - snapshot.assert_match(data) - finally: - test_map.drop(engine_testaccount) diff --git a/tests/test_unit_structured_types.py b/tests/test_unit_structured_types.py index c7bcd6ef..474ebde4 100644 --- a/tests/test_unit_structured_types.py +++ b/tests/test_unit_structured_types.py @@ -6,8 +6,8 @@ from snowflake.sqlalchemy import NUMBER from snowflake.sqlalchemy.custom_types import MAP, TEXT from src.snowflake.sqlalchemy.parser.custom_type_parser import ( - extract_parameters, parse_type, + tokenize_parameters, ) @@ -18,7 +18,7 @@ def test_compile_map_with_not_null(snapshot): def test_extract_parameters(): example = "a, b(c, d, f), d" - assert extract_parameters(example) == ["a", "b(c, d, f)", "d"] + assert tokenize_parameters(example) == ["a", "b(c, d, f)", "d"] @pytest.mark.parametrize( @@ -64,6 +64,10 @@ def test_extract_parameters(): ), ("MAP(DECIMAL(10, 0), VARIANT)", "MAP(DECIMAL(10, 0), VARIANT)"), ("OBJECT", "OBJECT"), + ( + "OBJECT(a DECIMAL(10, 0) NOT NULL, b DECIMAL(10, 0), c VARCHAR NOT NULL)", + "OBJECT(a DECIMAL(10, 0) NOT NULL, b DECIMAL(10, 0), c VARCHAR NOT NULL)", + ), ("ARRAY", "ARRAY"), ("GEOGRAPHY", "GEOGRAPHY"), ("GEOMETRY", "GEOMETRY"), From 0471c1fd7cc94a940359b41c3e4ea168f4683606 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Mon, 16 Dec 2024 10:11:42 -0600 Subject: [PATCH 55/74] SNOW-1776333 Add support for ARRAY (#560) --- DESCRIPTION.md | 1 + README.md | 73 +++++++++ src/snowflake/sqlalchemy/base.py | 5 +- src/snowflake/sqlalchemy/custom_types.py | 16 +- .../sqlalchemy/parser/custom_type_parser.py | 52 ++++--- .../test_structured_datatypes.ambr | 49 ++++++ tests/test_structured_datatypes.py | 141 ++++++++++++++++-- tests/test_unit_structured_types.py | 4 + 8 files changed, 309 insertions(+), 32 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index d5872bb9..0615bd15 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -14,6 +14,7 @@ Source code is also available at: - Fix index columns was not being reflected - Fix index reflection cache not working - Add support for structured OBJECT datatype + - Add support for structured ARRAY datatype - v1.7.1(December 02, 2024) - Add support for partition by to copy into diff --git a/README.md b/README.md index 3985a9d6..34d86376 100644 --- a/README.md +++ b/README.md @@ -367,6 +367,79 @@ data_object = json.loads(row[1]) data_array = json.loads(row[2]) ``` +### Structured Data Types Support + +This module defines custom SQLAlchemy types for Snowflake structured data, specifically for **Iceberg tables**. +The types —**MAP**, **OBJECT**, and **ARRAY**— allow you to store complex data structures in your SQLAlchemy models. +For detailed information, refer to the Snowflake [Structured data types](https://docs.snowflake.com/en/sql-reference/data-types-structured) documentation. + +--- + +#### MAP + +The `MAP` type represents a collection of key-value pairs, where each key and value can have different types. + +- **Key Type**: The type of the keys (e.g., `TEXT`, `NUMBER`). +- **Value Type**: The type of the values (e.g., `TEXT`, `NUMBER`). +- **Not Null**: Whether `NULL` values are allowed (default is `False`). + +*Example Usage* + +```python +IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_col", MAP(NUMBER(10, 0), TEXT(16777216))), + external_volume="external_volume", + base_location="base_location", +) +``` + +#### OBJECT + +The `OBJECT` type represents a semi-structured object with named fields. Each field can have a specific type, and you can also specify whether each field is nullable. + +- **Items Types**: A dictionary of field names and their types. The type can optionally include a nullable flag (`True` for not nullable, `False` for nullable, default is `False`). + +*Example Usage* + +```python +IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column( + "object_col", + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + OBJECT(key1=TEXT(16777216), key2=NUMBER(10, 0)), # Without nullable flag + ), + external_volume="external_volume", + base_location="base_location", +) +``` + +#### ARRAY + +The `ARRAY` type represents an ordered list of values, where each element has the same type. The type of the elements is defined when creating the array. + +- **Value Type**: The type of the elements in the array (e.g., `TEXT`, `NUMBER`). +- **Not Null**: Whether `NULL` values are allowed (default is `False`). + +*Example Usage* + +```python +IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("array_col", ARRAY(TEXT(16777216))), + external_volume="external_volume", + base_location="base_location", +) +``` + + ### CLUSTER BY Support Snowflake SQLAchemy supports the `CLUSTER BY` parameter for tables. For information about the parameter, see :doc:`/sql-reference/sql/create-table`. diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 3fef7709..9ce8b83c 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -1096,7 +1096,10 @@ def visit_MAP(self, type_, **kw): ) def visit_ARRAY(self, type_, **kw): - return "ARRAY" + if type_.is_semi_structured: + return "ARRAY" + not_null = f" {NOT_NULL}" if type_.not_null else "" + return f"ARRAY({type_.value_type.compile()}{not_null})" def visit_OBJECT(self, type_, **kw): if type_.is_semi_structured: diff --git a/src/snowflake/sqlalchemy/custom_types.py b/src/snowflake/sqlalchemy/custom_types.py index ce7ad592..11cd2eb8 100644 --- a/src/snowflake/sqlalchemy/custom_types.py +++ b/src/snowflake/sqlalchemy/custom_types.py @@ -1,7 +1,7 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from typing import Tuple, Union +from typing import Optional, Tuple, Union import sqlalchemy.types as sqltypes import sqlalchemy.util as util @@ -40,7 +40,8 @@ class VARIANT(SnowflakeType): class StructuredType(SnowflakeType): - def __init__(self): + def __init__(self, is_semi_structured: bool = False): + self.is_semi_structured = is_semi_structured super().__init__() @@ -81,9 +82,18 @@ def __repr__(self): ) -class ARRAY(SnowflakeType): +class ARRAY(StructuredType): __visit_name__ = "ARRAY" + def __init__( + self, + value_type: Optional[sqltypes.TypeEngine] = None, + not_null: bool = False, + ): + self.value_type = value_type + self.not_null = not_null + super().__init__(is_semi_structured=value_type is None) + class TIMESTAMP_TZ(SnowflakeType): __visit_name__ = "TIMESTAMP_TZ" diff --git a/src/snowflake/sqlalchemy/parser/custom_type_parser.py b/src/snowflake/sqlalchemy/parser/custom_type_parser.py index 1e99ba56..09cb6ab8 100644 --- a/src/snowflake/sqlalchemy/parser/custom_type_parser.py +++ b/src/snowflake/sqlalchemy/parser/custom_type_parser.py @@ -74,6 +74,8 @@ "GEOMETRY": GEOMETRY, } +NOT_NULL_STR = "NOT NULL" + def tokenize_parameters(text: str, character_for_strip=",") -> list: """ @@ -160,6 +162,8 @@ def parse_type(type_text: str) -> TypeEngine: col_type_kw = __parse_map_type_parameters(parameters) elif issubclass(col_type_class, OBJECT): col_type_kw = __parse_object_type_parameters(parameters) + elif issubclass(col_type_class, ARRAY): + col_type_kw = __parse_nullable_parameter(parameters) if col_type_kw is None: col_type_class = NullType col_type_kw = {} @@ -169,6 +173,7 @@ def parse_type(type_text: str) -> TypeEngine: def __parse_object_type_parameters(parameters): object_rows = {} + not_null_parts = NOT_NULL_STR.split(" ") for parameter in parameters: parameter_parts = tokenize_parameters(parameter, " ") if len(parameter_parts) >= 2: @@ -178,40 +183,51 @@ def __parse_object_type_parameters(parameters): return None not_null = ( len(parameter_parts) == 4 - and parameter_parts[2] == "NOT" - and parameter_parts[3] == "NULL" + and parameter_parts[2] == not_null_parts[0] + and parameter_parts[3] == not_null_parts[1] ) object_rows[key] = (value_type, not_null) return object_rows -def __parse_map_type_parameters(parameters): - if len(parameters) != 2: +def __parse_nullable_parameter(parameters): + if len(parameters) < 1: + return {} + elif len(parameters) > 1: return None - - key_type_str = parameters[0] - value_type_str = parameters[1] - not_null_str = "NOT NULL" - not_null = False + parameter_str = parameters[0] + is_not_null = False if ( - len(value_type_str) >= len(not_null_str) - and value_type_str[-len(not_null_str) :] == not_null_str + len(parameter_str) >= len(NOT_NULL_STR) + and parameter_str[-len(NOT_NULL_STR) :] == NOT_NULL_STR ): - not_null = True - value_type_str = value_type_str[: -len(not_null_str) - 1] + is_not_null = True + parameter_str = parameter_str[: -len(NOT_NULL_STR) - 1] - key_type: TypeEngine = parse_type(key_type_str) - value_type: TypeEngine = parse_type(value_type_str) - if isinstance(key_type, NullType) or isinstance(value_type, NullType): + value_type: TypeEngine = parse_type(parameter_str) + if isinstance(value_type, NullType): return None return { - "key_type": key_type, "value_type": value_type, - "not_null": not_null, + "not_null": is_not_null, } +def __parse_map_type_parameters(parameters): + if len(parameters) != 2: + return None + + key_type_str = parameters[0] + value_type_str = parameters[1] + key_type: TypeEngine = parse_type(key_type_str) + value_type = __parse_nullable_parameter([value_type_str]) + if isinstance(value_type, NullType) or isinstance(key_type, NullType): + return None + + return {"key_type": key_type, **value_type} + + def __parse_type_with_length_parameters(parameters): return ( {"length": int(parameters[0])} diff --git a/tests/__snapshots__/test_structured_datatypes.ambr b/tests/__snapshots__/test_structured_datatypes.ambr index 714f5d57..3dcedf7c 100644 --- a/tests/__snapshots__/test_structured_datatypes.ambr +++ b/tests/__snapshots__/test_structured_datatypes.ambr @@ -14,6 +14,20 @@ # name: test_compile_table_with_structured_data_type[structured_type2] 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tPRIMARY KEY ("Id"))' # --- +# name: test_compile_table_with_structured_data_type[structured_type3] + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname ARRAY(MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tPRIMARY KEY ("Id"))' +# --- +# name: test_insert_array + list([ + (1, '[\n "item1",\n "item2"\n]'), + ]) +# --- +# name: test_insert_array_orm + ''' + 002014 (22000): SQL compilation error: + Invalid expression [CAST(ARRAY_CONSTRUCT('item1', 'item2') AS ARRAY(VARCHAR(16777216)))] in VALUES clause + ''' +# --- # name: test_insert_map list([ (1, '{\n "100": "item1",\n "200": "item2"\n}'), @@ -166,6 +180,35 @@ }), ]) # --- +# name: test_inspect_structured_data_types[structured_type3-ARRAY] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'structured_type_col', + 'nullable': True, + 'primary_key': False, + 'type': ARRAY(value_type=VARCHAR(length=16777216)), + }), + ]) +# --- +# name: test_reflect_structured_data_types[ARRAY(MAP(NUMBER(10, 0), VARCHAR))] + "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col ARRAY(MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" +# --- # name: test_reflect_structured_data_types[MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))] "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" # --- @@ -175,6 +218,12 @@ # name: test_reflect_structured_data_types[OBJECT(key1 VARCHAR, key2 NUMBER(10, 0))] "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" # --- +# name: test_select_array_orm + list([ + (1, '[\n "item3",\n "item4"\n]'), + (2, '[\n "item1",\n "item2"\n]'), + ]) +# --- # name: test_select_map_orm list([ (1, '{\n "100": "item1",\n "200": "item2"\n}'), diff --git a/tests/test_structured_datatypes.py b/tests/test_structured_datatypes.py index d6beb3e9..ce030bd2 100644 --- a/tests/test_structured_datatypes.py +++ b/tests/test_structured_datatypes.py @@ -18,7 +18,7 @@ from sqlalchemy.sql.ddl import CreateTable from snowflake.sqlalchemy import NUMBER, IcebergTable, SnowflakeTable -from snowflake.sqlalchemy.custom_types import MAP, OBJECT, TEXT +from snowflake.sqlalchemy.custom_types import ARRAY, MAP, OBJECT, TEXT from snowflake.sqlalchemy.exc import StructuredTypeNotSupportedInTableColumnsError @@ -28,6 +28,7 @@ MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), TEXT(16777216))), OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), OBJECT(key1=TEXT(16777216), key2=NUMBER(10, 0)), + ARRAY(MAP(NUMBER(10, 0), TEXT(16777216))), ], ) def test_compile_table_with_structured_data_type( @@ -58,15 +59,6 @@ def test_insert_map(engine_testaccount, external_volume, base_location, snapshot external_volume=external_volume, base_location=base_location, ) - """ - Test inserting data into a table with a MAP column type. - - Args: - engine_testaccount: The SQLAlchemy engine connected to the test account. - external_volume: The external volume to use for the table. - base_location: The base location for the table. - snapshot: The snapshot object for assertion. - """ metadata.create_all(engine_testaccount) try: @@ -179,6 +171,128 @@ def __repr__(self): test_map.drop(engine_testaccount) +@pytest.mark.requires_external_volume +def test_select_array_orm(engine_testaccount, external_volume, base_location, snapshot): + metadata = MetaData() + table_name = "test_select_array_orm" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("array_col", ARRAY(TEXT(16777216))), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + slt1 = select( + 2, + cast( + text("['item1','item2']"), + ARRAY(TEXT(16777216)), + ), + ) + slt2 = select( + 1, + cast( + text("['item3','item4']"), + ARRAY(TEXT(16777216)), + ), + ).union_all(slt1) + ins = test_map.insert().from_select(["id", "array_col"], slt2) + conn.execute(ins) + conn.commit() + + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __table__ = test_map + + def __repr__(self): + return f"({self.id!r}, {self.array_col!r})" + + try: + data = session.query(TestIcebergTableOrm).all() + snapshot.assert_match(data) + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_insert_array(engine_testaccount, external_volume, base_location, snapshot): + metadata = MetaData() + table_name = "test_insert_map" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("array_col", ARRAY(TEXT(16777216))), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + slt = select( + 1, + cast( + text("['item1','item2']"), + ARRAY(TEXT(16777216)), + ), + ) + ins = test_map.insert().from_select(["id", "array_col"], slt) + conn.execute(ins) + + results = conn.execute(test_map.select()) + data = results.fetchmany() + results.close() + snapshot.assert_match(data) + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_insert_array_orm( + sql_compiler, external_volume, base_location, engine_testaccount, snapshot +): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __tablename__ = "test_iceberg_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return IcebergTable(name, metadata, *arg, **kw) + + __table_args__ = { + "external_volume": external_volume, + "base_location": base_location, + } + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + array_col = Column(ARRAY(TEXT(16777216))) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + cast_expr = cast(text("['item1','item2']"), ARRAY(TEXT(16777216))) + instance = TestIcebergTableOrm(id=0, array_col=cast_expr) + session.add(instance) + with pytest.raises(exc.ProgrammingError) as programming_error: + session.commit() + # TODO: Support variant in insert statement + assert str(programming_error.value.orig) == snapshot + finally: + Base.metadata.drop_all(engine_testaccount) + + @pytest.mark.requires_external_volume def test_insert_structured_object( engine_testaccount, external_volume, base_location, snapshot @@ -328,6 +442,7 @@ def __repr__(self): OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), OBJECT, ), + (ARRAY(TEXT(16777216)), ARRAY), ], ) def test_inspect_structured_data_types( @@ -369,6 +484,7 @@ def test_inspect_structured_data_types( "MAP(NUMBER(10, 0), VARCHAR)", "MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))", "OBJECT(key1 VARCHAR, key2 NUMBER(10, 0))", + "ARRAY(MAP(NUMBER(10, 0), VARCHAR))", ], ) def test_reflect_structured_data_types( @@ -425,6 +541,10 @@ def test_create_table_structured_datatypes( "object_col", OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), ), + Column( + "array_col", + ARRAY(TEXT(16777216)), + ), external_volume=external_volume, base_location=base_location, ) @@ -443,6 +563,7 @@ def test_create_table_structured_datatypes( "object_col", OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), ), + Column("name", ARRAY(TEXT(16777216))), ], ) def test_structured_type_not_supported_in_table_columns_error( diff --git a/tests/test_unit_structured_types.py b/tests/test_unit_structured_types.py index 474ebde4..472ce2e6 100644 --- a/tests/test_unit_structured_types.py +++ b/tests/test_unit_structured_types.py @@ -69,6 +69,10 @@ def test_extract_parameters(): "OBJECT(a DECIMAL(10, 0) NOT NULL, b DECIMAL(10, 0), c VARCHAR NOT NULL)", ), ("ARRAY", "ARRAY"), + ( + "ARRAY(MAP(DECIMAL(10, 0), VARCHAR NOT NULL))", + "ARRAY(MAP(DECIMAL(10, 0), VARCHAR NOT NULL))", + ), ("GEOGRAPHY", "GEOGRAPHY"), ("GEOMETRY", "GEOMETRY"), ], From d3e40d1efc915fc619027b9c35fca39d87e05ae3 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Mon, 16 Dec 2024 15:45:50 -0600 Subject: [PATCH 56/74] SNOW-1846847 Add support for autocommit (#555) * Add support for autocommit --- DESCRIPTION.md | 1 + src/snowflake/sqlalchemy/snowdialect.py | 34 +++++ tests/test_transactions.py | 157 ++++++++++++++++++++++++ 3 files changed, 192 insertions(+) create mode 100644 tests/test_transactions.py diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 0615bd15..0a50408e 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -19,6 +19,7 @@ Source code is also available at: - v1.7.1(December 02, 2024) - Add support for partition by to copy into - Fix BOOLEAN type not found in snowdialect + - Add support for autocommit Isolation Level - v1.7.0(November 21, 2024) - Add support for dynamic tables and required options diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index dd5e4375..96bcac71 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -4,6 +4,7 @@ import operator import re from collections import defaultdict +from enum import Enum from functools import reduce from typing import Any, Collection, Optional from urllib.parse import unquote_plus @@ -59,6 +60,11 @@ _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME = True +class SnowflakeIsolationLevel(Enum): + READ_COMMITTED = "READ COMMITTED" + AUTOCOMMIT = "AUTOCOMMIT" + + class SnowflakeDialect(default.DefaultDialect): name = DIALECT_NAME driver = "snowflake" @@ -139,6 +145,13 @@ class SnowflakeDialect(default.DefaultDialect): supports_identity_columns = True + def __init__( + self, + isolation_level: Optional[str] = SnowflakeIsolationLevel.READ_COMMITTED.value, + **kwargs: Any, + ): + super().__init__(isolation_level=isolation_level, **kwargs) + @classmethod def dbapi(cls): return cls.import_dbapi() @@ -216,6 +229,27 @@ def has_table(self, connection, table_name, schema=None, **kw): """ return self._has_object(connection, "TABLE", table_name, schema) + def get_isolation_level_values(self, dbapi_connection): + return [ + SnowflakeIsolationLevel.READ_COMMITTED.value, + SnowflakeIsolationLevel.AUTOCOMMIT.value, + ] + + def do_rollback(self, dbapi_connection): + dbapi_connection.rollback() + + def do_commit(self, dbapi_connection): + dbapi_connection.commit() + + def get_default_isolation_level(self, dbapi_conn): + return SnowflakeIsolationLevel.READ_COMMITTED.value + + def set_isolation_level(self, dbapi_connection, level): + if level == SnowflakeIsolationLevel.AUTOCOMMIT.value: + dbapi_connection.autocommit(True) + else: + dbapi_connection.autocommit(False) + @reflection.cache def has_sequence(self, connection, sequence_name, schema=None, **kw): """ diff --git a/tests/test_transactions.py b/tests/test_transactions.py new file mode 100644 index 00000000..c163c2b7 --- /dev/null +++ b/tests/test_transactions.py @@ -0,0 +1,157 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from sqlalchemy import Column, Integer, MetaData, String, select, text + +from snowflake.sqlalchemy import SnowflakeTable + +CURRENT_TRANSACTION = text("SELECT CURRENT_TRANSACTION()") + + +def test_connect_read_commited(engine_testaccount, assert_text_in_buf): + metadata = MetaData() + table_name = "test_connect_read_commited" + + test_table_1 = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by=["id", text("id > 5")], + ) + + metadata.create_all(engine_testaccount) + try: + with engine_testaccount.connect().execution_options( + isolation_level="READ COMMITTED" + ) as connection: + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] == (None,), result + ins = test_table_1.insert().values(id=1, name="test") + connection.execute(ins) + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] != ( + None, + ), "AUTOCOMMIT DISABLED, transaction should be started" + + with engine_testaccount.connect() as conn: + s = select(test_table_1) + results = conn.execute(s).fetchall() + assert len(results) == 0, results # No insert commited + assert_text_in_buf("ROLLBACK", occurrences=1) + finally: + metadata.drop_all(engine_testaccount) + + +def test_begin_read_commited(engine_testaccount, assert_text_in_buf): + metadata = MetaData() + table_name = "test_begin_read_commited" + + test_table_1 = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by=["id", text("id > 5")], + ) + + metadata.create_all(engine_testaccount) + try: + with engine_testaccount.connect().execution_options( + isolation_level="READ COMMITTED" + ) as connection, connection.begin(): + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] == (None,), result + ins = test_table_1.insert().values(id=1, name="test") + connection.execute(ins) + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] != ( + None, + ), "AUTOCOMMIT DISABLED, transaction should be started" + + with engine_testaccount.connect() as conn: + s = select(test_table_1) + results = conn.execute(s).fetchall() + assert len(results) == 1, results # Insert commited + assert_text_in_buf("COMMIT", occurrences=2) + finally: + metadata.drop_all(engine_testaccount) + + +def test_connect_autocommit(engine_testaccount, assert_text_in_buf): + metadata = MetaData() + table_name = "test_connect_autocommit" + + test_table_1 = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by=["id", text("id > 5")], + ) + + metadata.create_all(engine_testaccount) + try: + with engine_testaccount.connect().execution_options( + isolation_level="AUTOCOMMIT" + ) as connection: + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] == (None,), result + ins = test_table_1.insert().values(id=1, name="test") + connection.execute(ins) + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] == ( + None, + ), "Autocommit enabled, transaction should not be started" + + with engine_testaccount.connect() as conn: + s = select(test_table_1) + results = conn.execute(s).fetchall() + assert len(results) == 1, results + assert_text_in_buf( + "ROLLBACK using DBAPI connection.rollback(), DBAPI should ignore due to autocommit mode", + occurrences=1, + ) + + finally: + metadata.drop_all(engine_testaccount) + + +def test_begin_autocommit(engine_testaccount, assert_text_in_buf): + metadata = MetaData() + table_name = "test_begin_autocommit" + + test_table_1 = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by=["id", text("id > 5")], + ) + + metadata.create_all(engine_testaccount) + try: + with engine_testaccount.connect().execution_options( + isolation_level="AUTOCOMMIT" + ) as connection, connection.begin(): + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] == (None,), result + ins = test_table_1.insert().values(id=1, name="test") + connection.execute(ins) + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] == ( + None, + ), "Autocommit enabled, transaction should not be started" + + with engine_testaccount.connect() as conn: + s = select(test_table_1) + results = conn.execute(s).fetchall() + assert len(results) == 1, results + assert_text_in_buf( + "COMMIT using DBAPI connection.commit(), DBAPI should ignore due to autocommit mode", + occurrences=1, + ) + + finally: + metadata.drop_all(engine_testaccount) From 115da35aee57a6a65ab774e19770cdf8ff386fba Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Tue, 17 Dec 2024 07:45:16 -0600 Subject: [PATCH 57/74] Update DESCRIPTION.md (#561) --- DESCRIPTION.md | 2 +- src/snowflake/sqlalchemy/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 0a50408e..f67f91ef 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,7 +9,7 @@ Source code is also available at: # Release Notes -- (Unreleased) +- v1.7.2(December 18, 2024) - Fix quoting of `_` as column name - Fix index columns was not being reflected - Fix index reflection cache not working diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index f942f2bd..e54f6a9c 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -3,4 +3,4 @@ # # Update this for the versions # Don't change the forth version number from None -VERSION = "1.7.1" +VERSION = "1.7.2" From d84484e148f97586086504f53acf095b67af76af Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Sun, 5 Jan 2025 19:14:08 -0600 Subject: [PATCH 58/74] Fix return value of snowflake get_table_names (#564) * Fix return value of snowflake get_table_names --------- Co-authored-by: T Pham --- DESCRIPTION.md | 2 ++ src/snowflake/sqlalchemy/snowdialect.py | 2 +- tests/test_core.py | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index f67f91ef..1dab72f1 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -8,6 +8,8 @@ Source code is also available at: # Release Notes +- (Unreleased) + - Fix return value of snowflake get_table_names - v1.7.2(December 18, 2024) - Fix quoting of `_` as column name diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 96bcac71..6f1493c3 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -756,7 +756,7 @@ def get_table_names(self, connection, schema=None, **kw): ret = self._get_schema_tables_info( connection, schema, info_cache=kw.get("info_cache", None) ).keys() - return ret + return list(ret) @reflection.cache def get_view_names(self, connection, schema=None, **kw): diff --git a/tests/test_core.py b/tests/test_core.py index 63f097db..dfe0f714 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -469,6 +469,7 @@ def test_inspect_column(engine_testaccount): try: inspector = inspect(engine_testaccount) all_table_names = inspector.get_table_names() + assert isinstance(all_table_names, list) assert "users" in all_table_names assert "addresses" in all_table_names From b9b26e590db424777566c91b099a08a4dcb3b7cf Mon Sep 17 00:00:00 2001 From: Juan Martinez Ramirez <126511805+sfc-gh-jmartinezramirez@users.noreply.github.com> Date: Thu, 9 Jan 2025 07:31:36 -0600 Subject: [PATCH 59/74] Added flag to allow override SnowflakeDialect div_is_floor_div flag (#545) * Changed default behavior of SnowflakeDialect to disable the use of / division operator as floor div. Changed flag div_is_floor_div to False. * Update Description.md * Added flag to allow customer to test new behavior of div_is_floordiv that will be introduce, using new flag force_div_floordiv allow to test the new division behavior. Update sa14:scripts to ignore feature_v20 from execution * Added warning for use of div_is_floor_div with `True` value. Added tests to validate results of true and floor divisions using `force_div_is_floordiv` flag. --- DESCRIPTION.md | 4 ++ pyproject.toml | 6 ++ src/snowflake/sqlalchemy/base.py | 21 ++++++ src/snowflake/sqlalchemy/snowdialect.py | 10 +++ tests/conftest.py | 20 ++++++ tests/test_compiler.py | 79 ++++++++++++++++++++++ tests/test_core.py | 87 ++++++++++++++++++++++++- 7 files changed, 226 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 1dab72f1..00f860ce 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -10,6 +10,10 @@ Source code is also available at: # Release Notes - (Unreleased) - Fix return value of snowflake get_table_names + - Added `force_div_is_floordiv` flag to override `div_is_floordiv` new default value `False` in `SnowflakeDialect`. + - With the flag in `False`, the `/` division operator will be treated as a float division and `//` as a floor division. + - This flag is added to maintain backward compatibility with the previous behavior of Snowflake Dialect division. + - This flag will be removed in the future and Snowflake Dialect will use `div_is_floor_div` as `False`. - v1.7.2(December 18, 2024) - Fix quoting of `_` as column name diff --git a/pyproject.toml b/pyproject.toml index b0ae04c4..b22dc293 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,11 @@ extra-dependencies = ["SQLAlchemy>=1.4.19,<2.0.0"] features = ["development", "pandas"] python = "3.8" +[tool.hatch.envs.sa14.scripts] +test-dialect = "pytest --ignore_v20_test -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" +test-dialect-compatibility = "pytest --ignore_v20_test -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite" +test-dialect-aws = "pytest --ignore_v20_test -m \"aws\" -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" + [tool.hatch.envs.default.env-vars] COVERAGE_FILE = "coverage.xml" SQLACHEMY_WARN_20 = "1" @@ -134,4 +139,5 @@ markers = [ "requires_external_volume: tests that needs a external volume to be executed", "external: tests that could but should only run on our external CI", "feature_max_lob_size: tests that could but should only run on our external CI", + "feature_v20: tests that could but should only run on SqlAlchemy v20", ] diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 9ce8b83c..0226d37d 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -6,6 +6,7 @@ import operator import re import string +import warnings from typing import List from sqlalchemy import exc as sa_exc @@ -802,6 +803,26 @@ def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + join.onclause._compiler_dispatch(self, from_linter=from_linter, **kwargs) ) + def visit_truediv_binary(self, binary, operator, **kw): + if self.dialect.div_is_floordiv: + warnings.warn( + "div_is_floordiv value will be changed to False in a future release. This will generate a behavior change on true and floor division. Please review https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#python-division-operator-performs-true-division-for-all-backends-added-floor-division", + PendingDeprecationWarning, + stacklevel=2, + ) + return ( + self.process(binary.left, **kw) + " / " + self.process(binary.right, **kw) + ) + + def visit_floordiv_binary(self, binary, operator, **kw): + if self.dialect.div_is_floordiv and IS_VERSION_20: + warnings.warn( + "div_is_floordiv value will be changed to False in a future release. This will generate a behavior change on true and floor division. Please review https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#python-division-operator-performs-true-division-for-all-backends-added-floor-division", + PendingDeprecationWarning, + stacklevel=2, + ) + return super().visit_floordiv_binary(binary, operator, **kw) + def render_literal_value(self, value, type_): # escape backslash return super().render_literal_value(value, type_).replace("\\", "\\\\") diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 6f1493c3..bba1160f 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -79,6 +79,9 @@ class SnowflakeDialect(default.DefaultDialect): colspecs = colspecs ischema_names = ischema_names + # target database treats the / division operator as “floor division” + div_is_floordiv = False + # all str types must be converted in Unicode convert_unicode = True @@ -147,10 +150,17 @@ class SnowflakeDialect(default.DefaultDialect): def __init__( self, + force_div_is_floordiv: bool = True, isolation_level: Optional[str] = SnowflakeIsolationLevel.READ_COMMITTED.value, **kwargs: Any, ): super().__init__(isolation_level=isolation_level, **kwargs) + self.force_div_is_floordiv = force_div_is_floordiv + self.div_is_floordiv = force_div_is_floordiv + + def initialize(self, connection): + super().initialize(connection) + self.div_is_floordiv = self.force_div_is_floordiv @classmethod def dbapi(cls): diff --git a/tests/conftest.py b/tests/conftest.py index f2045121..5e0fd3ed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -47,6 +47,26 @@ TEST_SCHEMA = f"sqlalchemy_tests_{str(uuid.uuid4()).replace('-', '_')}" +def pytest_addoption(parser): + parser.addoption( + "--ignore_v20_test", + action="store_true", + default=False, + help="skip sqlalchemy 2.0 exclusive tests", + ) + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--ignore_v20_test"): + # --ignore_v20_test given in cli: skip sqlalchemy 2.0 tests + skip_feature_v2 = pytest.mark.skip( + reason="need remove --ignore_v20_test option to run" + ) + for item in items: + if "feature_v20" in item.keywords: + item.add_marker(skip_feature_v2) + + @pytest.fixture(scope="session") def on_travis(): return os.getenv("TRAVIS", "").lower() == "true" diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 55451c2f..0eea4607 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -2,12 +2,14 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +import pytest from sqlalchemy import Integer, String, and_, func, insert, select from sqlalchemy.schema import DropColumnComment, DropTableComment from sqlalchemy.sql import column, quoted_name, table from sqlalchemy.testing.assertions import AssertsCompiledSQL from snowflake.sqlalchemy import snowdialect +from src.snowflake.sqlalchemy.snowdialect import SnowflakeDialect table1 = table( "table1", column("id", Integer), column("name", String), column("value", Integer) @@ -135,3 +137,80 @@ def test_outer_lateral_join(): str(stmt.compile(dialect=snowdialect.dialect())) == "SELECT colname AS label \nFROM abc JOIN LATERAL flatten(PARSE_JSON(colname2)) AS anon_1 GROUP BY colname" ) + + +@pytest.mark.feature_v20 +def test_division_operator_with_force_div_is_floordiv_false(): + col1 = column("col1", Integer) + col2 = column("col2", Integer) + stmt = col1 / col2 + assert ( + str(stmt.compile(dialect=SnowflakeDialect(force_div_is_floordiv=False))) + == "col1 / col2" + ) + + +@pytest.mark.feature_v20 +def test_division_operator_with_denominator_expr_force_div_is_floordiv_false(): + col1 = column("col1", Integer) + col2 = column("col2", Integer) + stmt = col1 / func.sqrt(col2) + assert ( + str(stmt.compile(dialect=SnowflakeDialect(force_div_is_floordiv=False))) + == "col1 / sqrt(col2)" + ) + + +@pytest.mark.feature_v20 +def test_division_operator_with_force_div_is_floordiv_default_true(): + col1 = column("col1", Integer) + col2 = column("col2", Integer) + stmt = col1 / col2 + assert str(stmt.compile(dialect=SnowflakeDialect())) == "col1 / col2" + + +@pytest.mark.feature_v20 +def test_division_operator_with_denominator_expr_force_div_is_floordiv_default_true(): + col1 = column("col1", Integer) + col2 = column("col2", Integer) + stmt = col1 / func.sqrt(col2) + assert str(stmt.compile(dialect=SnowflakeDialect())) == "col1 / sqrt(col2)" + + +@pytest.mark.feature_v20 +def test_floor_division_operator_force_div_is_floordiv_false(): + col1 = column("col1", Integer) + col2 = column("col2", Integer) + stmt = col1 // col2 + assert ( + str(stmt.compile(dialect=SnowflakeDialect(force_div_is_floordiv=False))) + == "FLOOR(col1 / col2)" + ) + + +@pytest.mark.feature_v20 +def test_floor_division_operator_with_denominator_expr_force_div_is_floordiv_false(): + col1 = column("col1", Integer) + col2 = column("col2", Integer) + stmt = col1 // func.sqrt(col2) + assert ( + str(stmt.compile(dialect=SnowflakeDialect(force_div_is_floordiv=False))) + == "FLOOR(col1 / sqrt(col2))" + ) + + +@pytest.mark.feature_v20 +def test_floor_division_operator_force_div_is_floordiv_default_true(): + col1 = column("col1", Integer) + col2 = column("col2", Integer) + stmt = col1 // col2 + assert str(stmt.compile(dialect=SnowflakeDialect())) == "col1 / col2" + + +@pytest.mark.feature_v20 +def test_floor_division_operator_with_denominator_expr_force_div_is_floordiv_default_true(): + col1 = column("col1", Integer) + col2 = column("col2", Integer) + stmt = col1 // func.sqrt(col2) + res = stmt.compile(dialect=SnowflakeDialect()) + assert str(res) == "FLOOR(col1 / sqrt(col2))" diff --git a/tests/test_core.py b/tests/test_core.py index dfe0f714..a25342ac 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -31,13 +31,15 @@ create_engine, dialects, exc, + func, insert, inspect, text, ) from sqlalchemy.exc import DBAPIError, NoSuchTableError, OperationalError -from sqlalchemy.sql import and_, not_, or_, select +from sqlalchemy.sql import and_, literal, not_, or_, select from sqlalchemy.sql.ddl import CreateTable +from sqlalchemy.testing.assertions import eq_ import snowflake.connector.errors import snowflake.sqlalchemy.snowdialect @@ -1864,3 +1866,86 @@ def test_snowflake_sqlalchemy_as_valid_client_type(): snowflake.connector.connection.DEFAULT_CONFIGURATION[ "internal_application_version" ] = origin_internal_app_version + + +@pytest.mark.parametrize( + "operation", + [ + [ + literal(5), + literal(10), + 0.5, + ], + [literal(5), func.sqrt(literal(10)), 1.5811388300841895], + [literal(4), literal(5), decimal.Decimal("0.800000")], + [literal(2), literal(2), 1.0], + [literal(3), literal(2), 1.5], + [literal(4), literal(1.5), 2.666667], + [literal(5.5), literal(10.7), 0.5140187], + [literal(5.5), literal(8), 0.6875], + ], +) +def test_true_division_operation(engine_testaccount, operation): + # expected_warning = "div_is_floordiv value will be changed to False in a future release. This will generate a behavior change on true and floor division. Please review https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#python-division-operator-performs-true-division-for-all-backends-added-floor-division" + # with pytest.warns(PendingDeprecationWarning, match=expected_warning): + with engine_testaccount.connect() as conn: + eq_( + conn.execute(select(operation[0] / operation[1])).fetchall(), + [((operation[2]),)], + ) + + +@pytest.mark.parametrize( + "operation", + [ + [literal(5), literal(10), 0.5, 0.5], + [literal(5), func.sqrt(literal(10)), 1.5811388300841895, 1.0], + [ + literal(4), + literal(5), + decimal.Decimal("0.800000"), + decimal.Decimal("0.800000"), + ], + [literal(2), literal(2), 1.0, 1.0], + [literal(3), literal(2), 1.5, 1.5], + [literal(4), literal(1.5), 2.666667, 2.0], + [literal(5.5), literal(10.7), 0.5140187, 0], + [literal(5.5), literal(8), 0.6875, 0.6875], + ], +) +@pytest.mark.feature_v20 +def test_division_force_div_is_floordiv_default(engine_testaccount, operation): + expected_warning = "div_is_floordiv value will be changed to False in a future release. This will generate a behavior change on true and floor division. Please review https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#python-division-operator-performs-true-division-for-all-backends-added-floor-division" + with pytest.warns(PendingDeprecationWarning, match=expected_warning): + with engine_testaccount.connect() as conn: + eq_( + conn.execute( + select(operation[0] / operation[1], operation[0] // operation[1]) + ).fetchall(), + [(operation[2], operation[3])], + ) + + +@pytest.mark.parametrize( + "operation", + [ + [literal(5), literal(10), 0.5, 0], + [literal(5), func.sqrt(literal(10)), 1.5811388300841895, 1.0], + [literal(4), literal(5), decimal.Decimal("0.800000"), 0], + [literal(2), literal(2), 1.0, 1.0], + [literal(3), literal(2), 1.5, 1], + [literal(4), literal(1.5), 2.666667, 2.0], + [literal(5.5), literal(10.7), 0.5140187, 0], + [literal(5.5), literal(8), 0.6875, 0], + ], +) +@pytest.mark.feature_v20 +def test_division_force_div_is_floordiv_false(db_parameters, operation): + engine = create_engine(URL(**db_parameters), **{"force_div_is_floordiv": False}) + with engine.connect() as conn: + eq_( + conn.execute( + select(operation[0] / operation[1], operation[0] // operation[1]) + ).fetchall(), + [(operation[2], operation[3])], + ) From 75376ffb518a11971a2e04ef87bde9044d64a423 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Fri, 10 Jan 2025 13:57:12 -0600 Subject: [PATCH 60/74] Fix incorrect quoting of identifiers with `_` as initial character. (#569) --- DESCRIPTION.md | 3 ++- src/snowflake/sqlalchemy/base.py | 18 +++++++++++++++++- tests/test_compiler.py | 15 +++++++++++++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 00f860ce..1ec1e072 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,7 +9,8 @@ Source code is also available at: # Release Notes - (Unreleased) - - Fix return value of snowflake get_table_names + - Fix return value of snowflake get_table_names. + - Fix incorrect quoting of identifiers with `_` as initial character. - Added `force_div_is_floordiv` flag to override `div_is_floordiv` new default value `False` in `SnowflakeDialect`. - With the flag in `False`, the `/` division operator will be treated as a float division and `//` as a floor division. - This flag is added to maintain backward compatibility with the previous behavior of Snowflake Dialect division. diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 0226d37d..dc624949 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -117,7 +117,11 @@ r"\s*(?:UPDATE|INSERT|DELETE|MERGE|COPY)", re.I | re.UNICODE ) # used for quoting identifiers ie. table names, column names, etc. -ILLEGAL_INITIAL_CHARACTERS = frozenset({d for d in string.digits}.union({"_", "$"})) +ILLEGAL_INITIAL_CHARACTERS = frozenset({d for d in string.digits}.union({"$"})) + + +# used for quoting identifiers ie. table names, column names, etc. +ILLEGAL_IDENTIFIERS = frozenset({d for d in string.digits}.union({"_"})) """ Overwrite methods to handle Snowflake BCR change: @@ -443,6 +447,7 @@ def _join_left_to_right( class SnowflakeIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = {x.lower() for x in RESERVED_WORDS} illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS + illegal_identifiers = ILLEGAL_IDENTIFIERS def __init__(self, dialect, **kw): quote = '"' @@ -471,6 +476,17 @@ def format_label(self, label, name=None): return self.quote_identifier(s) if n.quote else s + def _requires_quotes(self, value: str) -> bool: + """Return True if the given identifier requires quoting.""" + lc_value = value.lower() + return ( + lc_value in self.reserved_words + or lc_value in self.illegal_identifiers + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(str(value)) + or (lc_value != value) + ) + def _split_schema_by_dot(self, schema): ret = [] idx = 0 diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 0eea4607..cb9632a4 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -50,6 +50,21 @@ def test_underscore_as_valid_identifier(self): dialect="snowflake", ) + def test_underscore_as_initial_character_as_non_quoted_identifier(self): + _table = table( + "table_1745924", + column("ca", Integer), + column("cb", String), + column("_identifier", String), + ) + + stmt = insert(_table).values(ca=1, cb="test", _identifier="test_") + self.assert_compile( + stmt, + "INSERT INTO table_1745924 (ca, cb, _identifier) VALUES (%(ca)s, %(cb)s, %(_identifier)s)", + dialect="snowflake", + ) + def test_multi_table_delete(self): statement = table1.delete().where(table1.c.id == table2.c.id) self.assert_compile( From ff2a7855cad5760c1cf0d64d055c18c727b6dc02 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Fri, 10 Jan 2025 15:22:47 -0600 Subject: [PATCH 61/74] SNOW-1871582 Fix support for SA array (#568) * Fix drop support for SA array * Update DESCRIPTION.md --- DESCRIPTION.md | 1 + src/snowflake/sqlalchemy/base.py | 3 +++ src/snowflake/sqlalchemy/custom_types.py | 2 +- .../__snapshots__/test_structured_datatypes.ambr | 3 +++ tests/test_structured_datatypes.py | 15 +++++++++++++++ 5 files changed, 23 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 1ec1e072..ec4fb907 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,6 +9,7 @@ Source code is also available at: # Release Notes - (Unreleased) + - Fix support for SqlAlchemy ARRAY. - Fix return value of snowflake get_table_names. - Fix incorrect quoting of identifiers with `_` as initial character. - Added `force_div_is_floordiv` flag to override `div_is_floordiv` new default value `False` in `SnowflakeDialect`. diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index dc624949..587a497c 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -1133,6 +1133,9 @@ def visit_MAP(self, type_, **kw): ) def visit_ARRAY(self, type_, **kw): + return "ARRAY" + + def visit_SNOWFLAKE_ARRAY(self, type_, **kw): if type_.is_semi_structured: return "ARRAY" not_null = f" {NOT_NULL}" if type_.not_null else "" diff --git a/src/snowflake/sqlalchemy/custom_types.py b/src/snowflake/sqlalchemy/custom_types.py index 11cd2eb8..c742b740 100644 --- a/src/snowflake/sqlalchemy/custom_types.py +++ b/src/snowflake/sqlalchemy/custom_types.py @@ -83,7 +83,7 @@ def __repr__(self): class ARRAY(StructuredType): - __visit_name__ = "ARRAY" + __visit_name__ = "SNOWFLAKE_ARRAY" def __init__( self, diff --git a/tests/__snapshots__/test_structured_datatypes.ambr b/tests/__snapshots__/test_structured_datatypes.ambr index 3dcedf7c..453d26e4 100644 --- a/tests/__snapshots__/test_structured_datatypes.ambr +++ b/tests/__snapshots__/test_structured_datatypes.ambr @@ -5,6 +5,9 @@ # name: test_compile_table_with_double_map 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL, MAP(DECIMAL, VARCHAR)), \tPRIMARY KEY ("Id"))' # --- +# name: test_compile_table_with_sqlalchemy_array + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname ARRAY, \tPRIMARY KEY ("Id"))' +# --- # name: test_compile_table_with_structured_data_type[structured_type0] 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tPRIMARY KEY ("Id"))' # --- diff --git a/tests/test_structured_datatypes.py b/tests/test_structured_datatypes.py index ce030bd2..fb73673b 100644 --- a/tests/test_structured_datatypes.py +++ b/tests/test_structured_datatypes.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import pytest +import sqlalchemy as sa from sqlalchemy import ( Column, Integer, @@ -47,6 +48,20 @@ def test_compile_table_with_structured_data_type( assert sql_compiler(create_table) == snapshot +def test_compile_table_with_sqlalchemy_array(sql_compiler, snapshot): + metadata = MetaData() + user_table = Table( + "clustered_user", + metadata, + Column("Id", Integer, primary_key=True), + Column("name", sa.ARRAY(sa.String)), + ) + + create_table = CreateTable(user_table) + + assert sql_compiler(create_table) == snapshot + + @pytest.mark.requires_external_volume def test_insert_map(engine_testaccount, external_volume, base_location, snapshot): metadata = MetaData() From 6a9ac4dd0961e9ede8c5dba1580b0c3e6ed33d07 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Mon, 13 Jan 2025 12:22:19 -0600 Subject: [PATCH 62/74] Add structured type support for hybrid tables (#572) --- DESCRIPTION.md | 1 + .../sql/custom_schema/hybrid_table.py | 1 + .../test_compile_hybrid_table.ambr | 3 +++ .../test_compile_hybrid_table.py | 25 ++++++++++++++++--- 4 files changed, 26 insertions(+), 4 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index ec4fb907..ee903263 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -12,6 +12,7 @@ Source code is also available at: - Fix support for SqlAlchemy ARRAY. - Fix return value of snowflake get_table_names. - Fix incorrect quoting of identifiers with `_` as initial character. + - Fix ARRAY type not supported in HYBRID tables. - Added `force_div_is_floordiv` flag to override `div_is_floordiv` new default value `False` in `SnowflakeDialect`. - With the flag in `False`, the `/` division operator will be treated as a float division and `//` as a floor division. - This flag is added to maintain backward compatibility with the previous behavior of Snowflake Dialect division. diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py index 16a58d47..b3a55f20 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py @@ -32,6 +32,7 @@ class HybridTable(CustomTableBase): __table_prefixes__ = [CustomTablePrefix.HYBRID] _enforce_primary_keys: bool = True + _support_structured_types = True def __init__( self, diff --git a/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr b/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr index 9412fb45..2622399c 100644 --- a/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr +++ b/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr @@ -5,3 +5,6 @@ # name: test_compile_hybrid_table_orm 'CREATE HYBRID TABLE test_hybrid_table_orm (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY (id))' # --- +# name: test_compile_hybrid_table_with_array + 'CREATE HYBRID TABLE test_hybrid_table (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tgeom GEOMETRY, \tarray ARRAY, \tPRIMARY KEY (id))' +# --- diff --git a/tests/custom_tables/test_compile_hybrid_table.py b/tests/custom_tables/test_compile_hybrid_table.py index f1af6dc2..7310e21c 100644 --- a/tests/custom_tables/test_compile_hybrid_table.py +++ b/tests/custom_tables/test_compile_hybrid_table.py @@ -1,15 +1,14 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -import pytest + from sqlalchemy import Column, Integer, MetaData, String from sqlalchemy.orm import declarative_base from sqlalchemy.sql.ddl import CreateTable -from snowflake.sqlalchemy import GEOMETRY, HybridTable +from snowflake.sqlalchemy import ARRAY, GEOMETRY, HybridTable -@pytest.mark.aws def test_compile_hybrid_table(sql_compiler, snapshot): metadata = MetaData() table_name = "test_hybrid_table" @@ -28,7 +27,25 @@ def test_compile_hybrid_table(sql_compiler, snapshot): assert actual == snapshot -@pytest.mark.aws +def test_compile_hybrid_table_with_array(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_hybrid_table" + test_geometry = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + Column("geom", GEOMETRY), + Column("array", ARRAY), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + def test_compile_hybrid_table_orm(sql_compiler, snapshot): Base = declarative_base() From 87c48cc4ff6863ebab818657d0c10dbe468304a9 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Mon, 13 Jan 2025 15:52:35 -0600 Subject: [PATCH 63/74] Update DESCRIPTION.md (#573) --- DESCRIPTION.md | 4 ++-- src/snowflake/sqlalchemy/version.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index ee903263..1fab870e 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -8,12 +8,12 @@ Source code is also available at: # Release Notes -- (Unreleased) +- v1.7.3(January 15, 2025) - Fix support for SqlAlchemy ARRAY. - Fix return value of snowflake get_table_names. - Fix incorrect quoting of identifiers with `_` as initial character. - Fix ARRAY type not supported in HYBRID tables. - - Added `force_div_is_floordiv` flag to override `div_is_floordiv` new default value `False` in `SnowflakeDialect`. + - Add `force_div_is_floordiv` flag to override `div_is_floordiv` new default value `False` in `SnowflakeDialect`. - With the flag in `False`, the `/` division operator will be treated as a float division and `//` as a floor division. - This flag is added to maintain backward compatibility with the previous behavior of Snowflake Dialect division. - This flag will be removed in the future and Snowflake Dialect will use `div_is_floor_div` as `False`. diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index e54f6a9c..5d3937ad 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -3,4 +3,4 @@ # # Update this for the versions # Don't change the forth version number from None -VERSION = "1.7.2" +VERSION = "1.7.3" From 6d8c5e9c51cdc8cc1f397b37953b0a057521346d Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Wed, 15 Jan 2025 13:29:42 -0600 Subject: [PATCH 64/74] Update README.md (#574) --- DESCRIPTION.md | 2 ++ README.md | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 1fab870e..b36cef68 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -6,6 +6,8 @@ Snowflake Documentation is available at: Source code is also available at: +# Unreleased Notes + - Update README.md to include instructions on how to verify package signatures using cosign. # Release Notes - v1.7.3(January 15, 2025) diff --git a/README.md b/README.md index 34d86376..38b11c82 100644 --- a/README.md +++ b/README.md @@ -677,6 +677,28 @@ dynamic_test_table_1 = DynamicTable( - Direct data insertion into Dynamic Tables is not supported. +## Verifying Package Signatures + +To ensure the authenticity and integrity of the Python package, follow the steps below to verify the package signature using `cosign`. + +**Steps to verify the signature:** +- Install cosign: + - This example is using golang installation: [installing-cosign-with-go](https://edu.chainguard.dev/open-source/sigstore/cosign/how-to-install-cosign/#installing-cosign-with-go) +- Download the file from the repository like pypi: + - https://pypi.org/project/snowflake-sqlalchemy/#files +- Download the signature files from the release tag, replace the version number with the version you are verifying: + - https://github.com/snowflakedb/snowflake-sqlalchemy/releases/tag/v1.7.3 +- Verify signature: + ````bash + # replace the version number with the version you are verifying + ./cosign verify-blob snowflake_sqlalchemy-1.7.3-py3-none-any.whl \ + --certificate snowflake_sqlalchemy-1.7.3-py3-none-any.whl.crt \ + --certificate-identity https://github.com/snowflakedb/snowflake-sqlalchemy/.github/workflows/python-publish.yml@refs/tags/v1.7.3 \ + --certificate-oidc-issuer https://token.actions.githubusercontent.com \ + --signature snowflake_sqlalchemy-1.7.3-py3-none-any.whl.sig + Verified OK + ```` + ## Support Feel free to file an issue or submit a PR here for general cases. For official support, contact Snowflake support at: From bb38b4ba0ed7a816197862ade7cd05842706b916 Mon Sep 17 00:00:00 2001 From: David Szmolka <69192509+sfc-gh-dszmolka@users.noreply.github.com> Date: Wed, 26 Mar 2025 18:51:50 +0100 Subject: [PATCH 65/74] jira creation ownership (to another support group) (#578) --- .github/workflows/jira_issue.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/jira_issue.yml b/.github/workflows/jira_issue.yml index 80a31d31..d12ff3e5 100644 --- a/.github/workflows/jira_issue.yml +++ b/.github/workflows/jira_issue.yml @@ -38,7 +38,7 @@ jobs: summary: '${{ github.event.issue.title }}' description: | ${{ github.event.issue.body }} \\ \\ _Created from GitHub Action_ for ${{ github.event.issue.html_url }} - fields: '{"customfield_11401":{"id":"14723"},"assignee":{"id":"712020:e1f41916-da57-4fe8-b317-116d5229aa51"},"components":[{"id":"16161"},{"id":"16403"}], "labels": ["oss"], "priority": {"id": "10001"} }' + fields: '{"customfield_11401":{"id":"14723"},"assignee":{"id":"712020:e527ae71-55cc-4e02-9217-1ca4ca8028a2"},"components":[{"id":"16161"},{"id":"16403"}], "labels": ["oss"], "priority": {"id": "10001"} }' - name: Update GitHub Issue uses: ./jira/gajira-issue-update From 8eea6dbc9075443a0585b51186e6c4e9ebf6aecf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olman=20Garc=C3=ADa?= <127251112+sfc-gh-ogarciabarquero@users.noreply.github.com> Date: Thu, 8 May 2025 11:10:27 -0600 Subject: [PATCH 66/74] Update README.md to reflect maintenance mode In maintenance mode, showing it in readme file --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 38b11c82..6356d798 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ Snowflake SQLAlchemy runs on the top of the Snowflake Connector for Python as a [dialect](http://docs.sqlalchemy.org/en/latest/dialects/) to bridge a Snowflake database and SQLAlchemy applications. -| :exclamation: | For production-affecting or urgent issues related to the connector, please [create a case with Snowflake Support](https://community.snowflake.com/s/article/How-To-Submit-a-Support-Case-in-Snowflake-Lodge). | +| :exclamation: | Effective May 8th, 2025, Snowflake SQLAlchemy will transition to maintenance mode and will cease active development. Support will be limited to addressing critical bugs and security vulnerabilities. To report such issues, please [create a case with Snowflake Support](https://community.snowflake.com/s/article/How-To-Submit-a-Support-Case-in-Snowflake-Lodge). for individual evaluation. Please note that pull requests from external contributors may not receive action from Snowflake. | |---------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| From 8c9784edb60e9901d14c78a65b81dc95e4d86afe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olman=20Garc=C3=ADa?= <127251112+sfc-gh-ogarciabarquero@users.noreply.github.com> Date: Fri, 30 May 2025 13:29:29 -0600 Subject: [PATCH 67/74] Getting values for specific columns instead of trying to fetch all columns. (#586) Trying to get all columns generates errors when parameters enable extra columns for describe table command. --- src/snowflake/sqlalchemy/snowdialect.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index bba1160f..2db2fb00 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -654,21 +654,13 @@ def _get_table_columns(self, connection, table_name, schema=None, **kw): f" TABLE {table_schema}.{table_name} TYPE = COLUMNS" ) ) - for ( - column_name, - coltype, - _kind, - is_nullable, - column_default, - primary_key, - _unique_key, - _check, - _expression, - comment, - _policy_name, - _privacy_domain, - _name_mapping, - ) in result: + for desc_data in result: + column_name = desc_data[0] + coltype = desc_data[1] + is_nullable = desc_data[3] + column_default = desc_data[4] + primary_key = desc_data[5] + comment = desc_data[9] column_name = self.normalize_name(column_name) if column_name.startswith("sys_clustering_column"): From d5b736e0eb0ab20bec11ca8460d7799dd851c52b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olman=20Garc=C3=ADa?= <127251112+sfc-gh-ogarciabarquero@users.noreply.github.com> Date: Fri, 30 May 2025 14:47:21 -0600 Subject: [PATCH 68/74] Version 1.7.4 Fixing exception caused by dependency on number of columns returned by DESCRIBE TABLE --- DESCRIPTION.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index b36cef68..ab862740 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -10,6 +10,8 @@ Source code is also available at: - Update README.md to include instructions on how to verify package signatures using cosign. # Release Notes +- v1.7.4(May 30, 2025) + - Fix dependency on DESCRIBE TABLE columns quantity (differences in columns caused by Snowflake parameters) - v1.7.3(January 15, 2025) - Fix support for SqlAlchemy ARRAY. - Fix return value of snowflake get_table_names. From 342629674dfd3a3da4ad6c65cc5bbeb0d5a2a33d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olman=20Garc=C3=ADa?= <127251112+sfc-gh-ogarciabarquero@users.noreply.github.com> Date: Fri, 30 May 2025 14:52:48 -0600 Subject: [PATCH 69/74] Update DESCRIPTION.md --- DESCRIPTION.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index ab862740..1ab1af28 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -10,7 +10,7 @@ Source code is also available at: - Update README.md to include instructions on how to verify package signatures using cosign. # Release Notes -- v1.7.4(May 30, 2025) +- (Unreleased) - Fix dependency on DESCRIBE TABLE columns quantity (differences in columns caused by Snowflake parameters) - v1.7.3(January 15, 2025) - Fix support for SqlAlchemy ARRAY. From 5e74644e86f9d8f49f265c9263a58e63d1e931a3 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Mon, 2 Jun 2025 09:24:43 -0600 Subject: [PATCH 70/74] Skip failing test (#587) * Fix failing tests * Try fixing again * Fix test third try * Fix tests forth try, add insecure_mode instead with disable_ocsp_checks * Remove skip, it is not necessary --- DESCRIPTION.md | 5 +++-- tests/conftest.py | 6 ++++++ tests/custom_tables/test_create_iceberg_table.py | 9 ++++++--- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 1ab1af28..d52dfadf 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -7,11 +7,12 @@ Snowflake Documentation is available at: Source code is also available at: # Unreleased Notes - - Update README.md to include instructions on how to verify package signatures using cosign. # Release Notes -- (Unreleased) +- v1.7.3(June 2, 2025) - Fix dependency on DESCRIBE TABLE columns quantity (differences in columns caused by Snowflake parameters) + - Update README.md to include instructions on how to verify package signatures using cosign. + - v1.7.3(January 15, 2025) - Fix support for SqlAlchemy ARRAY. - Fix return value of snowflake get_table_names. diff --git a/tests/conftest.py b/tests/conftest.py index 5e0fd3ed..df45bb2a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -203,6 +203,12 @@ def get_engine(url: URL, **engine_kwargs): "echo": True, } engine_params.update(engine_kwargs) + + connect_args = engine_params.get("connect_args", {}).copy() + connect_args["disable_ocsp_checks"] = True + connect_args["insecure_mode"] = True + engine_params["connect_args"] = connect_args + engine = create_engine(url, **engine_params) return engine diff --git a/tests/custom_tables/test_create_iceberg_table.py b/tests/custom_tables/test_create_iceberg_table.py index 3ecd703b..5ce75909 100644 --- a/tests/custom_tables/test_create_iceberg_table.py +++ b/tests/custom_tables/test_create_iceberg_table.py @@ -9,7 +9,7 @@ @pytest.mark.aws -def test_create_iceberg_table(engine_testaccount, snapshot): +def test_create_iceberg_table(engine_testaccount): metadata = MetaData() external_volume_name = "exvol" create_external_volume = f""" @@ -19,7 +19,7 @@ def test_create_iceberg_table(engine_testaccount, snapshot): ( NAME = 'my-s3-us-west-2' STORAGE_PROVIDER = 'S3' - STORAGE_BASE_URL = 's3://MY_EXAMPLE_BUCKET/' + STORAGE_BASE_URL = 's3://myexamplebucket/' STORAGE_AWS_ROLE_ARN = 'arn:aws:iam::123456789012:role/myrole' ENCRYPTION=(TYPE='AWS_SSE_KMS' KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab') ) @@ -40,4 +40,7 @@ def test_create_iceberg_table(engine_testaccount, snapshot): metadata.create_all(engine_testaccount) error_str = str(argument_error.value) - assert error_str[: error_str.rfind("\n")] == snapshot + assert ( + "(snowflake.connector.errors.ProgrammingError)" + in error_str[: error_str.rfind("\n")] + ) From 426395ff65c4d5cfc09a6f82f34ccc07d848aa8e Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Mon, 2 Jun 2025 13:19:14 -0600 Subject: [PATCH 71/74] Fix small typo (#588) --- DESCRIPTION.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index d52dfadf..82b051f7 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,7 +9,7 @@ Source code is also available at: # Unreleased Notes # Release Notes -- v1.7.3(June 2, 2025) +- v1.7.4(June 2, 2025) - Fix dependency on DESCRIBE TABLE columns quantity (differences in columns caused by Snowflake parameters) - Update README.md to include instructions on how to verify package signatures using cosign. From a77b86adcd465a2de69fc4ba76387078d6f82784 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Mon, 9 Jun 2025 09:38:11 -0600 Subject: [PATCH 72/74] Fix condition is failing when parsing SqlAlchemy columns (#592) * Fix issue * Update release notes * Add tests --- DESCRIPTION.md | 3 +- src/snowflake/sqlalchemy/snowdialect.py | 12 +------- .../test_reflect_snowflake_table.ambr | 3 ++ .../test_reflect_snowflake_table.py | 30 +++++++++++++++++++ 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 82b051f7..f1735d57 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -10,7 +10,8 @@ Source code is also available at: # Release Notes - v1.7.4(June 2, 2025) - - Fix dependency on DESCRIBE TABLE columns quantity (differences in columns caused by Snowflake parameters) + - Fix dependency on DESCRIBE TABLE columns quantity (differences in columns caused by Snowflake parameters). + - Fix unnecessary condition was causing issues when parsing StructuredTypes columns. - Update README.md to include instructions on how to verify package signatures using cosign. - v1.7.3(January 15, 2025) diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 2db2fb00..1e7ccaef 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -511,13 +511,6 @@ def _get_schema_columns(self, connection, schema, **kw): ) schema_name = self.denormalize_name(schema) - iceberg_table_names = self.get_table_names_with_prefix( - connection, - schema=schema_name, - prefix=CustomTablePrefix.ICEBERG.name, - info_cache=kw.get("info_cache", None), - ) - result = connection.execute( text( """ @@ -578,10 +571,7 @@ def _get_schema_columns(self, connection, schema, **kw): col_type_kw["scale"] = numeric_scale elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)): col_type_kw["length"] = character_maximum_length - elif ( - issubclass(col_type, StructuredType) - and table_name in iceberg_table_names - ): + elif issubclass(col_type, StructuredType): if (schema_name, table_name) not in full_columns_descriptions: full_columns_descriptions[(schema_name, table_name)] = ( self.table_columns_as_dict( diff --git a/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr b/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr index 7e85841a..e9a4ac83 100644 --- a/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr +++ b/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr @@ -21,6 +21,9 @@ }), ]) # --- +# name: test_reflection_of_table_with_object_data_type + 'CREATE TABLE test_snowflake_table_reflection (\tid DECIMAL(38, 0) NOT NULL, \tname OBJECT, \tCONSTRAINT demo_name PRIMARY KEY (id))' +# --- # name: test_simple_reflection_of_table_as_snowflake_table 'CREATE TABLE test_snowflake_table_reflection (\tid DECIMAL(38, 0) NOT NULL, \tname VARCHAR(16777216), \tCONSTRAINT demo_name PRIMARY KEY (id))' # --- diff --git a/tests/custom_tables/test_reflect_snowflake_table.py b/tests/custom_tables/test_reflect_snowflake_table.py index 603b6187..323dd281 100644 --- a/tests/custom_tables/test_reflect_snowflake_table.py +++ b/tests/custom_tables/test_reflect_snowflake_table.py @@ -7,6 +7,36 @@ from snowflake.sqlalchemy import SnowflakeTable +def test_reflection_of_table_with_object_data_type( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_snowflake_table_reflection" + + create_table_sql = f""" + CREATE TABLE {table_name} (id INT primary key, name OBJECT); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + snowflake_test_table = Table(table_name, metadata, autoload_with=engine_testaccount) + constraint = snowflake_test_table.constraints.pop() + constraint.name = "demo_name" + snowflake_test_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(snowflake_test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) + + def test_simple_reflection_of_table_as_sqlalchemy_table( engine_testaccount, db_parameters, sql_compiler, snapshot ): From 09f47b02d0de998647ecb7a5a7c2bad8ae62af63 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Tue, 10 Jun 2025 08:41:17 -0600 Subject: [PATCH 73/74] Update version.py (#593) --- DESCRIPTION.md | 2 +- src/snowflake/sqlalchemy/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index f1735d57..236e1b67 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,7 +9,7 @@ Source code is also available at: # Unreleased Notes # Release Notes -- v1.7.4(June 2, 2025) +- v1.7.4(June 10, 2025) - Fix dependency on DESCRIBE TABLE columns quantity (differences in columns caused by Snowflake parameters). - Fix unnecessary condition was causing issues when parsing StructuredTypes columns. - Update README.md to include instructions on how to verify package signatures using cosign. diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index 5d3937ad..3d03a626 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -3,4 +3,4 @@ # # Update this for the versions # Don't change the forth version number from None -VERSION = "1.7.3" +VERSION = "1.7.4" From afd2e752df8091fa182aef4d2661959cdfbdadb4 Mon Sep 17 00:00:00 2001 From: Simon Hewitt Date: Tue, 10 Jan 2023 13:20:28 -0800 Subject: [PATCH 74/74] call _compiler_dispatch for merge_into and copy_into clauses --- src/snowflake/sqlalchemy/base.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 587a497c..59c3f91e 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -530,8 +530,11 @@ def visit_merge_into(self, merge_into, **kw): clauses = " ".join( clause._compiler_dispatch(self, **kw) for clause in merge_into.clauses ) + target = merge_into.target._compiler_dispatch(self, asfrom=True, **kw) + source = merge_into.source._compiler_dispatch(self, asfrom=True, **kw) + on = merge_into.on._compiler_dispatch(self, **kw) return ( - f"MERGE INTO {merge_into.target} USING {merge_into.source} ON {merge_into.on}" + f"MERGE INTO {target} USING {source} ON {on}" + (" " + clauses if clauses else "") ) @@ -579,11 +582,8 @@ def visit_copy_into(self, copy_into, **kw): formatter = copy_into.formatter._compiler_dispatch(self, **kw) else: formatter = "" - into = ( - copy_into.into - if isinstance(copy_into.into, Table) - else copy_into.into._compiler_dispatch(self, **kw) - ) + into = copy_into.into._compiler_dispatch(self, asfrom=True, **kw) + from_ = None if isinstance(copy_into.from_, Table): from_ = copy_into.from_.name # this is intended to catch AWSBucket and AzureContainer