diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index d8130079..b2168af7 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -@snowflakedb/snowpark-python-api +* @snowflakedb/ORM diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index 9c648c73..5e9823f2 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -24,164 +24,282 @@ jobs: name: Check linting runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + with: + persist-credentials: false - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: '3.8' - - name: Display Python version - run: python -c "import sys; import os; print(\"\n\".join(os.environ[\"PATH\"].split(os.pathsep))); print(sys.version); print(sys.executable);" - - name: Upgrade setuptools, pip and wheel - run: python -m pip install -U setuptools pip wheel - - name: Install tox - run: python -m pip install tox + - name: Upgrade and install tools + run: | + python -m pip install -U uv + python -m uv pip install -U hatch + python -m hatch env create default - name: Set PY - run: echo "PY=$(python -VV | sha256sum | cut -d' ' -f1)" >> $GITHUB_ENV - - uses: actions/cache@v1 + run: echo "PY=$(hatch run gh-cache-sum)" >> $GITHUB_ENV + - uses: actions/cache@v4 with: path: ~/.cache/pre-commit key: pre-commit|${{ env.PY }}|${{ hashFiles('.pre-commit-config.yaml') }} - - name: Run fix_lint - run: python -m tox -e fix_lint + - name: Run lint checks + run: hatch run check + + build-install: + name: Test package build and installation + runs-on: ubuntu-latest + needs: lint + strategy: + fail-fast: true + matrix: + hatch-env: [default, sa14] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Setup up Python + uses: actions/setup-python@v5 + with: + python-version: '3.8' + - name: Upgrade and install tools + run: | + python -m pip install -U uv + python -m uv pip install -U hatch + - name: Build package + run: | + python -m hatch -e ${{ matrix.hatch-env }} build --clean + - name: Install and check import + run: | + python -m uv pip install dist/snowflake_sqlalchemy-*.whl + python -c "import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)" + + test-dialect: + name: Test dialect ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + needs: [ lint, build-install ] + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ + ubuntu-latest, + macos-13, + windows-latest, + ] + python-version: ["3.8"] + cloud-provider: [ + aws, + azure, + gcp, + ] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Upgrade pip and prepare environment + run: | + python -m pip install -U uv + python -m uv pip install -U hatch + python -m hatch env create default + - name: Setup parameters file + shell: bash + env: + PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} + run: | + gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ + .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py + - name: Run test for AWS + run: hatch run test-dialect-aws + if: matrix.cloud-provider == 'aws' + - name: Run tests + run: hatch run test-dialect + - uses: actions/upload-artifact@v4 + with: + name: coverage.xml_dialect-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + path: | + ./coverage.xml + + test-dialect-compatibility: + name: Test dialect compatibility ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + needs: [ lint, build-install ] + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ + ubuntu-latest, + macos-13, + windows-latest, + ] + python-version: ["3.8"] + cloud-provider: [ + aws, + azure, + gcp, + ] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Upgrade pip and install hatch + run: | + python -m pip install -U uv + python -m uv pip install -U hatch + python -m hatch env create default + - name: Setup parameters file + shell: bash + env: + PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} + run: | + gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ + .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py + - name: Run tests + run: hatch run test-dialect-compatibility + - uses: actions/upload-artifact@v4 + with: + name: coverage.xml_dialect-compatibility-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + path: | + ./coverage.xml - test: - name: Test ${{ matrix.os.download_name }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} - needs: lint - runs-on: ${{ matrix.os.image_name }} - strategy: - fail-fast: false - matrix: - os: - - image_name: ubuntu-latest - download_name: manylinux_x86_64 - - image_name: macos-latest - download_name: macosx_x86_64 - - image_name: windows-2019 - download_name: win_amd64 - python-version: ["3.8"] - cloud-provider: [aws, azure, gcp] - steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Display Python version - run: python -c "import sys; print(sys.version)" - - name: Setup parameters file - shell: bash - env: - PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} - run: | - gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ - .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py - - name: Upgrade setuptools, pip and wheel - run: python -m pip install -U setuptools pip wheel - - name: Install tox - run: python -m pip install tox - - name: List installed packages - run: python -m pip freeze - - name: Run tests - run: python -m tox -e "py${PYTHON_VERSION/\./}" --skip-missing-interpreters false - env: - PYTHON_VERSION: ${{ matrix.python-version }} - PYTEST_ADDOPTS: -vvv --color=yes --tb=short - TOX_PARALLEL_NO_SPINNER: 1 - - name: Combine coverages - run: python -m tox -e coverage --skip-missing-interpreters false - shell: bash - - uses: actions/upload-artifact@v2 - with: - name: coverage_${{ matrix.os.download_name }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} - path: | - .tox/.coverage - .tox/coverage.xml + test-dialect-v14: + name: Test dialect v14 ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + needs: [ lint, build-install ] + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ + ubuntu-latest, + macos-13, + windows-latest, + ] + python-version: ["3.8"] + cloud-provider: [ + aws, + azure, + gcp, + ] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Setup parameters file + shell: bash + env: + PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} + run: | + gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ + .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py + - name: Upgrade pip and install hatch + run: | + python -m pip install -U uv + python -m uv pip install -U hatch + python -m hatch env create default + - name: Run test for AWS + run: hatch run sa14:test-dialect-aws + if: matrix.cloud-provider == 'aws' + - name: Run tests + run: hatch run sa14:test-dialect + - uses: actions/upload-artifact@v4 + with: + name: coverage.xml_dialect-v14-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + path: | + ./coverage.xml - test_connector_regression: - name: Connector Regression Test ${{ matrix.os.download_name }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} - needs: lint - runs-on: ${{ matrix.os.image_name }} - strategy: - fail-fast: false - matrix: - os: - - image_name: ubuntu-latest - download_name: manylinux_x86_64 - python-version: ["3.8"] - cloud-provider: [aws] - steps: - - uses: actions/checkout@v2 - with: - submodules: true - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Display Python version - run: python -c "import sys; print(sys.version)" - - name: Setup parameters file - shell: bash - env: - PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} - run: | - gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ - .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/connector_regression/test/parameters.py - - name: Upgrade setuptools, pip and wheel - run: python -m pip install -U setuptools pip wheel - - name: Install tox - run: python -m pip install tox - - name: List installed packages - run: python -m pip freeze - - name: Run tests - run: python -m tox -e connector_regression --skip-missing-interpreters false - env: - PYTEST_ADDOPTS: -vvv --color=yes --tb=short - TOX_PARALLEL_NO_SPINNER: 1 + test-dialect-compatibility-v14: + name: Test dialect v14 compatibility ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + needs: lint + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ + ubuntu-latest, + macos-13, + windows-latest, + ] + python-version: ["3.8"] + cloud-provider: [ + aws, + azure, + gcp, + ] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Upgrade pip and install hatch + run: | + python -m pip install -U uv + python -m uv pip install -U hatch + python -m hatch env create default + - name: Setup parameters file + shell: bash + env: + PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} + run: | + gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ + .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py + - name: Run tests + run: hatch run sa14:test-dialect-compatibility + - uses: actions/upload-artifact@v4 + with: + name: coverage.xml_dialect-v14-compatibility-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + path: | + ./coverage.xml combine-coverage: - if: ${{ success() || failure() }} name: Combine coverage - needs: [test, test_connector_regression] + if: ${{ success() || failure() }} + needs: [test-dialect, test-dialect-compatibility, test-dialect-v14, test-dialect-compatibility-v14] runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions/download-artifact@v2 - with: - path: artifacts - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: - python-version: '3.8' - - name: Display Python version - run: python -c "import sys; print(sys.version)" - - name: Upgrade setuptools and pip - run: python -m pip install -U setuptools pip wheel - - name: Install tox - run: python -m pip install tox - - name: Collect all coverages to one dir + python-version: "3.8" + - name: Prepare environment + run: | + python -m pip install -U uv + python -m uv pip install -U hatch + hatch env create default + - uses: actions/checkout@v4 + with: + persist-credentials: false + - uses: actions/download-artifact@v4 + with: + path: artifacts/ + - name: Combine coverage files run: | - python -c ' - from pathlib import Path - import shutil - src_dir = Path("artifacts") - dst_dir = Path(".") / ".tox" - dst_dir.mkdir() - for src_file in src_dir.glob("*/.coverage"): - dst_file = dst_dir / f".coverage.{src_file.parent.name[9:]}" - print(f"{src_file} copy to {dst_file}") - shutil.copy(str(src_file), str(dst_file))' - - name: Combine coverages - run: python -m tox -e coverage - - name: Publish html coverage - uses: actions/upload-artifact@v2 - with: - name: overall_cov_html - path: .tox/htmlcov - - name: Publish xml coverage - uses: actions/upload-artifact@v2 - with: - name: overall_cov_xml - path: .tox/coverage.xml - - uses: codecov/codecov-action@v1 - with: - file: .tox/coverage.xml + hatch run coverage combine -a artifacts/coverage.xml_*/coverage.xml + hatch run coverage report -m + - name: Store coverage reports + uses: actions/upload-artifact@v4 + with: + name: coverage.xml + path: coverage.xml + - name: Uplaod to codecov + uses: codecov/codecov-action@v4 + with: + file: coverage.xml + env_vars: OS,PYTHON + fail_ci_if_error: false + flags: unittests + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + url: https://snowflake.codecov.io/ diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml new file mode 100644 index 00000000..252405fd --- /dev/null +++ b/.github/workflows/changelog.yml @@ -0,0 +1,21 @@ +name: Changelog Check + +on: + pull_request: + types: [opened, synchronize, labeled, unlabeled] + branches: + - main + +jobs: + check_change_log: + runs-on: ubuntu-latest + if: ${{!contains(github.event.pull_request.labels.*.name, 'NO-CHANGELOG-UPDATES')}} + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + persist-credentials: false + fetch-depth: 0 + + - name: Ensure DESCRIPTION.md is updated + run: git diff --name-only --diff-filter=ACMRT ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -wq "DESCRIPTION.md" diff --git a/.github/workflows/cla_bot.yml b/.github/workflows/cla_bot.yml index 5574667a..2c87fc92 100644 --- a/.github/workflows/cla_bot.yml +++ b/.github/workflows/cla_bot.yml @@ -8,6 +8,11 @@ on: jobs: CLAssistant: runs-on: ubuntu-latest + permissions: + actions: write + contents: write + pull-requests: write + statuses: write steps: - name: "CLA Assistant" if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target' @@ -17,7 +22,7 @@ jobs: PERSONAL_ACCESS_TOKEN : ${{ secrets.CLA_BOT_TOKEN }} with: path-to-signatures: 'signatures/version1.json' - path-to-document: 'https://github.com/Snowflake-Labs/CLA/blob/main/README.md' + path-to-document: 'https://github.com/snowflakedb/CLA/blob/main/README.md' branch: 'main' allowlist: 'dependabot[bot],github-actions, sfc-gh-snyk-sca-sa' remote-organization-name: 'snowflakedb' diff --git a/.github/workflows/create_req_files.yml b/.github/workflows/create_req_files.yml index 57f7efb8..2cb7a371 100644 --- a/.github/workflows/create_req_files.yml +++ b/.github/workflows/create_req_files.yml @@ -11,18 +11,20 @@ jobs: matrix: python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + with: + persist-credentials: false - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Display Python version run: python -c "import sys; print(sys.version)" - name: Upgrade setuptools, pip and wheel - run: python -m pip install -U setuptools pip wheel + run: python -m pip install -U setuptools pip wheel uv - name: Install Snowflake SQLAlchemy shell: bash - run: python -m pip install . + run: python -m uv pip install . - name: Generate reqs file name shell: bash run: echo "requirements_file=temp_requirement/requirements_$(python -c 'from sys import version_info;print(str(version_info.major)+str(version_info.minor))').reqs" >> $GITHUB_ENV @@ -32,12 +34,12 @@ jobs: mkdir temp_requirement echo "# Generated on: $(python --version)" >${{ env.requirements_file }} python -m pip freeze | grep -v snowflake-sqlalchemy 1>>${{ env.requirements_file }} 2>/dev/null - echo "snowflake-sqlalchemy==$(python -m pip show snowflake-sqlalchemy | grep ^Version | cut -d' ' -f2-)" >>${{ env.requirements_file }} + echo "snowflake-sqlalchemy==$(python -m uv pip show snowflake-sqlalchemy | grep ^Version | cut -d' ' -f2-)" >>${{ env.requirements_file }} id: create-reqs-file - name: Show created req file shell: bash run: cat ${{ env.requirements_file }} - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v4 with: path: temp_requirement @@ -46,11 +48,12 @@ jobs: name: Commit and push files runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: + persist-credentials: false token: ${{ secrets.SNOWFLAKE_GITHUB_TOKEN }} # stored in GitHub secrets - name: Download requirement files - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v4 with: name: artifact path: tested_requirements diff --git a/.github/workflows/jira_close.yml b/.github/workflows/jira_close.yml index dfcb8bc7..7862f483 100644 --- a/.github/workflows/jira_close.yml +++ b/.github/workflows/jira_close.yml @@ -9,14 +9,15 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 with: + persist-credentials: false repository: snowflakedb/gh-actions ref: jira_v1 token: ${{ secrets.SNOWFLAKE_GITHUB_TOKEN }} # stored in GitHub secrets path: . - name: Jira login - uses: atlassian/gajira-login@master + uses: atlassian/gajira-login@v3 env: JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} diff --git a/.github/workflows/jira_comment.yml b/.github/workflows/jira_comment.yml index 954929fa..8533c14c 100644 --- a/.github/workflows/jira_comment.yml +++ b/.github/workflows/jira_comment.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Jira login - uses: atlassian/gajira-login@master + uses: atlassian/gajira-login@v3 env: JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} @@ -22,7 +22,7 @@ jobs: jira=$(echo -n $TITLE | awk '{print $1}' | sed -e 's/://') echo ::set-output name=jira::$jira - name: Comment on issue - uses: atlassian/gajira-comment@master + uses: atlassian/gajira-comment@v3 if: startsWith(steps.extract.outputs.jira, 'SNOW-') with: issue: "${{ steps.extract.outputs.jira }}" diff --git a/.github/workflows/jira_issue.yml b/.github/workflows/jira_issue.yml index 3683bbba..d12ff3e5 100644 --- a/.github/workflows/jira_issue.yml +++ b/.github/workflows/jira_issue.yml @@ -9,18 +9,21 @@ on: jobs: create-issue: runs-on: ubuntu-latest + permissions: + issues: write if: ((github.event_name == 'issue_comment' && github.event.comment.body == 'recreate jira' && github.event.comment.user.login == 'sfc-gh-mkeller') || (github.event_name == 'issues' && github.event.pull_request.user.login != 'whitesource-for-github-com[bot]')) steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 with: + persist-credentials: false repository: snowflakedb/gh-actions ref: jira_v1 token: ${{ secrets.SNOWFLAKE_GITHUB_TOKEN }} # stored in GitHub secrets path: . - name: Login - uses: atlassian/gajira-login@v2.0.0 + uses: atlassian/gajira-login@v3 env: JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} @@ -28,14 +31,14 @@ jobs: - name: Create JIRA Ticket id: create - uses: atlassian/gajira-create@v2.0.1 + uses: atlassian/gajira-create@v3 with: project: SNOW issuetype: Bug summary: '${{ github.event.issue.title }}' description: | ${{ github.event.issue.body }} \\ \\ _Created from GitHub Action_ for ${{ github.event.issue.html_url }} - fields: '{"customfield_11401":{"id":"14586"},"assignee":{"id":"61027a237ab143006ecfb9a2"},"components":[{"id":"16161"},{"id":"16403"}]}' + fields: '{"customfield_11401":{"id":"14723"},"assignee":{"id":"712020:e527ae71-55cc-4e02-9217-1ca4ca8028a2"},"components":[{"id":"16161"},{"id":"16403"}], "labels": ["oss"], "priority": {"id": "10001"} }' - name: Update GitHub Issue uses: ./jira/gajira-issue-update diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index dd1e1ba6..52f43106 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -13,7 +13,8 @@ on: types: [published] permissions: - contents: read + contents: write + id-token: write jobs: deploy: @@ -21,17 +22,59 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + with: + persist-credentials: false - name: Set up Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: '3.x' - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install build + python -m pip install -U uv + python -m uv pip install -U hatch - name: Build package - run: python -m build + run: python -m hatch build --clean + - name: List artifacts + run: ls ./dist + - name: Install sigstore + run: python -m pip install sigstore + - name: Signing + run: | + for dist in dist/*; do + dist_base="$(basename "${dist}")" + echo "dist: ${dist}" + echo "dist_base: ${dist_base}" + python -m \ + sigstore sign "${dist}" \ + --output-signature "${dist_base}.sig" \ + --output-certificate "${dist_base}.crt" \ + --bundle "${dist_base}.sigstore" + + # Verify using `.sig` `.crt` pair; + python -m \ + sigstore verify identity "${dist}" \ + --signature "${dist_base}.sig" \ + --cert "${dist_base}.crt" \ + --cert-oidc-issuer https://token.actions.githubusercontent.com \ + --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/python-publish.yml@${GITHUB_REF} + + # Verify using `.sigstore` bundle; + python -m \ + sigstore verify identity "${dist}" \ + --bundle "${dist_base}.sigstore" \ + --cert-oidc-issuer https://token.actions.githubusercontent.com \ + --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/python-publish.yml@${GITHUB_REF} + done + - name: List artifacts after sign + run: ls ./dist + - name: Copy files to release + run: | + gh release upload ${{ github.event.release.tag_name }} *.sigstore + gh release upload ${{ github.event.release.tag_name }} *.sig + gh release upload ${{ github.event.release.tag_name }} *.crt + env: + GITHUB_TOKEN: ${{ github.TOKEN }} - name: Publish package uses: pypa/gh-action-pypi-publish@release/v1 with: diff --git a/.github/workflows/snyk-issue.yml b/.github/workflows/snyk-issue.yml index 7098b01e..94dfeb53 100644 --- a/.github/workflows/snyk-issue.yml +++ b/.github/workflows/snyk-issue.yml @@ -4,6 +4,11 @@ on: schedule: - cron: '* */12 * * *' +permissions: + contents: read + issues: write + pull-requests: write + concurrency: snyk-issue jobs: @@ -11,13 +16,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout Action - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: + persist-credentials: false repository: snowflakedb/whitesource-actions token: ${{ secrets.whitesource_action_token }} path: whitesource-actions - name: Set Env - run: echo "repo=$(basename $github_repository)" >> $github_env + run: echo "repo=${{ github.event.repository.name }}" >> $GITHUB_ENV - name: Jira Creation uses: ./whitesource-actions/snyk-issue with: diff --git a/.github/workflows/snyk-pr.yml b/.github/workflows/snyk-pr.yml index 51e531f4..cc5e8644 100644 --- a/.github/workflows/snyk-pr.yml +++ b/.github/workflows/snyk-pr.yml @@ -3,20 +3,28 @@ on: pull_request: branches: - main + +permissions: + contents: read + issues: write + pull-requests: write + jobs: snyk: runs-on: ubuntu-latest if: ${{ github.event.pull_request.user.login == 'sfc-gh-snyk-sca-sa' }} steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: + persist-credentials: false ref: ${{ github.event.pull_request.head.ref }} fetch-depth: 0 - name: Checkout Action - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: + persist-credentials: false repository: snowflakedb/whitesource-actions token: ${{ secrets.whitesource_action_token }} path: whitesource-actions diff --git a/.github/workflows/stale_issue_bot.yml b/.github/workflows/stale_issue_bot.yml index 6d76e9f4..4ee56ff8 100644 --- a/.github/workflows/stale_issue_bot.yml +++ b/.github/workflows/stale_issue_bot.yml @@ -10,7 +10,7 @@ jobs: stale: runs-on: ubuntu-latest steps: - - uses: actions/stale@v7 + - uses: actions/stale@v9 with: close-issue-message: 'To clean up and re-prioritize bugs and feature requests we are closing all issues older than 6 months as of Apr 1, 2023. If there are any issues or feature requests that you would like us to address, please re-create them. For urgent issues, opening a support case with this link [Snowflake Community](https://community.snowflake.com/s/article/How-To-Submit-a-Support-Case-in-Snowflake-Lodge) is the fastest way to get a response' days-before-issue-stale: ${{ inputs.staleDays }} diff --git a/.gitmodules b/.gitmodules index 6f8702e6..e69de29b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "tests/connector_regression"] - path = tests/connector_regression - url = git@github.com:snowflakedb/snowflake-connector-python diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3be42964..b7370b74 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,31 +1,32 @@ exclude: '^(.*egg.info.*|.*/parameters.py).*$' repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.5.0 hooks: - id: trailing-whitespace + exclude: '\.ambr$' - id: end-of-file-fixer - id: check-yaml exclude: .github/repo_meta.yaml - id: debug-statements - repo: https://github.com/PyCQA/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort - repo: https://github.com/asottile/pyupgrade - rev: v2.37.3 + rev: v3.15.1 hooks: - id: pyupgrade args: [--py37-plus] - repo: https://github.com/psf/black - rev: 22.6.0 + rev: 24.2.0 hooks: - id: black args: - --safe language_version: python3 - repo: https://github.com/Lucas-C/pre-commit-hooks.git - rev: v1.3.0 + rev: v1.5.5 hooks: - id: insert-license name: insert-py-license @@ -39,8 +40,15 @@ repos: - --license-filepath - license_header.txt - repo: https://github.com/pycqa/flake8 - rev: 5.0.4 + rev: 7.0.0 hooks: - id: flake8 additional_dependencies: - flake8-bugbear +- repo: local + hooks: + - id: requirements-update + name: "Update dependencies from pyproject.toml to snyk/requirements.txt" + language: system + entry: python snyk/update_requirements.py + files: ^pyproject.toml$ diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 5219de1e..236e1b67 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -6,8 +6,79 @@ Snowflake Documentation is available at: Source code is also available at: +# Unreleased Notes # Release Notes +- v1.7.4(June 10, 2025) + - Fix dependency on DESCRIBE TABLE columns quantity (differences in columns caused by Snowflake parameters). + - Fix unnecessary condition was causing issues when parsing StructuredTypes columns. + - Update README.md to include instructions on how to verify package signatures using cosign. + +- v1.7.3(January 15, 2025) + - Fix support for SqlAlchemy ARRAY. + - Fix return value of snowflake get_table_names. + - Fix incorrect quoting of identifiers with `_` as initial character. + - Fix ARRAY type not supported in HYBRID tables. + - Add `force_div_is_floordiv` flag to override `div_is_floordiv` new default value `False` in `SnowflakeDialect`. + - With the flag in `False`, the `/` division operator will be treated as a float division and `//` as a floor division. + - This flag is added to maintain backward compatibility with the previous behavior of Snowflake Dialect division. + - This flag will be removed in the future and Snowflake Dialect will use `div_is_floor_div` as `False`. + +- v1.7.2(December 18, 2024) + - Fix quoting of `_` as column name + - Fix index columns was not being reflected + - Fix index reflection cache not working + - Add support for structured OBJECT datatype + - Add support for structured ARRAY datatype + +- v1.7.1(December 02, 2024) + - Add support for partition by to copy into + - Fix BOOLEAN type not found in snowdialect + - Add support for autocommit Isolation Level + +- v1.7.0(November 21, 2024) + - Add support for dynamic tables and required options + - Add support for hybrid tables + - Fixed SAWarning when registering functions with existing name in default namespace + - Update options to be defined in key arguments instead of arguments. + - Add support for refresh_mode option in DynamicTable + - Add support for iceberg table with Snowflake Catalog + - Fix cluster by option to support explicit expressions + - Add support for MAP datatype + +- v1.6.1(July 9, 2024) + + - Update internal project workflow with pypi publishing + +- v1.6.0(July 8, 2024) + + - support for installing with SQLAlchemy 2.0.x + - use `hatch` & `uv` for managing project virtual environments + +- v1.5.4 + + - Add ability to set ORDER / NOORDER sequence on columns with IDENTITY + +- v1.5.3(April 16, 2024) + + - Limit SQLAlchemy to < 2.0.0 before releasing version compatible with 2.0 + +- v1.5.2(April 11, 2024) + + - Bump min SQLAlchemy to 1.4.19 for outer lateral join + - Add support for sequence ordering in tests + +- v1.5.1(November 03, 2023) + + - Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check . + - Fixed credentials with `externalbrowser` authentication not caching due to incorrect parsing of boolean query parameters. + - This fixes other boolean parameter passing to driver as well. + +- v1.5.0(Aug 23, 2023) + + - Added option to create a temporary stage command. + - Added support for geometry type. + - Fixed a compatibility issue of regex expression with SQLAlchemy 1.4.49. - v1.4.7(Mar 22, 2023) diff --git a/README.md b/README.md index 0c75854e..6356d798 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,11 @@ Snowflake SQLAlchemy runs on the top of the Snowflake Connector for Python as a [dialect](http://docs.sqlalchemy.org/en/latest/dialects/) to bridge a Snowflake database and SQLAlchemy applications. + +| :exclamation: | Effective May 8th, 2025, Snowflake SQLAlchemy will transition to maintenance mode and will cease active development. Support will be limited to addressing critical bugs and security vulnerabilities. To report such issues, please [create a case with Snowflake Support](https://community.snowflake.com/s/article/How-To-Submit-a-Support-Case-in-Snowflake-Lodge). for individual evaluation. Please note that pull requests from external contributors may not receive action from Snowflake. | +|---------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| + + ## Prerequisites ### Snowflake Connector for Python @@ -101,6 +106,7 @@ containing special characters need to be URL encoded to be parsed correctly. Thi characters could lead to authentication failure. The encoding for the password can be generated using `urllib.parse`: + ```python import urllib.parse urllib.parse.quote("kx@% jj5/g") @@ -111,6 +117,7 @@ urllib.parse.quote("kx@% jj5/g") To create an engine with the proper encodings, either manually constructing the url string by formatting or taking advantage of the `snowflake.sqlalchemy.URL` helper method: + ```python import urllib.parse from snowflake.sqlalchemy import URL @@ -191,14 +198,23 @@ engine = create_engine(...) engine.execute() engine.dispose() -# Do this. +# Better. engine = create_engine(...) connection = engine.connect() try: - connection.execute() + connection.execute(text()) finally: connection.close() engine.dispose() + +# Best +try: + with engine.connect() as connection: + connection.execute(text()) + # or + connection.exec_driver_sql() +finally: + engine.dispose() ``` ### Auto-increment Behavior @@ -214,11 +230,43 @@ t = Table('mytable', metadata, ### Object Name Case Handling -Snowflake stores all case-insensitive object names in uppercase text. In contrast, SQLAlchemy considers all lowercase object names to be case-insensitive. Snowflake SQLAlchemy converts the object name case during schema-level communication, i.e. during table and index reflection. If you use uppercase object names, SQLAlchemy assumes they are case-sensitive and encloses the names with quotes. This behavior will cause mismatches agaisnt data dictionary data received from Snowflake, so unless identifier names have been truly created as case sensitive using quotes, e.g., `"TestDb"`, all lowercase names should be used on the SQLAlchemy side. +Snowflake stores all case-insensitive object names in uppercase text. In contrast, SQLAlchemy considers all lowercase object names to be case-insensitive. Snowflake SQLAlchemy converts the object name case during schema-level communication, i.e. during table and index reflection. If you use uppercase object names, SQLAlchemy assumes they are case-sensitive and encloses the names with quotes. This behavior will cause mismatches against data dictionary data received from Snowflake, so unless identifier names have been truly created as case sensitive using quotes, e.g., `"TestDb"`, all lowercase names should be used on the SQLAlchemy side. ### Index Support -Snowflake does not utilize indexes, so neither does Snowflake SQLAlchemy. +Indexes are supported only for Hybrid Tables in Snowflake SQLAlchemy. For more details on limitations and use cases, refer to the [Create Index documentation](https://docs.snowflake.com/en/sql-reference/constraints-indexes.html). You can create an index using the following methods: + +#### Single Column Index + +You can create a single column index by setting the `index=True` parameter on the column or by explicitly defining an `Index` object. + +```python +hybrid_test_table_1 = HybridTable( + "table_name", + metadata, + Column("column1", Integer, primary_key=True), + Column("column2", String, index=True), + Index("index_1", "column1", "column2") +) + +metadata.create_all(engine_testaccount) +``` + +#### Multi-Column Index + +For multi-column indexes, you define the `Index` object specifying the columns that should be indexed. + +```python +hybrid_test_table_1 = HybridTable( + "table_name", + metadata, + Column("column1", Integer, primary_key=True), + Column("column2", String), + Index("index_1", "column1", "column2") +) + +metadata.create_all(engine_testaccount) +``` ### Numpy Data Type Support @@ -242,14 +290,14 @@ engine = create_engine(URL( specific_date = np.datetime64('2016-03-04T12:03:05.123456789Z') -connection = engine.connect() -connection.execute( - "CREATE OR REPLACE TABLE ts_tbl(c1 TIMESTAMP_NTZ)") -connection.execute( - "INSERT INTO ts_tbl(c1) values(%s)", (specific_date,) -) -df = pd.read_sql_query("SELECT * FROM ts_tbl", engine) -assert df.c1.values[0] == specific_date +with engine.connect() as connection: + connection.exec_driver_sql( + "CREATE OR REPLACE TABLE ts_tbl(c1 TIMESTAMP_NTZ)") + connection.exec_driver_sql( + "INSERT INTO ts_tbl(c1) values(%s)", (specific_date,) + ) + df = pd.read_sql_query("SELECT * FROM ts_tbl", connection) + assert df.c1.values[0] == specific_date ``` The following `NumPy` data types are supported: @@ -319,6 +367,79 @@ data_object = json.loads(row[1]) data_array = json.loads(row[2]) ``` +### Structured Data Types Support + +This module defines custom SQLAlchemy types for Snowflake structured data, specifically for **Iceberg tables**. +The types —**MAP**, **OBJECT**, and **ARRAY**— allow you to store complex data structures in your SQLAlchemy models. +For detailed information, refer to the Snowflake [Structured data types](https://docs.snowflake.com/en/sql-reference/data-types-structured) documentation. + +--- + +#### MAP + +The `MAP` type represents a collection of key-value pairs, where each key and value can have different types. + +- **Key Type**: The type of the keys (e.g., `TEXT`, `NUMBER`). +- **Value Type**: The type of the values (e.g., `TEXT`, `NUMBER`). +- **Not Null**: Whether `NULL` values are allowed (default is `False`). + +*Example Usage* + +```python +IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_col", MAP(NUMBER(10, 0), TEXT(16777216))), + external_volume="external_volume", + base_location="base_location", +) +``` + +#### OBJECT + +The `OBJECT` type represents a semi-structured object with named fields. Each field can have a specific type, and you can also specify whether each field is nullable. + +- **Items Types**: A dictionary of field names and their types. The type can optionally include a nullable flag (`True` for not nullable, `False` for nullable, default is `False`). + +*Example Usage* + +```python +IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column( + "object_col", + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + OBJECT(key1=TEXT(16777216), key2=NUMBER(10, 0)), # Without nullable flag + ), + external_volume="external_volume", + base_location="base_location", +) +``` + +#### ARRAY + +The `ARRAY` type represents an ordered list of values, where each element has the same type. The type of the elements is defined when creating the array. + +- **Value Type**: The type of the elements in the array (e.g., `TEXT`, `NUMBER`). +- **Not Null**: Whether `NULL` values are allowed (default is `False`). + +*Example Usage* + +```python +IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("array_col", ARRAY(TEXT(16777216))), + external_volume="external_volume", + base_location="base_location", +) +``` + + ### CLUSTER BY Support Snowflake SQLAchemy supports the `CLUSTER BY` parameter for tables. For information about the parameter, see :doc:`/sql-reference/sql/create-table`. @@ -329,7 +450,7 @@ This example shows how to create a table with two columns, `id` and `name`, as t t = Table('myuser', metadata, Column('id', Integer, primary_key=True), Column('name', String), - snowflake_clusterby=['id', 'name'], ... + snowflake_clusterby=['id', 'name', text('id > 5')], ... ) metadata.create_all(engine) ``` @@ -445,6 +566,139 @@ copy_into = CopyIntoStorage(from_=users, connection.execute(copy_into) ``` +### Iceberg Table with Snowflake Catalog support + +Snowflake SQLAlchemy supports Iceberg Tables with the Snowflake Catalog, along with various related parameters. For detailed information about Iceberg Tables, refer to the Snowflake [CREATE ICEBERG](https://docs.snowflake.com/en/sql-reference/sql/create-iceberg-table-snowflake) documentation. + +To create an Iceberg Table using Snowflake SQLAlchemy, you can define the table using the SQLAlchemy Core syntax as follows: + +```python +table = IcebergTable( + "myuser", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + external_volume=external_volume_name, + base_location="my_iceberg_table", + as_query="SELECT * FROM table" +) +``` + +Alternatively, you can define the table using a declarative approach: + +```python +class MyUser(Base): + __tablename__ = "myuser" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return IcebergTable(name, metadata, *arg, **kw) + + __table_args__ = { + "external_volume": "my_external_volume", + "base_location": "my_iceberg_table", + "as_query": "SELECT * FROM table", + } + + id = Column(Integer, primary_key=True) + name = Column(String) +``` + +### Hybrid Table support + +Snowflake SQLAlchemy supports Hybrid Tables with indexes. For detailed information, refer to the Snowflake [CREATE HYBRID TABLE](https://docs.snowflake.com/en/sql-reference/sql/create-hybrid-table) documentation. + +To create a Hybrid Table and add an index, you can use the SQLAlchemy Core syntax as follows: + +```python +table = HybridTable( + "myuser", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + Index("idx_name", "name") +) +``` + +Alternatively, you can define the table using the declarative approach: + +```python +class MyUser(Base): + __tablename__ = "myuser" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + __table_args__ = ( + Index("idx_name", "name"), + ) + + id = Column(Integer, primary_key=True) + name = Column(String) +``` + +### Dynamic Tables support + +Snowflake SQLAlchemy supports Dynamic Tables. For detailed information, refer to the Snowflake [CREATE DYNAMIC TABLE](https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table) documentation. + +To create a Dynamic Table, you can use the SQLAlchemy Core syntax as follows: + +```python +dynamic_test_table_1 = DynamicTable( + "dynamic_MyUser", + metadata, + Column("id", Integer), + Column("name", String), + target_lag=(1, TimeUnit.HOURS), # Additionally, you can use SnowflakeKeyword.DOWNSTREAM + warehouse='test_wh', + refresh_mode=SnowflakeKeyword.FULL, + as_query="SELECT id, name from MyUser;" +) +``` + +Alternatively, you can define a table without columns using the SQLAlchemy `select()` construct: + +```python +dynamic_test_table_1 = DynamicTable( + "dynamic_MyUser", + metadata, + target_lag=(1, TimeUnit.HOURS), + warehouse='test_wh', + refresh_mode=SnowflakeKeyword.FULL, + as_query=select(MyUser.id, MyUser.name) +) +``` + +### Notes + +- Defining a primary key in a Dynamic Table is not supported, meaning declarative tables don’t support Dynamic Tables. +- When using the `as_query` parameter with a string, you must explicitly define the columns. However, if you use the SQLAlchemy `select()` construct, you don’t need to explicitly define the columns. +- Direct data insertion into Dynamic Tables is not supported. + + +## Verifying Package Signatures + +To ensure the authenticity and integrity of the Python package, follow the steps below to verify the package signature using `cosign`. + +**Steps to verify the signature:** +- Install cosign: + - This example is using golang installation: [installing-cosign-with-go](https://edu.chainguard.dev/open-source/sigstore/cosign/how-to-install-cosign/#installing-cosign-with-go) +- Download the file from the repository like pypi: + - https://pypi.org/project/snowflake-sqlalchemy/#files +- Download the signature files from the release tag, replace the version number with the version you are verifying: + - https://github.com/snowflakedb/snowflake-sqlalchemy/releases/tag/v1.7.3 +- Verify signature: + ````bash + # replace the version number with the version you are verifying + ./cosign verify-blob snowflake_sqlalchemy-1.7.3-py3-none-any.whl \ + --certificate snowflake_sqlalchemy-1.7.3-py3-none-any.whl.crt \ + --certificate-identity https://github.com/snowflakedb/snowflake-sqlalchemy/.github/workflows/python-publish.yml@refs/tags/v1.7.3 \ + --certificate-oidc-issuer https://token.actions.githubusercontent.com \ + --signature snowflake_sqlalchemy-1.7.3-py3-none-any.whl.sig + Verified OK + ```` + ## Support Feel free to file an issue or submit a PR here for general cases. For official support, contact Snowflake support at: diff --git a/ci/build.sh b/ci/build.sh index 4229506d..b63c8e01 100755 --- a/ci/build.sh +++ b/ci/build.sh @@ -3,7 +3,7 @@ # Build snowflake-sqlalchemy set -o pipefail -PYTHON="python3.7" +PYTHON="python3.8" THIS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" SQLALCHEMY_DIR="$(dirname "${THIS_DIR}")" DIST_DIR="${SQLALCHEMY_DIR}/dist" @@ -11,14 +11,16 @@ DIST_DIR="${SQLALCHEMY_DIR}/dist" cd "$SQLALCHEMY_DIR" # Clean up previously built DIST_DIR if [ -d "${DIST_DIR}" ]; then - echo "[WARN] ${DIST_DIR} already existing, deleting it..." - rm -rf "${DIST_DIR}" + echo "[WARN] ${DIST_DIR} already existing, deleting it..." + rm -rf "${DIST_DIR}" fi # Constants and setup +export PATH=$PATH:$HOME/.local/bin echo "[Info] Building snowflake-sqlalchemy with $PYTHON" # Clean up possible build artifacts rm -rf build generated_version.py -${PYTHON} -m pip install --upgrade pip setuptools wheel build -${PYTHON} -m build --outdir ${DIST_DIR} . +export UV_NO_CACHE=true +${PYTHON} -m pip install uv hatch +${PYTHON} -m hatch build diff --git a/ci/test_linux.sh b/ci/test_linux.sh index 695251e6..f5afc4fb 100755 --- a/ci/test_linux.sh +++ b/ci/test_linux.sh @@ -6,9 +6,9 @@ # - This script assumes that ../dist/repaired_wheels has the wheel(s) built for all versions to be tested # - This is the script that test_docker.sh runs inside of the docker container -PYTHON_VERSIONS="${1:-3.7 3.8 3.9 3.10 3.11}" -THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -SQLALCHEMY_DIR="$( dirname "${THIS_DIR}")" +PYTHON_VERSIONS="${1:-3.8 3.9 3.10 3.11}" +THIS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SQLALCHEMY_DIR="$(dirname "${THIS_DIR}")" # Install one copy of tox python3 -m pip install -U tox @@ -16,10 +16,10 @@ python3 -m pip install -U tox # Run tests cd $SQLALCHEMY_DIR for PYTHON_VERSION in ${PYTHON_VERSIONS}; do - echo "[Info] Testing with ${PYTHON_VERSION}" - SHORT_VERSION=$(python3 -c "print('${PYTHON_VERSION}'.replace('.', ''))") - SQLALCHEMY_WHL=$(ls $SQLALCHEMY_DIR/dist/snowflake_sqlalchemy-*-py2.py3-none-any.whl | sort -r | head -n 1) - TEST_ENVLIST=fix_lint,py${SHORT_VERSION}-ci,py${SHORT_VERSION}-coverage - echo "[Info] Running tox for ${TEST_ENVLIST}" - python3 -m tox -e ${TEST_ENVLIST} --installpkg ${SQLALCHEMY_WHL} + echo "[Info] Testing with ${PYTHON_VERSION}" + SHORT_VERSION=$(python3 -c "print('${PYTHON_VERSION}'.replace('.', ''))") + SQLALCHEMY_WHL=$(ls $SQLALCHEMY_DIR/dist/snowflake_sqlalchemy-*-py3-none-any.whl | sort -r | head -n 1) + TEST_ENVLIST=fix_lint,py${SHORT_VERSION}-ci,py${SHORT_VERSION}-coverage + echo "[Info] Running tox for ${TEST_ENVLIST}" + python3 -m tox -e ${TEST_ENVLIST} --installpkg ${SQLALCHEMY_WHL} done diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..b22dc293 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,143 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "snowflake-sqlalchemy" +dynamic = ["version"] +description = "Snowflake SQLAlchemy Dialect" +readme = "README.md" +license = "Apache-2.0" +requires-python = ">=3.8" +authors = [ + { name = "Snowflake Inc.", email = "triage-snowpark-python-api-dl@snowflake.com" }, +] +keywords = ["Snowflake", "analytics", "cloud", "database", "db", "warehouse"] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Environment :: Other Environment", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Information Technology", + "Intended Audience :: System Administrators", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: SQL", + "Topic :: Database", + "Topic :: Scientific/Engineering :: Information Analysis", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Application Frameworks", + "Topic :: Software Development :: Libraries :: Python Modules", +] +dependencies = ["SQLAlchemy>=1.4.19", "snowflake-connector-python<4.0.0"] + +[tool.hatch.version] +path = "src/snowflake/sqlalchemy/version.py" + +[project.optional-dependencies] +development = [ + "pre-commit", + "pytest", + "setuptools", + "pytest-cov", + "pytest-timeout", + "pytest-rerunfailures", + "pytz", + "numpy", + "mock", + "syrupy==4.6.1", +] +pandas = ["snowflake-connector-python[pandas]"] + +[project.entry-points."sqlalchemy.dialects"] +snowflake = "snowflake.sqlalchemy:dialect" + +[project.urls] +Changelog = "https://github.com/snowflakedb/snowflake-sqlalchemy/blob/main/DESCRIPTION.md" +Documentation = "https://docs.snowflake.com/en/user-guide/sqlalchemy.html" +Homepage = "https://www.snowflake.com/" +Issues = "https://github.com/snowflakedb/snowflake-sqlalchemy/issues" +Source = "https://github.com/snowflakedb/snowflake-sqlalchemy" + +[tool.hatch.build.targets.sdist] +exclude = ["/.github"] + +[tool.hatch.build.targets.wheel] +packages = ["src/snowflake"] + +[tool.hatch.envs.default] +path = ".venv" +type = "virtual" +extra-dependencies = ["SQLAlchemy>=1.4.19,<2.1.0"] +features = ["development", "pandas"] +python = "3.8" +installer = "uv" + +[tool.hatch.envs.sa14] +extra-dependencies = ["SQLAlchemy>=1.4.19,<2.0.0"] +features = ["development", "pandas"] +python = "3.8" + +[tool.hatch.envs.sa14.scripts] +test-dialect = "pytest --ignore_v20_test -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" +test-dialect-compatibility = "pytest --ignore_v20_test -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite" +test-dialect-aws = "pytest --ignore_v20_test -m \"aws\" -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" + +[tool.hatch.envs.default.env-vars] +COVERAGE_FILE = "coverage.xml" +SQLACHEMY_WARN_20 = "1" + +[tool.hatch.envs.default.scripts] +check = "pre-commit run --all-files" +test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" +test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite" +test-dialect-aws = "pytest -m \"aws\" -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" +gh-cache-sum = "python -VV | sha256sum | cut -d' ' -f1" +check-import = "python -c 'import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)'" + +[[tool.hatch.envs.release.matrix]] +python = ["3.8", "3.9", "3.10", "3.11", "3.12"] +features = ["development", "pandas"] + +[tool.hatch.envs.release.scripts] +test-dialect = "pytest -ra -vvv --tb=short --ignore=tests/sqlalchemy_test_suite tests/" +test-compatibility = "pytest -ra -vvv --tb=short tests/sqlalchemy_test_suite tests/" + +[tool.ruff] +line-length = 88 + +[tool.black] +line-length = 88 + +[tool.pytest.ini_options] +addopts = "-m 'not feature_max_lob_size and not aws and not requires_external_volume'" +markers = [ + # Optional dependency groups markers + "lambda: AWS lambda tests", + "pandas: tests for pandas integration", + "sso: tests for sso optional dependency integration", + # Cloud provider markers + "aws: tests for Amazon Cloud storage", + "azure: tests for Azure Cloud storage", + "gcp: tests for Google Cloud storage", + # Test type markers + "integ: integration tests", + "unit: unit tests", + "skipolddriver: skip for old driver tests", + # Other markers + "timeout: tests that need a timeout time", + "internal: tests that could but should only run on our internal CI", + "requires_external_volume: tests that needs a external volume to be executed", + "external: tests that could but should only run on our external CI", + "feature_max_lob_size: tests that could but should only run on our external CI", + "feature_v20: tests that could but should only run on SqlAlchemy v20", +] diff --git a/setup.cfg b/setup.cfg index 04011f04..7924cc57 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,80 +1,3 @@ -[bdist_wheel] -universal = 1 - -[metadata] -name = snowflake-sqlalchemy -description = Snowflake SQLAlchemy Dialect -long_description = file: DESCRIPTION.md -long_description_content_type = text/markdown -url = https://www.snowflake.com/ -author = Snowflake, Inc -author_email = triage-snowpark-python-api-dl@snowflake.com -license = Apache-2.0 -license_files = LICENSE.txt -classifiers = - Development Status :: 5 - Production/Stable - Environment :: Console - Environment :: Other Environment - Intended Audience :: Developers - Intended Audience :: Education - Intended Audience :: Information Technology - Intended Audience :: System Administrators - License :: OSI Approved :: Apache Software License - Operating System :: OS Independent - Programming Language :: Python :: 3 - Programming Language :: Python :: 3 :: Only - Programming Language :: Python :: 3.7 - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 - Programming Language :: Python :: 3.11 - Programming Language :: SQL - Topic :: Database - Topic :: Scientific/Engineering :: Information Analysis - Topic :: Software Development - Topic :: Software Development :: Libraries - Topic :: Software Development :: Libraries :: Application Frameworks - Topic :: Software Development :: Libraries :: Python Modules -keywords = Snowflake db database cloud analytics warehouse -project_urls = - Documentation=https://docs.snowflake.com/en/user-guide/sqlalchemy.html - Source=https://github.com/snowflakedb/snowflake-sqlalchemy - Issues=https://github.com/snowflakedb/snowflake-sqlalchemy/issues - Changelog=https://github.com/snowflakedb/snowflake-sqlalchemy/blob/main/DESCRIPTION.md - -[options] -python_requires = >=3.7 -packages = find_namespace: -install_requires = - importlib-metadata;python_version<"3.8" - sqlalchemy<2.0.0,>=1.4.0 -; Keep in sync with extras dependency - snowflake-connector-python<4.0.0 -include_package_data = True -package_dir = - =src -zip_safe = False - -[options.packages.find] -where = src -include = snowflake.* - -[options.entry_points] -sqlalchemy.dialects = - snowflake=snowflake.sqlalchemy:dialect - -[options.extras_require] -development = - pytest - pytest-cov - pytest-rerunfailures - pytest-timeout - mock - pytz - numpy -pandas = - snowflake-connector-python[pandas]<4.0.0 - [sqla_testing] requirement_cls=snowflake.sqlalchemy.requirements:Requirements profile_file=tests/profiles.txt diff --git a/setup.py b/setup.py deleted file mode 100644 index 0ec32717..00000000 --- a/setup.py +++ /dev/null @@ -1,17 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - -import os - -from setuptools import setup - -SQLALCHEMY_SRC_DIR = os.path.join("src", "snowflake", "sqlalchemy") -VERSION = (1, 1, 1, None) # Default -with open(os.path.join(SQLALCHEMY_SRC_DIR, "version.py"), encoding="utf-8") as f: - exec(f.read()) - version = ".".join([str(v) for v in VERSION if v is not None]) - -setup( - version=version, -) diff --git a/snyk/requirements.txt b/snyk/requirements.txt new file mode 100644 index 00000000..0166d751 --- /dev/null +++ b/snyk/requirements.txt @@ -0,0 +1,2 @@ +SQLAlchemy>=1.4.19 +snowflake-connector-python<4.0.0 diff --git a/snyk/requiremtnts.txt b/snyk/requiremtnts.txt new file mode 100644 index 00000000..a92c527e --- /dev/null +++ b/snyk/requiremtnts.txt @@ -0,0 +1,2 @@ +snowflake-connector-python<4.0.0 +SQLAlchemy>=1.4.19,<2.1.0 diff --git a/snyk/update_requirements.py b/snyk/update_requirements.py new file mode 100644 index 00000000..e0771fbd --- /dev/null +++ b/snyk/update_requirements.py @@ -0,0 +1,17 @@ +from pathlib import Path + +import tomlkit + + +def sync(): + pyproject = tomlkit.loads(Path("pyproject.toml").read_text()) + snyk_reqiurements = Path("snyk/requirements.txt") + dependencies = pyproject.get("project", {}).get("dependencies", []) + + with snyk_reqiurements.open("w") as fh: + fh.write("\n".join(dependencies)) + fh.write("\n") + + +if __name__ == "__main__": + sync() diff --git a/src/snowflake/sqlalchemy/__init__.py b/src/snowflake/sqlalchemy/__init__.py index 063910fe..7d795b2a 100644 --- a/src/snowflake/sqlalchemy/__init__.py +++ b/src/snowflake/sqlalchemy/__init__.py @@ -9,7 +9,7 @@ else: import importlib.metadata as importlib_metadata -from sqlalchemy.types import ( +from sqlalchemy.types import ( # noqa BIGINT, BINARY, BOOLEAN, @@ -27,8 +27,8 @@ VARCHAR, ) -from . import base, snowdialect -from .custom_commands import ( +from . import base, snowdialect # noqa +from .custom_commands import ( # noqa AWSBucket, AzureContainer, CopyFormatter, @@ -41,7 +41,7 @@ MergeInto, PARQUETFormatter, ) -from .custom_types import ( +from .custom_types import ( # noqa ARRAY, BYTEINT, CHARACTER, @@ -49,6 +49,8 @@ DOUBLE, FIXED, GEOGRAPHY, + GEOMETRY, + MAP, NUMBER, OBJECT, STRING, @@ -60,13 +62,30 @@ VARBINARY, VARIANT, ) -from .util import _url as URL +from .sql.custom_schema import ( # noqa + DynamicTable, + HybridTable, + IcebergTable, + SnowflakeTable, +) +from .sql.custom_schema.options import ( # noqa + AsQueryOption, + ClusterByOption, + IdentifierOption, + KeywordOption, + LiteralOption, + SnowflakeKeyword, + TableOptionKey, + TargetLagOption, + TimeUnit, +) +from .util import _url as URL # noqa base.dialect = dialect = snowdialect.dialect __version__ = importlib_metadata.version("snowflake-sqlalchemy") -__all__ = ( +_custom_types = ( "BIGINT", "BINARY", "BOOLEAN", @@ -90,6 +109,7 @@ "DOUBLE", "FIXED", "GEOGRAPHY", + "GEOMETRY", "OBJECT", "NUMBER", "STRING", @@ -100,6 +120,10 @@ "TINYINT", "VARBINARY", "VARIANT", + "MAP", +) + +_custom_commands = ( "MergeInto", "CSVFormatter", "JSONFormatter", @@ -112,3 +136,27 @@ "CreateStage", "CreateFileFormat", ) + +_custom_tables = ("HybridTable", "DynamicTable", "IcebergTable", "SnowflakeTable") + +_custom_table_options = ( + "AsQueryOption", + "TargetLagOption", + "LiteralOption", + "IdentifierOption", + "KeywordOption", + "ClusterByOption", +) + +_enums = ( + "TimeUnit", + "TableOptionKey", + "SnowflakeKeyword", +) +__all__ = ( + *_custom_types, + *_custom_commands, + *_custom_tables, + *_custom_table_options, + *_enums, +) diff --git a/src/snowflake/sqlalchemy/_constants.py b/src/snowflake/sqlalchemy/_constants.py index dad5b19b..205ad5d9 100644 --- a/src/snowflake/sqlalchemy/_constants.py +++ b/src/snowflake/sqlalchemy/_constants.py @@ -9,4 +9,6 @@ PARAM_INTERNAL_APPLICATION_VERSION = "internal_application_version" APPLICATION_NAME = "SnowflakeSQLAlchemy" -SNOWFLAKE_SQLALCHEMY_VERSION = ".".join([str(v) for v in VERSION if v is not None]) +SNOWFLAKE_SQLALCHEMY_VERSION = VERSION +DIALECT_NAME = "snowflake" +NOT_NULL = "NOT NULL" diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index d87b78c1..59c3f91e 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -2,18 +2,48 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +import itertools import operator import re +import string +import warnings +from typing import List +from sqlalchemy import exc as sa_exc +from sqlalchemy import inspect, sql from sqlalchemy import util as sa_util from sqlalchemy.engine import default +from sqlalchemy.orm import context +from sqlalchemy.orm.context import _MapperEntity from sqlalchemy.schema import Sequence, Table -from sqlalchemy.sql import compiler, expression -from sqlalchemy.sql.elements import quoted_name -from sqlalchemy.util.compat import string_types +from sqlalchemy.sql import compiler, expression, functions +from sqlalchemy.sql.base import CompileState +from sqlalchemy.sql.elements import BindParameter, quoted_name +from sqlalchemy.sql.expression import Executable +from sqlalchemy.sql.selectable import Lateral, SelectState + +from snowflake.sqlalchemy._constants import DIALECT_NAME +from snowflake.sqlalchemy.compat import IS_VERSION_20, args_reducer, string_types +from snowflake.sqlalchemy.custom_commands import ( + AWSBucket, + AzureContainer, + ExternalStage, +) -from .custom_commands import AWSBucket, AzureContainer, ExternalStage -from .util import _set_connection_interpolate_empty_sequences +from ._constants import NOT_NULL +from .exc import ( + CustomOptionsAreOnlySupportedOnSnowflakeTables, + UnexpectedOptionTypeError, +) +from .functions import flatten +from .sql.custom_schema.custom_table_base import CustomTableBase +from .sql.custom_schema.options.table_option import TableOption +from .util import ( + _find_left_clause_to_join_from, + _set_connection_interpolate_empty_sequences, + _Snowflake_ORMJoin, + _Snowflake_Selectable_Join, +) RESERVED_WORDS = frozenset( [ @@ -86,10 +116,338 @@ AUTOCOMMIT_REGEXP = re.compile( r"\s*(?:UPDATE|INSERT|DELETE|MERGE|COPY)", re.I | re.UNICODE ) +# used for quoting identifiers ie. table names, column names, etc. +ILLEGAL_INITIAL_CHARACTERS = frozenset({d for d in string.digits}.union({"$"})) + + +# used for quoting identifiers ie. table names, column names, etc. +ILLEGAL_IDENTIFIERS = frozenset({d for d in string.digits}.union({"_"})) + +""" +Overwrite methods to handle Snowflake BCR change: +https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057 +- _join_determine_implicit_left_side +- _join_left_to_right +""" + + +# handle Snowflake BCR bcr-1057 +@CompileState.plugin_for("default", "select") +class SnowflakeSelectState(SelectState): + def _setup_joins(self, args, raw_columns): + for right, onclause, left, flags in args: + isouter = flags["isouter"] + full = flags["full"] + + if left is None: + ( + left, + replace_from_obj_index, + ) = self._join_determine_implicit_left_side( + raw_columns, left, right, onclause + ) + else: + (replace_from_obj_index) = self._join_place_explicit_left_side(left) + + if replace_from_obj_index is not None: + # splice into an existing element in the + # self._from_obj list + left_clause = self.from_clauses[replace_from_obj_index] + + self.from_clauses = ( + self.from_clauses[:replace_from_obj_index] + + ( + _Snowflake_Selectable_Join( # handle Snowflake BCR bcr-1057 + left_clause, + right, + onclause, + isouter=isouter, + full=full, + ), + ) + + self.from_clauses[replace_from_obj_index + 1 :] + ) + else: + self.from_clauses = self.from_clauses + ( + # handle Snowflake BCR bcr-1057 + _Snowflake_Selectable_Join( + left, right, onclause, isouter=isouter, full=full + ), + ) + + @sa_util.preload_module("sqlalchemy.sql.util") + def _join_determine_implicit_left_side(self, raw_columns, left, right, onclause): + """When join conditions don't express the left side explicitly, + determine if an existing FROM or entity in this query + can serve as the left hand side. + + """ + + replace_from_obj_index = None + + from_clauses = self.from_clauses + + if from_clauses: + # handle Snowflake BCR bcr-1057 + indexes = _find_left_clause_to_join_from(from_clauses, right, onclause) + + if len(indexes) == 1: + replace_from_obj_index = indexes[0] + left = from_clauses[replace_from_obj_index] + else: + potential = {} + statement = self.statement + + for from_clause in itertools.chain( + itertools.chain.from_iterable( + [element._from_objects for element in raw_columns] + ), + itertools.chain.from_iterable( + [element._from_objects for element in statement._where_criteria] + ), + ): + potential[from_clause] = () + + all_clauses = list(potential.keys()) + # handle Snowflake BCR bcr-1057 + indexes = _find_left_clause_to_join_from(all_clauses, right, onclause) + + if len(indexes) == 1: + left = all_clauses[indexes[0]] + + if len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already to " + "help resolve the ambiguity." + ) + elif not indexes: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already to " + "help resolve the ambiguity." % (right,) + ) + return left, replace_from_obj_index + + +# handle Snowflake BCR bcr-1057 +@sql.base.CompileState.plugin_for("orm", "select") +class SnowflakeORMSelectCompileState(context.ORMSelectCompileState): + def _join_determine_implicit_left_side( + self, entities_collection, left, right, onclause + ): + """When join conditions don't express the left side explicitly, + determine if an existing FROM or entity in this query + can serve as the left hand side. + + """ + + # when we are here, it means join() was called without an ORM- + # specific way of telling us what the "left" side is, e.g.: + # + # join(RightEntity) + # + # or + # + # join(RightEntity, RightEntity.foo == LeftEntity.bar) + # + + r_info = inspect(right) + + replace_from_obj_index = use_entity_index = None + + if self.from_clauses: + # we have a list of FROMs already. So by definition this + # join has to connect to one of those FROMs. + + # handle Snowflake BCR bcr-1057 + indexes = _find_left_clause_to_join_from( + self.from_clauses, r_info.selectable, onclause + ) + + if len(indexes) == 1: + replace_from_obj_index = indexes[0] + left = self.from_clauses[replace_from_obj_index] + elif len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." + ) + else: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." % (right,) + ) + + elif entities_collection: + # we have no explicit FROMs, so the implicit left has to + # come from our list of entities. + + potential = {} + for entity_index, ent in enumerate(entities_collection): + entity = ent.entity_zero_or_selectable + if entity is None: + continue + ent_info = inspect(entity) + if ent_info is r_info: # left and right are the same, skip + continue + + # by using a dictionary with the selectables as keys this + # de-duplicates those selectables as occurs when the query is + # against a series of columns from the same selectable + if isinstance(ent, context._MapperEntity): + potential[ent.selectable] = (entity_index, entity) + else: + potential[ent_info.selectable] = (None, entity) + + all_clauses = list(potential.keys()) + # handle Snowflake BCR bcr-1057 + indexes = _find_left_clause_to_join_from( + all_clauses, r_info.selectable, onclause + ) + + if len(indexes) == 1: + use_entity_index, left = potential[all_clauses[indexes[0]]] + elif len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." + ) + else: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." % (right,) + ) + else: + raise sa_exc.InvalidRequestError( + "No entities to join from; please use " + "select_from() to establish the left " + "entity/selectable of this join" + ) + + return left, replace_from_obj_index, use_entity_index + + @args_reducer(positions_to_drop=(6, 7)) + def _join_left_to_right( + self, entities_collection, left, right, onclause, prop, outerjoin, full + ): + """given raw "left", "right", "onclause" parameters consumed from + a particular key within _join(), add a real ORMJoin object to + our _from_obj list (or augment an existing one) + + """ + + if left is None: + # left not given (e.g. no relationship object/name specified) + # figure out the best "left" side based on our existing froms / + # entities + assert prop is None + ( + left, + replace_from_obj_index, + use_entity_index, + ) = self._join_determine_implicit_left_side( + entities_collection, left, right, onclause + ) + else: + # left is given via a relationship/name, or as explicit left side. + # Determine where in our + # "froms" list it should be spliced/appended as well as what + # existing entity it corresponds to. + ( + replace_from_obj_index, + use_entity_index, + ) = self._join_place_explicit_left_side(entities_collection, left) + + if left is right: + raise sa_exc.InvalidRequestError( + "Can't construct a join from %s to %s, they " + "are the same entity" % (left, right) + ) + + # the right side as given often needs to be adapted. additionally + # a lot of things can be wrong with it. handle all that and + # get back the new effective "right" side + + if IS_VERSION_20: + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop + ) + else: + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop, False, False + ) + + if not r_info.is_selectable: + extra_criteria = self._get_extra_criteria(r_info) + else: + extra_criteria = () + + if replace_from_obj_index is not None: + # splice into an existing element in the + # self._from_obj list + left_clause = self.from_clauses[replace_from_obj_index] + + self.from_clauses = ( + self.from_clauses[:replace_from_obj_index] + + [ + _Snowflake_ORMJoin( # handle Snowflake BCR bcr-1057 + left_clause, + right, + onclause, + isouter=outerjoin, + full=full, + _extra_criteria=extra_criteria, + ) + ] + + self.from_clauses[replace_from_obj_index + 1 :] + ) + else: + # add a new element to the self._from_obj list + if use_entity_index is not None: + # make use of _MapperEntity selectable, which is usually + # entity_zero.selectable, but if with_polymorphic() were used + # might be distinct + assert isinstance(entities_collection[use_entity_index], _MapperEntity) + left_clause = entities_collection[use_entity_index].selectable + else: + left_clause = left + + self.from_clauses = self.from_clauses + [ + _Snowflake_ORMJoin( # handle Snowflake BCR bcr-1057 + left_clause, + r_info, + onclause, + isouter=outerjoin, + full=full, + _extra_criteria=extra_criteria, + ) + ] class SnowflakeIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = {x.lower() for x in RESERVED_WORDS} + illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS + illegal_identifiers = ILLEGAL_IDENTIFIERS def __init__(self, dialect, **kw): quote = '"' @@ -118,6 +476,17 @@ def format_label(self, label, name=None): return self.quote_identifier(s) if n.quote else s + def _requires_quotes(self, value: str) -> bool: + """Return True if the given identifier requires quoting.""" + lc_value = value.lower() + return ( + lc_value in self.reserved_words + or lc_value in self.illegal_identifiers + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(str(value)) + or (lc_value != value) + ) + def _split_schema_by_dot(self, schema): ret = [] idx = 0 @@ -161,8 +530,11 @@ def visit_merge_into(self, merge_into, **kw): clauses = " ".join( clause._compiler_dispatch(self, **kw) for clause in merge_into.clauses ) + target = merge_into.target._compiler_dispatch(self, asfrom=True, **kw) + source = merge_into.source._compiler_dispatch(self, asfrom=True, **kw) + on = merge_into.on._compiler_dispatch(self, **kw) return ( - f"MERGE INTO {merge_into.target} USING {merge_into.source} ON {merge_into.on}" + f"MERGE INTO {target} USING {source} ON {on}" + (" " + clauses if clauses else "") ) @@ -210,14 +582,10 @@ def visit_copy_into(self, copy_into, **kw): formatter = copy_into.formatter._compiler_dispatch(self, **kw) else: formatter = "" - into = ( - copy_into.into - if isinstance(copy_into.into, Table) - else copy_into.into._compiler_dispatch(self, **kw) - ) + into = copy_into.into._compiler_dispatch(self, asfrom=True, **kw) from_ = None if isinstance(copy_into.from_, Table): - from_ = copy_into.from_ + from_ = copy_into.from_.name # this is intended to catch AWSBucket and AzureContainer elif ( isinstance(copy_into.from_, AWSBucket) @@ -228,6 +596,21 @@ def visit_copy_into(self, copy_into, **kw): # everything else (selects, etc.) else: from_ = f"({copy_into.from_._compiler_dispatch(self, **kw)})" + + partition_by_value = None + if isinstance(copy_into.partition_by, (BindParameter, Executable)): + partition_by_value = copy_into.partition_by.compile( + compile_kwargs={"literal_binds": True} + ) + elif copy_into.partition_by is not None: + partition_by_value = copy_into.partition_by + + partition_by = ( + f"PARTITION BY {partition_by_value}" + if partition_by_value is not None and partition_by_value != "" + else "" + ) + credentials, encryption = "", "" if isinstance(into, tuple): into, credentials, encryption = into @@ -238,14 +621,15 @@ def visit_copy_into(self, copy_into, **kw): options_list.sort(key=operator.itemgetter(0)) options = ( ( - " " - + " ".join( + " ".join( [ "{} = {}".format( n, - v._compiler_dispatch(self, **kw) - if getattr(v, "compiler_dispatch", False) - else str(v), + ( + v._compiler_dispatch(self, **kw) + if getattr(v, "compiler_dispatch", False) + else str(v) + ), ) for n, v in options_list ] @@ -258,7 +642,7 @@ def visit_copy_into(self, copy_into, **kw): options += f" {credentials}" if encryption: options += f" {encryption}" - return f"COPY INTO {into} FROM {from_} {formatter}{options}" + return f"COPY INTO {into} FROM {' '.join([from_, partition_by, formatter, options])}" def visit_copy_formatter(self, formatter, **kw): options_list = list(formatter.options.items()) @@ -268,20 +652,24 @@ def visit_copy_formatter(self, formatter, **kw): return f"FILE_FORMAT=(format_name = {formatter.options['format_name']})" return "FILE_FORMAT=(TYPE={}{})".format( formatter.file_format, - " " - + " ".join( - [ - "{}={}".format( - name, - value._compiler_dispatch(self, **kw) - if hasattr(value, "_compiler_dispatch") - else formatter.value_repr(name, value), - ) - for name, value in options_list - ] - ) - if formatter.options - else "", + ( + " " + + " ".join( + [ + "{}={}".format( + name, + ( + value._compiler_dispatch(self, **kw) + if hasattr(value, "_compiler_dispatch") + else formatter.value_repr(name, value) + ), + ) + for name, value in options_list + ] + ) + if formatter.options + else "" + ), ) def visit_aws_bucket(self, aws_bucket, **kw): @@ -376,7 +764,14 @@ def visit_regexp_match_op_binary(self, binary, operator, **kw): def visit_regexp_replace_op_binary(self, binary, operator, **kw): string, pattern, flags = self._get_regexp_args(binary, kw) - replacement = self.process(binary.modifiers["replacement"], **kw) + try: + replacement = self.process(binary.modifiers["replacement"], **kw) + except KeyError: + # in sqlalchemy 1.4.49, the internal structure of the expression is changed + # that binary.modifiers doesn't have "replacement": + # https://docs.sqlalchemy.org/en/20/changelog/changelog_14.html#change-1.4.49 + return f"REGEXP_REPLACE({string}, {pattern}{'' if flags is None else f', {flags}'})" + if flags is None: return f"REGEXP_REPLACE({string}, {pattern}, {replacement})" else: @@ -385,6 +780,65 @@ def visit_regexp_replace_op_binary(self, binary, operator, **kw): def visit_not_regexp_match_op_binary(self, binary, operator, **kw): return f"NOT {self.visit_regexp_match_op_binary(binary, operator, **kw)}" + def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + if from_linter: + from_linter.edges.update( + itertools.product(join.left._from_objects, join.right._from_objects) + ) + + if join.full: + join_type = " FULL OUTER JOIN " + elif join.isouter: + join_type = " LEFT OUTER JOIN " + else: + join_type = " JOIN " + + join_statement = ( + join.left._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + + join_type + + join.right._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + ) + + if join.onclause is None and isinstance(join.right, Lateral): + # in snowflake, onclause is not accepted for lateral due to BCR change: + # https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057 + # sqlalchemy only allows join with on condition. + # to adapt to snowflake syntax change, + # we make the change such that when oncaluse is None and the right part is + # Lateral, we do not append the on condition + return join_statement + + return ( + join_statement + + " ON " + # TODO: likely need asfrom=True here? + + join.onclause._compiler_dispatch(self, from_linter=from_linter, **kwargs) + ) + + def visit_truediv_binary(self, binary, operator, **kw): + if self.dialect.div_is_floordiv: + warnings.warn( + "div_is_floordiv value will be changed to False in a future release. This will generate a behavior change on true and floor division. Please review https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#python-division-operator-performs-true-division-for-all-backends-added-floor-division", + PendingDeprecationWarning, + stacklevel=2, + ) + return ( + self.process(binary.left, **kw) + " / " + self.process(binary.right, **kw) + ) + + def visit_floordiv_binary(self, binary, operator, **kw): + if self.dialect.div_is_floordiv and IS_VERSION_20: + warnings.warn( + "div_is_floordiv value will be changed to False in a future release. This will generate a behavior change on true and floor division. Please review https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#python-division-operator-performs-true-division-for-all-backends-added-floor-division", + PendingDeprecationWarning, + stacklevel=2, + ) + return super().visit_floordiv_binary(binary, operator, **kw) + def render_literal_value(self, value, type_): # escape backslash return super().render_literal_value(value, type_).replace("\\", "\\\\") @@ -492,7 +946,7 @@ def get_column_specification(self, column, **kwargs): return " ".join(colspec) - def post_create_table(self, table): + def handle_cluster_by(self, table): """ Handles snowflake-specific ``CREATE TABLE ... CLUSTER BY`` syntax. @@ -509,7 +963,7 @@ def post_create_table(self, table): ... metadata, ... sa.Column('id', sa.Integer, primary_key=True), ... sa.Column('name', sa.String), - ... snowflake_clusterby=['id', 'name'] + ... snowflake_clusterby=['id', 'name', text("id > 5")] ... ) >>> print(CreateTable(user).compile(engine)) @@ -517,28 +971,59 @@ def post_create_table(self, table): id INTEGER NOT NULL AUTOINCREMENT, name VARCHAR, PRIMARY KEY (id) - ) CLUSTER BY (id, name) + ) CLUSTER BY (id, name, id > 5) """ text = "" - info = table.dialect_options["snowflake"] + info = table.dialect_options[DIALECT_NAME] cluster = info.get("clusterby") if cluster: text += " CLUSTER BY ({})".format( - ", ".join(self.denormalize_column_name(key) for key in cluster) + ", ".join( + ( + self.denormalize_column_name(key) + if isinstance(key, str) + else str(key) + ) + for key in cluster + ) ) return text + def post_create_table(self, table): + text = self.handle_cluster_by(table) + options = [] + invalid_options: List[str] = [] + + for key, option in table.dialect_options[DIALECT_NAME].items(): + if isinstance(option, TableOption): + options.append(option) + elif key not in ["clusterby", "*"]: + invalid_options.append(key) + + if len(invalid_options) > 0: + raise UnexpectedOptionTypeError(sorted(invalid_options)) + + if isinstance(table, CustomTableBase): + options.sort(key=lambda x: (x.priority.value, x.option_name), reverse=True) + for option in options: + text += "\t" + option.render_option(self) + elif len(options) > 0: + raise CustomOptionsAreOnlySupportedOnSnowflakeTables() + + return text + def visit_create_stage(self, create_stage, **kw): """ This visitor will create the SQL representation for a CREATE STAGE command. """ - return "CREATE {}STAGE {}{} URL={}".format( - "OR REPLACE " if create_stage.replace_if_exists else "", + return "CREATE {or_replace}{temporary}STAGE {}{} URL={}".format( create_stage.stage.namespace, create_stage.stage.name, repr(create_stage.container), + or_replace="OR REPLACE " if create_stage.replace_if_exists else "", + temporary="TEMPORARY " if create_stage.temporary else "", ) def visit_create_file_format(self, file_format, **kw): @@ -577,13 +1062,38 @@ def visit_drop_column_comment(self, drop, **kw): ) def visit_identity_column(self, identity, **kw): - text = " IDENTITY" + text = "IDENTITY" if identity.start is not None or identity.increment is not None: start = 1 if identity.start is None else identity.start increment = 1 if identity.increment is None else identity.increment text += f"({start},{increment})" + if identity.order is not None: + order = "ORDER" if identity.order else "NOORDER" + text += f" {order}" return text + def get_identity_options(self, identity_options): + text = [] + if identity_options.increment is not None: + text.append("INCREMENT BY %d" % identity_options.increment) + if identity_options.start is not None: + text.append("START WITH %d" % identity_options.start) + if identity_options.minvalue is not None: + text.append("MINVALUE %d" % identity_options.minvalue) + if identity_options.maxvalue is not None: + text.append("MAXVALUE %d" % identity_options.maxvalue) + if identity_options.nominvalue is not None: + text.append("NO MINVALUE") + if identity_options.nomaxvalue is not None: + text.append("NO MAXVALUE") + if identity_options.cache is not None: + text.append("CACHE %d" % identity_options.cache) + if identity_options.cycle is not None: + text.append("CYCLE" if identity_options.cycle else "NO CYCLE") + if identity_options.order is not None: + text.append("ORDER" if identity_options.order else "NOORDER") + return " ".join(text) + class SnowflakeTypeCompiler(compiler.GenericTypeCompiler): def visit_BYTEINT(self, type_, **kw): @@ -616,11 +1126,34 @@ def visit_TINYINT(self, type_, **kw): def visit_VARIANT(self, type_, **kw): return "VARIANT" + def visit_MAP(self, type_, **kw): + not_null = f" {NOT_NULL}" if type_.not_null else "" + return ( + f"MAP({type_.key_type.compile()}, {type_.value_type.compile()}{not_null})" + ) + def visit_ARRAY(self, type_, **kw): return "ARRAY" + def visit_SNOWFLAKE_ARRAY(self, type_, **kw): + if type_.is_semi_structured: + return "ARRAY" + not_null = f" {NOT_NULL}" if type_.not_null else "" + return f"ARRAY({type_.value_type.compile()}{not_null})" + def visit_OBJECT(self, type_, **kw): - return "OBJECT" + if type_.is_semi_structured: + return "OBJECT" + else: + contents = [] + for key in type_.items_types: + + row_text = f"{key} {type_.items_types[key][0].compile()}" + # Type and not null is specified + if len(type_.items_types[key]) > 1: + row_text += f"{' NOT NULL' if type_.items_types[key][1] else ''}" + contents.append(row_text) + return "OBJECT" if contents == [] else f"OBJECT({', '.join(contents)})" def visit_BLOB(self, type_, **kw): return "BINARY" @@ -646,5 +1179,10 @@ def visit_TIMESTAMP(self, type_, **kw): def visit_GEOGRAPHY(self, type_, **kw): return "GEOGRAPHY" + def visit_GEOMETRY(self, type_, **kw): + return "GEOMETRY" + construct_arguments = [(Table, {"clusterby": None})] + +functions.register_function("flatten", flatten, "snowflake") diff --git a/src/snowflake/sqlalchemy/compat.py b/src/snowflake/sqlalchemy/compat.py new file mode 100644 index 00000000..9e97e574 --- /dev/null +++ b/src/snowflake/sqlalchemy/compat.py @@ -0,0 +1,36 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +from __future__ import annotations + +import functools +from typing import Callable + +from sqlalchemy import __version__ as SA_VERSION +from sqlalchemy import util + +string_types = (str,) +returns_unicode = util.symbol("RETURNS_UNICODE") + +IS_VERSION_20 = tuple(int(v) for v in SA_VERSION.split(".")) >= (2, 0, 0) + + +def args_reducer(positions_to_drop: tuple): + """Removes args at positions provided in tuple positions_to_drop. + + For example tuple (3, 5) will remove items at third and fifth position. + Keep in mind that on class methods first postion is cls or self. + """ + + def fn_wrapper(fn: Callable): + @functools.wraps(fn) + def wrapper(*args): + reduced_args = args + if not IS_VERSION_20: + reduced_args = tuple( + arg for idx, arg in enumerate(args) if idx not in positions_to_drop + ) + fn(*reduced_args) + + return wrapper + + return fn_wrapper diff --git a/src/snowflake/sqlalchemy/custom_commands.py b/src/snowflake/sqlalchemy/custom_commands.py index 9cc14389..1b9260fe 100644 --- a/src/snowflake/sqlalchemy/custom_commands.py +++ b/src/snowflake/sqlalchemy/custom_commands.py @@ -10,7 +10,8 @@ from sqlalchemy.sql.dml import UpdateBase from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.roles import FromClauseRole -from sqlalchemy.util.compat import string_types + +from .compat import string_types NoneType = type(None) @@ -114,18 +115,23 @@ class CopyInto(UpdateBase): __visit_name__ = "copy_into" _bind = None - def __init__(self, from_, into, formatter=None): + def __init__(self, from_, into, partition_by=None, formatter=None): self.from_ = from_ self.into = into self.formatter = formatter self.copy_options = {} + self.partition_by = partition_by def __repr__(self): """ repr for debugging / logging purposes only. For compilation logic, see the corresponding visitor in base.py """ - return f"COPY INTO {self.into} FROM {repr(self.from_)} {repr(self.formatter)} ({self.copy_options})" + val = f"COPY INTO {self.into} FROM {repr(self.from_)}" + if self.partition_by is not None: + val += f" PARTITION BY {self.partition_by}" + + return val + f" {repr(self.formatter)} ({self.copy_options})" def bind(self): return None @@ -259,7 +265,8 @@ def field_delimiter(self, deli_type): def file_extension(self, ext): """String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is - responsible for specifying a valid file extension that can be read by the desired software or service.""" + responsible for specifying a valid file extension that can be read by the desired software or service. + """ if not isinstance(ext, (NoneType, string_types)): raise TypeError("File extension should be a string") self.options["FILE_EXTENSION"] = ext @@ -386,7 +393,8 @@ def compression(self, comp_type): def file_extension(self, ext): """String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is - responsible for specifying a valid file extension that can be read by the desired software or service.""" + responsible for specifying a valid file extension that can be read by the desired software or service. + """ if not isinstance(ext, (NoneType, string_types)): raise TypeError("File extension should be a string") self.options["FILE_EXTENSION"] = ext @@ -482,9 +490,10 @@ class CreateStage(DDLElement): __visit_name__ = "create_stage" - def __init__(self, container, stage, replace_if_exists=False): + def __init__(self, container, stage, replace_if_exists=False, *, temporary=False): super().__init__() self.container = container + self.temporary = temporary self.stage = stage self.replace_if_exists = replace_if_exists diff --git a/src/snowflake/sqlalchemy/custom_types.py b/src/snowflake/sqlalchemy/custom_types.py index 3f42f034..c742b740 100644 --- a/src/snowflake/sqlalchemy/custom_types.py +++ b/src/snowflake/sqlalchemy/custom_types.py @@ -1,9 +1,11 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +from typing import Optional, Tuple, Union import sqlalchemy.types as sqltypes import sqlalchemy.util as util +from sqlalchemy.types import TypeEngine TEXT = sqltypes.VARCHAR CHARACTER = sqltypes.CHAR @@ -37,12 +39,60 @@ class VARIANT(SnowflakeType): __visit_name__ = "VARIANT" -class OBJECT(SnowflakeType): +class StructuredType(SnowflakeType): + def __init__(self, is_semi_structured: bool = False): + self.is_semi_structured = is_semi_structured + super().__init__() + + +class MAP(StructuredType): + __visit_name__ = "MAP" + + def __init__( + self, + key_type: sqltypes.TypeEngine, + value_type: sqltypes.TypeEngine, + not_null: bool = False, + ): + self.key_type = key_type + self.value_type = value_type + self.not_null = not_null + super().__init__() + + +class OBJECT(StructuredType): __visit_name__ = "OBJECT" + def __init__(self, **items_types: Union[TypeEngine, Tuple[TypeEngine, bool]]): + for key, value in items_types.items(): + if not isinstance(value, tuple): + items_types[key] = (value, False) + + self.items_types = items_types + self.is_semi_structured = len(items_types) == 0 + super().__init__() -class ARRAY(SnowflakeType): - __visit_name__ = "ARRAY" + def __repr__(self): + quote_char = "'" + return "OBJECT(%s)" % ", ".join( + [ + f"{repr(key).strip(quote_char)}={repr(value)}" + for key, value in self.items_types.items() + ] + ) + + +class ARRAY(StructuredType): + __visit_name__ = "SNOWFLAKE_ARRAY" + + def __init__( + self, + value_type: Optional[sqltypes.TypeEngine] = None, + not_null: bool = False, + ): + self.value_type = value_type + self.not_null = not_null + super().__init__(is_semi_structured=value_type is None) class TIMESTAMP_TZ(SnowflakeType): @@ -61,6 +111,10 @@ class GEOGRAPHY(SnowflakeType): __visit_name__ = "GEOGRAPHY" +class GEOMETRY(SnowflakeType): + __visit_name__ = "GEOMETRY" + + class _CUSTOM_Date(SnowflakeType, sqltypes.Date): def literal_processor(self, dialect): def process(value): diff --git a/src/snowflake/sqlalchemy/exc.py b/src/snowflake/sqlalchemy/exc.py new file mode 100644 index 00000000..399e94b6 --- /dev/null +++ b/src/snowflake/sqlalchemy/exc.py @@ -0,0 +1,82 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from typing import List + +from sqlalchemy.exc import ArgumentError + + +class NoPrimaryKeyError(ArgumentError): + def __init__(self, target: str): + super().__init__(f"Table {target} required primary key.") + + +class UnsupportedPrimaryKeysAndForeignKeysError(ArgumentError): + def __init__(self, target: str): + super().__init__(f"Primary key and foreign keys are not supported in {target}.") + + +class RequiredParametersNotProvidedError(ArgumentError): + def __init__(self, target: str, parameters: List[str]): + super().__init__( + f"{target} requires the following parameters: %s." % ", ".join(parameters) + ) + + +class UnexpectedTableOptionKeyError(ArgumentError): + def __init__(self, expected: str, actual: str): + super().__init__(f"Expected table option {expected} but got {actual}.") + + +class OptionKeyNotProvidedError(ArgumentError): + def __init__(self, target: str): + super().__init__( + f"Expected option key in {target} option but got NoneType instead." + ) + + +class UnexpectedOptionParameterTypeError(ArgumentError): + def __init__(self, parameter_name: str, target: str, types: List[str]): + super().__init__( + f"Parameter {parameter_name} of {target} requires to be one" + f" of following types: {', '.join(types)}." + ) + + +class CustomOptionsAreOnlySupportedOnSnowflakeTables(ArgumentError): + def __init__(self): + super().__init__( + "Identifier, Literal, TargetLag and other custom options are only supported on Snowflake tables." + ) + + +class UnexpectedOptionTypeError(ArgumentError): + def __init__(self, options: List[str]): + super().__init__( + f"The following options are either unsupported or should be defined using a Snowflake table: {', '.join(options)}." + ) + + +class InvalidTableParameterTypeError(ArgumentError): + def __init__(self, name: str, input_type: str, expected_types: List[str]): + expected_types_str = "', '".join(expected_types) + super().__init__( + f"Invalid parameter type '{input_type}' provided for '{name}'. " + f"Expected one of the following types: '{expected_types_str}'.\n" + ) + + +class MultipleErrors(ArgumentError): + def __init__(self, errors): + self.errors = errors + + def __str__(self): + return "".join(str(e) for e in self.errors) + + +class StructuredTypeNotSupportedInTableColumnsError(ArgumentError): + def __init__(self, table_type: str, table_name: str, column_name: str): + super().__init__( + f"Column '{column_name}' is of a structured type, which is only supported on Iceberg tables. " + f"The table '{table_name}' is of type '{table_type}', not Iceberg." + ) diff --git a/src/snowflake/sqlalchemy/functions.py b/src/snowflake/sqlalchemy/functions.py new file mode 100644 index 00000000..c08aa734 --- /dev/null +++ b/src/snowflake/sqlalchemy/functions.py @@ -0,0 +1,16 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import warnings + +from sqlalchemy.sql import functions as sqlfunc + +FLATTEN_WARNING = "For backward compatibility params are not rendered." + + +class flatten(sqlfunc.GenericFunction): + name = "flatten" + + def __init__(self, *args, **kwargs): + warnings.warn(FLATTEN_WARNING, DeprecationWarning, stacklevel=2) + super().__init__(*args, **kwargs) diff --git a/src/snowflake/sqlalchemy/parser/custom_type_parser.py b/src/snowflake/sqlalchemy/parser/custom_type_parser.py new file mode 100644 index 00000000..09cb6ab8 --- /dev/null +++ b/src/snowflake/sqlalchemy/parser/custom_type_parser.py @@ -0,0 +1,245 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +from typing import List + +import sqlalchemy.types as sqltypes +from sqlalchemy.sql.type_api import TypeEngine +from sqlalchemy.types import ( + BIGINT, + BINARY, + BOOLEAN, + CHAR, + DATE, + DATETIME, + DECIMAL, + FLOAT, + INTEGER, + REAL, + SMALLINT, + TIME, + TIMESTAMP, + VARCHAR, + NullType, +) + +from ..custom_types import ( + _CUSTOM_DECIMAL, + ARRAY, + DOUBLE, + GEOGRAPHY, + GEOMETRY, + MAP, + OBJECT, + TIMESTAMP_LTZ, + TIMESTAMP_NTZ, + TIMESTAMP_TZ, + VARIANT, +) + +ischema_names = { + "BIGINT": BIGINT, + "BINARY": BINARY, + # 'BIT': BIT, + "BOOLEAN": BOOLEAN, + "CHAR": CHAR, + "CHARACTER": CHAR, + "DATE": DATE, + "DATETIME": DATETIME, + "DEC": DECIMAL, + "DECIMAL": DECIMAL, + "DOUBLE": DOUBLE, + "FIXED": DECIMAL, + "FLOAT": FLOAT, # Snowflake FLOAT datatype doesn't have parameters + "INT": INTEGER, + "INTEGER": INTEGER, + "NUMBER": _CUSTOM_DECIMAL, + "REAL": REAL, + "BYTEINT": SMALLINT, + "SMALLINT": SMALLINT, + "STRING": VARCHAR, + "TEXT": VARCHAR, + "TIME": TIME, + "TIMESTAMP": TIMESTAMP, + "TIMESTAMP_TZ": TIMESTAMP_TZ, + "TIMESTAMP_LTZ": TIMESTAMP_LTZ, + "TIMESTAMP_NTZ": TIMESTAMP_NTZ, + "TINYINT": SMALLINT, + "VARBINARY": BINARY, + "VARCHAR": VARCHAR, + "VARIANT": VARIANT, + "MAP": MAP, + "OBJECT": OBJECT, + "ARRAY": ARRAY, + "GEOGRAPHY": GEOGRAPHY, + "GEOMETRY": GEOMETRY, +} + +NOT_NULL_STR = "NOT NULL" + + +def tokenize_parameters(text: str, character_for_strip=",") -> list: + """ + Extracts parameters from a comma-separated string, handling parentheses. + + :param text: A string with comma-separated parameters, which may include parentheses. + + :param character_for_strip: A character to strip the text. + + :return: A list of parameters as strings. + + :example: + For input `"a, (b, c), d"`, the output is `['a', '(b, c)', 'd']`. + """ + output_parameters = [] + parameter = "" + open_parenthesis = 0 + for c in text: + + if c == "(": + open_parenthesis += 1 + elif c == ")": + open_parenthesis -= 1 + + if open_parenthesis > 0 or c != character_for_strip: + parameter += c + elif c == character_for_strip: + output_parameters.append(parameter.strip(" ")) + parameter = "" + if parameter != "": + output_parameters.append(parameter.strip(" ")) + return output_parameters + + +def parse_index_columns(columns: str) -> List[str]: + """ + Parses a string with a list of columns for an index. + + :param columns: A string with a list of columns for an index, which may include parentheses. + :param compiler: A SQLAlchemy compiler. + + :return: A list of columns as strings. + + :example: + For input `"[A, B, C]"`, the output is `['A', 'B', 'C']`. + """ + return [column.strip() for column in columns.strip("[]").split(",")] + + +def parse_type(type_text: str) -> TypeEngine: + """ + Parses a type definition string and returns the corresponding SQLAlchemy type. + + The function handles types with or without parameters, such as `VARCHAR(255)` or `INTEGER`. + + :param type_text: A string representing a SQLAlchemy type, which may include parameters + in parentheses (e.g., "VARCHAR(255)" or "DECIMAL(10, 2)"). + :return: An instance of the corresponding SQLAlchemy type class (e.g., `String`, `Integer`), + or `NullType` if the type is not recognized. + + :example: + parse_type("VARCHAR(255)") + String(length=255) + """ + + index = type_text.find("(") + type_name = type_text[:index] if index != -1 else type_text + + parameters = ( + tokenize_parameters(type_text[index + 1 : -1]) if type_name != type_text else [] + ) + + col_type_class = ischema_names.get(type_name, None) + col_type_kw = {} + + if col_type_class is None: + col_type_class = NullType + else: + if issubclass(col_type_class, sqltypes.Numeric): + col_type_kw = __parse_numeric_type_parameters(parameters) + elif issubclass(col_type_class, (sqltypes.String, sqltypes.BINARY)): + col_type_kw = __parse_type_with_length_parameters(parameters) + elif issubclass(col_type_class, MAP): + col_type_kw = __parse_map_type_parameters(parameters) + elif issubclass(col_type_class, OBJECT): + col_type_kw = __parse_object_type_parameters(parameters) + elif issubclass(col_type_class, ARRAY): + col_type_kw = __parse_nullable_parameter(parameters) + if col_type_kw is None: + col_type_class = NullType + col_type_kw = {} + + return col_type_class(**col_type_kw) + + +def __parse_object_type_parameters(parameters): + object_rows = {} + not_null_parts = NOT_NULL_STR.split(" ") + for parameter in parameters: + parameter_parts = tokenize_parameters(parameter, " ") + if len(parameter_parts) >= 2: + key = parameter_parts[0] + value_type = parse_type(parameter_parts[1]) + if isinstance(value_type, NullType): + return None + not_null = ( + len(parameter_parts) == 4 + and parameter_parts[2] == not_null_parts[0] + and parameter_parts[3] == not_null_parts[1] + ) + object_rows[key] = (value_type, not_null) + return object_rows + + +def __parse_nullable_parameter(parameters): + if len(parameters) < 1: + return {} + elif len(parameters) > 1: + return None + parameter_str = parameters[0] + is_not_null = False + if ( + len(parameter_str) >= len(NOT_NULL_STR) + and parameter_str[-len(NOT_NULL_STR) :] == NOT_NULL_STR + ): + is_not_null = True + parameter_str = parameter_str[: -len(NOT_NULL_STR) - 1] + + value_type: TypeEngine = parse_type(parameter_str) + if isinstance(value_type, NullType): + return None + + return { + "value_type": value_type, + "not_null": is_not_null, + } + + +def __parse_map_type_parameters(parameters): + if len(parameters) != 2: + return None + + key_type_str = parameters[0] + value_type_str = parameters[1] + key_type: TypeEngine = parse_type(key_type_str) + value_type = __parse_nullable_parameter([value_type_str]) + if isinstance(value_type, NullType) or isinstance(key_type, NullType): + return None + + return {"key_type": key_type, **value_type} + + +def __parse_type_with_length_parameters(parameters): + return ( + {"length": int(parameters[0])} + if len(parameters) == 1 and str.isdigit(parameters[0]) + else {} + ) + + +def __parse_numeric_type_parameters(parameters): + result = {} + if len(parameters) >= 1 and str.isdigit(parameters[0]): + result["precision"] = int(parameters[0]) + if len(parameters) == 2 and str.isdigit(parameters[1]): + result["scale"] = int(parameters[1]) + return result diff --git a/src/snowflake/sqlalchemy/requirements.py b/src/snowflake/sqlalchemy/requirements.py index ea30a823..f2844804 100644 --- a/src/snowflake/sqlalchemy/requirements.py +++ b/src/snowflake/sqlalchemy/requirements.py @@ -289,9 +289,25 @@ def datetime_implicit_bound(self): # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. return exclusions.closed() + @property + def date_implicit_bound(self): + # Supporting this would require behavior breaking change to implicitly convert str to timestamp when binding + # parameters in string forms of timestamp values. + return exclusions.closed() + + @property + def time_implicit_bound(self): + # Supporting this would require behavior breaking change to implicitly convert str to timestamp when binding + # parameters in string forms of timestamp values. + return exclusions.closed() + @property def timestamp_microseconds_implicit_bound(self): # Supporting this would require behavior breaking change to implicitly convert str to timestamp when binding # parameters in string forms of timestamp values. # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. return exclusions.closed() + + @property + def array_type(self): + return exclusions.closed() diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index dc1dea1b..1e7ccaef 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -1,45 +1,31 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # - import operator +import re from collections import defaultdict +from enum import Enum from functools import reduce +from typing import Any, Collection, Optional from urllib.parse import unquote_plus -import sqlalchemy.types as sqltypes +import sqlalchemy.sql.sqltypes as sqltypes from sqlalchemy import event as sa_vnt from sqlalchemy import exc as sa_exc from sqlalchemy import util as sa_util -from sqlalchemy.engine import default, reflection +from sqlalchemy.engine import URL, default, reflection from sqlalchemy.schema import Table from sqlalchemy.sql import text from sqlalchemy.sql.elements import quoted_name -from sqlalchemy.sql.sqltypes import String -from sqlalchemy.types import ( - BIGINT, - BINARY, - BOOLEAN, - CHAR, - DATE, - DATETIME, - DECIMAL, - FLOAT, - INTEGER, - REAL, - SMALLINT, - TIME, - TIMESTAMP, - VARCHAR, - Date, - DateTime, - Float, - Time, -) +from sqlalchemy.sql.sqltypes import NullType +from sqlalchemy.types import FLOAT, Date, DateTime, Float, Time from snowflake.connector import errors as sf_errors +from snowflake.connector.connection import DEFAULT_CONFIGURATION from snowflake.connector.constants import UTF8 +from snowflake.sqlalchemy.compat import returns_unicode +from ._constants import DIALECT_NAME from .base import ( SnowflakeCompiler, SnowflakeDDLCompiler, @@ -48,20 +34,21 @@ SnowflakeTypeCompiler, ) from .custom_types import ( - _CUSTOM_DECIMAL, - ARRAY, - GEOGRAPHY, - OBJECT, - TIMESTAMP_LTZ, - TIMESTAMP_NTZ, - TIMESTAMP_TZ, - VARIANT, + StructuredType, _CUSTOM_Date, _CUSTOM_DateTime, _CUSTOM_Float, _CUSTOM_Time, ) -from .util import _update_connection_application_name +from .parser.custom_type_parser import * # noqa +from .parser.custom_type_parser import _CUSTOM_DECIMAL # noqa +from .parser.custom_type_parser import ischema_names, parse_index_columns, parse_type +from .sql.custom_schema.custom_table_prefix import CustomTablePrefix +from .util import ( + _update_connection_application_name, + parse_url_boolean, + parse_url_integer, +) colspecs = { Date: _CUSTOM_Date, @@ -70,49 +57,16 @@ Float: _CUSTOM_Float, } -ischema_names = { - "BIGINT": BIGINT, - "BINARY": BINARY, - # 'BIT': BIT, - "BOOLEAN": BOOLEAN, - "CHAR": CHAR, - "CHARACTER": CHAR, - "DATE": DATE, - "DATETIME": DATETIME, - "DEC": DECIMAL, - "DECIMAL": DECIMAL, - "DOUBLE": FLOAT, - "FIXED": DECIMAL, - "FLOAT": FLOAT, - "INT": INTEGER, - "INTEGER": INTEGER, - "NUMBER": _CUSTOM_DECIMAL, - # 'OBJECT': ? - "REAL": REAL, - "BYTEINT": SMALLINT, - "SMALLINT": SMALLINT, - "STRING": VARCHAR, - "TEXT": VARCHAR, - "TIME": TIME, - "TIMESTAMP": TIMESTAMP, - "TIMESTAMP_TZ": TIMESTAMP_TZ, - "TIMESTAMP_LTZ": TIMESTAMP_LTZ, - "TIMESTAMP_NTZ": TIMESTAMP_NTZ, - "TINYINT": SMALLINT, - "VARBINARY": BINARY, - "VARCHAR": VARCHAR, - "VARIANT": VARIANT, - "OBJECT": OBJECT, - "ARRAY": ARRAY, - "GEOGRAPHY": GEOGRAPHY, -} +_ENABLE_SQLALCHEMY_AS_APPLICATION_NAME = True -_ENABLE_SQLALCHEMY_AS_APPLICATION_NAME = True +class SnowflakeIsolationLevel(Enum): + READ_COMMITTED = "READ COMMITTED" + AUTOCOMMIT = "AUTOCOMMIT" class SnowflakeDialect(default.DefaultDialect): - name = "snowflake" + name = DIALECT_NAME driver = "snowflake" max_identifier_length = 255 cte_follows_insert = True @@ -125,6 +79,9 @@ class SnowflakeDialect(default.DefaultDialect): colspecs = colspecs ischema_names = ischema_names + # target database treats the / division operator as “floor division” + div_is_floordiv = False + # all str types must be converted in Unicode convert_unicode = True @@ -132,7 +89,7 @@ class SnowflakeDialect(default.DefaultDialect): # unicode strings supports_unicode_statements = True supports_unicode_binds = True - returns_unicode_strings = String.RETURNS_UNICODE + returns_unicode_strings = returns_unicode description_encoding = None # No lastrowid support. See SNOW-11155 @@ -191,13 +148,51 @@ class SnowflakeDialect(default.DefaultDialect): supports_identity_columns = True + def __init__( + self, + force_div_is_floordiv: bool = True, + isolation_level: Optional[str] = SnowflakeIsolationLevel.READ_COMMITTED.value, + **kwargs: Any, + ): + super().__init__(isolation_level=isolation_level, **kwargs) + self.force_div_is_floordiv = force_div_is_floordiv + self.div_is_floordiv = force_div_is_floordiv + + def initialize(self, connection): + super().initialize(connection) + self.div_is_floordiv = self.force_div_is_floordiv + @classmethod def dbapi(cls): + return cls.import_dbapi() + + @classmethod + def import_dbapi(cls): from snowflake import connector return connector - def create_connect_args(self, url): + @staticmethod + def parse_query_param_type(name: str, value: Any) -> Any: + """Cast param value if possible to type defined in connector-python.""" + if not (maybe_type_configuration := DEFAULT_CONFIGURATION.get(name)): + return value + + _, expected_type = maybe_type_configuration + if not isinstance(expected_type, tuple): + expected_type = (expected_type,) + + if isinstance(value, expected_type): + return value + + elif bool in expected_type: + return parse_url_boolean(value) + elif int in expected_type: + return parse_url_integer(value) + else: + return value + + def create_connect_args(self, url: URL): opts = url.translate_connect_args(username="user") if "database" in opts: name_spaces = [unquote_plus(e) for e in opts["database"].split("/")] @@ -224,26 +219,55 @@ def create_connect_args(self, url): opts["host"] = opts["host"] + ".snowflakecomputing.com" opts["port"] = "443" opts["autocommit"] = False # autocommit is disabled by default - opts.update(url.query) + + query = dict(**url.query) # make mutable + cache_column_metadata = query.pop("cache_column_metadata", None) self._cache_column_metadata = ( - opts.get("cache_column_metadata", "false").lower() == "true" + parse_url_boolean(cache_column_metadata) if cache_column_metadata else False ) + + # URL sets the query parameter values as strings, we need to cast to expected types when necessary + for name, value in query.items(): + opts[name] = self.parse_query_param_type(name, value) + return ([], opts) - def has_table(self, connection, table_name, schema=None): + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): """ Checks if the table exists """ return self._has_object(connection, "TABLE", table_name, schema) - def has_sequence(self, connection, sequence_name, schema=None): + def get_isolation_level_values(self, dbapi_connection): + return [ + SnowflakeIsolationLevel.READ_COMMITTED.value, + SnowflakeIsolationLevel.AUTOCOMMIT.value, + ] + + def do_rollback(self, dbapi_connection): + dbapi_connection.rollback() + + def do_commit(self, dbapi_connection): + dbapi_connection.commit() + + def get_default_isolation_level(self, dbapi_conn): + return SnowflakeIsolationLevel.READ_COMMITTED.value + + def set_isolation_level(self, dbapi_connection, level): + if level == SnowflakeIsolationLevel.AUTOCOMMIT.value: + dbapi_connection.autocommit(True) + else: + dbapi_connection.autocommit(False) + + @reflection.cache + def has_sequence(self, connection, sequence_name, schema=None, **kw): """ Checks if the sequence exists """ return self._has_object(connection, "SEQUENCE", sequence_name, schema) def _has_object(self, connection, object_type, object_name, schema=None): - full_name = self._denormalize_quote_join(schema, object_name) try: results = connection.execute( @@ -292,8 +316,8 @@ def _denormalize_quote_join(self, *idents): @reflection.cache def _current_database_schema(self, connection, **kw): - res = connection.exec_driver_sql( - "select current_database(), current_schema();" + res = connection.execute( + text("select current_database(), current_schema();") ).fetchone() return ( self.normalize_name(res[0]), @@ -312,14 +336,6 @@ def _map_name_to_idx(result): name_to_idx[col[0]] = idx return name_to_idx - @reflection.cache - def get_indexes(self, connection, table_name, schema=None, **kw): - """ - Gets all indexes - """ - # no index is supported by Snowflake - return [] - @reflection.cache def get_check_constraints(self, connection, table_name, schema, **kw): # check constraints are not supported by Snowflake @@ -475,6 +491,12 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kw): ) return foreign_key_map.get(table_name, []) + def table_columns_as_dict(self, columns): + result = {} + for column in columns: + result[column["name"]] = column + return result + @reflection.cache def _get_schema_columns(self, connection, schema, **kw): """Get all columns in the schema, if we hit 'Information schema query returned too much data' problem return @@ -482,10 +504,13 @@ def _get_schema_columns(self, connection, schema, **kw): ans = {} current_database, _ = self._current_database_schema(connection, **kw) full_schema_name = self._denormalize_quote_join(current_database, schema) + full_columns_descriptions = {} try: schema_primary_keys = self._get_schema_primary_keys( connection, full_schema_name, **kw ) + schema_name = self.denormalize_name(schema) + result = connection.execute( text( """ @@ -506,7 +531,7 @@ def _get_schema_columns(self, connection, schema, **kw): WHERE ic.table_schema=:table_schema ORDER BY ic.ordinal_position""" ), - {"table_schema": self.denormalize_name(schema)}, + {"table_schema": schema_name}, ) except sa_exc.ProgrammingError as pe: if pe.orig.errno == 90030: @@ -536,10 +561,7 @@ def _get_schema_columns(self, connection, schema, **kw): col_type = self.ischema_names.get(coltype, None) col_type_kw = {} if col_type is None: - sa_util.warn( - f"Did not recognize type '{coltype}' of column '{column_name}'" - ) - col_type = sqltypes.NULLTYPE + col_type = NullType else: if issubclass(col_type, FLOAT): col_type_kw["precision"] = numeric_precision @@ -549,6 +571,33 @@ def _get_schema_columns(self, connection, schema, **kw): col_type_kw["scale"] = numeric_scale elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)): col_type_kw["length"] = character_maximum_length + elif issubclass(col_type, StructuredType): + if (schema_name, table_name) not in full_columns_descriptions: + full_columns_descriptions[(schema_name, table_name)] = ( + self.table_columns_as_dict( + self._get_table_columns( + connection, table_name, schema_name + ) + ) + ) + + if ( + (schema_name, table_name) in full_columns_descriptions + and column_name + in full_columns_descriptions[(schema_name, table_name)] + ): + ans[table_name].append( + full_columns_descriptions[(schema_name, table_name)][ + column_name + ] + ) + continue + else: + col_type = NullType + if col_type == NullType: + sa_util.warn( + f"Did not recognize type '{coltype}' of column '{column_name}'" + ) type_instance = col_type(**col_type_kw) @@ -563,11 +612,13 @@ def _get_schema_columns(self, connection, schema, **kw): "autoincrement": is_identity == "YES", "comment": comment, "primary_key": ( - column_name - in schema_primary_keys[table_name]["constrained_columns"] - ) - if current_table_pks - else False, + ( + column_name + in schema_primary_keys[table_name]["constrained_columns"] + ) + if current_table_pks + else False + ), } ) if is_identity == "YES": @@ -581,89 +632,63 @@ def _get_schema_columns(self, connection, schema, **kw): def _get_table_columns(self, connection, table_name, schema=None, **kw): """Get all columns in a table in a schema""" ans = [] - current_database, _ = self._current_database_schema(connection, **kw) - full_schema_name = self._denormalize_quote_join(current_database, schema) - schema_primary_keys = self._get_schema_primary_keys( - connection, full_schema_name, **kw + current_database, default_schema = self._current_database_schema( + connection, **kw ) + schema = schema if schema else default_schema + table_schema = self.denormalize_name(schema) + table_name = self.denormalize_name(table_name) result = connection.execute( text( - """ - SELECT /* sqlalchemy:get_table_columns */ - ic.table_name, - ic.column_name, - ic.data_type, - ic.character_maximum_length, - ic.numeric_precision, - ic.numeric_scale, - ic.is_nullable, - ic.column_default, - ic.is_identity, - ic.comment - FROM information_schema.columns ic - WHERE ic.table_schema=:table_schema - AND ic.table_name=:table_name - ORDER BY ic.ordinal_position""" - ), - { - "table_schema": self.denormalize_name(schema), - "table_name": self.denormalize_name(table_name), - }, + "DESC /* sqlalchemy:_get_schema_columns */" + f" TABLE {table_schema}.{table_name} TYPE = COLUMNS" + ) ) - for ( - table_name, - column_name, - coltype, - character_maximum_length, - numeric_precision, - numeric_scale, - is_nullable, - column_default, - is_identity, - comment, - ) in result: - table_name = self.normalize_name(table_name) + for desc_data in result: + column_name = desc_data[0] + coltype = desc_data[1] + is_nullable = desc_data[3] + column_default = desc_data[4] + primary_key = desc_data[5] + comment = desc_data[9] + column_name = self.normalize_name(column_name) if column_name.startswith("sys_clustering_column"): continue # ignoring clustering column - col_type = self.ischema_names.get(coltype, None) - col_type_kw = {} - if col_type is None: + type_instance = parse_type(coltype) + if isinstance(type_instance, NullType): sa_util.warn( f"Did not recognize type '{coltype}' of column '{column_name}'" ) - col_type = sqltypes.NULLTYPE - else: - if issubclass(col_type, FLOAT): - col_type_kw["precision"] = numeric_precision - col_type_kw["decimal_return_scale"] = numeric_scale - elif issubclass(col_type, sqltypes.Numeric): - col_type_kw["precision"] = numeric_precision - col_type_kw["scale"] = numeric_scale - elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)): - col_type_kw["length"] = character_maximum_length - type_instance = col_type(**col_type_kw) - - current_table_pks = schema_primary_keys.get(table_name) + identity = None + match = re.match( + r"IDENTITY START (?P\d+) INCREMENT (?P\d+) (?PORDER|NOORDER)", + column_default if column_default else "", + ) + if match: + identity = { + "start": int(match.group("start")), + "increment": int(match.group("increment")), + "order_type": match.group("order_type"), + } + is_identity = identity is not None ans.append( { "name": column_name, "type": type_instance, - "nullable": is_nullable == "YES", - "default": column_default, - "autoincrement": is_identity == "YES", + "nullable": is_nullable == "Y", + "default": None if is_identity else column_default, + "autoincrement": is_identity, "comment": comment if comment != "" else None, - "primary_key": ( - column_name - in schema_primary_keys[table_name]["constrained_columns"] - ) - if current_table_pks - else False, + "primary_key": primary_key == "Y", } ) + if is_identity: + ans[-1]["identity"] = identity + # If we didn't find any columns for the table, the table doesn't exist. if len(ans) == 0: raise sa_exc.NoSuchTableError() @@ -686,28 +711,44 @@ def get_columns(self, connection, table_name, schema=None, **kw): raise sa_exc.NoSuchTableError() return schema_columns[normalized_table_name] + def get_prefixes_from_data(self, name_to_index_map, row, **kw): + prefixes_found = [] + for valid_prefix in CustomTablePrefix: + key = f"is_{valid_prefix.name.lower()}" + if key in name_to_index_map and row[name_to_index_map[key]] == "Y": + prefixes_found.append(valid_prefix.name) + return prefixes_found + @reflection.cache - def get_table_names(self, connection, schema=None, **kw): + def _get_schema_tables_info(self, connection, schema=None, **kw): """ - Gets all table names. + Retrieves information about all tables in the specified schema. """ + schema = schema or self.default_schema_name - current_schema = schema - if schema: - cursor = connection.execute( - text( - f"SHOW /* sqlalchemy:get_table_names */ TABLES IN {self._denormalize_quote_join(schema)}" - ) - ) - else: - cursor = connection.execute( - text("SHOW /* sqlalchemy:get_table_names */ TABLES") + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_schema_tables_info */ TABLES IN SCHEMA {self._denormalize_quote_join(schema)}" ) - _, current_schema = self._current_database_schema(connection) + ) - ret = [self.normalize_name(row[1]) for row in cursor] + name_to_index_map = self._map_name_to_idx(result) + tables = {} + for row in result.cursor.fetchall(): + table_name = self.normalize_name(str(row[name_to_index_map["name"]])) + table_prefixes = self.get_prefixes_from_data(name_to_index_map, row) + tables[table_name] = {"prefixes": table_prefixes} - return ret + return tables + + def get_table_names(self, connection, schema=None, **kw): + """ + Gets all table names. + """ + ret = self._get_schema_tables_info( + connection, schema, info_cache=kw.get("info_cache", None) + ).keys() + return list(ret) @reflection.cache def get_view_names(self, connection, schema=None, **kw): @@ -760,17 +801,12 @@ def get_view_definition(self, connection, view_name, schema=None, **kw): def get_temp_table_names(self, connection, schema=None, **kw): schema = schema or self.default_schema_name - if schema: - cursor = connection.execute( - text( - f"SHOW /* sqlalchemy:get_temp_table_names */ TABLES \ - IN {self._denormalize_quote_join(schema)}" - ) - ) - else: - cursor = connection.execute( - text("SHOW /* sqlalchemy:get_temp_table_names */ TABLES") + cursor = connection.execute( + text( + f"SHOW /* sqlalchemy:get_temp_table_names */ TABLES \ + IN SCHEMA {self._denormalize_quote_join(schema)}" ) + ) ret = [] n2i = self.__class__._map_name_to_idx(cursor) @@ -844,18 +880,118 @@ def get_table_comment(self, connection, table_name, schema=None, **kw): result = self._get_view_comment(connection, table_name, schema) return { - "text": result._mapping["comment"] - if result and result._mapping["comment"] - else None + "text": ( + result._mapping["comment"] + if result and result._mapping["comment"] + else None + ) } + def get_table_names_with_prefix( + self, + connection, + *, + schema, + prefix, + **kw, + ): + tables_data = self._get_schema_tables_info(connection, schema, **kw) + table_names = [] + for table_name, tables_data_value in tables_data.items(): + if prefix in tables_data_value["prefixes"]: + table_names.append(table_name) + return table_names + + def get_multi_indexes( + self, + connection, + *, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw, + ): + """ + Gets the indexes definition + """ + schema = schema or self.default_schema_name + hybrid_table_names = self.get_table_names_with_prefix( + connection, + schema=schema, + prefix=CustomTablePrefix.HYBRID.name, + info_cache=kw.get("info_cache", None), + ) + if len(hybrid_table_names) == 0: + return [] + + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_indexes */ INDEXES IN SCHEMA {self._denormalize_quote_join(schema)}" + ) + ) + + n2i = self._map_name_to_idx(result) + indexes = {} + + for row in result.cursor.fetchall(): + table_name = self.normalize_name(str(row[n2i["table"]])) + if ( + row[n2i["name"]] == f'SYS_INDEX_{row[n2i["table"]]}_PRIMARY' + or table_name not in filter_names + or table_name not in hybrid_table_names + ): + continue + index = { + "name": row[n2i["name"]], + "unique": row[n2i["is_unique"]] == "Y", + "column_names": [ + self.normalize_name(column) + for column in parse_index_columns(row[n2i["columns"]]) + ], + "include_columns": [ + self.normalize_name(column) + for column in parse_index_columns(row[n2i["included_columns"]]) + ], + "dialect_options": {}, + } + + if (schema, table_name) in indexes: + indexes[(schema, table_name)] = indexes[(schema, table_name)].append( + index + ) + else: + indexes[(schema, table_name)] = [index] + + return list(indexes.items()) + + def _value_or_default(self, data, table, schema): + table = self.normalize_name(str(table)) + dic_data = dict(data) + if (schema, table) in dic_data: + return dic_data[(schema, table)] + else: + return [] + + @reflection.cache + def get_indexes(self, connection, tablename, schema, **kw): + """ + Gets the indexes definition + """ + table_name = self.normalize_name(str(tablename)) + data = self.get_multi_indexes( + connection=connection, schema=schema, filter_names=[table_name], **kw + ) + + return self._value_or_default(data, table_name, schema) + def connect(self, *cargs, **cparams): return ( super().connect( *cargs, - **_update_connection_application_name(**cparams) - if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME - else cparams, + **( + _update_connection_application_name(**cparams) + if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME + else cparams + ), ) if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME else super().connect(*cargs, **cparams) @@ -864,8 +1000,12 @@ def connect(self, *cargs, **cparams): @sa_vnt.listens_for(Table, "before_create") def check_table(table, connection, _ddl_runner, **kw): + from .sql.custom_schema.hybrid_table import HybridTable + + if HybridTable.is_equal_type(table): # noqa + return True if isinstance(_ddl_runner.dialect, SnowflakeDialect) and table.indexes: - raise NotImplementedError("Snowflake does not support indexes") + raise NotImplementedError("Only Snowflake Hybrid Tables supports indexes") dialect = SnowflakeDialect diff --git a/src/snowflake/sqlalchemy/sql/__init__.py b/src/snowflake/sqlalchemy/sql/__init__.py new file mode 100644 index 00000000..ef416f64 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py new file mode 100644 index 00000000..cbc75ebc --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py @@ -0,0 +1,9 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from .dynamic_table import DynamicTable +from .hybrid_table import HybridTable +from .iceberg_table import IcebergTable +from .snowflake_table import SnowflakeTable + +__all__ = ["DynamicTable", "HybridTable", "IcebergTable", "SnowflakeTable"] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/clustered_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/clustered_table.py new file mode 100644 index 00000000..6c0904a8 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/clustered_table.py @@ -0,0 +1,37 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from typing import Any, Optional + +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from .custom_table_base import CustomTableBase +from .options.as_query_option import AsQueryOption +from .options.cluster_by_option import ClusterByOption, ClusterByOptionType +from .options.table_option import TableOptionKey + + +class ClusteredTableBase(CustomTableBase): + + @property + def cluster_by(self) -> Optional[AsQueryOption]: + return self._get_dialect_option(TableOptionKey.CLUSTER_BY) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + cluster_by: ClusterByOptionType = None, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + + options = [ + ClusterByOption.create(cluster_by), + ] + + kw.update(self._as_dialect_options(options)) + super().__init__(name, metadata, *args, **kw) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py new file mode 100644 index 00000000..6f7ee0c5 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py @@ -0,0 +1,127 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import typing +from typing import Any, List + +from sqlalchemy.sql.schema import MetaData, SchemaItem, Table + +from ..._constants import DIALECT_NAME +from ...compat import IS_VERSION_20 +from ...custom_commands import NoneType +from ...custom_types import StructuredType +from ...exc import ( + MultipleErrors, + NoPrimaryKeyError, + RequiredParametersNotProvidedError, + StructuredTypeNotSupportedInTableColumnsError, + UnsupportedPrimaryKeysAndForeignKeysError, +) +from .custom_table_prefix import CustomTablePrefix +from .options.invalid_table_option import InvalidTableOption +from .options.table_option import TableOption, TableOptionKey + + +class CustomTableBase(Table): + __table_prefixes__: typing.List[CustomTablePrefix] = [] + _support_primary_and_foreign_keys: bool = True + _enforce_primary_keys: bool = False + _required_parameters: List[TableOptionKey] = [] + _support_structured_types: bool = False + + @property + def table_prefixes(self) -> typing.List[str]: + return [prefix.name for prefix in self.__table_prefixes__] + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if len(self.__table_prefixes__) > 0: + prefixes = kw.get("prefixes", []) + self.table_prefixes + kw.update(prefixes=prefixes) + + if not IS_VERSION_20 and hasattr(super(), "_init"): + kw.pop("_no_init", True) + super()._init(name, metadata, *args, **kw) + else: + super().__init__(name, metadata, *args, **kw) + + if not kw.get("autoload_with", False): + self._validate_table() + + def _validate_table(self): + exceptions: List[Exception] = [] + + columns_validation = self.__validate_columns() + if columns_validation is not None: + exceptions.append(columns_validation) + + for _, option in self.dialect_options[DIALECT_NAME].items(): + if isinstance(option, InvalidTableOption): + exceptions.append(option.exception) + + if isinstance(self.key, NoneType) and self._enforce_primary_keys: + exceptions.append(NoPrimaryKeyError(self.__class__.__name__)) + missing_parameters: List[str] = [] + + for required_parameter in self._required_parameters: + if isinstance(self._get_dialect_option(required_parameter), NoneType): + missing_parameters.append(required_parameter.name.lower()) + if missing_parameters: + exceptions.append( + RequiredParametersNotProvidedError( + self.__class__.__name__, missing_parameters + ) + ) + + if not self._support_primary_and_foreign_keys and ( + self.primary_key or self.foreign_keys + ): + exceptions.append( + UnsupportedPrimaryKeysAndForeignKeysError(self.__class__.__name__) + ) + + if len(exceptions) > 1: + exceptions.sort(key=lambda e: str(e)) + raise MultipleErrors(exceptions) + elif len(exceptions) == 1: + raise exceptions[0] + + def __validate_columns(self): + for column in self.columns: + if not self._support_structured_types and isinstance( + column.type, StructuredType + ): + return StructuredTypeNotSupportedInTableColumnsError( + self.__class__.__name__, self.name, column.name + ) + + def _get_dialect_option( + self, option_name: TableOptionKey + ) -> typing.Optional[TableOption]: + if option_name.value in self.dialect_options[DIALECT_NAME]: + return self.dialect_options[DIALECT_NAME][option_name.value] + return None + + def _as_dialect_options( + self, table_options: List[TableOption] + ) -> typing.Dict[str, TableOption]: + result = {} + for table_option in table_options: + if isinstance(table_option, TableOption) and isinstance( + table_option.option_name, str + ): + result[DIALECT_NAME + "_" + table_option.option_name] = table_option + return result + + @classmethod + def is_equal_type(cls, table: Table) -> bool: + for prefix in cls.__table_prefixes__: + if prefix.name not in table._prefixes: + return False + + return True diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py new file mode 100644 index 00000000..de7835d1 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from enum import Enum + + +class CustomTablePrefix(Enum): + DEFAULT = 0 + EXTERNAL = 1 + EVENT = 2 + HYBRID = 3 + ICEBERG = 4 + DYNAMIC = 5 diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py new file mode 100644 index 00000000..91c379f0 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py @@ -0,0 +1,117 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import typing +from typing import Any, Union + +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from .custom_table_prefix import CustomTablePrefix +from .options import ( + IdentifierOption, + IdentifierOptionType, + KeywordOptionType, + TableOptionKey, + TargetLagOption, + TargetLagOptionType, +) +from .options.keyword_option import KeywordOption +from .table_from_query import TableFromQueryBase + + +class DynamicTable(TableFromQueryBase): + """ + A class representing a dynamic table with configurable options and settings. + + The `DynamicTable` class allows for the creation and querying of tables with + specific options, such as `Warehouse` and `TargetLag`. + + While it does not support reflection at this time, it provides a flexible + interface for creating dynamic tables and management. + + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + Example using option values: + DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + target_lag=(1, TimeUnit.HOURS), + warehouse='warehouse_name', + refresh_mode=SnowflakeKeyword.AUTO + as_query="SELECT id, name from test_table_1;" + ) + + Example using explicit options: + DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + target_lag=TargetLag(1, TimeUnit.HOURS), + warehouse=Identifier('warehouse_name'), + refresh_mode=KeywordOption(SnowflakeKeyword.AUTO) + as_query=AsQuery("SELECT id, name from test_table_1;") + ) + """ + + __table_prefixes__ = [CustomTablePrefix.DYNAMIC] + _support_primary_and_foreign_keys = False + _required_parameters = [ + TableOptionKey.WAREHOUSE, + TableOptionKey.AS_QUERY, + TableOptionKey.TARGET_LAG, + ] + + @property + def warehouse(self) -> typing.Optional[IdentifierOption]: + return self._get_dialect_option(TableOptionKey.WAREHOUSE) + + @property + def target_lag(self) -> typing.Optional[TargetLagOption]: + return self._get_dialect_option(TableOptionKey.TARGET_LAG) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + warehouse: IdentifierOptionType = None, + target_lag: Union[TargetLagOptionType, KeywordOptionType] = None, + refresh_mode: KeywordOptionType = None, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + + options = [ + IdentifierOption.create(TableOptionKey.WAREHOUSE, warehouse), + TargetLagOption.create(target_lag), + KeywordOption.create(TableOptionKey.REFRESH_MODE, refresh_mode), + ] + + kw.update(self._as_dialect_options(options)) + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + self.__init__(name, metadata, *args, _no_init=False, **kw) + + def __repr__(self) -> str: + return "DynamicTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [repr(self.target_lag)] + + [repr(self.warehouse)] + + [repr(self.cluster_by)] + + [repr(self.as_query)] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py new file mode 100644 index 00000000..b3a55f20 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py @@ -0,0 +1,63 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from typing import Any + +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from .custom_table_base import CustomTableBase +from .custom_table_prefix import CustomTablePrefix + + +class HybridTable(CustomTableBase): + """ + A class representing a hybrid table with configurable options and settings. + + The `HybridTable` class allows for the creation and querying of OLTP Snowflake Tables . + + While it does not support reflection at this time, it provides a flexible + interface for creating hybrid tables and management. + + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-hybrid-table + + Example usage: + HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String) + ) + """ + + __table_prefixes__ = [CustomTablePrefix.HYBRID] + _enforce_primary_keys: bool = True + _support_structured_types = True + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + self.__init__(name, metadata, *args, _no_init=False, **kw) + + def __repr__(self) -> str: + return "HybridTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py new file mode 100644 index 00000000..4f62d4f2 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py @@ -0,0 +1,102 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import typing +from typing import Any + +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from .custom_table_prefix import CustomTablePrefix +from .options import LiteralOption, LiteralOptionType, TableOptionKey +from .table_from_query import TableFromQueryBase + + +class IcebergTable(TableFromQueryBase): + """ + A class representing an iceberg table with configurable options and settings. + + While it does not support reflection at this time, it provides a flexible + interface for creating iceberg tables and management. + + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-iceberg-table + + Example using option values: + + IcebergTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + external_volume='my_external_volume', + base_location='my_iceberg_table'" + ) + + Example using explicit options: + DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + external_volume=LiteralOption('my_external_volume') + base_location=LiteralOption('my_iceberg_table') + ) + """ + + __table_prefixes__ = [CustomTablePrefix.ICEBERG] + _support_structured_types = True + + @property + def external_volume(self) -> typing.Optional[LiteralOption]: + return self._get_dialect_option(TableOptionKey.EXTERNAL_VOLUME) + + @property + def base_location(self) -> typing.Optional[LiteralOption]: + return self._get_dialect_option(TableOptionKey.BASE_LOCATION) + + @property + def catalog(self) -> typing.Optional[LiteralOption]: + return self._get_dialect_option(TableOptionKey.CATALOG) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + external_volume: LiteralOptionType = None, + base_location: LiteralOptionType = None, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + + options = [ + LiteralOption.create(TableOptionKey.EXTERNAL_VOLUME, external_volume), + LiteralOption.create(TableOptionKey.BASE_LOCATION, base_location), + LiteralOption.create(TableOptionKey.CATALOG, "SNOWFLAKE"), + ] + + kw.update(self._as_dialect_options(options)) + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + self.__init__(name, metadata, *args, _no_init=False, **kw) + + def __repr__(self) -> str: + return "IcebergTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [repr(self.external_volume)] + + [repr(self.base_location)] + + [repr(self.catalog)] + + [repr(self.cluster_by)] + + [repr(self.as_query)] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py new file mode 100644 index 00000000..e94ea46b --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py @@ -0,0 +1,33 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from .as_query_option import AsQueryOption, AsQueryOptionType +from .cluster_by_option import ClusterByOption, ClusterByOptionType +from .identifier_option import IdentifierOption, IdentifierOptionType +from .keyword_option import KeywordOption, KeywordOptionType +from .keywords import SnowflakeKeyword +from .literal_option import LiteralOption, LiteralOptionType +from .table_option import TableOptionKey +from .target_lag_option import TargetLagOption, TargetLagOptionType, TimeUnit + +__all__ = [ + # Options + "IdentifierOption", + "LiteralOption", + "KeywordOption", + "AsQueryOption", + "TargetLagOption", + "ClusterByOption", + # Enums + "TimeUnit", + "SnowflakeKeyword", + "TableOptionKey", + # Types + "IdentifierOptionType", + "LiteralOptionType", + "AsQueryOptionType", + "TargetLagOptionType", + "KeywordOptionType", + "ClusterByOptionType", +] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py new file mode 100644 index 00000000..93994abc --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py @@ -0,0 +1,63 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional, Union + +from sqlalchemy.sql import Selectable + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class AsQueryOption(TableOption): + """Class to represent an AS clause in tables. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-table#create-table-as-select-also-referred-to-as-ctas + + Example: + as_query=AsQueryOption('select name, address from existing_table where name = "test"') + + is equivalent to: + + as select name, address from existing_table where name = "test" + """ + + def __init__(self, query: Union[str, Selectable]) -> None: + super().__init__() + self._name: TableOptionKey = TableOptionKey.AS_QUERY + self.query = query + + @staticmethod + def create( + value: Optional[Union["AsQueryOption", str, Selectable]] + ) -> "TableOption": + if isinstance(value, (NoneType, AsQueryOption)): + return value + if isinstance(value, (str, Selectable)): + return AsQueryOption(value) + return TableOption._get_invalid_table_option( + TableOptionKey.AS_QUERY, + str(type(value).__name__), + [AsQueryOption.__name__, str.__name__, Selectable.__name__], + ) + + def template(self) -> str: + return "AS %s" + + @property + def priority(self) -> Priority: + return Priority.LOWEST + + def __get_expression(self): + if isinstance(self.query, Selectable): + return self.query.compile(compile_kwargs={"literal_binds": True}) + return self.query + + def _render(self, compiler) -> str: + return self.template() % (self.__get_expression()) + + def __repr__(self) -> str: + return "AsQueryOption(%s)" % self.__get_expression() + + +AsQueryOptionType = Union[AsQueryOption, str, Selectable] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/cluster_by_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/cluster_by_option.py new file mode 100644 index 00000000..b92029bb --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/cluster_by_option.py @@ -0,0 +1,58 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import List, Union + +from sqlalchemy.sql.expression import TextClause + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class ClusterByOption(TableOption): + """Class to represent the cluster by clause in tables. + For further information on this clause, please refer to: https://docs.snowflake.com/en/user-guide/tables-clustering-keys + Example: + cluster_by=ClusterByOption('name', text('id > 0')) + + is equivalent to: + + cluster by (name, id > 0) + """ + + def __init__(self, *expressions: Union[str, TextClause]) -> None: + super().__init__() + self._name: TableOptionKey = TableOptionKey.CLUSTER_BY + self.expressions = expressions + + @staticmethod + def create(value: "ClusterByOptionType") -> "TableOption": + if isinstance(value, (NoneType, ClusterByOption)): + return value + if isinstance(value, List): + return ClusterByOption(*value) + return TableOption._get_invalid_table_option( + TableOptionKey.CLUSTER_BY, + str(type(value).__name__), + [ClusterByOption.__name__, list.__name__], + ) + + def template(self) -> str: + return f"{self.option_name.upper()} (%s)" + + @property + def priority(self) -> Priority: + return Priority.HIGH + + def __get_expression(self): + return ", ".join([str(expression) for expression in self.expressions]) + + def _render(self, compiler) -> str: + return self.template() % (self.__get_expression()) + + def __repr__(self) -> str: + return "ClusterByOption(%s)" % self.__get_expression() + + +ClusterByOptionType = Union[ClusterByOption, List[Union[str, TextClause]]] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py new file mode 100644 index 00000000..b296898b --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py @@ -0,0 +1,63 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class IdentifierOption(TableOption): + """Class to represent an identifier option in Snowflake Tables. + + Example: + warehouse = IdentifierOption('my_warehouse') + + is equivalent to: + + WAREHOUSE = my_warehouse + """ + + def __init__(self, value: Union[str]) -> None: + super().__init__() + self.value: str = value + + @property + def priority(self): + return Priority.HIGH + + @staticmethod + def create( + name: TableOptionKey, value: Optional[Union[str, "IdentifierOption"]] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return None + + if isinstance(value, str): + value = IdentifierOption(value) + + if isinstance(value, IdentifierOption): + value._set_option_name(name) + return value + + return TableOption._get_invalid_table_option( + name, str(type(value).__name__), [IdentifierOption.__name__, str.__name__] + ) + + def template(self) -> str: + return f"{self.option_name.upper()} = %s" + + def _render(self, compiler) -> str: + return self.template() % self.value + + def __repr__(self) -> str: + option_name = ( + f", table_option_key={self.option_name}" + if not isinstance(self.option_name, NoneType) + else "" + ) + return f"IdentifierOption(value='{self.value}'{option_name})" + + +IdentifierOptionType = Union[IdentifierOption, str, int] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py new file mode 100644 index 00000000..2bdc9dd3 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py @@ -0,0 +1,25 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional + +from .table_option import TableOption, TableOptionKey + + +class InvalidTableOption(TableOption): + """Class to store errors and raise them after table initialization in order to avoid recursion error.""" + + def __init__(self, name: TableOptionKey, value: Exception) -> None: + super().__init__() + self.exception: Exception = value + self._name = name + + @staticmethod + def create(name: TableOptionKey, value: Exception) -> Optional[TableOption]: + return InvalidTableOption(name, value) + + def _render(self, compiler) -> str: + raise self.exception + + def __repr__(self) -> str: + return f"ErrorOption(value='{self.exception}')" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py new file mode 100644 index 00000000..ff6b444d --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py @@ -0,0 +1,65 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .keywords import SnowflakeKeyword +from .table_option import Priority, TableOption, TableOptionKey + + +class KeywordOption(TableOption): + """Class to represent a keyword option in Snowflake Tables. + + Example: + target_lag = KeywordOption(SnowflakeKeyword.DOWNSTREAM) + + is equivalent to: + + TARGET_LAG = DOWNSTREAM + """ + + def __init__(self, value: Union[SnowflakeKeyword]) -> None: + super().__init__() + self.value: str = value.value + + @property + def priority(self): + return Priority.HIGH + + def template(self) -> str: + return f"{self.option_name.upper()} = %s" + + def _render(self, compiler) -> str: + return self.template() % self.value.upper() + + @staticmethod + def create( + name: TableOptionKey, value: Optional[Union[SnowflakeKeyword, "KeywordOption"]] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return value + if isinstance(value, SnowflakeKeyword): + value = KeywordOption(value) + + if isinstance(value, KeywordOption): + value._set_option_name(name) + return value + + return TableOption._get_invalid_table_option( + name, + str(type(value).__name__), + [KeywordOption.__name__, SnowflakeKeyword.__name__], + ) + + def __repr__(self) -> str: + option_name = ( + f", table_option_key={self.option_name}" + if isinstance(self.option_name, NoneType) + else "" + ) + return f"KeywordOption(value='{self.value}'{option_name})" + + +KeywordOptionType = Union[KeywordOption, SnowflakeKeyword] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py new file mode 100644 index 00000000..f4ba87ea --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py @@ -0,0 +1,14 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from enum import Enum + + +class SnowflakeKeyword(Enum): + # TARGET_LAG + DOWNSTREAM = "DOWNSTREAM" + + # REFRESH_MODE + AUTO = "AUTO" + FULL = "FULL" + INCREMENTAL = "INCREMENTAL" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py new file mode 100644 index 00000000..55dd7675 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Any, Optional, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class LiteralOption(TableOption): + """Class to represent a literal option in Snowflake Table. + + Example: + warehouse = LiteralOption('my_warehouse') + + is equivalent to: + + WAREHOUSE = 'my_warehouse' + """ + + def __init__(self, value: Union[int, str]) -> None: + super().__init__() + self.value: Any = value + + @property + def priority(self): + return Priority.HIGH + + @staticmethod + def create( + name: TableOptionKey, value: Optional[Union[str, int, "LiteralOption"]] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return None + if isinstance(value, (str, int)): + value = LiteralOption(value) + + if isinstance(value, LiteralOption): + value._set_option_name(name) + return value + + return TableOption._get_invalid_table_option( + name, + str(type(value).__name__), + [LiteralOption.__name__, str.__name__, int.__name__], + ) + + def template(self) -> str: + if isinstance(self.value, int): + return f"{self.option_name.upper()} = %d" + else: + return f"{self.option_name.upper()} = '%s'" + + def _render(self, compiler) -> str: + return self.template() % self.value + + def __repr__(self) -> str: + option_name = ( + f", table_option_key={self.option_name}" + if not isinstance(self.option_name, NoneType) + else "" + ) + return f"LiteralOption(value='{self.value}'{option_name})" + + +LiteralOptionType = Union[LiteralOption, str, int] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py new file mode 100644 index 00000000..5ebb4817 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py @@ -0,0 +1,84 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from enum import Enum +from typing import List, Optional + +from snowflake.sqlalchemy import exc +from snowflake.sqlalchemy.custom_commands import NoneType + + +class Priority(Enum): + LOWEST = 0 + VERY_LOW = 1 + LOW = 2 + MEDIUM = 4 + HIGH = 6 + VERY_HIGH = 7 + HIGHEST = 8 + + +class TableOption: + + def __init__(self) -> None: + self._name: Optional[TableOptionKey] = None + + @property + def option_name(self) -> str: + if isinstance(self._name, NoneType): + return None + return str(self._name.value) + + def _set_option_name(self, name: Optional["TableOptionKey"]): + self._name = name + + @property + def priority(self) -> Priority: + return Priority.MEDIUM + + @staticmethod + def create(**kwargs) -> "TableOption": + raise NotImplementedError + + @staticmethod + def _get_invalid_table_option( + parameter_name: "TableOptionKey", input_type: str, expected_types: List[str] + ) -> "TableOption": + from .invalid_table_option import InvalidTableOption + + return InvalidTableOption( + parameter_name, + exc.InvalidTableParameterTypeError( + parameter_name.value, input_type, expected_types + ), + ) + + def _validate_option(self): + if isinstance(self.option_name, NoneType): + raise exc.OptionKeyNotProvidedError(self.__class__.__name__) + + def template(self) -> str: + return f"{self.option_name.upper()} = %s" + + def render_option(self, compiler) -> str: + self._validate_option() + return self._render(compiler) + + def _render(self, compiler) -> str: + raise NotImplementedError + + +class TableOptionKey(Enum): + AS_QUERY = "as_query" + BASE_LOCATION = "base_location" + CATALOG = "catalog" + CATALOG_SYNC = "catalog_sync" + CLUSTER_BY = "cluster by" + DATA_RETENTION_TIME_IN_DAYS = "data_retention_time_in_days" + DEFAULT_DDL_COLLATION = "default_ddl_collation" + EXTERNAL_VOLUME = "external_volume" + MAX_DATA_EXTENSION_TIME_IN_DAYS = "max_data_extension_time_in_days" + REFRESH_MODE = "refresh_mode" + STORAGE_SERIALIZATION_POLICY = "storage_serialization_policy" + TARGET_LAG = "target_lag" + WAREHOUSE = "warehouse" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py new file mode 100644 index 00000000..7c1c0825 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py @@ -0,0 +1,94 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# from enum import Enum +from enum import Enum +from typing import Optional, Tuple, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .keyword_option import KeywordOption, KeywordOptionType +from .keywords import SnowflakeKeyword +from .table_option import Priority, TableOption, TableOptionKey + + +class TimeUnit(Enum): + SECONDS = "seconds" + MINUTES = "minutes" + HOURS = "hours" + DAYS = "days" + + +class TargetLagOption(TableOption): + """Class to represent the target lag clause in Dynamic Tables. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + Example using the time and unit parameters: + + target_lag = TargetLagOption(10, TimeUnit.SECONDS) + + is equivalent to: + + TARGET_LAG = '10 SECONDS' + + Example using keyword parameter: + + target_lag = KeywordOption(SnowflakeKeyword.DOWNSTREAM) + + is equivalent to: + + TARGET_LAG = DOWNSTREAM + + """ + + def __init__( + self, + time: Optional[int] = 0, + unit: Optional[TimeUnit] = TimeUnit.MINUTES, + ) -> None: + super().__init__() + self.time = time + self.unit = unit + self._name: TableOptionKey = TableOptionKey.TARGET_LAG + + @staticmethod + def create( + value: Union["TargetLagOption", Tuple[int, TimeUnit], KeywordOptionType] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return value + + if isinstance(value, Tuple): + time, unit = value + value = TargetLagOption(time, unit) + + if isinstance(value, TargetLagOption): + return value + + if isinstance(value, (KeywordOption, SnowflakeKeyword)): + return KeywordOption.create(TableOptionKey.TARGET_LAG, value) + + return TableOption._get_invalid_table_option( + TableOptionKey.TARGET_LAG, + str(type(value).__name__), + [ + TargetLagOption.__name__, + f"Tuple[int, {TimeUnit.__name__}])", + SnowflakeKeyword.__name__, + ], + ) + + def __get_expression(self): + return f"'{str(self.time)} {str(self.unit.value)}'" + + @property + def priority(self) -> Priority: + return Priority.HIGH + + def _render(self, compiler) -> str: + return self.template() % (self.__get_expression()) + + def __repr__(self) -> str: + return "TargetLagOption(%s)" % self.__get_expression() + + +TargetLagOptionType = Union[TargetLagOption, Tuple[int, TimeUnit]] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/snowflake_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/snowflake_table.py new file mode 100644 index 00000000..56a14c83 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/snowflake_table.py @@ -0,0 +1,70 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from typing import Any + +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from .table_from_query import TableFromQueryBase + + +class SnowflakeTable(TableFromQueryBase): + """ + A class representing a table in Snowflake with configurable options and settings. + + While it does not support reflection at this time, it provides a flexible + interface for creating tables and management. + + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-table + Example usage: + + SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by = ["id", text("name > 5")] + ) + + Example using explict options: + + SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by = ClusterByOption("id", text("name > 5")) + ) + + """ + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + self.__init__(name, metadata, *args, _no_init=False, **kw) + + def __repr__(self) -> str: + return "SnowflakeTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [repr(self.cluster_by)] + + [repr(self.as_query)] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py new file mode 100644 index 00000000..cbd65de3 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py @@ -0,0 +1,54 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import typing +from typing import Any, Optional + +from sqlalchemy.sql import Selectable +from sqlalchemy.sql.schema import Column, MetaData, SchemaItem + +from .clustered_table import ClusteredTableBase +from .options.as_query_option import AsQueryOption, AsQueryOptionType +from .options.table_option import TableOptionKey + + +class TableFromQueryBase(ClusteredTableBase): + + @property + def as_query(self) -> Optional[AsQueryOption]: + return self._get_dialect_option(TableOptionKey.AS_QUERY) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + as_query: AsQueryOptionType = None, + **kw: Any, + ) -> None: + items = [item for item in args] + as_query = AsQueryOption.create(as_query) # noqa + kw.update(self._as_dialect_options([as_query])) + if ( + isinstance(as_query, AsQueryOption) + and isinstance(as_query.query, Selectable) + and not self.__has_defined_columns(items) + ): + columns = self.__create_columns_from_selectable(as_query.query) + args = items + columns + super().__init__(name, metadata, *args, **kw) + + def __has_defined_columns(self, items: typing.List[SchemaItem]) -> bool: + for item in items: + if isinstance(item, Column): + return True + + def __create_columns_from_selectable( + self, selectable: Selectable + ) -> Optional[typing.List[Column]]: + if not isinstance(selectable, Selectable): + return + columns: typing.List[Column] = [] + for _, c in selectable.exported_columns.items(): + columns += [Column(c.name, c.type)] + return columns diff --git a/src/snowflake/sqlalchemy/util.py b/src/snowflake/sqlalchemy/util.py index 56b5de5b..a1aefff9 100644 --- a/src/snowflake/sqlalchemy/util.py +++ b/src/snowflake/sqlalchemy/util.py @@ -3,13 +3,23 @@ # import re +from itertools import chain from typing import Any from urllib.parse import quote_plus -from sqlalchemy import exc +from sqlalchemy import exc, inspection, sql +from sqlalchemy.exc import NoForeignKeysError +from sqlalchemy.orm.interfaces import MapperProperty +from sqlalchemy.orm.util import _ORMJoin as sa_orm_util_ORMJoin +from sqlalchemy.orm.util import attributes +from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql.base import _expand_cloned, _from_objects +from sqlalchemy.sql.elements import _find_columns +from sqlalchemy.sql.selectable import Join, Lateral, coercions, operators, roles from snowflake.connector.compat import IS_STR from snowflake.connector.connection import SnowflakeConnection +from snowflake.sqlalchemy import compat from ._constants import ( APPLICATION_NAME, @@ -104,3 +114,231 @@ def _update_connection_application_name(**conn_kwargs: Any) -> Any: if PARAM_INTERNAL_APPLICATION_VERSION not in conn_kwargs: conn_kwargs[PARAM_INTERNAL_APPLICATION_VERSION] = SNOWFLAKE_SQLALCHEMY_VERSION return conn_kwargs + + +def parse_url_boolean(value: str) -> bool: + if value == "True": + return True + elif value == "False": + return False + else: + raise ValueError(f"Invalid boolean value detected: '{value}'") + + +def parse_url_integer(value: str) -> int: + try: + return int(value) + except ValueError as e: + raise ValueError(f"Invalid int value detected: '{value}") from e + + +# handle Snowflake BCR bcr-1057 +# the BCR impacts sqlalchemy.orm.context.ORMSelectCompileState and sqlalchemy.sql.selectable.SelectState +# which used the 'sqlalchemy.util.preloaded.sql_util.find_left_clause_to_join_from' method that +# can not handle the BCR change, we implement it in a way that lateral join does not need onclause +def _find_left_clause_to_join_from(clauses, join_to, onclause): + """Given a list of FROM clauses, a selectable, + and optional ON clause, return a list of integer indexes from the + clauses list indicating the clauses that can be joined from. + + The presence of an "onclause" indicates that at least one clause can + definitely be joined from; if the list of clauses is of length one + and the onclause is given, returns that index. If the list of clauses + is more than length one, and the onclause is given, attempts to locate + which clauses contain the same columns. + + """ + idx = [] + selectables = set(_from_objects(join_to)) + + # if we are given more than one target clause to join + # from, use the onclause to provide a more specific answer. + # otherwise, don't try to limit, after all, "ON TRUE" is a valid + # on clause + if len(clauses) > 1 and onclause is not None: + resolve_ambiguity = True + cols_in_onclause = _find_columns(onclause) + else: + resolve_ambiguity = False + cols_in_onclause = None + + for i, f in enumerate(clauses): + for s in selectables.difference([f]): + if resolve_ambiguity: + if set(f.c).union(s.c).issuperset(cols_in_onclause): + idx.append(i) + break + elif onclause is not None or Join._can_join(f, s): + idx.append(i) + break + elif onclause is None and isinstance(s, Lateral): + # in snowflake, onclause is not accepted for lateral due to BCR change: + # https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057 + # sqlalchemy only allows join with on condition. + # to adapt to snowflake syntax change, + # we make the change such that when oncaluse is None and the right part is + # Lateral, we append the index indicating Lateral clause can be joined from with without onclause. + idx.append(i) + break + + if len(idx) > 1: + # this is the same "hide froms" logic from + # Selectable._get_display_froms + toremove = set(chain(*[_expand_cloned(f._hide_froms) for f in clauses])) + idx = [i for i in idx if clauses[i] not in toremove] + + # onclause was given and none of them resolved, so assume + # all indexes can match + if not idx and onclause is not None: + return range(len(clauses)) + else: + return idx + + +class _Snowflake_ORMJoin(sa_orm_util_ORMJoin): + def __init__( + self, + left, + right, + onclause=None, + isouter=False, + full=False, + _left_memo=None, + _right_memo=None, + _extra_criteria=(), + ): + left_info = inspection.inspect(left) + + right_info = inspection.inspect(right) + adapt_to = right_info.selectable + + # used by joined eager loader + self._left_memo = _left_memo + self._right_memo = _right_memo + + # legacy, for string attr name ON clause. if that's removed + # then the "_joined_from_info" concept can go + left_orm_info = getattr(left, "_joined_from_info", left_info) + self._joined_from_info = right_info + if isinstance(onclause, compat.string_types): + onclause = getattr(left_orm_info.entity, onclause) + # #### + + if isinstance(onclause, attributes.QueryableAttribute): + on_selectable = onclause.comparator._source_selectable() + prop = onclause.property + _extra_criteria += onclause._extra_criteria + elif isinstance(onclause, MapperProperty): + # used internally by joined eager loader...possibly not ideal + prop = onclause + on_selectable = prop.parent.selectable + else: + prop = None + + if prop: + left_selectable = left_info.selectable + + if sql_util.clause_is_present(on_selectable, left_selectable): + adapt_from = on_selectable + else: + adapt_from = left_selectable + + ( + pj, + sj, + source, + dest, + secondary, + target_adapter, + ) = prop._create_joins( + source_selectable=adapt_from, + dest_selectable=adapt_to, + source_polymorphic=True, + of_type_entity=right_info, + alias_secondary=True, + extra_criteria=_extra_criteria, + ) + + if sj is not None: + if isouter: + # note this is an inner join from secondary->right + right = sql.join(secondary, right, sj) + onclause = pj + else: + left = sql.join(left, secondary, pj, isouter) + onclause = sj + else: + onclause = pj + + self._target_adapter = target_adapter + + # we don't use the normal coercions logic for _ORMJoin + # (probably should), so do some gymnastics to get the entity. + # logic here is for #8721, which was a major bug in 1.4 + # for almost two years, not reported/fixed until 1.4.43 (!) + if left_info.is_selectable: + parententity = left_selectable._annotations.get("parententity", None) + elif left_info.is_mapper or left_info.is_aliased_class: + parententity = left_info + else: + parententity = None + + if parententity is not None: + self._annotations = self._annotations.union( + {"parententity": parententity} + ) + + augment_onclause = onclause is None and _extra_criteria + # handle Snowflake BCR bcr-1057 + _Snowflake_Selectable_Join.__init__(self, left, right, onclause, isouter, full) + + if augment_onclause: + self.onclause &= sql.and_(*_extra_criteria) + + if ( + not prop + and getattr(right_info, "mapper", None) + and right_info.mapper.single + ): + # if single inheritance target and we are using a manual + # or implicit ON clause, augment it the same way we'd augment the + # WHERE. + single_crit = right_info.mapper._single_table_criterion + if single_crit is not None: + if right_info.is_aliased_class: + single_crit = right_info._adapter.traverse(single_crit) + self.onclause = self.onclause & single_crit + + +class _Snowflake_Selectable_Join(Join): + def __init__(self, left, right, onclause=None, isouter=False, full=False): + """Construct a new :class:`_expression.Join`. + + The usual entrypoint here is the :func:`_expression.join` + function or the :meth:`_expression.FromClause.join` method of any + :class:`_expression.FromClause` object. + + """ + self.left = coercions.expect(roles.FromClauseRole, left, deannotate=True) + self.right = coercions.expect( + roles.FromClauseRole, right, deannotate=True + ).self_group() + + if onclause is None: + try: + self.onclause = self._match_primaries(self.left, self.right) + except NoForeignKeysError: + # handle Snowflake BCR bcr-1057 + if isinstance(self.right, Lateral): + self.onclause = None + else: + raise + else: + # note: taken from If91f61527236fd4d7ae3cad1f24c38be921c90ba + # not merged yet + self.onclause = coercions.expect(roles.OnClauseRole, onclause).self_group( + against=operators._asbool + ) + + self.isouter = isouter + self.full = full diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index f2b7d15d..3d03a626 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -3,4 +3,4 @@ # # Update this for the versions # Don't change the forth version number from None -VERSION = (1, 4, 7, None) +VERSION = "1.7.4" diff --git a/tests/__snapshots__/test_compile_dynamic_table.ambr b/tests/__snapshots__/test_compile_dynamic_table.ambr new file mode 100644 index 00000000..81c7f90f --- /dev/null +++ b/tests/__snapshots__/test_compile_dynamic_table.ambr @@ -0,0 +1,13 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm + "CREATE DYNAMIC TABLE test_dynamic_table_orm (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm_with_str_keys + "CREATE DYNAMIC TABLE test_dynamic_table_orm_2 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_selectable + "CREATE DYNAMIC TABLE dynamic_test_table_1 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT test_table_1.id, test_table_1.name FROM test_table_1 WHERE test_table_1.id = 23" +# --- diff --git a/tests/__snapshots__/test_core.ambr b/tests/__snapshots__/test_core.ambr new file mode 100644 index 00000000..7a4e0f99 --- /dev/null +++ b/tests/__snapshots__/test_core.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_compile_table_with_cluster_by_with_expression + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY ("Id")) CLUSTER BY ("Id", name, "Id" > 5)' +# --- diff --git a/tests/__snapshots__/test_orm.ambr b/tests/__snapshots__/test_orm.ambr new file mode 100644 index 00000000..2116e9e9 --- /dev/null +++ b/tests/__snapshots__/test_orm.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_orm_one_to_many_relationship_with_hybrid_table + ProgrammingError('(snowflake.connector.errors.ProgrammingError) 200009 (22000): Foreign key constraint "SYS_INDEX_HB_TBL_ADDRESS_FOREIGN_KEY_USER_ID_HB_TBL_USER_ID" was violated.') +# --- diff --git a/tests/__snapshots__/test_reflect_dynamic_table.ambr b/tests/__snapshots__/test_reflect_dynamic_table.ambr new file mode 100644 index 00000000..d4cc22b5 --- /dev/null +++ b/tests/__snapshots__/test_reflect_dynamic_table.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- diff --git a/tests/__snapshots__/test_structured_datatypes.ambr b/tests/__snapshots__/test_structured_datatypes.ambr new file mode 100644 index 00000000..453d26e4 --- /dev/null +++ b/tests/__snapshots__/test_structured_datatypes.ambr @@ -0,0 +1,249 @@ +# serializer version: 1 +# name: test_compile_table_with_cluster_by_with_expression + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL, VARCHAR), \tPRIMARY KEY ("Id"))' +# --- +# name: test_compile_table_with_double_map + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL, MAP(DECIMAL, VARCHAR)), \tPRIMARY KEY ("Id"))' +# --- +# name: test_compile_table_with_sqlalchemy_array + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname ARRAY, \tPRIMARY KEY ("Id"))' +# --- +# name: test_compile_table_with_structured_data_type[structured_type0] + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tPRIMARY KEY ("Id"))' +# --- +# name: test_compile_table_with_structured_data_type[structured_type1] + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tPRIMARY KEY ("Id"))' +# --- +# name: test_compile_table_with_structured_data_type[structured_type2] + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tPRIMARY KEY ("Id"))' +# --- +# name: test_compile_table_with_structured_data_type[structured_type3] + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname ARRAY(MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tPRIMARY KEY ("Id"))' +# --- +# name: test_insert_array + list([ + (1, '[\n "item1",\n "item2"\n]'), + ]) +# --- +# name: test_insert_array_orm + ''' + 002014 (22000): SQL compilation error: + Invalid expression [CAST(ARRAY_CONSTRUCT('item1', 'item2') AS ARRAY(VARCHAR(16777216)))] in VALUES clause + ''' +# --- +# name: test_insert_map + list([ + (1, '{\n "100": "item1",\n "200": "item2"\n}'), + ]) +# --- +# name: test_insert_map_orm + ''' + 002014 (22000): SQL compilation error: + Invalid expression [CAST(OBJECT_CONSTRUCT('100', 'item1', '200', 'item2') AS MAP(NUMBER(10,0), VARCHAR(16777216)))] in VALUES clause + ''' +# --- +# name: test_insert_structured_object + list([ + (1, '{\n "key1": "item1",\n "key2": 15\n}'), + ]) +# --- +# name: test_insert_structured_object_orm + ''' + 002014 (22000): SQL compilation error: + Invalid expression [CAST(OBJECT_CONSTRUCT('key1', 1, 'key2', 'item1') AS OBJECT(key1 NUMBER(10,0), key2 VARCHAR(16777216)))] in VALUES clause + ''' +# --- +# name: test_inspect_structured_data_types[structured_type0-MAP] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'structured_type_col', + 'nullable': True, + 'primary_key': False, + 'type': MAP(_CUSTOM_DECIMAL(precision=10, scale=0), VARCHAR(length=16777216)), + }), + ]) +# --- +# name: test_inspect_structured_data_types[structured_type0] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'map_id', + 'nullable': True, + 'primary_key': False, + 'type': MAP(_CUSTOM_DECIMAL(precision=10, scale=0), VARCHAR(length=16777216)), + }), + ]) +# --- +# name: test_inspect_structured_data_types[structured_type1-MAP] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'structured_type_col', + 'nullable': True, + 'primary_key': False, + 'type': MAP(_CUSTOM_DECIMAL(precision=10, scale=0), MAP(_CUSTOM_DECIMAL(precision=10, scale=0), VARCHAR(length=16777216))), + }), + ]) +# --- +# name: test_inspect_structured_data_types[structured_type1] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'map_id', + 'nullable': True, + 'primary_key': False, + 'type': MAP(_CUSTOM_DECIMAL(precision=10, scale=0), MAP(_CUSTOM_DECIMAL(precision=10, scale=0), VARCHAR(length=16777216))), + }), + ]) +# --- +# name: test_inspect_structured_data_types[structured_type2-OBJECT] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'structured_type_col', + 'nullable': True, + 'primary_key': False, + 'type': OBJECT(key1=(VARCHAR(length=16777216), False), key2=(_CUSTOM_DECIMAL(precision=10, scale=0), False)), + }), + ]) +# --- +# name: test_inspect_structured_data_types[structured_type3-ARRAY] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'structured_type_col', + 'nullable': True, + 'primary_key': False, + 'type': ARRAY(value_type=VARCHAR(length=16777216)), + }), + ]) +# --- +# name: test_reflect_structured_data_types[ARRAY(MAP(NUMBER(10, 0), VARCHAR))] + "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col ARRAY(MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" +# --- +# name: test_reflect_structured_data_types[MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))] + "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" +# --- +# name: test_reflect_structured_data_types[MAP(NUMBER(10, 0), VARCHAR)] + "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col MAP(DECIMAL(10, 0), VARCHAR(16777216)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" +# --- +# name: test_reflect_structured_data_types[OBJECT(key1 VARCHAR, key2 NUMBER(10, 0))] + "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tstructured_type_col OBJECT(key1 VARCHAR(16777216), key2 DECIMAL(10, 0)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" +# --- +# name: test_select_array_orm + list([ + (1, '[\n "item3",\n "item4"\n]'), + (2, '[\n "item1",\n "item2"\n]'), + ]) +# --- +# name: test_select_map_orm + list([ + (1, '{\n "100": "item1",\n "200": "item2"\n}'), + (2, '{\n "100": "item1",\n "200": "item2"\n}'), + ]) +# --- +# name: test_select_map_orm.1 + list([ + ]) +# --- +# name: test_select_map_orm.2 + list([ + ]) +# --- +# name: test_select_structured_object_orm + list([ + (1, '{\n "key1": "value2",\n "key2": 2\n}'), + (2, '{\n "key1": "value1",\n "key2": 1\n}'), + ]) +# --- diff --git a/tests/__snapshots__/test_unit_structured_types.ambr b/tests/__snapshots__/test_unit_structured_types.ambr new file mode 100644 index 00000000..ff861351 --- /dev/null +++ b/tests/__snapshots__/test_unit_structured_types.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_compile_map_with_not_null + 'MAP(DECIMAL(10, 0), VARCHAR NOT NULL)' +# --- diff --git a/tests/conftest.py b/tests/conftest.py index e22e4d42..df45bb2a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,16 +1,18 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +from __future__ import annotations +import logging.handlers import os import sys import time import uuid -from functools import partial from logging import getLogger import pytest from sqlalchemy import create_engine +from sqlalchemy.pool import NullPool import snowflake.connector import snowflake.connector.connection @@ -44,19 +46,27 @@ TEST_SCHEMA = f"sqlalchemy_tests_{str(uuid.uuid4()).replace('-', '_')}" -create_engine_with_future_flag = create_engine - def pytest_addoption(parser): parser.addoption( - "--run_v20_sqlalchemy", - help="Use only 2.0 SQLAlchemy APIs, any legacy features (< 2.0) will not be supported." - "Turning on this option will set future flag to True on Engine and Session objects according to" - "the migration guide: https://docs.sqlalchemy.org/en/14/changelog/migration_20.html", + "--ignore_v20_test", action="store_true", + default=False, + help="skip sqlalchemy 2.0 exclusive tests", ) +def pytest_collection_modifyitems(config, items): + if config.getoption("--ignore_v20_test"): + # --ignore_v20_test given in cli: skip sqlalchemy 2.0 tests + skip_feature_v2 = pytest.mark.skip( + reason="need remove --ignore_v20_test option to run" + ) + for item in items: + if "feature_v20" in item.keywords: + item.add_marker(skip_feature_v2) + + @pytest.fixture(scope="session") def on_travis(): return os.getenv("TRAVIS", "").lower() == "true" @@ -102,10 +112,40 @@ def help(): @pytest.fixture(scope="session") def db_parameters(): - return get_db_parameters() + yield get_db_parameters() + + +@pytest.fixture(scope="session") +def external_volume(): + db_parameters = get_db_parameters() + if "external_volume" in db_parameters: + yield db_parameters["external_volume"] + else: + raise ValueError("External_volume is not set") -def get_db_parameters(): +@pytest.fixture(scope="session") +def external_stage(): + db_parameters = get_db_parameters() + if "external_stage" in db_parameters: + yield db_parameters["external_stage"] + else: + raise ValueError("External_stage is not set") + + +@pytest.fixture(scope="function") +def base_location(external_stage, engine_testaccount): + unique_id = str(uuid.uuid4()) + base_location = "L" + unique_id.replace("-", "_") + yield base_location + remove_base_location = f""" + REMOVE @{external_stage} pattern='.*{base_location}.*'; + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(remove_base_location) + + +def get_db_parameters() -> dict: """ Sets the db connection parameters """ @@ -113,12 +153,9 @@ def get_db_parameters(): os.environ["TZ"] = "UTC" if not IS_WINDOWS: time.tzset() - for k, v in CONNECTION_PARAMETERS.items(): - ret[k] = v - for k, v in DEFAULT_PARAMETERS.items(): - if k not in ret: - ret[k] = v + ret.update(DEFAULT_PARAMETERS) + ret.update(CONNECTION_PARAMETERS) if "account" in ret and ret["account"] == DEFAULT_PARAMETERS["account"]: help() @@ -153,43 +190,82 @@ def get_db_parameters(): return ret -def get_engine(user=None, password=None, account=None, schema=None): - """ - Creates a connection using the parameters defined in JDBC connect string - """ - ret = get_db_parameters() - - if user is not None: - ret["user"] = user - if password is not None: - ret["password"] = password - if account is not None: - ret["account"] = account - - from sqlalchemy.pool import NullPool - - engine = create_engine_with_future_flag( - URL( - user=ret["user"], - password=ret["password"], - host=ret["host"], - port=ret["port"], - database=ret["database"], - schema=TEST_SCHEMA if not schema else schema, - account=ret["account"], - protocol=ret["protocol"], - ), - poolclass=NullPool, - ) +def url_factory(**kwargs) -> URL: + url_params = get_db_parameters() + url_params.update(kwargs) + return URL(**url_params) + + +def get_engine(url: URL, **engine_kwargs): + engine_params = { + "poolclass": NullPool, + "future": True, + "echo": True, + } + engine_params.update(engine_kwargs) - return engine, ret + connect_args = engine_params.get("connect_args", {}).copy() + connect_args["disable_ocsp_checks"] = True + connect_args["insecure_mode"] = True + engine_params["connect_args"] = connect_args + + engine = create_engine(url, **engine_params) + return engine @pytest.fixture() def engine_testaccount(request): - engine, _ = get_engine() + url = url_factory() + engine = get_engine(url) request.addfinalizer(engine.dispose) - return engine + yield engine + + +@pytest.fixture() +def assert_text_in_buf(): + buf = logging.handlers.BufferingHandler(100) + for log in [ + logging.getLogger("sqlalchemy.engine"), + ]: + log.addHandler(buf) + + def go(expected, occurrences=1): + assert buf.buffer + buflines = [rec.getMessage() for rec in buf.buffer] + + ocurrences_found = buflines.count(expected) + assert occurrences == ocurrences_found, ( + f"Expected {occurrences} of {expected}, got {ocurrences_found} " + f"occurrences in {buflines}." + ) + buf.flush() + + yield go + for log in [ + logging.getLogger("sqlalchemy.engine"), + ]: + log.removeHandler(buf) + + +@pytest.fixture() +def engine_testaccount_with_numpy(request): + url = url_factory(numpy=True) + engine = get_engine(url) + request.addfinalizer(engine.dispose) + yield engine + + +@pytest.fixture() +def engine_testaccount_with_qmark(request): + snowflake.connector.paramstyle = "qmark" + + url = url_factory() + engine = get_engine(url) + request.addfinalizer(engine.dispose) + + yield engine + + snowflake.connector.paramstyle = "pyformat" @pytest.fixture(scope="session", autouse=True) @@ -232,19 +308,6 @@ def sql_compiler(): ).replace("\n", "") -@pytest.fixture(scope="session") -def run_v20_sqlalchemy(pytestconfig): - return pytestconfig.option.run_v20_sqlalchemy - - -def pytest_sessionstart(session): - # patch the create_engine with future flag - global create_engine_with_future_flag - create_engine_with_future_flag = partial( - create_engine, future=session.config.option.run_v20_sqlalchemy - ) - - def running_on_public_ci() -> bool: """Whether or not tests are currently running on one of our public CIs.""" return os.getenv("GITHUB_ACTIONS") == "true" diff --git a/tests/connector_regression b/tests/connector_regression deleted file mode 160000 index ec95c563..00000000 --- a/tests/connector_regression +++ /dev/null @@ -1 +0,0 @@ -Subproject commit ec95c563ded4694f69e8bde4eb2f010f92681e58 diff --git a/tests/custom_tables/__init__.py b/tests/custom_tables/__init__.py new file mode 100644 index 00000000..d43f066c --- /dev/null +++ b/tests/custom_tables/__init__.py @@ -0,0 +1,2 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. diff --git a/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr b/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr new file mode 100644 index 00000000..66c8f98e --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr @@ -0,0 +1,40 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm + "CREATE DYNAMIC TABLE test_dynamic_table_orm (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm_with_str_keys + 'CREATE DYNAMIC TABLE "SCHEMA_DB".test_dynamic_table_orm_2 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = \'10 seconds\'\tAS SELECT * FROM table' +# --- +# name: test_compile_dynamic_table_with_multiple_wrong_option_types + ''' + Invalid parameter type 'IdentifierOption' provided for 'refresh_mode'. Expected one of the following types: 'KeywordOption', 'SnowflakeKeyword'. + Invalid parameter type 'IdentifierOption' provided for 'target_lag'. Expected one of the following types: 'TargetLagOption', 'Tuple[int, TimeUnit])', 'SnowflakeKeyword'. + Invalid parameter type 'KeywordOption' provided for 'as_query'. Expected one of the following types: 'AsQueryOption', 'str', 'Selectable'. + Invalid parameter type 'KeywordOption' provided for 'warehouse'. Expected one of the following types: 'IdentifierOption', 'str'. + + ''' +# --- +# name: test_compile_dynamic_table_with_one_wrong_option_types + ''' + Invalid parameter type 'LiteralOption' provided for 'warehouse'. Expected one of the following types: 'IdentifierOption', 'str'. + + ''' +# --- +# name: test_compile_dynamic_table_with_options_objects + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = AUTO\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_refresh_mode[SnowflakeKeyword.AUTO] + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = AUTO\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_refresh_mode[SnowflakeKeyword.FULL] + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = FULL\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_refresh_mode[SnowflakeKeyword.INCREMENTAL] + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = INCREMENTAL\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_selectable + "CREATE DYNAMIC TABLE dynamic_test_table_1 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT test_table_1.id, test_table_1.name FROM test_table_1 WHERE test_table_1.id = 23" +# --- diff --git a/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr b/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr new file mode 100644 index 00000000..2622399c --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr @@ -0,0 +1,10 @@ +# serializer version: 1 +# name: test_compile_hybrid_table + 'CREATE HYBRID TABLE test_hybrid_table (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tgeom GEOMETRY, \tPRIMARY KEY (id))' +# --- +# name: test_compile_hybrid_table_orm + 'CREATE HYBRID TABLE test_hybrid_table_orm (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY (id))' +# --- +# name: test_compile_hybrid_table_with_array + 'CREATE HYBRID TABLE test_hybrid_table (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tgeom GEOMETRY, \tarray ARRAY, \tPRIMARY KEY (id))' +# --- diff --git a/tests/custom_tables/__snapshots__/test_compile_iceberg_table.ambr b/tests/custom_tables/__snapshots__/test_compile_iceberg_table.ambr new file mode 100644 index 00000000..b243cc09 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_iceberg_table.ambr @@ -0,0 +1,19 @@ +# serializer version: 1 +# name: test_compile_dynamic_table_orm_with_as_query + "CREATE ICEBERG TABLE test_iceberg_table_orm_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY (id))\tEXTERNAL_VOLUME = 'my_external_volume'\tCATALOG = 'SNOWFLAKE'\tBASE_LOCATION = 'my_iceberg_table'\tAS SELECT * FROM table" +# --- +# name: test_compile_icberg_table_with_primary_key + "CREATE ICEBERG TABLE test_iceberg_table_with_options (\tid INTEGER NOT NULL AUTOINCREMENT, \tgeom VARCHAR, \tPRIMARY KEY (id))\tEXTERNAL_VOLUME = 'my_external_volume'\tCATALOG = 'SNOWFLAKE'\tBASE_LOCATION = 'my_iceberg_table'" +# --- +# name: test_compile_iceberg_table + "CREATE ICEBERG TABLE test_iceberg_table (\tid INTEGER, \tgeom VARCHAR)\tEXTERNAL_VOLUME = 'my_external_volume'\tCATALOG = 'SNOWFLAKE'\tBASE_LOCATION = 'my_iceberg_table'" +# --- +# name: test_compile_iceberg_table_with_one_wrong_option_types + ''' + Invalid parameter type 'IdentifierOption' provided for 'external_volume'. Expected one of the following types: 'LiteralOption', 'str', 'int'. + + ''' +# --- +# name: test_compile_iceberg_table_with_options_objects + "CREATE ICEBERG TABLE test_iceberg_table_with_options (\tid INTEGER, \tgeom VARCHAR)\tEXTERNAL_VOLUME = 'my_external_volume'\tCATALOG = 'SNOWFLAKE'\tBASE_LOCATION = 'my_iceberg_table'" +# --- diff --git a/tests/custom_tables/__snapshots__/test_compile_snowflake_table.ambr b/tests/custom_tables/__snapshots__/test_compile_snowflake_table.ambr new file mode 100644 index 00000000..5ea64c12 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_snowflake_table.ambr @@ -0,0 +1,35 @@ +# serializer version: 1 +# name: test_compile_dynamic_table_orm_with_str_keys + 'CREATE TABLE "SCHEMA_DB".test_snowflake_table_orm_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_dynamic_table_with_foreign_key + 'CREATE TABLE test_table_2 (\tid INTEGER NOT NULL, \tgeom VARCHAR, \tPRIMARY KEY (id), \tFOREIGN KEY(id) REFERENCES "table" (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_dynamic_table_with_primary_key + 'CREATE TABLE test_table_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tgeom VARCHAR, \tPRIMARY KEY (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table + 'CREATE TABLE test_table_1 (\tid INTEGER, \tgeom VARCHAR)\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_orm_with_str_keys + 'CREATE TABLE "SCHEMA_DB".test_snowflake_table_orm_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_with_explicit_options + 'CREATE TABLE test_table_2 (\tid INTEGER, \tgeom VARCHAR)\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_with_foreign_key + 'CREATE TABLE test_table_2 (\tid INTEGER NOT NULL, \tgeom VARCHAR, \tPRIMARY KEY (id), \tFOREIGN KEY(id) REFERENCES "table" (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_with_primary_key + 'CREATE TABLE test_table_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tgeom VARCHAR, \tPRIMARY KEY (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_with_selectable + 'CREATE TABLE snowflake_test_table_1 (\tid INTEGER, \tgeom VARCHAR)\tAS SELECT test_table_1.id, test_table_1.geom FROM test_table_1 WHERE test_table_1.id = 23' +# --- +# name: test_compile_snowflake_table_with_wrong_option_types + ''' + Invalid parameter type 'AsQueryOption' provided for 'cluster by'. Expected one of the following types: 'ClusterByOption', 'list'. + Invalid parameter type 'ClusterByOption' provided for 'as_query'. Expected one of the following types: 'AsQueryOption', 'str', 'Selectable'. + + ''' +# --- diff --git a/tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr b/tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr new file mode 100644 index 00000000..80201495 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_create_dynamic_table_without_dynamictable_and_defined_options + CustomOptionsAreOnlySupportedOnSnowflakeTables('Identifier, Literal, TargetLag and other custom options are only supported on Snowflake tables.') +# --- +# name: test_create_dynamic_table_without_dynamictable_class + UnexpectedOptionTypeError('The following options are either unsupported or should be defined using a Snowflake table: as_query, warehouse.') +# --- diff --git a/tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr b/tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr new file mode 100644 index 00000000..696ff9c8 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_create_hybrid_table + "[(1, 'test')]" +# --- +# name: test_create_hybrid_table_with_multiple_index + ProgrammingError("(snowflake.connector.errors.ProgrammingError) 391480 (0A000): Another index is being built on table 'TEST_HYBRID_TABLE_WITH_MULTIPLE_INDEX'. Only one index can be built at a time. Either cancel the other index creation or wait until it is complete.") +# --- diff --git a/tests/custom_tables/__snapshots__/test_create_iceberg_table.ambr b/tests/custom_tables/__snapshots__/test_create_iceberg_table.ambr new file mode 100644 index 00000000..908a4c60 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_iceberg_table.ambr @@ -0,0 +1,14 @@ +# serializer version: 1 +# name: test_create_iceberg_table + ''' + (snowflake.connector.errors.ProgrammingError) 091017 (22000): S3 bucket 'my_example_bucket' does not exist or not authorized. + [SQL: + CREATE ICEBERG TABLE "Iceberg_Table_1" ( + id INTEGER NOT NULL AUTOINCREMENT, + geom VARCHAR, + PRIMARY KEY (id) + ) EXTERNAL_VOLUME = 'exvol' CATALOG = 'SNOWFLAKE' BASE_LOCATION = 'my_iceberg_table' + + ] + ''' +# --- diff --git a/tests/custom_tables/__snapshots__/test_create_snowflake_table.ambr b/tests/custom_tables/__snapshots__/test_create_snowflake_table.ambr new file mode 100644 index 00000000..98d3137f --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_snowflake_table.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_create_snowflake_table_with_cluster_by + "[(1, 'test')]" +# --- diff --git a/tests/custom_tables/__snapshots__/test_generic_options.ambr b/tests/custom_tables/__snapshots__/test_generic_options.ambr new file mode 100644 index 00000000..eef5e6fd --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_generic_options.ambr @@ -0,0 +1,13 @@ +# serializer version: 1 +# name: test_identifier_option_with_wrong_type + InvalidTableParameterTypeError("Invalid parameter type 'int' provided for 'warehouse'. Expected one of the following types: 'IdentifierOption', 'str'.\n") +# --- +# name: test_identifier_option_without_name + OptionKeyNotProvidedError('Expected option key in IdentifierOption option but got NoneType instead.') +# --- +# name: test_invalid_as_query_option + InvalidTableParameterTypeError("Invalid parameter type 'int' provided for 'as_query'. Expected one of the following types: 'AsQueryOption', 'str', 'Selectable'.\n") +# --- +# name: test_literal_option_with_wrong_type + InvalidTableParameterTypeError("Invalid parameter type 'SnowflakeKeyword' provided for 'warehouse'. Expected one of the following types: 'LiteralOption', 'str', 'int'.\n") +# --- diff --git a/tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr b/tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr new file mode 100644 index 00000000..6f6cd395 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_simple_reflection_hybrid_table_as_table + 'CREATE TABLE test_hybrid_table_reflection (\tid DECIMAL(38, 0) NOT NULL, \tname VARCHAR(16777216), \tCONSTRAINT demo_name PRIMARY KEY (id))' +# --- diff --git a/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr b/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr new file mode 100644 index 00000000..e9a4ac83 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr @@ -0,0 +1,32 @@ +# serializer version: 1 +# name: test_inspect_snowflake_table + list([ + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=38, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'name', + 'nullable': True, + 'primary_key': False, + 'type': VARCHAR(length=16777216), + }), + ]) +# --- +# name: test_reflection_of_table_with_object_data_type + 'CREATE TABLE test_snowflake_table_reflection (\tid DECIMAL(38, 0) NOT NULL, \tname OBJECT, \tCONSTRAINT demo_name PRIMARY KEY (id))' +# --- +# name: test_simple_reflection_of_table_as_snowflake_table + 'CREATE TABLE test_snowflake_table_reflection (\tid DECIMAL(38, 0) NOT NULL, \tname VARCHAR(16777216), \tCONSTRAINT demo_name PRIMARY KEY (id))' +# --- +# name: test_simple_reflection_of_table_as_sqlalchemy_table + 'CREATE TABLE test_snowflake_table_reflection (\tid DECIMAL(38, 0) NOT NULL, \tname VARCHAR(16777216), \tCONSTRAINT demo_name PRIMARY KEY (id))' +# --- diff --git a/tests/custom_tables/test_compile_dynamic_table.py b/tests/custom_tables/test_compile_dynamic_table.py new file mode 100644 index 00000000..935c61cd --- /dev/null +++ b/tests/custom_tables/test_compile_dynamic_table.py @@ -0,0 +1,271 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import ( + Column, + ForeignKeyConstraint, + Integer, + MetaData, + String, + Table, + exc, + select, +) +from sqlalchemy.exc import ArgumentError +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import GEOMETRY, DynamicTable +from snowflake.sqlalchemy.exc import MultipleErrors +from snowflake.sqlalchemy.sql.custom_schema.options import ( + AsQueryOption, + IdentifierOption, + KeywordOption, + LiteralOption, + TargetLagOption, + TimeUnit, +) +from snowflake.sqlalchemy.sql.custom_schema.options.keywords import SnowflakeKeyword + + +def test_compile_dynamic_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + test_geometry = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +@pytest.mark.parametrize( + "refresh_mode_keyword", + [ + SnowflakeKeyword.AUTO, + SnowflakeKeyword.FULL, + SnowflakeKeyword.INCREMENTAL, + ], +) +def test_compile_dynamic_table_with_refresh_mode( + sql_compiler, snapshot, refresh_mode_keyword +): + metadata = MetaData() + table_name = "test_dynamic_table" + test_geometry = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", + refresh_mode=refresh_mode_keyword, + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_with_options_objects(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + test_geometry = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=TargetLagOption(10, TimeUnit.SECONDS), + warehouse=IdentifierOption("warehouse"), + as_query=AsQueryOption("SELECT * FROM table"), + refresh_mode=KeywordOption(SnowflakeKeyword.AUTO), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_with_one_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + with pytest.raises(ArgumentError) as argument_error: + DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=TargetLagOption(10, TimeUnit.SECONDS), + warehouse=LiteralOption("warehouse"), + as_query=AsQueryOption("SELECT * FROM table"), + refresh_mode=KeywordOption(SnowflakeKeyword.AUTO), + ) + + assert str(argument_error.value) == snapshot + + +def test_compile_dynamic_table_with_multiple_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + with pytest.raises(MultipleErrors) as argument_error: + DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=IdentifierOption(SnowflakeKeyword.AUTO), + warehouse=KeywordOption(SnowflakeKeyword.AUTO), + as_query=KeywordOption(SnowflakeKeyword.AUTO), + refresh_mode=IdentifierOption(SnowflakeKeyword.AUTO), + ) + + assert str(argument_error.value) == snapshot + + +def test_compile_dynamic_table_without_required_args(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="DynamicTable requires the following parameters: warehouse, " + "as_query, target_lag.", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer, primary_key=True), + Column("geom", GEOMETRY), + ) + + +def test_compile_dynamic_table_with_primary_key(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="Primary key and foreign keys are not supported in DynamicTable.", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer, primary_key=True), + Column("geom", GEOMETRY), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", + ) + + +def test_compile_dynamic_table_with_foreign_key(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="Primary key and foreign keys are not supported in DynamicTable.", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer), + Column("geom", GEOMETRY), + ForeignKeyConstraint(["id"], ["table.id"]), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", + ) + + +def test_compile_dynamic_table_orm(sql_compiler, snapshot): + Base = declarative_base() + metadata = MetaData() + table_name = "test_dynamic_table_orm" + test_dynamic_table_orm = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("name", String), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", + ) + + class TestDynamicTableOrm(Base): + __table__ = test_dynamic_table_orm + __mapper_args__ = { + "primary_key": [test_dynamic_table_orm.c.id, test_dynamic_table_orm.c.name] + } + + def __repr__(self): + return f"" + + value = CreateTable(TestDynamicTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_orm_with_str_keys(sql_compiler, snapshot): + Base = declarative_base() + + class TestDynamicTableOrm(Base): + __tablename__ = "test_dynamic_table_orm_2" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return DynamicTable(name, metadata, *arg, **kw) + + __table_args__ = { + "schema": "SCHEMA_DB", + "target_lag": (10, TimeUnit.SECONDS), + "warehouse": "warehouse", + "as_query": "SELECT * FROM table", + } + + id = Column(Integer) + name = Column(String) + + __mapper_args__ = {"primary_key": [id, name]} + + def __repr__(self): + return f"" + + value = CreateTable(TestDynamicTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_with_selectable(sql_compiler, snapshot): + Base = declarative_base() + + test_table_1 = Table( + "test_table_1", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + + dynamic_test_table = DynamicTable( + "dynamic_test_table_1", + Base.metadata, + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query=select(test_table_1).where(test_table_1.c.id == 23), + ) + + value = CreateTable(dynamic_test_table) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/custom_tables/test_compile_hybrid_table.py b/tests/custom_tables/test_compile_hybrid_table.py new file mode 100644 index 00000000..7310e21c --- /dev/null +++ b/tests/custom_tables/test_compile_hybrid_table.py @@ -0,0 +1,69 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from sqlalchemy import Column, Integer, MetaData, String +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import ARRAY, GEOMETRY, HybridTable + + +def test_compile_hybrid_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_hybrid_table" + test_geometry = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + Column("geom", GEOMETRY), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_hybrid_table_with_array(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_hybrid_table" + test_geometry = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + Column("geom", GEOMETRY), + Column("array", ARRAY), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_hybrid_table_orm(sql_compiler, snapshot): + Base = declarative_base() + + class TestHybridTableOrm(Base): + __tablename__ = "test_hybrid_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"" + + value = CreateTable(TestHybridTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/custom_tables/test_compile_iceberg_table.py b/tests/custom_tables/test_compile_iceberg_table.py new file mode 100644 index 00000000..173e7b0a --- /dev/null +++ b/tests/custom_tables/test_compile_iceberg_table.py @@ -0,0 +1,116 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Column, Integer, MetaData, String +from sqlalchemy.exc import ArgumentError +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import IcebergTable +from snowflake.sqlalchemy.sql.custom_schema.options import ( + IdentifierOption, + LiteralOption, +) + + +def test_compile_iceberg_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_iceberg_table" + test_table = IcebergTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + external_volume="my_external_volume", + base_location="my_iceberg_table", + ) + + value = CreateTable(test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_iceberg_table_with_options_objects(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_iceberg_table_with_options" + test_table = IcebergTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + external_volume=LiteralOption("my_external_volume"), + base_location=LiteralOption("my_iceberg_table"), + ) + + value = CreateTable(test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_iceberg_table_with_one_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_wrong_iceberg_table" + with pytest.raises(ArgumentError) as argument_error: + IcebergTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + external_volume=IdentifierOption("my_external_volume"), + base_location=LiteralOption("my_iceberg_table"), + ) + + assert str(argument_error.value) == snapshot + + +def test_compile_icberg_table_with_primary_key(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_iceberg_table_with_options" + test_table = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + external_volume=LiteralOption("my_external_volume"), + base_location=LiteralOption("my_iceberg_table"), + ) + + value = CreateTable(test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_orm_with_as_query(sql_compiler, snapshot): + Base = declarative_base() + + class TestDynamicTableOrm(Base): + __tablename__ = "test_iceberg_table_orm_2" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return IcebergTable(name, metadata, *arg, **kw) + + __table_args__ = { + "external_volume": "my_external_volume", + "base_location": "my_iceberg_table", + "as_query": "SELECT * FROM table", + } + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"" + + value = CreateTable(TestDynamicTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/custom_tables/test_compile_snowflake_table.py b/tests/custom_tables/test_compile_snowflake_table.py new file mode 100644 index 00000000..be9383eb --- /dev/null +++ b/tests/custom_tables/test_compile_snowflake_table.py @@ -0,0 +1,180 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import ( + Column, + ForeignKeyConstraint, + Integer, + MetaData, + String, + select, + text, +) +from sqlalchemy.exc import ArgumentError +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import SnowflakeTable +from snowflake.sqlalchemy.sql.custom_schema.options import ( + AsQueryOption, + ClusterByOption, +) + + +def test_compile_snowflake_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_table_1" + test_geometry = SnowflakeTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + cluster_by=["id", text("id > 100")], + as_query="SELECT * FROM table", + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_with_explicit_options(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_table_2" + test_geometry = SnowflakeTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + cluster_by=ClusterByOption("id", text("id > 100")), + as_query=AsQueryOption("SELECT * FROM table"), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_with_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_snowflake_table" + with pytest.raises(ArgumentError) as argument_error: + SnowflakeTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + as_query=ClusterByOption("id", text("id > 100")), + cluster_by=AsQueryOption("SELECT * FROM table"), + ) + + assert str(argument_error.value) == snapshot + + +def test_compile_snowflake_table_with_primary_key(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_table_2" + test_geometry = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + cluster_by=ClusterByOption("id", text("id > 100")), + as_query=AsQueryOption("SELECT * FROM table"), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_with_foreign_key(sql_compiler, snapshot): + metadata = MetaData() + + SnowflakeTable( + "table", + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + ForeignKeyConstraint(["id"], ["table.id"]), + cluster_by=ClusterByOption("id", text("id > 100")), + as_query=AsQueryOption("SELECT * FROM table"), + ) + + table_name = "test_table_2" + test_geometry = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + ForeignKeyConstraint(["id"], ["table.id"]), + cluster_by=ClusterByOption("id", text("id > 100")), + as_query=AsQueryOption("SELECT * FROM table"), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_orm_with_str_keys(sql_compiler, snapshot): + Base = declarative_base() + + class TestSnowflakeTableOrm(Base): + __tablename__ = "test_snowflake_table_orm_2" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return SnowflakeTable(name, metadata, *arg, **kw) + + __table_args__ = { + "schema": "SCHEMA_DB", + "cluster_by": ["id", text("id > 100")], + "as_query": "SELECT * FROM table", + } + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"" + + value = CreateTable(TestSnowflakeTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_with_selectable(sql_compiler, snapshot): + Base = declarative_base() + + test_table_1 = SnowflakeTable( + "test_table_1", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + ForeignKeyConstraint(["id"], ["table.id"]), + cluster_by=ClusterByOption("id", text("id > 100")), + ) + + test_table_2 = SnowflakeTable( + "snowflake_test_table_1", + Base.metadata, + as_query=select(test_table_1).where(test_table_1.c.id == 23), + ) + + value = CreateTable(test_table_2) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/custom_tables/test_create_dynamic_table.py b/tests/custom_tables/test_create_dynamic_table.py new file mode 100644 index 00000000..b583faad --- /dev/null +++ b/tests/custom_tables/test_create_dynamic_table.py @@ -0,0 +1,124 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Column, Integer, MetaData, String, Table, select + +from snowflake.sqlalchemy import DynamicTable, exc +from snowflake.sqlalchemy.sql.custom_schema.options.as_query_option import AsQueryOption +from snowflake.sqlalchemy.sql.custom_schema.options.identifier_option import ( + IdentifierOption, +) +from snowflake.sqlalchemy.sql.custom_schema.options.keywords import SnowflakeKeyword +from snowflake.sqlalchemy.sql.custom_schema.options.table_option import TableOptionKey +from snowflake.sqlalchemy.sql.custom_schema.options.target_lag_option import ( + TargetLagOption, + TimeUnit, +) + + +def test_create_dynamic_table(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + + dynamic_test_table_1 = DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + target_lag=(1, TimeUnit.HOURS), + warehouse=warehouse, + as_query="SELECT id, name from test_table_1;", + refresh_mode=SnowflakeKeyword.FULL, + ) + + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table_1) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) + + +def test_create_dynamic_table_without_dynamictable_class( + engine_testaccount, db_parameters, snapshot +): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + + Table( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + snowflake_warehouse=warehouse, + snowflake_as_query="SELECT id, name from test_table_1;", + prefixes=["DYNAMIC"], + ) + + with pytest.raises(exc.UnexpectedOptionTypeError) as exc_info: + metadata.create_all(engine_testaccount) + assert exc_info.value == snapshot + + +def test_create_dynamic_table_without_dynamictable_and_defined_options( + engine_testaccount, db_parameters, snapshot +): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + + Table( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + snowflake_target_lag=TargetLagOption.create((1, TimeUnit.HOURS)), + snowflake_warehouse=IdentifierOption.create( + TableOptionKey.WAREHOUSE, warehouse + ), + snowflake_as_query=AsQueryOption.create("SELECT id, name from test_table_1;"), + prefixes=["DYNAMIC"], + ) + + with pytest.raises(exc.CustomOptionsAreOnlySupportedOnSnowflakeTables) as exc_info: + metadata.create_all(engine_testaccount) + assert exc_info.value == snapshot diff --git a/tests/custom_tables/test_create_hybrid_table.py b/tests/custom_tables/test_create_hybrid_table.py new file mode 100644 index 00000000..43ae3ab6 --- /dev/null +++ b/tests/custom_tables/test_create_hybrid_table.py @@ -0,0 +1,95 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +import sqlalchemy.exc +from sqlalchemy import Column, Index, Integer, MetaData, String, select +from sqlalchemy.orm import Session, declarative_base + +from snowflake.sqlalchemy import HybridTable + + +@pytest.mark.aws +def test_create_hybrid_table(engine_testaccount, db_parameters, snapshot): + metadata = MetaData() + table_name = "test_create_hybrid_table" + + dynamic_test_table_1 = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = dynamic_test_table_1.insert().values(id=1, name="test") + conn.execute(ins) + conn.commit() + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table_1) + results_hybrid_table = conn.execute(s).fetchall() + assert str(results_hybrid_table) == snapshot + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_create_hybrid_table_with_multiple_index( + engine_testaccount, db_parameters, snapshot, sql_compiler +): + metadata = MetaData() + table_name = "test_hybrid_table_with_multiple_index" + + hybrid_test_table_1 = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String, index=True), + Column("name2", String), + Column("name3", String), + ) + + metadata.create_all(engine_testaccount) + + index = Index("idx_col34", hybrid_test_table_1.c.name2, hybrid_test_table_1.c.name3) + + with pytest.raises(sqlalchemy.exc.ProgrammingError) as exc_info: + index.create(engine_testaccount) + try: + assert exc_info.value == snapshot + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_create_hybrid_table_with_orm(sql_compiler, engine_testaccount): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestHybridTableOrm(Base): + __tablename__ = "test_hybrid_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + instance = TestHybridTableOrm(id=0, name="name_example") + session.add(instance) + session.commit() + data = session.query(TestHybridTableOrm).all() + assert str(data) == "[(0, 'name_example')]" + finally: + Base.metadata.drop_all(engine_testaccount) diff --git a/tests/custom_tables/test_create_iceberg_table.py b/tests/custom_tables/test_create_iceberg_table.py new file mode 100644 index 00000000..5ce75909 --- /dev/null +++ b/tests/custom_tables/test_create_iceberg_table.py @@ -0,0 +1,46 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Column, Integer, MetaData, String +from sqlalchemy.exc import ProgrammingError + +from snowflake.sqlalchemy import IcebergTable + + +@pytest.mark.aws +def test_create_iceberg_table(engine_testaccount): + metadata = MetaData() + external_volume_name = "exvol" + create_external_volume = f""" + CREATE OR REPLACE EXTERNAL VOLUME {external_volume_name} + STORAGE_LOCATIONS = + ( + ( + NAME = 'my-s3-us-west-2' + STORAGE_PROVIDER = 'S3' + STORAGE_BASE_URL = 's3://myexamplebucket/' + STORAGE_AWS_ROLE_ARN = 'arn:aws:iam::123456789012:role/myrole' + ENCRYPTION=(TYPE='AWS_SSE_KMS' KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab') + ) + ); + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_external_volume) + IcebergTable( + "Iceberg_Table_1", + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + external_volume=external_volume_name, + base_location="my_iceberg_table", + ) + + with pytest.raises(ProgrammingError) as argument_error: + metadata.create_all(engine_testaccount) + + error_str = str(argument_error.value) + assert ( + "(snowflake.connector.errors.ProgrammingError)" + in error_str[: error_str.rfind("\n")] + ) diff --git a/tests/custom_tables/test_create_snowflake_table.py b/tests/custom_tables/test_create_snowflake_table.py new file mode 100644 index 00000000..09140fb8 --- /dev/null +++ b/tests/custom_tables/test_create_snowflake_table.py @@ -0,0 +1,66 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from sqlalchemy import Column, Integer, MetaData, String, select, text +from sqlalchemy.orm import Session, declarative_base + +from snowflake.sqlalchemy import SnowflakeTable + + +def test_create_snowflake_table_with_cluster_by( + engine_testaccount, db_parameters, snapshot +): + metadata = MetaData() + table_name = "test_create_snowflake_table" + + test_table_1 = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by=["id", text("id > 5")], + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + conn.execute(ins) + conn.commit() + + try: + with engine_testaccount.connect() as conn: + s = select(test_table_1) + results_hybrid_table = conn.execute(s).fetchall() + assert str(results_hybrid_table) == snapshot + finally: + metadata.drop_all(engine_testaccount) + + +def test_create_snowflake_table_with_orm(sql_compiler, engine_testaccount): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestHybridTableOrm(Base): + __tablename__ = "test_snowflake_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return SnowflakeTable(name, metadata, *arg, **kw) + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + instance = TestHybridTableOrm(id=0, name="name_example") + session.add(instance) + session.commit() + data = session.query(TestHybridTableOrm).all() + assert str(data) == "[(0, 'name_example')]" + finally: + Base.metadata.drop_all(engine_testaccount) diff --git a/tests/custom_tables/test_generic_options.py b/tests/custom_tables/test_generic_options.py new file mode 100644 index 00000000..916b94c6 --- /dev/null +++ b/tests/custom_tables/test_generic_options.py @@ -0,0 +1,83 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import pytest + +from snowflake.sqlalchemy import ( + AsQueryOption, + IdentifierOption, + KeywordOption, + LiteralOption, + SnowflakeKeyword, + TableOptionKey, + TargetLagOption, + exc, +) +from snowflake.sqlalchemy.sql.custom_schema.options.invalid_table_option import ( + InvalidTableOption, +) + + +def test_identifier_option(): + identifier = IdentifierOption.create(TableOptionKey.WAREHOUSE, "xsmall") + assert identifier.render_option(None) == "WAREHOUSE = xsmall" + + +def test_literal_option(): + literal = LiteralOption.create(TableOptionKey.WAREHOUSE, "xsmall") + assert literal.render_option(None) == "WAREHOUSE = 'xsmall'" + + +def test_identifier_option_without_name(snapshot): + identifier = IdentifierOption("xsmall") + with pytest.raises(exc.OptionKeyNotProvidedError) as exc_info: + identifier.render_option(None) + assert exc_info.value == snapshot + + +def test_identifier_option_with_wrong_type(snapshot): + identifier = IdentifierOption.create(TableOptionKey.WAREHOUSE, 23) + with pytest.raises(exc.InvalidTableParameterTypeError) as exc_info: + identifier.render_option(None) + assert exc_info.value == snapshot + + +def test_literal_option_with_wrong_type(snapshot): + literal = LiteralOption.create( + TableOptionKey.WAREHOUSE, SnowflakeKeyword.DOWNSTREAM + ) + with pytest.raises(exc.InvalidTableParameterTypeError) as exc_info: + literal.render_option(None) + assert exc_info.value == snapshot + + +def test_invalid_as_query_option(snapshot): + as_query = AsQueryOption.create(23) + with pytest.raises(exc.InvalidTableParameterTypeError) as exc_info: + as_query.render_option(None) + assert exc_info.value == snapshot + + +@pytest.mark.parametrize( + "table_option", + [ + IdentifierOption, + LiteralOption, + KeywordOption, + ], +) +def test_generic_option_with_wrong_type(table_option): + literal = table_option.create(TableOptionKey.WAREHOUSE, 0.32) + assert isinstance(literal, InvalidTableOption), "Expected InvalidTableOption" + + +@pytest.mark.parametrize( + "table_option", + [ + TargetLagOption, + AsQueryOption, + ], +) +def test_non_generic_option_with_wrong_type(table_option): + literal = table_option.create(0.32) + assert isinstance(literal, InvalidTableOption), "Expected InvalidTableOption" diff --git a/tests/custom_tables/test_reflect_dynamic_table.py b/tests/custom_tables/test_reflect_dynamic_table.py new file mode 100644 index 00000000..52eb4457 --- /dev/null +++ b/tests/custom_tables/test_reflect_dynamic_table.py @@ -0,0 +1,88 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from sqlalchemy import Column, Integer, MetaData, String, Table, select + +from snowflake.sqlalchemy import DynamicTable +from snowflake.sqlalchemy.custom_commands import NoneType + + +def test_simple_reflection_dynamic_table_as_table(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + create_table_sql = f""" + CREATE DYNAMIC TABLE dynamic_test_table (id INT, name VARCHAR) + TARGET_LAG = '20 minutes' + WAREHOUSE = {warehouse} + AS SELECT id, name from test_table_1; + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + dynamic_test_table = Table( + "dynamic_test_table", metadata, autoload_with=engine_testaccount + ) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) + + +def test_simple_reflection_without_options_loading(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + create_table_sql = f""" + CREATE DYNAMIC TABLE dynamic_test_table (id INT, name VARCHAR) + TARGET_LAG = '20 minutes' + WAREHOUSE = {warehouse} + AS SELECT id, name from test_table_1; + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + dynamic_test_table = DynamicTable( + "dynamic_test_table", metadata, autoload_with=engine_testaccount + ) + + # TODO: Add support for loading options when table is reflected + assert isinstance(dynamic_test_table.warehouse, NoneType) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/custom_tables/test_reflect_hybrid_table.py b/tests/custom_tables/test_reflect_hybrid_table.py new file mode 100644 index 00000000..4a777bf0 --- /dev/null +++ b/tests/custom_tables/test_reflect_hybrid_table.py @@ -0,0 +1,65 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import MetaData, Table +from sqlalchemy.sql.ddl import CreateTable + + +@pytest.mark.aws +def test_simple_reflection_hybrid_table_as_table( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_hybrid_table_reflection" + + create_table_sql = f""" + CREATE HYBRID TABLE {table_name} (id INT primary key, name VARCHAR, INDEX index_name (name)); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + hybrid_test_table = Table(table_name, metadata, autoload_with=engine_testaccount) + + constraint = hybrid_test_table.constraints.pop() + constraint.name = "demo_name" + hybrid_test_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(hybrid_test_table) + + actual = sql_compiler(value) + + # Prefixes reflection not supported, example: "HYBRID, DYNAMIC" + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_reflect_hybrid_table_with_index( + engine_testaccount, db_parameters, sql_compiler +): + metadata = MetaData() + schema = db_parameters["schema"] + + table_name = "test_hybrid_table_2" + index_name = "INDEX_NAME_2" + + create_table_sql = f""" + CREATE HYBRID TABLE {table_name} (id INT primary key, name VARCHAR, INDEX {index_name} (name)); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + table = Table(table_name, metadata, schema=schema, autoload_with=engine_testaccount) + + try: + assert len(table.indexes) == 1 and table.indexes.pop().name == index_name + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/custom_tables/test_reflect_snowflake_table.py b/tests/custom_tables/test_reflect_snowflake_table.py new file mode 100644 index 00000000..323dd281 --- /dev/null +++ b/tests/custom_tables/test_reflect_snowflake_table.py @@ -0,0 +1,122 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from sqlalchemy import MetaData, Table, inspect +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import SnowflakeTable + + +def test_reflection_of_table_with_object_data_type( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_snowflake_table_reflection" + + create_table_sql = f""" + CREATE TABLE {table_name} (id INT primary key, name OBJECT); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + snowflake_test_table = Table(table_name, metadata, autoload_with=engine_testaccount) + constraint = snowflake_test_table.constraints.pop() + constraint.name = "demo_name" + snowflake_test_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(snowflake_test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) + + +def test_simple_reflection_of_table_as_sqlalchemy_table( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_snowflake_table_reflection" + + create_table_sql = f""" + CREATE TABLE {table_name} (id INT primary key, name VARCHAR); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + snowflake_test_table = Table(table_name, metadata, autoload_with=engine_testaccount) + constraint = snowflake_test_table.constraints.pop() + constraint.name = "demo_name" + snowflake_test_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(snowflake_test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) + + +def test_simple_reflection_of_table_as_snowflake_table( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_snowflake_table_reflection" + + create_table_sql = f""" + CREATE TABLE {table_name} (id INT primary key, name VARCHAR); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + snowflake_test_table = SnowflakeTable( + table_name, metadata, autoload_with=engine_testaccount + ) + constraint = snowflake_test_table.constraints.pop() + constraint.name = "demo_name" + snowflake_test_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(snowflake_test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) + + +def test_inspect_snowflake_table( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_snowflake_table_inspect" + + create_table_sql = f""" + CREATE TABLE {table_name} (id INT primary key, name VARCHAR); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + try: + with engine_testaccount.connect() as conn: + insp = inspect(conn) + table = insp.get_columns(table_name) + assert table == snapshot + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/sqlalchemy_test_suite/conftest.py b/tests/sqlalchemy_test_suite/conftest.py index 31cd7c5c..f0464c7d 100644 --- a/tests/sqlalchemy_test_suite/conftest.py +++ b/tests/sqlalchemy_test_suite/conftest.py @@ -15,6 +15,7 @@ import snowflake.connector from snowflake.sqlalchemy import URL +from snowflake.sqlalchemy.compat import IS_VERSION_20 from ..conftest import get_db_parameters from ..util import random_string @@ -25,6 +26,12 @@ TEST_SCHEMA_2 = f"{TEST_SCHEMA}_2" +if IS_VERSION_20: + collect_ignore_glob = ["test_suite.py"] +else: + collect_ignore_glob = ["test_suite_20.py"] + + # patch sqlalchemy.testing.config.Confi.__init__ for schema name randomization # same schema name would result in conflict as we're running tests in parallel in the CI def config_patched__init__(self, db, db_opts, options, file_config): diff --git a/tests/sqlalchemy_test_suite/test_suite.py b/tests/sqlalchemy_test_suite/test_suite.py index d79e511e..643d1559 100644 --- a/tests/sqlalchemy_test_suite/test_suite.py +++ b/tests/sqlalchemy_test_suite/test_suite.py @@ -69,6 +69,10 @@ def test_empty_insert(self, connection): def test_empty_insert_multiple(self, connection): pass + @pytest.mark.skip("Snowflake does not support returning in insert.") + def test_no_results_for_non_returning_insert(self, connection, style, executemany): + pass + # 2. Patched Tests diff --git a/tests/sqlalchemy_test_suite/test_suite_20.py b/tests/sqlalchemy_test_suite/test_suite_20.py new file mode 100644 index 00000000..1f79c4e9 --- /dev/null +++ b/tests/sqlalchemy_test_suite/test_suite_20.py @@ -0,0 +1,205 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Integer, testing +from sqlalchemy.schema import Column, Sequence, Table +from sqlalchemy.testing import config +from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.testing.suite import ( + BizarroCharacterFKResolutionTest as _BizarroCharacterFKResolutionTest, +) +from sqlalchemy.testing.suite import ( + CompositeKeyReflectionTest as _CompositeKeyReflectionTest, +) +from sqlalchemy.testing.suite import DateTimeHistoricTest as _DateTimeHistoricTest +from sqlalchemy.testing.suite import FetchLimitOffsetTest as _FetchLimitOffsetTest +from sqlalchemy.testing.suite import HasSequenceTest as _HasSequenceTest +from sqlalchemy.testing.suite import InsertBehaviorTest as _InsertBehaviorTest +from sqlalchemy.testing.suite import LikeFunctionsTest as _LikeFunctionsTest +from sqlalchemy.testing.suite import LongNameBlowoutTest as _LongNameBlowoutTest +from sqlalchemy.testing.suite import SimpleUpdateDeleteTest as _SimpleUpdateDeleteTest +from sqlalchemy.testing.suite import TimeMicrosecondsTest as _TimeMicrosecondsTest +from sqlalchemy.testing.suite import TrueDivTest as _TrueDivTest +from sqlalchemy.testing.suite import * # noqa + +# 1. Unsupported by snowflake db + +del ComponentReflectionTest # require indexes not supported by snowflake +del HasIndexTest # require indexes not supported by snowflake +del QuotedNameArgumentTest # require indexes not supported by snowflake + + +class LongNameBlowoutTest(_LongNameBlowoutTest): + # The combination ("ix",) is removed due to Snowflake not supporting indexes + def ix(self, metadata, connection): + pytest.skip("ix required index feature not supported by Snowflake") + + +class FetchLimitOffsetTest(_FetchLimitOffsetTest): + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_bound_offset(self, connection): + pass + + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_simple_limit_expr_offset(self, connection): + pass + + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_simple_offset(self, connection): + pass + + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_simple_offset_zero(self, connection): + pass + + +class InsertBehaviorTest(_InsertBehaviorTest): + @pytest.mark.skip( + "Snowflake does not support inserting empty values, the value may be a literal or an expression." + ) + def test_empty_insert(self, connection): + pass + + @pytest.mark.skip( + "Snowflake does not support inserting empty values, The value may be a literal or an expression." + ) + def test_empty_insert_multiple(self, connection): + pass + + @pytest.mark.skip("Snowflake does not support returning in insert.") + def test_no_results_for_non_returning_insert(self, connection, style, executemany): + pass + + +# road to 2.0 +class TrueDivTest(_TrueDivTest): + @pytest.mark.skip("`//` not supported") + def test_floordiv_integer_bound(self, connection): + """Snowflake does not provide `//` arithmetic operator. + + https://docs.snowflake.com/en/sql-reference/operators-arithmetic. + """ + pass + + @pytest.mark.skip("`//` not supported") + def test_floordiv_integer(self, connection, left, right, expected): + """Snowflake does not provide `//` arithmetic operator. + + https://docs.snowflake.com/en/sql-reference/operators-arithmetic. + """ + pass + + +class TimeMicrosecondsTest(_TimeMicrosecondsTest): + def __init__(self): + super().__init__() + + +class DateTimeHistoricTest(_DateTimeHistoricTest): + def __init__(self): + super().__init__() + + +# 2. Patched Tests + + +class HasSequenceTest(_HasSequenceTest): + # Override the define_tables method as snowflake does not support 'nomaxvalue'/'nominvalue' + @classmethod + def define_tables(cls, metadata): + Sequence("user_id_seq", metadata=metadata) + # Replace Sequence("other_seq") creation as in the original test suite, + # the Sequence created with 'nomaxvalue' and 'nominvalue' + # which snowflake does not support: + # Sequence("other_seq", metadata=metadata, nomaxvalue=True, nominvalue=True) + Sequence("other_seq", metadata=metadata) + if testing.requires.schemas.enabled: + Sequence("user_id_seq", schema=config.test_schema, metadata=metadata) + Sequence("schema_seq", schema=config.test_schema, metadata=metadata) + Table( + "user_id_table", + metadata, + Column("id", Integer, primary_key=True), + ) + + +class LikeFunctionsTest(_LikeFunctionsTest): + @testing.requires.regexp_match + @testing.combinations( + ("a.cde.*", {1, 5, 6, 9}), + ("abc.*", {1, 5, 6, 9, 10}), + ("^abc.*", {1, 5, 6, 9, 10}), + (".*9cde.*", {8}), + ("^a.*", set(range(1, 11))), + (".*(b|c).*", set(range(1, 11))), + ("^(b|c).*", set()), + ) + def test_regexp_match(self, text, expected): + super().test_regexp_match(text, expected) + + def test_not_regexp_match(self): + col = self.tables.some_table.c.data + self._test(~col.regexp_match("a.cde.*"), {2, 3, 4, 7, 8, 10}) + + +class SimpleUpdateDeleteTest(_SimpleUpdateDeleteTest): + def test_update(self, connection): + t = self.tables.plain_pk + r = connection.execute(t.update().where(t.c.id == 2), dict(data="d2_new")) + assert not r.is_insert + # snowflake returns a row with numbers of rows updated and number of multi-joined rows updated + assert r.returns_rows + assert r.rowcount == 1 + + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (2, "d2_new"), (3, "d3")], + ) + + def test_delete(self, connection): + t = self.tables.plain_pk + r = connection.execute(t.delete().where(t.c.id == 2)) + assert not r.is_insert + # snowflake returns a row with number of rows deleted + assert r.returns_rows + assert r.rowcount == 1 + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (3, "d3")], + ) + + +class CompositeKeyReflectionTest(_CompositeKeyReflectionTest): + @pytest.mark.xfail(reason="Fixing this would require behavior breaking change.") + def test_fk_column_order(self): + # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. + super().test_fk_column_order() + + @pytest.mark.xfail(reason="Fixing this would require behavior breaking change.") + def test_pk_column_order(self): + # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. + super().test_pk_column_order() + + +class BizarroCharacterFKResolutionTest(_BizarroCharacterFKResolutionTest): + @testing.combinations( + ("id",), ("(3)",), ("col%p",), ("[brack]",), argnames="columnname" + ) + @testing.variation("use_composite", [True, False]) + @testing.combinations( + ("plain",), + ("(2)",), + ("[brackets]",), + argnames="tablename", + ) + def test_fk_ref(self, connection, metadata, use_composite, tablename, columnname): + super().test_fk_ref(connection, metadata, use_composite, tablename, columnname) diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 4098f915..cb9632a4 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -2,10 +2,14 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from sqlalchemy import Integer, String, and_, func, select +import pytest +from sqlalchemy import Integer, String, and_, func, insert, select from sqlalchemy.schema import DropColumnComment, DropTableComment from sqlalchemy.sql import column, quoted_name, table -from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing.assertions import AssertsCompiledSQL + +from snowflake.sqlalchemy import snowdialect +from src.snowflake.sqlalchemy.snowdialect import SnowflakeDialect table1 = table( "table1", column("id", Integer), column("name", String), column("value", Integer) @@ -31,6 +35,36 @@ def test_now_func(self): dialect="snowflake", ) + def test_underscore_as_valid_identifier(self): + _table = table( + "table_1745924", + column("ca", Integer), + column("cb", String), + column("_", String), + ) + + stmt = insert(_table).values(ca=1, cb="test", _="test_") + self.assert_compile( + stmt, + 'INSERT INTO table_1745924 (ca, cb, "_") VALUES (%(ca)s, %(cb)s, %(_)s)', + dialect="snowflake", + ) + + def test_underscore_as_initial_character_as_non_quoted_identifier(self): + _table = table( + "table_1745924", + column("ca", Integer), + column("cb", String), + column("_identifier", String), + ) + + stmt = insert(_table).values(ca=1, cb="test", _identifier="test_") + self.assert_compile( + stmt, + "INSERT INTO table_1745924 (ca, cb, _identifier) VALUES (%(ca)s, %(cb)s, %(_identifier)s)", + dialect="snowflake", + ) + def test_multi_table_delete(self): statement = table1.delete().where(table1.c.id == table2.c.id) self.assert_compile( @@ -107,3 +141,91 @@ def test_quoted_name_label(engine_testaccount): sel_from_tbl = select(col).group_by(col).select_from(table("abc")) compiled_result = sel_from_tbl.compile() assert str(compiled_result) == t["output"] + + +def test_outer_lateral_join(): + col = column("colname").label("label") + col2 = column("colname2").label("label2") + lateral_table = func.flatten(func.PARSE_JSON(col2), outer=True).lateral() + stmt = select(col).select_from(table("abc")).join(lateral_table).group_by(col) + assert ( + str(stmt.compile(dialect=snowdialect.dialect())) + == "SELECT colname AS label \nFROM abc JOIN LATERAL flatten(PARSE_JSON(colname2)) AS anon_1 GROUP BY colname" + ) + + +@pytest.mark.feature_v20 +def test_division_operator_with_force_div_is_floordiv_false(): + col1 = column("col1", Integer) + col2 = column("col2", Integer) + stmt = col1 / col2 + assert ( + str(stmt.compile(dialect=SnowflakeDialect(force_div_is_floordiv=False))) + == "col1 / col2" + ) + + +@pytest.mark.feature_v20 +def test_division_operator_with_denominator_expr_force_div_is_floordiv_false(): + col1 = column("col1", Integer) + col2 = column("col2", Integer) + stmt = col1 / func.sqrt(col2) + assert ( + str(stmt.compile(dialect=SnowflakeDialect(force_div_is_floordiv=False))) + == "col1 / sqrt(col2)" + ) + + +@pytest.mark.feature_v20 +def test_division_operator_with_force_div_is_floordiv_default_true(): + col1 = column("col1", Integer) + col2 = column("col2", Integer) + stmt = col1 / col2 + assert str(stmt.compile(dialect=SnowflakeDialect())) == "col1 / col2" + + +@pytest.mark.feature_v20 +def test_division_operator_with_denominator_expr_force_div_is_floordiv_default_true(): + col1 = column("col1", Integer) + col2 = column("col2", Integer) + stmt = col1 / func.sqrt(col2) + assert str(stmt.compile(dialect=SnowflakeDialect())) == "col1 / sqrt(col2)" + + +@pytest.mark.feature_v20 +def test_floor_division_operator_force_div_is_floordiv_false(): + col1 = column("col1", Integer) + col2 = column("col2", Integer) + stmt = col1 // col2 + assert ( + str(stmt.compile(dialect=SnowflakeDialect(force_div_is_floordiv=False))) + == "FLOOR(col1 / col2)" + ) + + +@pytest.mark.feature_v20 +def test_floor_division_operator_with_denominator_expr_force_div_is_floordiv_false(): + col1 = column("col1", Integer) + col2 = column("col2", Integer) + stmt = col1 // func.sqrt(col2) + assert ( + str(stmt.compile(dialect=SnowflakeDialect(force_div_is_floordiv=False))) + == "FLOOR(col1 / sqrt(col2))" + ) + + +@pytest.mark.feature_v20 +def test_floor_division_operator_force_div_is_floordiv_default_true(): + col1 = column("col1", Integer) + col2 = column("col2", Integer) + stmt = col1 // col2 + assert str(stmt.compile(dialect=SnowflakeDialect())) == "col1 / col2" + + +@pytest.mark.feature_v20 +def test_floor_division_operator_with_denominator_expr_force_div_is_floordiv_default_true(): + col1 = column("col1", Integer) + col2 = column("col2", Integer) + stmt = col1 // func.sqrt(col2) + res = stmt.compile(dialect=SnowflakeDialect()) + assert str(res) == "FLOOR(col1 / sqrt(col2))" diff --git a/tests/test_copy.py b/tests/test_copy.py index e0752d4f..8dfcf286 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -4,7 +4,7 @@ import pytest from sqlalchemy import Column, Integer, MetaData, Sequence, String, Table -from sqlalchemy.sql import select, text +from sqlalchemy.sql import functions, select, text from snowflake.sqlalchemy import ( AWSBucket, @@ -58,8 +58,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_1) - == "COPY INTO 's3://backup' FROM python_tests_foods FILE_FORMAT=(TYPE=csv " - "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" + == "COPY INTO 's3://backup' FROM python_tests_foods FILE_FORMAT=(TYPE=csv " + "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" "(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')" ) copy_stmt_2 = CopyIntoStorage( @@ -73,8 +73,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): sql_compiler(copy_stmt_2) == "COPY INTO 's3://backup' FROM (SELECT python_tests_foods.id, " "python_tests_foods.name, python_tests_foods.quantity FROM python_tests_foods " - "WHERE python_tests_foods.id = 1) FILE_FORMAT=(TYPE=json COMPRESSION='zstd' " - "FILE_EXTENSION='json') CREDENTIALS=(AWS_ROLE='some_iam_role') " + "WHERE python_tests_foods.id = 1) FILE_FORMAT=(TYPE=json COMPRESSION='zstd' " + "FILE_EXTENSION='json') CREDENTIALS=(AWS_ROLE='some_iam_role') " "ENCRYPTION=(TYPE='AWS_SSE_S3')" ) copy_stmt_3 = CopyIntoStorage( @@ -87,7 +87,7 @@ def test_copy_into_location(engine_testaccount, sql_compiler): assert ( sql_compiler(copy_stmt_3) == "COPY INTO 'azure://snowflake.blob.core.windows.net/snowpile/backup' " - "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " + "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " "CREDENTIALS=(AZURE_SAS_TOKEN='token')" ) @@ -95,7 +95,7 @@ def test_copy_into_location(engine_testaccount, sql_compiler): assert ( sql_compiler(copy_stmt_3) == "COPY INTO 'azure://snowflake.blob.core.windows.net/snowpile/backup' " - "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " + "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " "MAX_FILE_SIZE = 50000000 " "CREDENTIALS=(AZURE_SAS_TOKEN='token')" ) @@ -112,8 +112,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_4) - == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " - "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" + == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " + "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" "(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')" ) @@ -126,8 +126,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_5) - == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " - "FIELD_DELIMITER=',') ENCRYPTION=" + == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " + "FIELD_DELIMITER=',') ENCRYPTION=" "(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')" ) @@ -138,7 +138,7 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_6) - == "COPY INTO @stage_name FROM python_tests_foods FILE_FORMAT=(TYPE=csv)" + == "COPY INTO @stage_name FROM python_tests_foods FILE_FORMAT=(TYPE=csv) " ) copy_stmt_7 = CopyIntoStorage( @@ -148,7 +148,38 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_7) - == "COPY INTO @name.stage_name/prefix/file FROM python_tests_foods FILE_FORMAT=(TYPE=csv)" + == "COPY INTO @name.stage_name/prefix/file FROM python_tests_foods FILE_FORMAT=(TYPE=csv) " + ) + + copy_stmt_8 = CopyIntoStorage( + from_=food_items, + into=ExternalStage(name="stage_name"), + partition_by=text("('YEAR=' || year)"), + ) + assert ( + sql_compiler(copy_stmt_8) + == "COPY INTO @stage_name FROM python_tests_foods PARTITION BY ('YEAR=' || year) " + ) + + copy_stmt_9 = CopyIntoStorage( + from_=food_items, + into=ExternalStage(name="stage_name"), + partition_by=functions.concat( + text("'YEAR='"), text(food_items.columns["name"].name) + ), + ) + assert ( + sql_compiler(copy_stmt_9) + == "COPY INTO @stage_name FROM python_tests_foods PARTITION BY concat('YEAR=', name) " + ) + + copy_stmt_10 = CopyIntoStorage( + from_=food_items, + into=ExternalStage(name="stage_name"), + partition_by="", + ) + assert ( + sql_compiler(copy_stmt_10) == "COPY INTO @stage_name FROM python_tests_foods " ) # NOTE Other than expect known compiled text, submit it to RegressionTests environment and expect them to fail, but @@ -231,7 +262,7 @@ def test_copy_into_storage_csv_extended(sql_compiler): result = sql_compiler(copy_into) expected = ( r"COPY INTO TEST_IMPORT " - r"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata " + r"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata " r"FILE_FORMAT=(TYPE=csv COMPRESSION='auto' DATE_FORMAT='AUTO' " r"ERROR_ON_COLUMN_COUNT_MISMATCH=True ESCAPE=None " r"ESCAPE_UNENCLOSED_FIELD='\134' FIELD_DELIMITER=',' " @@ -288,7 +319,7 @@ def test_copy_into_storage_parquet_named_format(sql_compiler): expected = ( "COPY INTO TEST_IMPORT " "FROM (SELECT $1:COL1::number, $1:COL2::varchar " - "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet) " + "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet) " "FILE_FORMAT=(format_name = parquet_file_format) force = TRUE" ) assert result == expected @@ -350,7 +381,7 @@ def test_copy_into_storage_parquet_files(sql_compiler): "COPY INTO TEST_IMPORT " "FROM (SELECT $1:COL1::number, $1:COL2::varchar " "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet " - "(file_format => parquet_file_format)) FILES = ('foo.txt','bar.txt') " + "(file_format => parquet_file_format)) FILES = ('foo.txt','bar.txt') " "FORCE = true" ) assert result == expected @@ -412,6 +443,6 @@ def test_copy_into_storage_parquet_pattern(sql_compiler): "COPY INTO TEST_IMPORT " "FROM (SELECT $1:COL1::number, $1:COL2::varchar " "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet " - "(file_format => parquet_file_format)) FORCE = true PATTERN = '.*csv'" + "(file_format => parquet_file_format)) FORCE = true PATTERN = '.*csv'" ) assert result == expected diff --git a/tests/test_core.py b/tests/test_core.py index 157889ff..a25342ac 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -28,13 +28,18 @@ String, Table, UniqueConstraint, + create_engine, dialects, + exc, + func, + insert, inspect, text, ) -from sqlalchemy.exc import DBAPIError, NoSuchTableError -from sqlalchemy.pool import NullPool -from sqlalchemy.sql import and_, not_, or_, select +from sqlalchemy.exc import DBAPIError, NoSuchTableError, OperationalError +from sqlalchemy.sql import and_, literal, not_, or_, select +from sqlalchemy.sql.ddl import CreateTable +from sqlalchemy.testing.assertions import eq_ import snowflake.connector.errors import snowflake.sqlalchemy.snowdialect @@ -46,8 +51,7 @@ ) from snowflake.sqlalchemy.snowdialect import SnowflakeDialect -from .conftest import create_engine_with_future_flag as create_engine -from .conftest import get_engine +from .conftest import get_engine, url_factory from .parameters import CONNECTION_PARAMETERS from .util import ischema_names_baseline, random_string @@ -123,14 +127,26 @@ def test_connect_args(): Snowflake connect string supports account name as a replacement of host:port """ + server = "" + if "host" in CONNECTION_PARAMETERS and "port" in CONNECTION_PARAMETERS: + server = "{host}:{port}".format( + host=CONNECTION_PARAMETERS["host"], port=CONNECTION_PARAMETERS["port"] + ) + elif "account" in CONNECTION_PARAMETERS and "region" in CONNECTION_PARAMETERS: + server = "{account}.{region}".format( + account=CONNECTION_PARAMETERS["account"], + region=CONNECTION_PARAMETERS["region"], + ) + elif "account" in CONNECTION_PARAMETERS: + server = CONNECTION_PARAMETERS["account"] + engine = create_engine( - "snowflake://{user}:{password}@{host}:{port}/{database}/{schema}" + "snowflake://{user}:{password}@{server}/{database}/{schema}" "?account={account}&protocol={protocol}".format( user=CONNECTION_PARAMETERS["user"], account=CONNECTION_PARAMETERS["account"], password=CONNECTION_PARAMETERS["password"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], + server=server, database=CONNECTION_PARAMETERS["database"], schema=CONNECTION_PARAMETERS["schema"], protocol=CONNECTION_PARAMETERS["protocol"], @@ -141,35 +157,34 @@ def test_connect_args(): finally: engine.dispose() - engine = create_engine( - URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - ) - ) + engine = create_engine(URL(**CONNECTION_PARAMETERS)) try: verify_engine_connection(engine) finally: engine.dispose() + parameters = {**CONNECTION_PARAMETERS} + parameters["warehouse"] = "testwh" + engine = create_engine(URL(**parameters)) + try: + verify_engine_connection(engine) + finally: + engine.dispose() + + +def test_boolean_query_argument_parsing(): engine = create_engine( URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - warehouse="testwh", + **CONNECTION_PARAMETERS, + validate_default_parameters=True, ) ) try: verify_engine_connection(engine) + connection = engine.raw_connection() + assert connection.validate_default_parameters is True finally: + connection.close() engine.dispose() @@ -385,16 +400,6 @@ def test_insert_tables(engine_testaccount): str(users.join(addresses)) == "users JOIN addresses ON " "users.id = addresses.user_id" ) - assert ( - str( - users.join( - addresses, - addresses.c.email_address.like(users.c.name + "%"), - ) - ) - == "users JOIN addresses " - "ON addresses.email_address LIKE users.name || :name_1" - ) s = select(users.c.fullname).select_from( users.join( @@ -423,7 +428,7 @@ def test_table_does_not_exist(engine_testaccount): """ meta = MetaData() with pytest.raises(NoSuchTableError): - Table("does_not_exist", meta, autoload=True, autoload_with=engine_testaccount) + Table("does_not_exist", meta, autoload_with=engine_testaccount) @pytest.mark.skip( @@ -449,9 +454,7 @@ def test_reflextion(engine_testaccount): ) try: meta = MetaData() - user_reflected = Table( - "user", meta, autoload=True, autoload_with=engine_testaccount - ) + user_reflected = Table("user", meta, autoload_with=engine_testaccount) assert user_reflected.c == ["user.id", "user.name", "user.fullname"] finally: conn.execute("DROP TABLE IF EXISTS user") @@ -468,6 +471,7 @@ def test_inspect_column(engine_testaccount): try: inspector = inspect(engine_testaccount) all_table_names = inspector.get_table_names() + assert isinstance(all_table_names, list) assert "users" in all_table_names assert "addresses" in all_table_names @@ -493,19 +497,20 @@ def test_inspect_column(engine_testaccount): users.drop(engine_testaccount) -def test_get_indexes(engine_testaccount): +def test_get_indexes(engine_testaccount, db_parameters): """ Tests get indexes - NOTE: Snowflake doesn't support indexes + NOTE: Only Snowflake Hybrid Tables support indexes """ + schema = db_parameters["schema"] metadata = MetaData() users, addresses = _create_users_addresses_tables_without_sequence( engine_testaccount, metadata ) try: inspector = inspect(engine_testaccount) - assert inspector.get_indexes("users") == [] + assert inspector.get_indexes("users", schema) == [] finally: addresses.drop(engine_testaccount) @@ -689,6 +694,39 @@ def test_create_table_with_cluster_by(engine_testaccount): user.drop(engine_testaccount) +def test_create_table_with_cluster_by_with_expression(engine_testaccount): + metadata = MetaData() + Table( + "clustered_user", + metadata, + Column("Id", Integer, primary_key=True), + Column("name", String), + snowflake_clusterby=["Id", "name", text('"Id" > 5')], + ) + metadata.create_all(engine_testaccount) + try: + inspector = inspect(engine_testaccount) + columns_in_table = inspector.get_columns("clustered_user") + assert columns_in_table[0]["name"] == "Id", "name" + finally: + metadata.drop_all(engine_testaccount) + + +def test_compile_table_with_cluster_by_with_expression(sql_compiler, snapshot): + metadata = MetaData() + user = Table( + "clustered_user", + metadata, + Column("Id", Integer, primary_key=True), + Column("name", String), + snowflake_clusterby=["Id", "name", text('"Id" > 5')], + ) + + create_table = CreateTable(user) + + assert sql_compiler(create_table) == snapshot + + def test_view_names(engine_testaccount): """ Tests all views @@ -915,37 +953,6 @@ class Appointment(Base): assert str(t.columns["real_data"].type) == "FLOAT" -def _get_engine_with_columm_metadata_cache( - db_parameters, user=None, password=None, account=None -): - """ - Creates a connection with column metadata cache - """ - if user is not None: - db_parameters["user"] = user - if password is not None: - db_parameters["password"] = password - if account is not None: - db_parameters["account"] = account - - engine = create_engine( - URL( - user=db_parameters["user"], - password=db_parameters["password"], - host=db_parameters["host"], - port=db_parameters["port"], - database=db_parameters["database"], - schema=db_parameters["schema"], - account=db_parameters["account"], - protocol=db_parameters["protocol"], - cache_column_metadata=True, - ), - poolclass=NullPool, - ) - - return engine - - def test_many_table_column_metadta(db_parameters): """ Get dozens of table metadata with column metadata cache. @@ -953,7 +960,9 @@ def test_many_table_column_metadta(db_parameters): cache_column_metadata=True will cache all column metadata for all tables in the schema. """ - engine = _get_engine_with_columm_metadata_cache(db_parameters) + url = url_factory(cache_column_metadata=True) + engine = get_engine(url) + RE_SUFFIX_NUM = re.compile(r".*(\d+)$") metadata = MetaData() total_objects = 10 @@ -1079,28 +1088,16 @@ def harass_inspector(): assert outcome -@pytest.mark.timeout(15) -def test_region(): - engine = create_engine( - URL( - user="testuser", - password="testpassword", - account="testaccount", - region="eu-central-1", - login_timeout=5, - ) - ) - try: - engine.connect() - pytest.fail("should not run") - except Exception as ex: - assert ex.orig.errno == 250001 - assert "Failed to connect to DB" in ex.orig.msg - assert "testaccount.eu-central-1.snowflakecomputing.com" in ex.orig.msg - - -@pytest.mark.timeout(15) -def test_azure(): +@pytest.mark.skip(reason="Testaccount is not available, it returns 404 error.") +@pytest.mark.timeout(10) +@pytest.mark.parametrize( + "region", + ( + pytest.param("eu-central-1", id="region"), + pytest.param("east-us-2.azure", id="azure"), + ), +) +def test_connection_timeout_error(region): engine = create_engine( URL( user="testuser", @@ -1110,13 +1107,13 @@ def test_azure(): login_timeout=5, ) ) - try: + + with pytest.raises(OperationalError) as excinfo: engine.connect() - pytest.fail("should not run") - except Exception as ex: - assert ex.orig.errno == 250001 - assert "Failed to connect to DB" in ex.orig.msg - assert "testaccount.east-us-2.azure.snowflakecomputing.com" in ex.orig.msg + + assert excinfo.value.orig.errno == 250001 + assert "Could not connect to Snowflake backend" in excinfo.value.orig.msg + assert region not in excinfo.value.orig.msg def test_load_dialect(): @@ -1322,7 +1319,8 @@ def test_comment_sqlalchemy(db_parameters, engine_testaccount, on_public_ci): column_comment1 = random_string(10, choices=string.ascii_uppercase) table_comment2 = random_string(10, choices=string.ascii_uppercase) column_comment2 = random_string(10, choices=string.ascii_uppercase) - engine2, _ = get_engine(schema=new_schema) + + engine2 = get_engine(url_factory(schema=new_schema)) con2 = None if not on_public_ci: con2 = engine2.connect() @@ -1403,47 +1401,51 @@ def test_special_schema_character(db_parameters, on_public_ci): def test_autoincrement(engine_testaccount): + """Snowflake does not guarantee generating sequence numbers without gaps. + + The generated numbers are not necessarily contiguous. + https://docs.snowflake.com/en/user-guide/querying-sequences + """ metadata = MetaData() users = Table( "users", metadata, - Column("uid", Integer, Sequence("id_seq"), primary_key=True), + Column("uid", Integer, Sequence("id_seq", order=True), primary_key=True), Column("name", String(39)), ) try: - users.create(engine_testaccount) - - with engine_testaccount.connect() as connection: - with connection.begin(): - connection.execute(users.insert(), [{"name": "sf1"}]) - assert connection.execute(select(users)).fetchall() == [(1, "sf1")] - connection.execute(users.insert(), [{"name": "sf2"}, {"name": "sf3"}]) - assert connection.execute(select(users)).fetchall() == [ - (1, "sf1"), - (2, "sf2"), - (3, "sf3"), - ] - connection.execute(users.insert(), {"name": "sf4"}) - assert connection.execute(select(users)).fetchall() == [ - (1, "sf1"), - (2, "sf2"), - (3, "sf3"), - (4, "sf4"), - ] - - seq = Sequence("id_seq") - nextid = connection.execute(seq) - connection.execute(users.insert(), [{"uid": nextid, "name": "sf5"}]) - assert connection.execute(select(users)).fetchall() == [ - (1, "sf1"), - (2, "sf2"), - (3, "sf3"), - (4, "sf4"), - (5, "sf5"), - ] + metadata.create_all(engine_testaccount) + + with engine_testaccount.begin() as connection: + connection.execute(insert(users), [{"name": "sf1"}]) + assert connection.execute(select(users)).fetchall() == [(1, "sf1")] + connection.execute(insert(users), [{"name": "sf2"}, {"name": "sf3"}]) + assert connection.execute(select(users)).fetchall() == [ + (1, "sf1"), + (2, "sf2"), + (3, "sf3"), + ] + connection.execute(insert(users), {"name": "sf4"}) + assert connection.execute(select(users)).fetchall() == [ + (1, "sf1"), + (2, "sf2"), + (3, "sf3"), + (4, "sf4"), + ] + + seq = Sequence("id_seq") + nextid = connection.execute(seq) + connection.execute(insert(users), [{"uid": nextid, "name": "sf5"}]) + assert connection.execute(select(users)).fetchall() == [ + (1, "sf1"), + (2, "sf2"), + (3, "sf3"), + (4, "sf4"), + (5, "sf5"), + ] finally: - users.drop(engine_testaccount) + metadata.drop_all(engine_testaccount) @pytest.mark.skip( @@ -1538,13 +1540,11 @@ def test_too_many_columns_detection(engine_testaccount, db_parameters): metadata.create_all(engine_testaccount) inspector = inspect(engine_testaccount) # Do test - original_execute = inspector.bind.execute + connection = inspector.bind.connect() + original_execute = connection.execute - def mock_helper(command, *args, **kwargs): - if "_get_schema_columns" in command: - # Creating exception exactly how SQLAlchemy does - raise DBAPIError.instance( - """ + exception_instance = DBAPIError.instance( + """ SELECT /* sqlalchemy:_get_schema_columns */ ic.table_name, ic.column_name, @@ -1559,24 +1559,32 @@ def mock_helper(command, *args, **kwargs): FROM information_schema.columns ic WHERE ic.table_schema='schema_name' ORDER BY ic.ordinal_position""", - {"table_schema": "TESTSCHEMA"}, - ProgrammingError( - "Information schema query returned too much data. Please repeat query with more " - "selective predicates.", - 90030, - ), - Error, - hide_parameters=False, - connection_invalidated=False, - dialect=SnowflakeDialect(), - ismulti=None, - ) + {"table_schema": "TESTSCHEMA"}, + ProgrammingError( + "Information schema query returned too much data. Please repeat query with more " + "selective predicates.", + 90030, + ), + Error, + hide_parameters=False, + connection_invalidated=False, + dialect=SnowflakeDialect(), + ismulti=None, + ) + + def mock_helper(command, *args, **kwargs): + if "_get_schema_columns" in command.text: + # Creating exception exactly how SQLAlchemy does + raise exception_instance else: return original_execute(command, *args, **kwargs) - with patch.object(inspector.bind, "execute", side_effect=mock_helper): - column_metadata = inspector.get_columns("users", db_parameters["schema"]) - assert len(column_metadata) == 4 + with patch.object(engine_testaccount, "connect") as conn: + conn.return_value = connection + with patch.object(connection, "execute", side_effect=mock_helper): + with pytest.raises(exc.ProgrammingError) as exception: + inspector.get_columns("users", db_parameters["schema"]) + assert exception.value.orig == exception_instance.orig # Clean up metadata.drop_all(engine_testaccount) @@ -1612,18 +1620,17 @@ def test_column_type_schema(engine_testaccount): C1 BIGINT, C2 BINARY, C3 BOOLEAN, C4 CHAR, C5 CHARACTER, C6 DATE, C7 DATETIME, C8 DEC, C9 DECIMAL, C10 DOUBLE, C11 FLOAT, C12 INT, C13 INTEGER, C14 NUMBER, C15 REAL, C16 BYTEINT, C17 SMALLINT, C18 STRING, C19 TEXT, C20 TIME, C21 TIMESTAMP, C22 TIMESTAMP_TZ, C23 TIMESTAMP_LTZ, - C24 TIMESTAMP_NTZ, C25 TINYINT, C26 VARBINARY, C27 VARCHAR, C28 VARIANT, C29 OBJECT, C30 ARRAY, C31 GEOGRAPHY + C24 TIMESTAMP_NTZ, C25 TINYINT, C26 VARBINARY, C27 VARCHAR, C28 VARIANT, C29 OBJECT, C30 ARRAY, C31 GEOGRAPHY, + C32 GEOMETRY ) """ ) - table_reflected = Table( - table_name, MetaData(), autoload=True, autoload_with=conn - ) + table_reflected = Table(table_name, MetaData(), autoload_with=conn) columns = table_reflected.columns - assert ( - len(columns) == len(ischema_names_baseline) - 1 - ) # -1 because FIXED is not supported + assert len(columns) == ( + len(ischema_names_baseline) - 2 + ) # -2 because FIXED and MAP is not supported def test_result_type_and_value(engine_testaccount): @@ -1635,13 +1642,12 @@ def test_result_type_and_value(engine_testaccount): C1 BIGINT, C2 BINARY, C3 BOOLEAN, C4 CHAR, C5 CHARACTER, C6 DATE, C7 DATETIME, C8 DEC(12,3), C9 DECIMAL(12,3), C10 DOUBLE, C11 FLOAT, C12 INT, C13 INTEGER, C14 NUMBER, C15 REAL, C16 BYTEINT, C17 SMALLINT, C18 STRING, C19 TEXT, C20 TIME, C21 TIMESTAMP, C22 TIMESTAMP_TZ, C23 TIMESTAMP_LTZ, - C24 TIMESTAMP_NTZ, C25 TINYINT, C26 VARBINARY, C27 VARCHAR, C28 VARIANT, C29 OBJECT, C30 ARRAY, C31 GEOGRAPHY + C24 TIMESTAMP_NTZ, C25 TINYINT, C26 VARBINARY, C27 VARCHAR, C28 VARIANT, C29 OBJECT, C30 ARRAY, C31 GEOGRAPHY, + C32 GEOMETRY ) """ ) - table_reflected = Table( - table_name, MetaData(), autoload=True, autoload_with=conn - ) + table_reflected = Table(table_name, MetaData(), autoload_with=conn) current_date = date.today() current_utctime = datetime.utcnow() current_localtime = pytz.utc.localize(current_utctime, is_dst=False).astimezone( @@ -1661,6 +1667,8 @@ def test_result_type_and_value(engine_testaccount): CHAR_VALUE = "A" GEOGRAPHY_VALUE = "POINT(-122.35 37.55)" GEOGRAPHY_RESULT_VALUE = '{"coordinates": [-122.35,37.55],"type": "Point"}' + GEOMETRY_VALUE = "POINT(-94.58473 39.08985)" + GEOMETRY_RESULT_VALUE = '{"coordinates": [-94.58473,39.08985],"type": "Point"}' ins = table_reflected.insert().values( c1=MAX_INT_VALUE, # BIGINT @@ -1694,6 +1702,7 @@ def test_result_type_and_value(engine_testaccount): c29=None, # OBJECT, currently snowflake-sqlalchemy/connector does not support binding variant c30=None, # ARRAY, currently snowflake-sqlalchemy/connector does not support binding variant c31=GEOGRAPHY_VALUE, # GEOGRAPHY + c32=GEOMETRY_VALUE, # GEOMETRY ) conn.execute(ins) @@ -1732,6 +1741,7 @@ def test_result_type_and_value(engine_testaccount): and result[28] is None and result[29] is None and json.loads(result[30]) == json.loads(GEOGRAPHY_RESULT_VALUE) + and json.loads(result[31]) == json.loads(GEOMETRY_RESULT_VALUE) ) sql = f""" @@ -1798,30 +1808,14 @@ def test_normalize_and_denormalize_empty_string_column_name(engine_testaccount): def test_snowflake_sqlalchemy_as_valid_client_type(): engine = create_engine( - URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - ), + URL(**CONNECTION_PARAMETERS), connect_args={"internal_application_name": "UnknownClient"}, ) with engine.connect() as conn: with pytest.raises(snowflake.connector.errors.NotSupportedError): conn.exec_driver_sql("select 1").cursor.fetch_pandas_all() - engine = create_engine( - URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - ) - ) + engine = create_engine(URL(**CONNECTION_PARAMETERS)) with engine.connect() as conn: conn.exec_driver_sql("select 1").cursor.fetch_pandas_all() @@ -1842,20 +1836,17 @@ def test_snowflake_sqlalchemy_as_valid_client_type(): ) snowflake.connector.connection.DEFAULT_CONFIGURATION[ "internal_application_name" - ] = ("PythonConnector", (type(None), str)) + ] = ( + "PythonConnector", + (type(None), str), + ) snowflake.connector.connection.DEFAULT_CONFIGURATION[ "internal_application_version" - ] = ("3.0.0", (type(None), str)) - engine = create_engine( - URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - ) + ] = ( + "3.0.0", + (type(None), str), ) + engine = create_engine(URL(**CONNECTION_PARAMETERS)) with engine.connect() as conn: conn.exec_driver_sql("select 1").cursor.fetch_pandas_all() assert ( @@ -1875,3 +1866,86 @@ def test_snowflake_sqlalchemy_as_valid_client_type(): snowflake.connector.connection.DEFAULT_CONFIGURATION[ "internal_application_version" ] = origin_internal_app_version + + +@pytest.mark.parametrize( + "operation", + [ + [ + literal(5), + literal(10), + 0.5, + ], + [literal(5), func.sqrt(literal(10)), 1.5811388300841895], + [literal(4), literal(5), decimal.Decimal("0.800000")], + [literal(2), literal(2), 1.0], + [literal(3), literal(2), 1.5], + [literal(4), literal(1.5), 2.666667], + [literal(5.5), literal(10.7), 0.5140187], + [literal(5.5), literal(8), 0.6875], + ], +) +def test_true_division_operation(engine_testaccount, operation): + # expected_warning = "div_is_floordiv value will be changed to False in a future release. This will generate a behavior change on true and floor division. Please review https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#python-division-operator-performs-true-division-for-all-backends-added-floor-division" + # with pytest.warns(PendingDeprecationWarning, match=expected_warning): + with engine_testaccount.connect() as conn: + eq_( + conn.execute(select(operation[0] / operation[1])).fetchall(), + [((operation[2]),)], + ) + + +@pytest.mark.parametrize( + "operation", + [ + [literal(5), literal(10), 0.5, 0.5], + [literal(5), func.sqrt(literal(10)), 1.5811388300841895, 1.0], + [ + literal(4), + literal(5), + decimal.Decimal("0.800000"), + decimal.Decimal("0.800000"), + ], + [literal(2), literal(2), 1.0, 1.0], + [literal(3), literal(2), 1.5, 1.5], + [literal(4), literal(1.5), 2.666667, 2.0], + [literal(5.5), literal(10.7), 0.5140187, 0], + [literal(5.5), literal(8), 0.6875, 0.6875], + ], +) +@pytest.mark.feature_v20 +def test_division_force_div_is_floordiv_default(engine_testaccount, operation): + expected_warning = "div_is_floordiv value will be changed to False in a future release. This will generate a behavior change on true and floor division. Please review https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#python-division-operator-performs-true-division-for-all-backends-added-floor-division" + with pytest.warns(PendingDeprecationWarning, match=expected_warning): + with engine_testaccount.connect() as conn: + eq_( + conn.execute( + select(operation[0] / operation[1], operation[0] // operation[1]) + ).fetchall(), + [(operation[2], operation[3])], + ) + + +@pytest.mark.parametrize( + "operation", + [ + [literal(5), literal(10), 0.5, 0], + [literal(5), func.sqrt(literal(10)), 1.5811388300841895, 1.0], + [literal(4), literal(5), decimal.Decimal("0.800000"), 0], + [literal(2), literal(2), 1.0, 1.0], + [literal(3), literal(2), 1.5, 1], + [literal(4), literal(1.5), 2.666667, 2.0], + [literal(5.5), literal(10.7), 0.5140187, 0], + [literal(5.5), literal(8), 0.6875, 0], + ], +) +@pytest.mark.feature_v20 +def test_division_force_div_is_floordiv_false(db_parameters, operation): + engine = create_engine(URL(**db_parameters), **{"force_div_is_floordiv": False}) + with engine.connect() as conn: + eq_( + conn.execute( + select(operation[0] / operation[1], operation[0] // operation[1]) + ).fetchall(), + [(operation[2], operation[3])], + ) diff --git a/tests/test_create.py b/tests/test_create.py index 5271118f..0b8b48fa 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -54,6 +54,16 @@ def test_create_stage(sql_compiler): ) assert actual == expected + create_stage = CreateStage(stage=stage, container=container, temporary=True) + # validate that the resulting SQL is as expected + actual = sql_compiler(create_stage) + expected = ( + "CREATE TEMPORARY STAGE MY_DB.MY_SCHEMA.AZURE_STAGE " + "URL='azure://myaccount.blob.core.windows.net/my-container' " + "CREDENTIALS=(AZURE_SAS_TOKEN='saas_token')" + ) + assert actual == expected + def test_create_csv_format(sql_compiler): """ diff --git a/tests/test_custom_functions.py b/tests/test_custom_functions.py new file mode 100644 index 00000000..2a1e1cb5 --- /dev/null +++ b/tests/test_custom_functions.py @@ -0,0 +1,25 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import pytest +from sqlalchemy import func + +from snowflake.sqlalchemy import snowdialect + + +def test_flatten_does_not_render_params(): + """This behavior is for backward compatibility. + + In previous version params were not rendered. + In future this behavior will change. + """ + flat = func.flatten("[1, 2]", outer=True) + res = flat.compile(dialect=snowdialect.dialect()) + + assert str(res) == "flatten(%(flatten_1)s)" + + +def test_flatten_emits_warning(): + expected_warning = "For backward compatibility params are not rendered." + with pytest.warns(DeprecationWarning, match=expected_warning): + func.flatten().compile(dialect=snowdialect.dialect()) diff --git a/tests/test_custom_types.py b/tests/test_custom_types.py index b7962199..3961a5d3 100644 --- a/tests/test_custom_types.py +++ b/tests/test_custom_types.py @@ -2,7 +2,10 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from snowflake.sqlalchemy import custom_types +import pytest +from sqlalchemy import Column, Integer, MetaData, Table, text + +from snowflake.sqlalchemy import TEXT, custom_types def test_string_conversions(): @@ -15,6 +18,7 @@ def test_string_conversions(): "TIMESTAMP_LTZ", "TIMESTAMP_NTZ", "GEOGRAPHY", + "GEOMETRY", ] sf_types = [ "TEXT", @@ -33,3 +37,31 @@ def test_string_conversions(): sample = getattr(custom_types, type_)() if type_ in sf_custom_types: assert type_ == str(sample) + + +@pytest.mark.feature_max_lob_size +def test_create_table_with_text_type(engine_testaccount): + metadata = MetaData() + table_name = "test_max_lob_size_0" + test_max_lob_size = Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("full_name", TEXT(), server_default=text("id::varchar")), + ) + + metadata.create_all(engine_testaccount) + try: + assert test_max_lob_size is not None + + with engine_testaccount.connect() as conn: + with conn.begin(): + query = text(f"SELECT GET_DDL('TABLE', '{table_name}')") + result = conn.execute(query) + row = str(result.mappings().fetchone()) + assert ( + "VARCHAR(134217728)" in row + ), f"Expected VARCHAR(134217728) in {row}" + + finally: + test_max_lob_size.drop(engine_testaccount) diff --git a/tests/test_geometry.py b/tests/test_geometry.py new file mode 100644 index 00000000..742b518e --- /dev/null +++ b/tests/test_geometry.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from json import loads + +from sqlalchemy import Column, Integer, MetaData, Table +from sqlalchemy.sql import select + +from snowflake.sqlalchemy import GEOMETRY + + +def test_create_table_geometry_datatypes(engine_testaccount): + """ + Create table including geometry data types + """ + metadata = MetaData() + table_name = "test_geometry0" + test_geometry = Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("geom", GEOMETRY), + ) + metadata.create_all(engine_testaccount) + try: + assert test_geometry is not None + finally: + test_geometry.drop(engine_testaccount) + + +def test_inspect_geometry_datatypes(engine_testaccount): + """ + Create table including geometry data types + """ + metadata = MetaData() + table_name = "test_geometry0" + test_geometry = Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("geom1", GEOMETRY), + Column("geom2", GEOMETRY), + ) + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + test_point = "POINT(-94.58473 39.08985)" + test_point1 = '{"coordinates": [-94.58473, 39.08985],"type": "Point"}' + + ins = test_geometry.insert().values( + id=1, geom1=test_point, geom2=test_point1 + ) + + with conn.begin(): + results = conn.execute(ins) + results.close() + + s = select(test_geometry) + results = conn.execute(s) + rows = results.fetchone() + results.close() + assert rows[0] == 1 + assert rows[1] == rows[2] + assert loads(rows[2]) == loads(test_point1) + finally: + test_geometry.drop(engine_testaccount) diff --git a/tests/test_imports.py b/tests/test_imports.py new file mode 100644 index 00000000..0cfe5931 --- /dev/null +++ b/tests/test_imports.py @@ -0,0 +1,64 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import importlib +import inspect + +import pytest + + +def get_classes_from_module(module_name): + """Returns a set of class names from a given module.""" + try: + module = importlib.import_module(module_name) + members = inspect.getmembers(module) + return {name for name, obj in members if inspect.isclass(obj)} + + except ImportError: + print(f"Module '{module_name}' could not be imported.") + return set() + + +def test_types_in_snowdialect(): + classes_a = get_classes_from_module( + "snowflake.sqlalchemy.parser.custom_type_parser" + ) + classes_b = get_classes_from_module("snowflake.sqlalchemy.snowdialect") + assert classes_a.issubset(classes_b), str(classes_a - classes_b) + + +@pytest.mark.parametrize( + "type_class_name", + [ + "BIGINT", + "BINARY", + "BOOLEAN", + "CHAR", + "DATE", + "DATETIME", + "DECIMAL", + "FLOAT", + "INTEGER", + "REAL", + "SMALLINT", + "TIME", + "TIMESTAMP", + "VARCHAR", + "NullType", + "_CUSTOM_DECIMAL", + "ARRAY", + "DOUBLE", + "GEOGRAPHY", + "GEOMETRY", + "MAP", + "OBJECT", + "TIMESTAMP_LTZ", + "TIMESTAMP_NTZ", + "TIMESTAMP_TZ", + "VARIANT", + ], +) +def test_snowflake_data_types_instance(type_class_name): + classes_b = get_classes_from_module("snowflake.sqlalchemy.snowdialect") + assert type_class_name in classes_b, type_class_name diff --git a/tests/test_index_reflection.py b/tests/test_index_reflection.py new file mode 100644 index 00000000..a808703b --- /dev/null +++ b/tests/test_index_reflection.py @@ -0,0 +1,68 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import MetaData, inspect +from sqlalchemy.sql.ddl import CreateSchema, DropSchema + + +@pytest.mark.aws +def test_indexes_reflection(engine_testaccount, db_parameters, sql_compiler): + metadata = MetaData() + + table_name = "test_hybrid_table_2" + index_name = "INDEX_NAME_2" + schema = db_parameters["schema"] + index_columns = ["name", "name2"] + + create_table_sql = f""" + CREATE HYBRID TABLE {table_name} ( + id INT primary key, + name VARCHAR, + name2 VARCHAR, + INDEX {index_name} ({', '.join(index_columns)}) + ); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + insp = inspect(engine_testaccount) + + try: + with engine_testaccount.connect(): + # Prefixes reflection not supported, example: "HYBRID, DYNAMIC" + indexes = insp.get_indexes(table_name, schema) + assert len(indexes) == 1 + assert indexes[0].get("name") == index_name + assert indexes[0].get("column_names") == index_columns + + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_simple_reflection_hybrid_table_as_table( + engine_testaccount, assert_text_in_buf, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_simple_reflection_hybrid_table_as_table" + schema = db_parameters["schema"] + "_reflections" + with engine_testaccount.connect() as connection: + try: + connection.execute(CreateSchema(schema)) + + create_table_sql = f""" + CREATE HYBRID TABLE {schema}.{table_name} (id INT primary key, new_column VARCHAR, INDEX index_name (new_column)); + """ + connection.exec_driver_sql(create_table_sql) + + metadata.reflect(engine_testaccount, schema=schema) + + assert_text_in_buf( + f"SHOW /* sqlalchemy:get_schema_tables_info */ TABLES IN SCHEMA {schema}", + occurrences=1, + ) + + finally: + connection.execute(DropSchema(schema, cascade=True)) diff --git a/tests/test_orm.py b/tests/test_orm.py index 363da671..cb3a7768 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -3,13 +3,28 @@ # import enum +import logging import pytest -from sqlalchemy import Column, Enum, ForeignKey, Integer, Sequence, String, text +from sqlalchemy import ( + TEXT, + Column, + Enum, + ForeignKey, + Integer, + Sequence, + String, + exc, + func, + select, + text, +) from sqlalchemy.orm import Session, declarative_base, relationship +from snowflake.sqlalchemy import HybridTable + -def test_basic_orm(engine_testaccount, run_v20_sqlalchemy): +def test_basic_orm(engine_testaccount): """ Tests declarative """ @@ -35,7 +50,6 @@ def __repr__(self): ed_user = User(name="ed", fullname="Edward Jones") session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(ed_user) our_user = session.query(User).filter_by(name="ed").first() @@ -45,14 +59,15 @@ def __repr__(self): Base.metadata.drop_all(engine_testaccount) -def test_orm_one_to_many_relationship(engine_testaccount, run_v20_sqlalchemy): +def test_orm_one_to_many_relationship(engine_testaccount, db_parameters): """ Tests One to Many relationship """ Base = declarative_base() + prefix = "tbl_" class User(Base): - __tablename__ = "user" + __tablename__ = prefix + "user" id = Column(Integer, Sequence("user_id_seq"), primary_key=True) name = Column(String) @@ -62,13 +77,13 @@ def __repr__(self): return f"" class Address(Base): - __tablename__ = "address" + __tablename__ = prefix + "address" id = Column(Integer, Sequence("address_id_seq"), primary_key=True) email_address = Column(String, nullable=False) - user_id = Column(Integer, ForeignKey("user.id")) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) - user = relationship("User", backref="addresses") + user = relationship(User, backref="addresses") def __repr__(self): return f"" @@ -86,7 +101,6 @@ def __repr__(self): ] session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(jack) # cascade each Address into the Session as well session.commit() @@ -113,14 +127,143 @@ def __repr__(self): Base.metadata.drop_all(engine_testaccount) -def test_delete_cascade(engine_testaccount, run_v20_sqlalchemy): +@pytest.mark.aws +def test_orm_one_to_many_relationship_with_hybrid_table(engine_testaccount, snapshot): + """ + Tests One to Many relationship + """ + Base = declarative_base() + + class User(Base): + __tablename__ = "hb_tbl_user" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + name = Column(String) + fullname = Column(String) + + def __repr__(self): + return f"" + + class Address(Base): + __tablename__ = "hb_tbl_address" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, Sequence("address_id_seq"), primary_key=True) + email_address = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) + + user = relationship(User, backref="addresses") + + def __repr__(self): + return f"" + + Base.metadata.create_all(engine_testaccount) + + try: + jack = User(name="jack", fullname="Jack Bean") + assert jack.addresses == [], "one to many record is empty list" + + jack.addresses = [ + Address(email_address="jack@gmail.com"), + Address(email_address="j25@yahoo.com"), + Address(email_address="jack@hotmail.com"), + ] + + session = Session(bind=engine_testaccount) + session.add(jack) # cascade each Address into the Session as well + session.commit() + + session.delete(jack) + + with pytest.raises(exc.ProgrammingError) as exc_info: + session.query(Address).all() + + assert exc_info.value == snapshot, "Iceberg Table enforce FK constraint" + + finally: + Base.metadata.drop_all(engine_testaccount) + + +def test_delete_cascade(engine_testaccount): """ Test delete cascade """ Base = declarative_base() + prefix = "tbl_" class User(Base): - __tablename__ = "user" + __tablename__ = prefix + "user" + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + name = Column(String) + fullname = Column(String) + + addresses = relationship( + "Address", back_populates="user", cascade="all, delete, delete-orphan" + ) + + def __repr__(self): + return f"" + + class Address(Base): + __tablename__ = prefix + "address" + + id = Column(Integer, Sequence("address_id_seq"), primary_key=True) + email_address = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) + + user = relationship(User, back_populates="addresses") + + def __repr__(self): + return f"" + + Base.metadata.create_all(engine_testaccount) + + try: + jack = User(name="jack", fullname="Jack Bean") + assert jack.addresses == [], "one to many record is empty list" + + jack.addresses = [ + Address(email_address="jack@gmail.com"), + Address(email_address="j25@yahoo.com"), + Address(email_address="jack@hotmail.com"), + ] + + session = Session(bind=engine_testaccount) + session.add(jack) # cascade each Address into the Session as well + session.commit() + + got_jack = session.query(User).first() + assert got_jack == jack + + session.delete(jack) + got_addresses = session.query(Address).all() + assert len(got_addresses) == 0, "no address record" + finally: + Base.metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_delete_cascade_hybrid_table(engine_testaccount): + """ + Test delete cascade + """ + Base = declarative_base() + prefix = "hb_tbl_" + + class User(Base): + __tablename__ = prefix + "user" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) id = Column(Integer, Sequence("user_id_seq"), primary_key=True) name = Column(String) @@ -134,13 +277,17 @@ def __repr__(self): return f"" class Address(Base): - __tablename__ = "address" + __tablename__ = prefix + "address" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) id = Column(Integer, Sequence("address_id_seq"), primary_key=True) email_address = Column(String, nullable=False) - user_id = Column(Integer, ForeignKey("user.id")) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) - user = relationship("User", back_populates="addresses") + user = relationship(User, back_populates="addresses") def __repr__(self): return f"" @@ -158,7 +305,6 @@ def __repr__(self): ] session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(jack) # cascade each Address into the Session as well session.commit() @@ -178,7 +324,7 @@ def __repr__(self): WIP """, ) -def test_orm_query(engine_testaccount, run_v20_sqlalchemy): +def test_orm_query(engine_testaccount): """ Tests ORM query """ @@ -199,7 +345,6 @@ def __repr__(self): # TODO: insert rows session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy # TODO: query.all() for name, fullname in session.query(User.name, User.fullname): @@ -209,7 +354,7 @@ def __repr__(self): # MultipleResultsFound if not one result -def test_schema_including_db(engine_testaccount, db_parameters, run_v20_sqlalchemy): +def test_schema_including_db(engine_testaccount, db_parameters): """ Test schema parameter including database separated by a dot. """ @@ -232,7 +377,6 @@ class User(Base): ed_user = User(name="ed", fullname="Edward Jones") session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(ed_user) ret_user = session.query(User.id, User.name).first() @@ -244,7 +388,7 @@ class User(Base): Base.metadata.drop_all(engine_testaccount) -def test_schema_including_dot(engine_testaccount, db_parameters, run_v20_sqlalchemy): +def test_schema_including_dot(engine_testaccount, db_parameters): """ Tests pseudo schema name including dot. """ @@ -265,7 +409,6 @@ class User(Base): fullname = Column(String) session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy query = session.query(User.id) assert str(query).startswith( 'SELECT {db}."{schema}.{schema}".{db}.users.id'.format( @@ -274,9 +417,7 @@ class User(Base): ) -def test_schema_translate_map( - engine_testaccount, db_parameters, sql_compiler, run_v20_sqlalchemy -): +def test_schema_translate_map(engine_testaccount, db_parameters): """ Test schema translate map execution option works replaces schema correctly """ @@ -299,7 +440,6 @@ class User(Base): schema_translate_map={schema_map: db_parameters["schema"]} ) as con: session = Session(bind=con) - session.future = run_v20_sqlalchemy with con.begin(): Base.metadata.create_all(con) try: @@ -326,3 +466,119 @@ class User(Base): assert user.fullname == "test_user" finally: Base.metadata.drop_all(con) + + +def test_outer_lateral_join(engine_testaccount, caplog): + Base = declarative_base() + + class Employee(Base): + __tablename__ = "employees" + + employee_id = Column(Integer, primary_key=True) + last_name = Column(String) + + class Department(Base): + __tablename__ = "departments" + + department_id = Column(Integer, primary_key=True) + name = Column(String) + + Base.metadata.create_all(engine_testaccount) + session = Session(bind=engine_testaccount) + e1 = Employee(employee_id=101, last_name="Richards") + d1 = Department(department_id=1, name="Engineering") + session.add_all([e1, d1]) + session.commit() + + sub = select(Department).lateral() + query = ( + select(Employee.employee_id, Department.department_id) + .select_from(Employee) + .outerjoin(sub) + ) + compiled_stmts = ( + # v1.x + "SELECT employees.employee_id, departments.department_id " + "FROM departments, employees LEFT OUTER JOIN LATERAL " + "(SELECT departments.department_id AS department_id, departments.name AS name " + "FROM departments) AS anon_1", + # v2.x + "SELECT employees.employee_id, departments.department_id " + "FROM employees LEFT OUTER JOIN LATERAL " + "(SELECT departments.department_id AS department_id, departments.name AS name " + "FROM departments) AS anon_1, departments", + ) + compiled_stmt = str(query.compile(engine_testaccount)).replace("\n", "") + assert compiled_stmt in compiled_stmts + + with caplog.at_level(logging.DEBUG): + assert [res for res in session.execute(query)] + assert ( + "SELECT employees.employee_id, departments.department_id FROM departments" + in caplog.text + ) or ( + "SELECT employees.employee_id, departments.department_id FROM employees" + in caplog.text + ) + + +def test_lateral_join_without_condition(engine_testaccount, caplog): + Base = declarative_base() + + class Employee(Base): + __tablename__ = "Employee" + + pkey = Column(String, primary_key=True) + uid = Column(Integer) + content = Column(String) + + Base.metadata.create_all(engine_testaccount) + lateral_table = func.flatten( + func.PARSE_JSON(Employee.content), outer=False + ).lateral() + query = ( + select( + Employee.uid, + ) + .select_from(Employee) + .join(lateral_table) + .where(Employee.uid == "123") + ) + session = Session(bind=engine_testaccount) + with caplog.at_level(logging.DEBUG): + session.execute(query) + assert ( + '[SELECT "Employee".uid FROM "Employee" JOIN LATERAL flatten(PARSE_JSON("Employee"' + in caplog.text + ) + + +@pytest.mark.feature_max_lob_size +def test_basic_table_with_large_lob_size_in_memory(engine_testaccount, sql_compiler): + Base = declarative_base() + + class User(Base): + __tablename__ = "user" + + id = Column(Integer, primary_key=True) + full_name = Column(TEXT(), server_default=text("id::varchar")) + + def __repr__(self): + return f"" + + Base.metadata.create_all(engine_testaccount) + + try: + assert User.__table__ is not None + + with engine_testaccount.connect() as conn: + with conn.begin(): + query = text(f"SELECT GET_DDL('TABLE', '{User.__tablename__}')") + result = conn.execute(query) + row = str(result.mappings().fetchone()) + assert ( + "VARCHAR(134217728)" in row + ), f"Expected VARCHAR(134217728) in {row}" + + finally: + Base.metadata.drop_all(engine_testaccount) diff --git a/tests/test_pandas.py b/tests/test_pandas.py index aadfb9cf..2a6b9f1b 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -24,13 +24,10 @@ select, text, ) -from sqlalchemy.pool import NullPool from snowflake.connector import ProgrammingError from snowflake.connector.pandas_tools import make_pd_writer, pd_writer -from snowflake.sqlalchemy import URL - -from .conftest import create_engine_with_future_flag as create_engine +from snowflake.sqlalchemy.compat import IS_VERSION_20 def _create_users_addresses_tables(engine_testaccount, metadata): @@ -113,40 +110,8 @@ def test_a_simple_read_sql(engine_testaccount): users.drop(engine_testaccount) -def get_engine_with_numpy(db_parameters, user=None, password=None, account=None): - """ - Creates a connection using the parameters defined in JDBC connect string - """ - from snowflake.sqlalchemy import URL - - if user is not None: - db_parameters["user"] = user - if password is not None: - db_parameters["password"] = password - if account is not None: - db_parameters["account"] = account - - engine = create_engine( - URL( - user=db_parameters["user"], - password=db_parameters["password"], - host=db_parameters["host"], - port=db_parameters["port"], - database=db_parameters["database"], - schema=db_parameters["schema"], - account=db_parameters["account"], - protocol=db_parameters["protocol"], - numpy=True, - ), - poolclass=NullPool, - ) - - return engine - - -def test_numpy_datatypes(db_parameters): - engine = get_engine_with_numpy(db_parameters) - with engine.connect() as conn: +def test_numpy_datatypes(engine_testaccount_with_numpy, db_parameters): + with engine_testaccount_with_numpy.connect() as conn: try: specific_date = np.datetime64("2016-03-04T12:03:05.123456789") with conn.begin(): @@ -163,12 +128,11 @@ def test_numpy_datatypes(db_parameters): assert df.c1.values[0] == specific_date finally: conn.exec_driver_sql(f"DROP TABLE IF EXISTS {db_parameters['name']}") - engine.dispose() + engine_testaccount_with_numpy.dispose() -def test_to_sql(db_parameters): - engine = get_engine_with_numpy(db_parameters) - with engine.connect() as conn: +def test_to_sql(engine_testaccount_with_numpy, db_parameters): + with engine_testaccount_with_numpy.connect() as conn: total_rows = 10000 conn.exec_driver_sql( textwrap.dedent( @@ -182,7 +146,13 @@ def test_to_sql(db_parameters): conn.exec_driver_sql("create or replace table dst(c1 float)") tbl = pd.read_sql_query(text("select * from src"), conn) - tbl.to_sql("dst", engine, if_exists="append", chunksize=1000, index=False) + tbl.to_sql( + "dst", + engine_testaccount_with_numpy, + if_exists="append", + chunksize=1000, + index=False, + ) df = pd.read_sql_query(text("select count(*) as cnt from dst"), conn) assert df.cnt.values[0] == total_rows @@ -199,41 +169,15 @@ def test_no_indexes(engine_testaccount, db_parameters): con=conn, if_exists="replace", ) - assert str(exc.value) == "Snowflake does not support indexes" + assert str(exc.value) == "Only Snowflake Hybrid Tables supports indexes" -def test_timezone(db_parameters): +def test_timezone(db_parameters, engine_testaccount, engine_testaccount_with_numpy): test_table_name = "".join([random.choice(string.ascii_letters) for _ in range(5)]) - sa_engine = create_engine( - URL( - account=db_parameters["account"], - password=db_parameters["password"], - database=db_parameters["database"], - port=db_parameters["port"], - user=db_parameters["user"], - host=db_parameters["host"], - protocol=db_parameters["protocol"], - schema=db_parameters["schema"], - numpy=True, - ) - ) - - sa_engine2_raw_conn = create_engine( - URL( - account=db_parameters["account"], - password=db_parameters["password"], - database=db_parameters["database"], - port=db_parameters["port"], - user=db_parameters["user"], - host=db_parameters["host"], - protocol=db_parameters["protocol"], - schema=db_parameters["schema"], - timezone="America/Los_Angeles", - numpy="", - ) - ).raw_connection() + sa_engine = engine_testaccount_with_numpy + sa_engine2_raw_conn = engine_testaccount.raw_connection() with sa_engine.connect() as conn: @@ -297,8 +241,8 @@ def test_timezone(db_parameters): conn.exec_driver_sql(f"DROP TABLE {test_table_name};") -def test_pandas_writeback(engine_testaccount, run_v20_sqlalchemy): - if run_v20_sqlalchemy and sys.version_info < (3, 8): +def test_pandas_writeback(engine_testaccount): + if IS_VERSION_20 and sys.version_info < (3, 8): pytest.skip( "In Python 3.7, this test depends on pandas features of which the implementation is incompatible with sqlachemy 2.0, and pandas does not support Python 3.7 anymore." ) @@ -316,18 +260,13 @@ def test_pandas_writeback(engine_testaccount, run_v20_sqlalchemy): sf_connector_version_df = pd.DataFrame( sf_connector_version_data, columns=["NAME", "NEWEST_VERSION"] ) - sf_connector_version_df.to_sql(table_name, conn, index=False, method=pd_writer) - - assert ( - ( - pd.read_sql_table(table_name, conn).rename( - columns={"newest_version": "NEWEST_VERSION", "name": "NAME"} - ) - == sf_connector_version_df - ) - .all() - .all() + sf_connector_version_df.to_sql( + table_name, conn, index=False, method=pd_writer, if_exists="replace" + ) + results = pd.read_sql_table(table_name, conn).rename( + columns={"newest_version": "NEWEST_VERSION", "name": "NAME"} ) + assert results.equals(sf_connector_version_df) @pytest.mark.parametrize("chunk_size", [5, 1]) @@ -414,8 +353,8 @@ def test_pandas_invalid_make_pd_writer(engine_testaccount): ) -def test_percent_signs(engine_testaccount, run_v20_sqlalchemy): - if run_v20_sqlalchemy and sys.version_info < (3, 8): +def test_percent_signs(engine_testaccount): + if IS_VERSION_20 and sys.version_info < (3, 8): pytest.skip( "In Python 3.7, this test depends on pandas features of which the implementation is incompatible with sqlachemy 2.0, and pandas does not support Python 3.7 anymore." ) @@ -438,7 +377,7 @@ def test_percent_signs(engine_testaccount, run_v20_sqlalchemy): not_like_sql = f"select * from {table_name} where c2 not like '%b%'" like_sql = f"select * from {table_name} where c2 like '%b%'" calculate_sql = "SELECT 1600 % 400 AS a, 1599 % 400 as b" - if run_v20_sqlalchemy: + if IS_VERSION_20: not_like_sql = sqlalchemy.text(not_like_sql) like_sql = sqlalchemy.text(like_sql) calculate_sql = sqlalchemy.text(calculate_sql) diff --git a/tests/test_qmark.py b/tests/test_qmark.py index fe50fae9..3761181a 100644 --- a/tests/test_qmark.py +++ b/tests/test_qmark.py @@ -5,76 +5,35 @@ import os import sys +import pandas as pd import pytest from sqlalchemy import text -from snowflake.sqlalchemy import URL - -from .conftest import create_engine_with_future_flag as create_engine - THIS_DIR = os.path.dirname(os.path.realpath(__file__)) -def _get_engine_with_qmark(db_parameters, user=None, password=None, account=None): - """ - Creates a connection with column metadata cache - """ - if user is not None: - db_parameters["user"] = user - if password is not None: - db_parameters["password"] = password - if account is not None: - db_parameters["account"] = account - - engine = create_engine( - URL( - user=db_parameters["user"], - password=db_parameters["password"], - host=db_parameters["host"], - port=db_parameters["port"], - database=db_parameters["database"], - schema=db_parameters["schema"], - account=db_parameters["account"], - protocol=db_parameters["protocol"], - ) - ) - return engine - - -def test_qmark_bulk_insert(db_parameters, run_v20_sqlalchemy): +def test_qmark_bulk_insert(engine_testaccount_with_qmark): """ Bulk insert using qmark paramstyle """ - if run_v20_sqlalchemy and sys.version_info < (3, 8): + if sys.version_info < (3, 8): pytest.skip( "In Python 3.7, this test depends on pandas features of which the implementation is incompatible with sqlachemy 2.0, and pandas does not support Python 3.7 anymore." ) - import snowflake.connector - - snowflake.connector.paramstyle = "qmark" - - engine = _get_engine_with_qmark(db_parameters) - import pandas as pd - - with engine.connect() as con: - try: - with con.begin(): - con.exec_driver_sql( - """ - create or replace table src(c1 int, c2 string) as select seq8(), - randstr(100, random()) from table(generator(rowcount=>100000)) - """ + with engine_testaccount_with_qmark.connect() as con: + with con.begin(): + con.exec_driver_sql( + """ + create or replace table src(c1 int, c2 string) as select seq8(), + randstr(100, random()) from table(generator(rowcount=>100000)) + """ + ) + con.exec_driver_sql("create or replace table dst like src") + + for data in pd.read_sql_query( + text("select * from src"), con, chunksize=16000 + ): + data.to_sql( + "dst", con, if_exists="append", index=False, index_label=None ) - con.exec_driver_sql("create or replace table dst like src") - - for data in pd.read_sql_query( - text("select * from src"), con, chunksize=16000 - ): - data.to_sql( - "dst", con, if_exists="append", index=False, index_label=None - ) - - finally: - engine.dispose() - snowflake.connector.paramstyle = "pyformat" diff --git a/tests/test_quote.py b/tests/test_quote.py index ca6f36dd..0dd69059 100644 --- a/tests/test_quote.py +++ b/tests/test_quote.py @@ -38,3 +38,26 @@ def test_table_name_with_reserved_words(engine_testaccount, db_parameters): finally: insert_table.drop(engine_testaccount) return insert_table + + +def test_table_column_as_underscore(engine_testaccount): + metadata = MetaData() + test_table_name = "table_1745924" + insert_table = Table( + test_table_name, + metadata, + Column("ca", Integer), + Column("cb", String), + Column("_", String), + ) + metadata.create_all(engine_testaccount) + try: + inspector = inspect(engine_testaccount) + columns_in_insert = inspector.get_columns(test_table_name) + assert len(columns_in_insert) == 3 + assert columns_in_insert[0]["name"] == "ca" + assert columns_in_insert[1]["name"] == "cb" + assert columns_in_insert[2]["name"] == "_" + finally: + insert_table.drop(engine_testaccount) + return insert_table diff --git a/tests/test_quote_identifiers.py b/tests/test_quote_identifiers.py new file mode 100644 index 00000000..58575fad --- /dev/null +++ b/tests/test_quote_identifiers.py @@ -0,0 +1,43 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +import pytest +from sqlalchemy import Column, Integer, MetaData, String, Table, insert, select + + +@pytest.mark.parametrize( + "identifier", + ( + pytest.param("_", id="underscore"), + pytest.param(".", id="dot"), + ), +) +def test_insert_with_identifier_as_column_name(identifier: str, engine_testaccount): + expected_identifier = f"test: {identifier}" + metadata = MetaData() + table = Table( + "table_1745924", + metadata, + Column("ca", Integer), + Column("cb", String), + Column(identifier, String), + ) + + try: + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as connection: + connection.execute( + insert(table).values( + { + "ca": 1, + "cb": "test", + identifier: f"test: {identifier}", + } + ) + ) + result = connection.execute(select(table)).fetchall() + assert result == [(1, "test", expected_identifier)] + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 78658012..32fc390e 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -2,89 +2,162 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from sqlalchemy import Column, Integer, MetaData, Sequence, String, Table, select +from sqlalchemy import ( + Column, + Identity, + Integer, + MetaData, + Sequence, + String, + Table, + insert, + select, +) +from sqlalchemy.sql import text +from sqlalchemy.sql.ddl import CreateTable def test_table_with_sequence(engine_testaccount, db_parameters): + """Snowflake does not guarantee generating sequence numbers without gaps. + + The generated numbers are not necessarily contiguous. + https://docs.snowflake.com/en/user-guide/querying-sequences + """ # https://github.com/snowflakedb/snowflake-sqlalchemy/issues/124 test_table_name = "sequence" test_sequence_name = f"{test_table_name}_id_seq" + metadata = MetaData() + sequence_table = Table( test_table_name, - MetaData(), - Column("id", Integer, Sequence(test_sequence_name), primary_key=True), + metadata, + Column( + "id", Integer, Sequence(test_sequence_name, order=True), primary_key=True + ), Column("data", String(39)), ) - sequence_table.create(engine_testaccount) - seq = Sequence(test_sequence_name) + + autoload_metadata = MetaData() + try: - with engine_testaccount.connect() as conn: - with conn.begin(): - conn.execute(sequence_table.insert(), [{"data": "test_insert_1"}]) - select_stmt = select(sequence_table).order_by("id") - result = conn.execute(select_stmt).fetchall() - assert result == [(1, "test_insert_1")] - autoload_sequence_table = Table( - test_table_name, MetaData(), autoload_with=engine_testaccount - ) - conn.execute( - autoload_sequence_table.insert(), - [{"data": "multi_insert_1"}, {"data": "multi_insert_2"}], - ) - conn.execute( - autoload_sequence_table.insert(), [{"data": "test_insert_2"}] - ) - nextid = conn.execute(seq) - conn.execute( - autoload_sequence_table.insert(), - [{"id": nextid, "data": "test_insert_seq"}], - ) - result = conn.execute(select_stmt).fetchall() - assert result == [ - (1, "test_insert_1"), - (2, "multi_insert_1"), - (3, "multi_insert_2"), - (4, "test_insert_2"), - (5, "test_insert_seq"), - ] + metadata.create_all(engine_testaccount) + + with engine_testaccount.begin() as conn: + conn.execute(insert(sequence_table), ({"data": "test_insert_1"})) + result = conn.execute(select(sequence_table)).fetchall() + assert result == [(1, "test_insert_1")], result + + autoload_sequence_table = Table( + test_table_name, + autoload_metadata, + autoload_with=engine_testaccount, + ) + seq = Sequence(test_sequence_name, order=True) + + conn.execute( + insert(autoload_sequence_table), + ( + {"data": "multi_insert_1"}, + {"data": "multi_insert_2"}, + ), + ) + conn.execute( + insert(autoload_sequence_table), + ({"data": "test_insert_2"},), + ) + + nextid = conn.execute(seq) + conn.execute( + insert(autoload_sequence_table), + ({"id": nextid, "data": "test_insert_seq"}), + ) + + result = conn.execute(select(sequence_table)).fetchall() + + assert result == [ + (1, "test_insert_1"), + (2, "multi_insert_1"), + (3, "multi_insert_2"), + (4, "test_insert_2"), + (5, "test_insert_seq"), + ], result + finally: - sequence_table.drop(engine_testaccount) - seq.drop(engine_testaccount) + metadata.drop_all(engine_testaccount) -def test_table_with_autoincrement(engine_testaccount, db_parameters): +def test_table_with_autoincrement(engine_testaccount): + """Snowflake does not guarantee generating sequence numbers without gaps. + + The generated numbers are not necessarily contiguous. + https://docs.snowflake.com/en/user-guide/querying-sequences + """ # https://github.com/snowflakedb/snowflake-sqlalchemy/issues/124 test_table_name = "sequence" + metadata = MetaData() autoincrement_table = Table( test_table_name, - MetaData(), + metadata, Column("id", Integer, autoincrement=True, primary_key=True), Column("data", String(39)), ) - autoincrement_table.create(engine_testaccount) + + select_stmt = select(autoincrement_table).order_by("id") + try: - with engine_testaccount.connect() as conn: - with conn.begin(): - conn.execute(autoincrement_table.insert(), [{"data": "test_insert_1"}]) - select_stmt = select(autoincrement_table).order_by("id") - result = conn.execute(select_stmt).fetchall() - assert result == [(1, "test_insert_1")] - autoload_sequence_table = Table( - test_table_name, MetaData(), autoload_with=engine_testaccount - ) - conn.execute( - autoload_sequence_table.insert(), - [{"data": "multi_insert_1"}, {"data": "multi_insert_2"}], - ) - conn.execute( - autoload_sequence_table.insert(), [{"data": "test_insert_2"}] - ) - result = conn.execute(select_stmt).fetchall() - assert result == [ - (1, "test_insert_1"), - (2, "multi_insert_1"), - (3, "multi_insert_2"), - (4, "test_insert_2"), - ] + with engine_testaccount.begin() as conn: + conn.execute(text("ALTER SESSION SET NOORDER_SEQUENCE_AS_DEFAULT = FALSE")) + metadata.create_all(conn) + + conn.execute(insert(autoincrement_table), ({"data": "test_insert_1"})) + result = conn.execute(select_stmt).fetchall() + assert result == [(1, "test_insert_1")] + + autoload_sequence_table = Table( + test_table_name, MetaData(), autoload_with=engine_testaccount + ) + conn.execute( + insert(autoload_sequence_table), + [ + {"data": "multi_insert_1"}, + {"data": "multi_insert_2"}, + ], + ) + conn.execute( + insert(autoload_sequence_table), + [{"data": "test_insert_2"}], + ) + result = conn.execute(select_stmt).fetchall() + assert result == [ + (1, "test_insert_1"), + (2, "multi_insert_1"), + (3, "multi_insert_2"), + (4, "test_insert_2"), + ], result + finally: - autoincrement_table.drop(engine_testaccount) + metadata.drop_all(engine_testaccount) + + +def test_table_with_identity(sql_compiler): + test_table_name = "identity" + metadata = MetaData() + identity_autoincrement_table = Table( + test_table_name, + metadata, + Column( + "id", Integer, Identity(start=1, increment=1, order=True), primary_key=True + ), + Column("identity_col_unordered", Integer, Identity(order=False)), + Column("identity_col", Integer, Identity()), + ) + create_table = CreateTable(identity_autoincrement_table) + actual = sql_compiler(create_table) + expected = ( + "CREATE TABLE identity (" + "\tid INTEGER NOT NULL IDENTITY(1,1) ORDER, " + "\tidentity_col_unordered INTEGER NOT NULL IDENTITY NOORDER, " + "\tidentity_col INTEGER NOT NULL IDENTITY, " + "\tPRIMARY KEY (id))" + ) + assert actual == expected diff --git a/tests/test_structured_datatypes.py b/tests/test_structured_datatypes.py new file mode 100644 index 00000000..fb73673b --- /dev/null +++ b/tests/test_structured_datatypes.py @@ -0,0 +1,597 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +import sqlalchemy as sa +from sqlalchemy import ( + Column, + Integer, + MetaData, + Sequence, + Table, + cast, + exc, + inspect, + text, +) +from sqlalchemy.orm import Session, declarative_base +from sqlalchemy.sql import select +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import NUMBER, IcebergTable, SnowflakeTable +from snowflake.sqlalchemy.custom_types import ARRAY, MAP, OBJECT, TEXT +from snowflake.sqlalchemy.exc import StructuredTypeNotSupportedInTableColumnsError + + +@pytest.mark.parametrize( + "structured_type", + [ + MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), TEXT(16777216))), + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + OBJECT(key1=TEXT(16777216), key2=NUMBER(10, 0)), + ARRAY(MAP(NUMBER(10, 0), TEXT(16777216))), + ], +) +def test_compile_table_with_structured_data_type( + sql_compiler, snapshot, structured_type +): + metadata = MetaData() + user_table = Table( + "clustered_user", + metadata, + Column("Id", Integer, primary_key=True), + Column("name", structured_type), + ) + + create_table = CreateTable(user_table) + + assert sql_compiler(create_table) == snapshot + + +def test_compile_table_with_sqlalchemy_array(sql_compiler, snapshot): + metadata = MetaData() + user_table = Table( + "clustered_user", + metadata, + Column("Id", Integer, primary_key=True), + Column("name", sa.ARRAY(sa.String)), + ) + + create_table = CreateTable(user_table) + + assert sql_compiler(create_table) == snapshot + + +@pytest.mark.requires_external_volume +def test_insert_map(engine_testaccount, external_volume, base_location, snapshot): + metadata = MetaData() + table_name = "test_insert_map" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_id", MAP(NUMBER(10, 0), TEXT(16777216))), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + slt = select( + 1, + cast( + text("{'100':'item1', '200':'item2'}"), + MAP(NUMBER(10, 0), TEXT(16777216)), + ), + ) + ins = test_map.insert().from_select(["id", "map_id"], slt) + conn.execute(ins) + + results = conn.execute(test_map.select()) + data = results.fetchmany() + results.close() + snapshot.assert_match(data) + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_insert_map_orm( + sql_compiler, external_volume, base_location, engine_testaccount, snapshot +): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __tablename__ = "test_iceberg_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return IcebergTable(name, metadata, *arg, **kw) + + __table_args__ = { + "external_volume": external_volume, + "base_location": base_location, + } + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + map_id = Column(MAP(NUMBER(10, 0), TEXT(16777216))) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + cast_expr = cast( + text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT(16777216)) + ) + instance = TestIcebergTableOrm(id=0, map_id=cast_expr) + session.add(instance) + with pytest.raises(exc.ProgrammingError) as programming_error: + session.commit() + # TODO: Support variant in insert statement + assert str(programming_error.value.orig) == snapshot + finally: + Base.metadata.drop_all(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_select_map_orm(engine_testaccount, external_volume, base_location, snapshot): + metadata = MetaData() + table_name = "test_select_map_orm" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_id", MAP(NUMBER(10, 0), TEXT(16777216))), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + slt1 = select( + 2, + cast( + text("{'100':'item1', '200':'item2'}"), + MAP(NUMBER(10, 0), TEXT(16777216)), + ), + ) + slt2 = select( + 1, + cast( + text("{'100':'item1', '200':'item2'}"), + MAP(NUMBER(10, 0), TEXT(16777216)), + ), + ).union_all(slt1) + ins = test_map.insert().from_select(["id", "map_id"], slt2) + conn.execute(ins) + conn.commit() + + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __table__ = test_map + + def __repr__(self): + return f"({self.id!r}, {self.map_id!r})" + + try: + data = session.query(TestIcebergTableOrm).all() + snapshot.assert_match(data) + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_select_array_orm(engine_testaccount, external_volume, base_location, snapshot): + metadata = MetaData() + table_name = "test_select_array_orm" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("array_col", ARRAY(TEXT(16777216))), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + slt1 = select( + 2, + cast( + text("['item1','item2']"), + ARRAY(TEXT(16777216)), + ), + ) + slt2 = select( + 1, + cast( + text("['item3','item4']"), + ARRAY(TEXT(16777216)), + ), + ).union_all(slt1) + ins = test_map.insert().from_select(["id", "array_col"], slt2) + conn.execute(ins) + conn.commit() + + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __table__ = test_map + + def __repr__(self): + return f"({self.id!r}, {self.array_col!r})" + + try: + data = session.query(TestIcebergTableOrm).all() + snapshot.assert_match(data) + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_insert_array(engine_testaccount, external_volume, base_location, snapshot): + metadata = MetaData() + table_name = "test_insert_map" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("array_col", ARRAY(TEXT(16777216))), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + slt = select( + 1, + cast( + text("['item1','item2']"), + ARRAY(TEXT(16777216)), + ), + ) + ins = test_map.insert().from_select(["id", "array_col"], slt) + conn.execute(ins) + + results = conn.execute(test_map.select()) + data = results.fetchmany() + results.close() + snapshot.assert_match(data) + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_insert_array_orm( + sql_compiler, external_volume, base_location, engine_testaccount, snapshot +): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __tablename__ = "test_iceberg_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return IcebergTable(name, metadata, *arg, **kw) + + __table_args__ = { + "external_volume": external_volume, + "base_location": base_location, + } + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + array_col = Column(ARRAY(TEXT(16777216))) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + cast_expr = cast(text("['item1','item2']"), ARRAY(TEXT(16777216))) + instance = TestIcebergTableOrm(id=0, array_col=cast_expr) + session.add(instance) + with pytest.raises(exc.ProgrammingError) as programming_error: + session.commit() + # TODO: Support variant in insert statement + assert str(programming_error.value.orig) == snapshot + finally: + Base.metadata.drop_all(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_insert_structured_object( + engine_testaccount, external_volume, base_location, snapshot +): + metadata = MetaData() + table_name = "test_insert_structured_object" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column( + "object_col", + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + ), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + slt = select( + 1, + cast( + text("{'key1':'item1', 'key2': 15}"), + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + ), + ) + ins = test_map.insert().from_select(["id", "object_col"], slt) + conn.execute(ins) + + results = conn.execute(test_map.select()) + data = results.fetchmany() + results.close() + snapshot.assert_match(data) + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_insert_structured_object_orm( + sql_compiler, external_volume, base_location, engine_testaccount, snapshot +): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __tablename__ = "test_iceberg_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return IcebergTable(name, metadata, *arg, **kw) + + __table_args__ = { + "external_volume": external_volume, + "base_location": base_location, + } + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + object_col = Column( + OBJECT(key1=(NUMBER(10, 0), False), key2=(TEXT(16777216), False)) + ) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + cast_expr = cast( + text("{ 'key1' : 1, 'key2' : 'item1' }"), + OBJECT(key1=(NUMBER(10, 0), False), key2=(TEXT(16777216), False)), + ) + instance = TestIcebergTableOrm(id=0, object_col=cast_expr) + session.add(instance) + with pytest.raises(exc.ProgrammingError) as programming_error: + session.commit() + # TODO: Support variant in insert statement + assert str(programming_error.value.orig) == snapshot + finally: + Base.metadata.drop_all(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_select_structured_object_orm( + engine_testaccount, external_volume, base_location, snapshot +): + metadata = MetaData() + table_name = "test_select_structured_object_orm" + iceberg_table = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column( + "structured_obj_col", + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + ), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + first_select = select( + 2, + cast( + text("{'key1': 'value1', 'key2': 1}"), + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + ), + ) + second_select = select( + 1, + cast( + text("{'key1': 'value2', 'key2': 2}"), + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + ), + ).union_all(first_select) + insert_statement = iceberg_table.insert().from_select( + ["id", "structured_obj_col"], second_select + ) + conn.execute(insert_statement) + conn.commit() + + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __table__ = iceberg_table + + def __repr__(self): + return f"({self.id!r}, {self.structured_obj_col!r})" + + try: + data = session.query(TestIcebergTableOrm).all() + snapshot.assert_match(data) + finally: + iceberg_table.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +@pytest.mark.parametrize( + "structured_type, expected_type", + [ + (MAP(NUMBER(10, 0), TEXT(16777216)), MAP), + (MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), TEXT(16777216))), MAP), + ( + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + OBJECT, + ), + (ARRAY(TEXT(16777216)), ARRAY), + ], +) +def test_inspect_structured_data_types( + engine_testaccount, + external_volume, + base_location, + snapshot, + structured_type, + expected_type, +): + metadata = MetaData() + table_name = "test_st_types" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("structured_type_col", structured_type), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + try: + inspecter = inspect(engine_testaccount) + columns = inspecter.get_columns(table_name) + + assert isinstance(columns[0]["type"], NUMBER) + assert isinstance(columns[1]["type"], expected_type) + assert columns == snapshot + + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +@pytest.mark.parametrize( + "structured_type", + [ + "MAP(NUMBER(10, 0), VARCHAR)", + "MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))", + "OBJECT(key1 VARCHAR, key2 NUMBER(10, 0))", + "ARRAY(MAP(NUMBER(10, 0), VARCHAR))", + ], +) +def test_reflect_structured_data_types( + engine_testaccount, + external_volume, + base_location, + snapshot, + structured_type, + sql_compiler, +): + metadata = MetaData() + table_name = "test_reflect_st_types" + create_table_sql = f""" +CREATE OR REPLACE ICEBERG TABLE {table_name} ( + id number(38,0) primary key, + structured_type_col {structured_type}) +CATALOG = 'SNOWFLAKE' +EXTERNAL_VOLUME = '{external_volume}' +BASE_LOCATION = '{base_location}'; + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + iceberg_table = IcebergTable(table_name, metadata, autoload_with=engine_testaccount) + constraint = iceberg_table.constraints.pop() + constraint.name = "constraint_name" + iceberg_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(iceberg_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_create_table_structured_datatypes( + engine_testaccount, external_volume, base_location +): + metadata = MetaData() + table_name = "test_structured0" + test_structured_dt = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_id", MAP(NUMBER(10, 0), TEXT(16777216))), + Column( + "object_col", + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + ), + Column( + "array_col", + ARRAY(TEXT(16777216)), + ), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + try: + assert test_structured_dt is not None + finally: + test_structured_dt.drop(engine_testaccount) + + +@pytest.mark.parametrize( + "structured_type_col", + [ + Column("name", MAP(NUMBER(10, 0), TEXT(16777216))), + Column( + "object_col", + OBJECT(key1=(TEXT(16777216), False), key2=(NUMBER(10, 0), False)), + ), + Column("name", ARRAY(TEXT(16777216))), + ], +) +def test_structured_type_not_supported_in_table_columns_error( + sql_compiler, structured_type_col +): + metadata = MetaData() + with pytest.raises( + StructuredTypeNotSupportedInTableColumnsError + ) as programming_error: + SnowflakeTable( + "clustered_user", + metadata, + Column("Id", Integer, primary_key=True), + structured_type_col, + ) + assert programming_error is not None diff --git a/tests/test_transactions.py b/tests/test_transactions.py new file mode 100644 index 00000000..c163c2b7 --- /dev/null +++ b/tests/test_transactions.py @@ -0,0 +1,157 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from sqlalchemy import Column, Integer, MetaData, String, select, text + +from snowflake.sqlalchemy import SnowflakeTable + +CURRENT_TRANSACTION = text("SELECT CURRENT_TRANSACTION()") + + +def test_connect_read_commited(engine_testaccount, assert_text_in_buf): + metadata = MetaData() + table_name = "test_connect_read_commited" + + test_table_1 = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by=["id", text("id > 5")], + ) + + metadata.create_all(engine_testaccount) + try: + with engine_testaccount.connect().execution_options( + isolation_level="READ COMMITTED" + ) as connection: + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] == (None,), result + ins = test_table_1.insert().values(id=1, name="test") + connection.execute(ins) + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] != ( + None, + ), "AUTOCOMMIT DISABLED, transaction should be started" + + with engine_testaccount.connect() as conn: + s = select(test_table_1) + results = conn.execute(s).fetchall() + assert len(results) == 0, results # No insert commited + assert_text_in_buf("ROLLBACK", occurrences=1) + finally: + metadata.drop_all(engine_testaccount) + + +def test_begin_read_commited(engine_testaccount, assert_text_in_buf): + metadata = MetaData() + table_name = "test_begin_read_commited" + + test_table_1 = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by=["id", text("id > 5")], + ) + + metadata.create_all(engine_testaccount) + try: + with engine_testaccount.connect().execution_options( + isolation_level="READ COMMITTED" + ) as connection, connection.begin(): + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] == (None,), result + ins = test_table_1.insert().values(id=1, name="test") + connection.execute(ins) + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] != ( + None, + ), "AUTOCOMMIT DISABLED, transaction should be started" + + with engine_testaccount.connect() as conn: + s = select(test_table_1) + results = conn.execute(s).fetchall() + assert len(results) == 1, results # Insert commited + assert_text_in_buf("COMMIT", occurrences=2) + finally: + metadata.drop_all(engine_testaccount) + + +def test_connect_autocommit(engine_testaccount, assert_text_in_buf): + metadata = MetaData() + table_name = "test_connect_autocommit" + + test_table_1 = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by=["id", text("id > 5")], + ) + + metadata.create_all(engine_testaccount) + try: + with engine_testaccount.connect().execution_options( + isolation_level="AUTOCOMMIT" + ) as connection: + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] == (None,), result + ins = test_table_1.insert().values(id=1, name="test") + connection.execute(ins) + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] == ( + None, + ), "Autocommit enabled, transaction should not be started" + + with engine_testaccount.connect() as conn: + s = select(test_table_1) + results = conn.execute(s).fetchall() + assert len(results) == 1, results + assert_text_in_buf( + "ROLLBACK using DBAPI connection.rollback(), DBAPI should ignore due to autocommit mode", + occurrences=1, + ) + + finally: + metadata.drop_all(engine_testaccount) + + +def test_begin_autocommit(engine_testaccount, assert_text_in_buf): + metadata = MetaData() + table_name = "test_begin_autocommit" + + test_table_1 = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by=["id", text("id > 5")], + ) + + metadata.create_all(engine_testaccount) + try: + with engine_testaccount.connect().execution_options( + isolation_level="AUTOCOMMIT" + ) as connection, connection.begin(): + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] == (None,), result + ins = test_table_1.insert().values(id=1, name="test") + connection.execute(ins) + result = connection.execute(CURRENT_TRANSACTION).fetchall() + assert result[0] == ( + None, + ), "Autocommit enabled, transaction should not be started" + + with engine_testaccount.connect() as conn: + s = select(test_table_1) + results = conn.execute(s).fetchall() + assert len(results) == 1, results + assert_text_in_buf( + "COMMIT using DBAPI connection.commit(), DBAPI should ignore due to autocommit mode", + occurrences=1, + ) + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/test_unit_structured_types.py b/tests/test_unit_structured_types.py new file mode 100644 index 00000000..472ce2e6 --- /dev/null +++ b/tests/test_unit_structured_types.py @@ -0,0 +1,81 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest + +from snowflake.sqlalchemy import NUMBER +from snowflake.sqlalchemy.custom_types import MAP, TEXT +from src.snowflake.sqlalchemy.parser.custom_type_parser import ( + parse_type, + tokenize_parameters, +) + + +def test_compile_map_with_not_null(snapshot): + user_table = MAP(NUMBER(10, 0), TEXT(), not_null=True) + assert user_table.compile() == snapshot + + +def test_extract_parameters(): + example = "a, b(c, d, f), d" + assert tokenize_parameters(example) == ["a", "b(c, d, f)", "d"] + + +@pytest.mark.parametrize( + "input_type, expected_type", + [ + ("BIGINT", "BIGINT"), + ("BINARY(16)", "BINARY(16)"), + ("BOOLEAN", "BOOLEAN"), + ("CHAR(5)", "CHAR(5)"), + ("CHARACTER(5)", "CHAR(5)"), + ("DATE", "DATE"), + ("DATETIME(3)", "DATETIME"), + ("DECIMAL(10, 2)", "DECIMAL(10, 2)"), + ("DEC(10, 2)", "DECIMAL(10, 2)"), + ("DOUBLE", "FLOAT"), + ("FLOAT", "FLOAT"), + ("FIXED(10, 2)", "DECIMAL(10, 2)"), + ("INT", "INTEGER"), + ("INTEGER", "INTEGER"), + ("NUMBER(12, 4)", "DECIMAL(12, 4)"), + ("REAL", "REAL"), + ("BYTEINT", "SMALLINT"), + ("SMALLINT", "SMALLINT"), + ("STRING(255)", "VARCHAR(255)"), + ("TEXT(255)", "VARCHAR(255)"), + ("VARCHAR(255)", "VARCHAR(255)"), + ("TIME(6)", "TIME"), + ("TIMESTAMP(3)", "TIMESTAMP"), + ("TIMESTAMP_TZ(3)", "TIMESTAMP_TZ"), + ("TIMESTAMP_LTZ(3)", "TIMESTAMP_LTZ"), + ("TIMESTAMP_NTZ(3)", "TIMESTAMP_NTZ"), + ("TINYINT", "SMALLINT"), + ("VARBINARY(16)", "BINARY(16)"), + ("VARCHAR(255)", "VARCHAR(255)"), + ("VARIANT", "VARIANT"), + ( + "MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR NOT NULL))", + "MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR NOT NULL))", + ), + ( + "MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR))", + "MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR))", + ), + ("MAP(DECIMAL(10, 0), VARIANT)", "MAP(DECIMAL(10, 0), VARIANT)"), + ("OBJECT", "OBJECT"), + ( + "OBJECT(a DECIMAL(10, 0) NOT NULL, b DECIMAL(10, 0), c VARCHAR NOT NULL)", + "OBJECT(a DECIMAL(10, 0) NOT NULL, b DECIMAL(10, 0), c VARCHAR NOT NULL)", + ), + ("ARRAY", "ARRAY"), + ( + "ARRAY(MAP(DECIMAL(10, 0), VARCHAR NOT NULL))", + "ARRAY(MAP(DECIMAL(10, 0), VARCHAR NOT NULL))", + ), + ("GEOGRAPHY", "GEOGRAPHY"), + ("GEOMETRY", "GEOMETRY"), + ], +) +def test_snowflake_data_types(input_type, expected_type): + assert parse_type(input_type).compile() == expected_type diff --git a/tests/util.py b/tests/util.py index f53333c4..264478ff 100644 --- a/tests/util.py +++ b/tests/util.py @@ -28,6 +28,8 @@ from snowflake.sqlalchemy.custom_types import ( ARRAY, GEOGRAPHY, + GEOMETRY, + MAP, OBJECT, TIMESTAMP_LTZ, TIMESTAMP_NTZ, @@ -70,6 +72,8 @@ "OBJECT": OBJECT, "ARRAY": ARRAY, "GEOGRAPHY": GEOGRAPHY, + "GEOMETRY": GEOMETRY, + "MAP": MAP, } diff --git a/tox.ini b/tox.ini index 99891d22..102e2273 100644 --- a/tox.ini +++ b/tox.ini @@ -3,7 +3,6 @@ min_version = 4.0.0 envlist = fix_lint, py{37,38,39,310,311}{,-pandas}, coverage, - connector_regression skip_missing_interpreters = true [testenv] @@ -35,29 +34,17 @@ passenv = setenv = COVERAGE_FILE = {env:COVERAGE_FILE:{toxworkdir}/.coverage.{envname}} SQLALCHEMY_WARN_20 = 1 - ci: SNOWFLAKE_PYTEST_OPTS = -vvv + ci: SNOWFLAKE_PYTEST_OPTS = -vvv --tb=long commands = pytest \ {env:SNOWFLAKE_PYTEST_OPTS:} \ --cov "snowflake.sqlalchemy" \ --junitxml {toxworkdir}/junit_{envname}.xml \ + --ignore=tests/sqlalchemy_test_suite \ {posargs:tests} pytest {env:SNOWFLAKE_PYTEST_OPTS:} \ --cov "snowflake.sqlalchemy" --cov-append \ --junitxml {toxworkdir}/junit_{envname}.xml \ {posargs:tests/sqlalchemy_test_suite} - pytest \ - {env:SNOWFLAKE_PYTEST_OPTS:} \ - --cov "snowflake.sqlalchemy" --cov-append \ - --junitxml {toxworkdir}/junit_{envname}.xml \ - --run_v20_sqlalchemy \ - {posargs:tests} - -[testenv:connector_regression] -deps = pendulum -commands = pytest \ - {env:SNOWFLAKE_PYTEST_OPTS:} \ - -m "not gcp and not azure" \ - {posargs:tests/connector_regression/test} [testenv:.pkg_external] deps = build @@ -88,13 +75,14 @@ passenv = PROGRAMDATA deps = {[testenv]deps} + tomlkit pre-commit >= 2.9.0 skip_install = True commands = pre-commit run --all-files python -c 'import pathlib; print("hint: run \{\} install to add checks as pre-commit hook".format(pathlib.Path(r"{envdir}") / "bin" / "pre-commit"))' [pytest] -addopts = -ra --strict-markers --ignore=tests/sqlalchemy_test_suite --ignore=tests/connector_regression +addopts = -ra --ignore=tests/sqlalchemy_test_suite junit_family = legacy log_level = info markers =