diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index 0c98405d20..c9d3f1c125 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -10,10 +10,7 @@ on: - v* pull_request: branches: - - master - - main - - prep-** - - dev/aio-connector + - '**' workflow_dispatch: inputs: logLevel: @@ -24,7 +21,7 @@ on: description: "Test scenario tags" concurrency: - # older builds for the same pull request numer or branch should be cancelled + # older builds for the same pull request number or branch should be cancelled cancel-in-progress: true group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 54d3b33807..3f8686eea4 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -15,9 +15,11 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne - Added handling of PAT provided in `password` field. - Improved error message for client-side query cancellations due to timeouts. - Added support of GCS regional endpoints. - - Added `gcs_use_virtual_endpoints` connection property that forces the usage of the virtual GCS usage. See more: https://cloud.google.com/storage/docs/request-endpoints#xml-api + - Added `gcs_use_virtual_endpoints` connection property that forces the usage of the virtual GCS usage. Thanks to this it should be possible to set up private DNS entry for the GCS endpoint. See more: https://cloud.google.com/storage/docs/request-endpoints#xml-api - Fixed a bug that caused driver to fail silently on `TO_DATE` arrow to python conversion when invalid date was followed by the correct one. - Added `check_arrow_conversion_error_on_every_column` connection property that can be set to `False` to restore previous behaviour in which driver will ignore errors until it occurs in the last column. This flag's purpose is to unblock workflows that may be impacted by the bugfix and will be removed in later releases. + - Lower log levels from info to debug for some of the messages to make the output easier to follow. + - Allow the connector to inherit a UUID4 generated upstream, provided in statement parameters (field: `requestId`), rather than automatically generate a UUID4 to use for the HTTP Request ID. - v3.14.0(March 03, 2025) - Bumped pyOpenSSL dependency upper boundary from <25.0.0 to <26.0.0. diff --git a/Jenkinsfile b/Jenkinsfile index bc16773aa4..699a514970 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -38,10 +38,10 @@ timestamps { stage('Test') { try { def commit_hash = "main" // default which we want to override - def bptp_tag = "bptp-built" + def bptp_tag = "bptp-stable" def response = authenticatedGithubCall("https://api.github.com/repos/snowflakedb/snowflake/git/ref/tags/${bptp_tag}") commit_hash = response.object.sha - // Append the bptp-built commit sha to params + // Append the bptp-stable commit sha to params params += [string(name: 'svn_revision', value: commit_hash)] } catch(Exception e) { println("Exception computing commit hash from: ${response}") diff --git a/benchmark/benchmark_unit_converter.py b/benchmark/benchmark_unit_converter.py index 74895c4c16..fdc199e344 100644 --- a/benchmark/benchmark_unit_converter.py +++ b/benchmark/benchmark_unit_converter.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. -# + from __future__ import annotations from logging import getLogger diff --git a/samples/auth_by_key_pair_from_file.py b/samples/auth_by_key_pair_from_file.py index fa5d830e05..5a33240b7f 100644 --- a/samples/auth_by_key_pair_from_file.py +++ b/samples/auth_by_key_pair_from_file.py @@ -1,7 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# """ This sample shows how to implement a key pair authentication plugin which reads private key from a file diff --git a/setup.cfg b/setup.cfg index dba3420ed4..68d731c138 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,8 @@ python_requires = >=3.9 packages = find_namespace: install_requires = asn1crypto>0.24.0,<2.0.0 + boto3>=1.0 + botocore>=1.0 cffi>=1.9,<2.0.0 cryptography>=3.1.0 pyOpenSSL>=22.0.0,<25.0.0 @@ -98,3 +100,4 @@ secure-local-storage = keyring>=23.1.0,<26.0.0 aio = aiohttp + aioboto3 diff --git a/setup.py b/setup.py index fb54c20046..5a9e364e27 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. -# import os import sys diff --git a/src/snowflake/connector/__init__.py b/src/snowflake/connector/__init__.py index 706757921a..41b5288ac7 100644 --- a/src/snowflake/connector/__init__.py +++ b/src/snowflake/connector/__init__.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - # Python Db API v2 # from __future__ import annotations @@ -16,6 +12,8 @@ import logging from logging import NullHandler +from snowflake.connector.externals_utils.externals_setup import setup_external_libraries + from .connection import SnowflakeConnection from .cursor import DictCursor from .dbapi import ( @@ -48,6 +46,7 @@ from .version import VERSION logging.getLogger(__name__).addHandler(NullHandler()) +setup_external_libraries() @wraps(SnowflakeConnection.__init__) diff --git a/src/snowflake/connector/_query_context_cache.py b/src/snowflake/connector/_query_context_cache.py index 26d35b48f2..43688e2a24 100644 --- a/src/snowflake/connector/_query_context_cache.py +++ b/src/snowflake/connector/_query_context_cache.py @@ -1,6 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# from __future__ import annotations from functools import total_ordering diff --git a/src/snowflake/connector/_sql_util.py b/src/snowflake/connector/_sql_util.py index e5584c1ded..d2ae2d5631 100644 --- a/src/snowflake/connector/_sql_util.py +++ b/src/snowflake/connector/_sql_util.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import re diff --git a/src/snowflake/connector/_utils.py b/src/snowflake/connector/_utils.py index 85ea830739..e22881f103 100644 --- a/src/snowflake/connector/_utils.py +++ b/src/snowflake/connector/_utils.py @@ -1,13 +1,10 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import string from enum import Enum from random import choice from threading import Timer +from uuid import UUID class TempObjectType(Enum): @@ -33,6 +30,8 @@ class TempObjectType(Enum): "PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS" ) +REQUEST_ID_STATEMENT_PARAM_NAME = "requestId" + def generate_random_alphanumeric(length: int = 10) -> str: return "".join(choice(ALPHANUMERIC) for _ in range(length)) @@ -46,6 +45,21 @@ def get_temp_type_for_object(use_scoped_temp_objects: bool) -> str: return SCOPED_TEMPORARY_STRING if use_scoped_temp_objects else TEMPORARY_STRING +def is_uuid4(str_or_uuid: str | UUID) -> bool: + """Check whether provided string str is a valid UUID version4.""" + if isinstance(str_or_uuid, UUID): + return str_or_uuid.version == 4 + + if not isinstance(str_or_uuid, str): + return False + + try: + uuid_str = str(UUID(str_or_uuid, version=4)) + except ValueError: + return False + return uuid_str == str_or_uuid + + class _TrackedQueryCancellationTimer(Timer): def __init__(self, interval, function, args=None, kwargs=None): super().__init__(interval, function, args, kwargs) diff --git a/src/snowflake/connector/aio/__init__.py b/src/snowflake/connector/aio/__init__.py index 628bc2abf1..0b0410ebaa 100644 --- a/src/snowflake/connector/aio/__init__.py +++ b/src/snowflake/connector/aio/__init__.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from ._connection import SnowflakeConnection diff --git a/src/snowflake/connector/aio/_azure_storage_client.py b/src/snowflake/connector/aio/_azure_storage_client.py index fa255d1c7a..c1c88a58a0 100644 --- a/src/snowflake/connector/aio/_azure_storage_client.py +++ b/src/snowflake/connector/aio/_azure_storage_client.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import base64 @@ -15,7 +11,6 @@ import aiohttp -from ..azure_storage_client import AzureCredentialFilter from ..azure_storage_client import ( SnowflakeAzureRestClient as SnowflakeAzureRestClientSync, ) @@ -37,8 +32,6 @@ logger = getLogger(__name__) -getLogger("aiohttp").addFilter(AzureCredentialFilter()) - class SnowflakeAzureRestClient( SnowflakeStorageClientAsync, SnowflakeAzureRestClientSync @@ -49,7 +42,6 @@ def __init__( credentials: StorageCredential | None, chunk_size: int, stage_info: dict[str, Any], - use_s3_regional_url: bool = False, unsafe_file_write: bool = False, ) -> None: SnowflakeAzureRestClientSync.__init__( diff --git a/src/snowflake/connector/aio/_build_upload_agent.py b/src/snowflake/connector/aio/_build_upload_agent.py index f6f44511dc..d68d053234 100644 --- a/src/snowflake/connector/aio/_build_upload_agent.py +++ b/src/snowflake/connector/aio/_build_upload_agent.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# + from __future__ import annotations diff --git a/src/snowflake/connector/aio/_connection.py b/src/snowflake/connector/aio/_connection.py index de813d1b5c..c7a2add13d 100644 --- a/src/snowflake/connector/aio/_connection.py +++ b/src/snowflake/connector/aio/_connection.py @@ -1,6 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# from __future__ import annotations import asyncio @@ -35,6 +32,7 @@ from ..connection import _get_private_bytes_from_file from ..constants import ( _CONNECTIVITY_ERR_MSG, + ENV_VAR_EXPERIMENTAL_AUTHENTICATION, ENV_VAR_PARTNER, PARAMETER_AUTOCOMMIT, PARAMETER_CLIENT_PREFETCH_THREADS, @@ -55,6 +53,7 @@ ER_CONNECTION_IS_CLOSED, ER_FAILED_TO_CONNECT_TO_DB, ER_INVALID_VALUE, + ER_INVALID_WIF_SETTINGS, ) from ..network import ( DEFAULT_AUTHENTICATOR, @@ -64,14 +63,17 @@ PROGRAMMATIC_ACCESS_TOKEN, REQUEST_ID, USR_PWD_MFA_AUTHENTICATOR, + WORKLOAD_IDENTITY_AUTHENTICATOR, ReauthenticationRequest, ) from ..sqlstate import SQLSTATE_CONNECTION_NOT_EXISTS, SQLSTATE_FEATURE_NOT_SUPPORTED from ..telemetry import TelemetryData, TelemetryField from ..time_util import get_time_millis from ..util_text import split_statements +from ..wif_util import AttestationProvider from ._cursor import SnowflakeCursor from ._description import CLIENT_NAME +from ._direct_file_operation_utils import FileOperationParser, StreamDownloader from ._network import SnowflakeRestful from ._telemetry import TelemetryClient from ._time_util import HeartBeatTimer @@ -87,6 +89,7 @@ AuthByPlugin, AuthByUsrPwdMfa, AuthByWebBrowser, + AuthByWorkloadIdentity, ) logger = getLogger(__name__) @@ -116,6 +119,10 @@ def __init__( # check SNOW-1218851 for long term improvement plan to refactor ocsp code atexit.register(self._close_at_exit) + # Set up the file operation parser and stream downloader. + self._file_operation_parser = FileOperationParser(self) + self._stream_downloader = StreamDownloader(self) + def __enter__(self): # async connection does not support sync context manager raise TypeError( @@ -301,8 +308,6 @@ async def __open_connection(self): backoff_generator=self._backoff_generator, ) elif self._authenticator == PROGRAMMATIC_ACCESS_TOKEN: - if not self._token and self._password: - self._token = self._password self.auth_class = AuthByPAT(self._token) elif self._authenticator == USR_PWD_MFA_AUTHENTICATOR: self._session_parameters[PARAMETER_CLIENT_REQUEST_MFA_TOKEN] = ( @@ -320,6 +325,29 @@ async def __open_connection(self): timeout=self.login_timeout, backoff_generator=self._backoff_generator, ) + elif self._authenticator == WORKLOAD_IDENTITY_AUTHENTICATOR: + if ENV_VAR_EXPERIMENTAL_AUTHENTICATION not in os.environ: + Error.errorhandler_wrapper( + self, + None, + ProgrammingError, + { + "msg": f"Please set the '{ENV_VAR_EXPERIMENTAL_AUTHENTICATION}' environment variable to use the '{WORKLOAD_IDENTITY_AUTHENTICATOR}' authenticator.", + "errno": ER_INVALID_WIF_SETTINGS, + }, + ) + # Standardize the provider enum. + if self._workload_identity_provider and isinstance( + self._workload_identity_provider, str + ): + self._workload_identity_provider = AttestationProvider.from_string( + self._workload_identity_provider + ) + self.auth_class = AuthByWorkloadIdentity( + provider=self._workload_identity_provider, + token=self._token, + entra_resource=self._workload_identity_entra_resource, + ) else: # okta URL, e.g., https://.okta.com/ self.auth_class = AuthByOkta( @@ -767,7 +795,7 @@ async def close(self, retry: bool = True) -> None: await self._cancel_heartbeat() # close telemetry first, since it needs rest to send remaining data - logger.info("closed") + logger.debug("closed") await self._telemetry.close( send_on_close=bool(retry and self.telemetry_enabled) @@ -776,7 +804,7 @@ async def close(self, retry: bool = True) -> None: await self._all_async_queries_finished() and not self._server_session_keep_alive ): - logger.info("No async queries seem to be running, deleting session") + logger.debug("No async queries seem to be running, deleting session") try: await self.rest.delete_session(retry=retry) except Exception as e: @@ -784,7 +812,7 @@ async def close(self, retry: bool = True) -> None: "Exception encountered in deleting session. ignoring...: %s", e ) else: - logger.info( + logger.debug( "There are {} async queries still running, not deleting session".format( len(self._async_sfqids) ) diff --git a/src/snowflake/connector/aio/_cursor.py b/src/snowflake/connector/aio/_cursor.py index 1a45b9231d..39a9f34791 100644 --- a/src/snowflake/connector/aio/_cursor.py +++ b/src/snowflake/connector/aio/_cursor.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import asyncio @@ -34,6 +30,8 @@ ) from snowflake.connector.aio._result_set import ResultSet, ResultSetIterator from snowflake.connector.constants import ( + CMD_TYPE_DOWNLOAD, + CMD_TYPE_UPLOAD, PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT, QueryStatus, ) @@ -58,6 +56,8 @@ from snowflake.connector.telemetry import TelemetryData, TelemetryField from snowflake.connector.time_util import get_time_millis +from .._utils import REQUEST_ID_STATEMENT_PARAM_NAME, is_uuid4 + if TYPE_CHECKING: from pandas import DataFrame from pyarrow import Table @@ -204,7 +204,27 @@ async def _execute_helper( ) self._sequence_counter = await self._connection._next_sequence_counter() - self._request_id = uuid.uuid4() + + # If requestId is contained in statement parameters, use it to set request id. Verify here it is a valid uuid4 + # identifier. + if ( + statement_params is not None + and REQUEST_ID_STATEMENT_PARAM_NAME in statement_params + ): + request_id = statement_params[REQUEST_ID_STATEMENT_PARAM_NAME] + + if not is_uuid4(request_id): + # uuid.UUID will throw an error if invalid, but we explicitly check and throw here. + raise ValueError(f"requestId {request_id} is not a valid UUID4.") + self._request_id = uuid.UUID(str(request_id), version=4) + + # Create a (deep copy) and remove the statement param, there is no need to encode it as extra parameter + # one more time. + statement_params = statement_params.copy() + statement_params.pop(REQUEST_ID_STATEMENT_PARAM_NAME) + else: + # Generate UUID for query. + self._request_id = uuid.uuid4() logger.debug(f"Request id: {self._request_id}") @@ -663,6 +683,7 @@ async def execute( multipart_threshold=data.get("threshold"), use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1, unsafe_file_write=self._connection.unsafe_file_write, + gcs_use_virtual_endpoints=self._connection.gcs_use_virtual_endpoints, ) await sf_file_transfer_agent.execute() data = sf_file_transfer_agent.result() @@ -685,7 +706,15 @@ async def execute( logger.debug(ret) err = ret["message"] code = ret.get("code", -1) - if self._timebomb and self._timebomb.result(): + if ( + self._timebomb + and self._timebomb.result() + and "SQL execution canceled" in err + ): + # Modify the error message only if the server error response indicates the query was canceled. + # If the error occurs before the cancellation request reaches the backend + # (e.g., due to a very short timeout), we retain the original error message + # as the query might have encountered an issue prior to cancellation. err = ( f"SQL execution was cancelled by the client due to a timeout. " f"Error message received from the server: {err}" @@ -1034,6 +1063,155 @@ async def get_result_batches(self) -> list[ResultBatch] | None: ) return self._result_set.batches + async def _download( + self, + stage_location: str, + target_directory: str, + options: dict[str, Any], + _do_reset: bool = True, + ) -> None: + """Downloads from the stage location to the target directory. + + Args: + stage_location (str): The location of the stage to download from. + target_directory (str): The destination directory to download into. + options (dict[str, Any]): The download options. + _do_reset (bool, optional): Whether to reset the cursor before + downloading, by default we will reset the cursor. + """ + from ._file_transfer_agent import SnowflakeFileTransferAgent + + if _do_reset: + self.reset() + + # Interpret the file operation. + ret = self.connection._file_operation_parser.parse_file_operation( + stage_location=stage_location, + local_file_name=None, + target_directory=target_directory, + command_type=CMD_TYPE_DOWNLOAD, + options=options, + ) + + # Execute the file operation based on the interpretation above. + file_transfer_agent = SnowflakeFileTransferAgent( + self, + "", # empty command because it is triggered by directly calling this util not by a SQL query + ret, + ) + await file_transfer_agent.execute() + await self._init_result_and_meta(file_transfer_agent.result()) + + async def _upload( + self, + local_file_name: str, + stage_location: str, + options: dict[str, Any], + _do_reset: bool = True, + ) -> None: + """Uploads the local file to the stage location. + + Args: + local_file_name (str): The local file to be uploaded. + stage_location (str): The stage location to upload the local file to. + options (dict[str, Any]): The upload options. + _do_reset (bool, optional): Whether to reset the cursor before + uploading, by default we will reset the cursor. + """ + from ._file_transfer_agent import SnowflakeFileTransferAgent + + if _do_reset: + self.reset() + + # Interpret the file operation. + ret = self.connection._file_operation_parser.parse_file_operation( + stage_location=stage_location, + local_file_name=local_file_name, + target_directory=None, + command_type=CMD_TYPE_UPLOAD, + options=options, + ) + + # Execute the file operation based on the interpretation above. + file_transfer_agent = SnowflakeFileTransferAgent( + self, + "", # empty command because it is triggered by directly calling this util not by a SQL query + ret, + force_put_overwrite=False, # _upload should respect user decision on overwriting + ) + await file_transfer_agent.execute() + await self._init_result_and_meta(file_transfer_agent.result()) + + async def _download_stream( + self, stage_location: str, decompress: bool = False + ) -> IO[bytes]: + """Downloads from the stage location as a stream. + + Args: + stage_location (str): The location of the stage to download from. + decompress (bool, optional): Whether to decompress the file, by + default we do not decompress. + + Returns: + IO[bytes]: A stream to read from. + """ + # Interpret the file operation. + ret = self.connection._file_operation_parser.parse_file_operation( + stage_location=stage_location, + local_file_name=None, + target_directory=None, + command_type=CMD_TYPE_DOWNLOAD, + options=None, + has_source_from_stream=True, + ) + + # Set up stream downloading based on the interpretation and return the stream for reading. + return await self.connection._stream_downloader.download_as_stream( + ret, decompress + ) + + async def _upload_stream( + self, + input_stream: IO[bytes], + stage_location: str, + options: dict[str, Any], + _do_reset: bool = True, + ) -> None: + """Uploads content in the input stream to the stage location. + + Args: + input_stream (IO[bytes]): A stream to read from. + stage_location (str): The location of the stage to upload to. + options (dict[str, Any]): The upload options. + _do_reset (bool, optional): Whether to reset the cursor before + uploading, by default we will reset the cursor. + """ + from ._file_transfer_agent import SnowflakeFileTransferAgent + + if _do_reset: + self.reset() + + # Interpret the file operation. + ret = self.connection._file_operation_parser.parse_file_operation( + stage_location=stage_location, + local_file_name=None, + target_directory=None, + command_type=CMD_TYPE_UPLOAD, + options=options, + has_source_from_stream=input_stream, + ) + + # Execute the file operation based on the interpretation above. + file_transfer_agent = SnowflakeFileTransferAgent( + self, + "", # empty command because it is triggered by directly calling this util not by a SQL query + ret, + source_from_stream=input_stream, + force_put_overwrite=False, # _upload should respect user decision on overwriting + ) + await file_transfer_agent.execute() + await self._init_result_and_meta(file_transfer_agent.result()) + async def get_results_from_sfqid(self, sfqid: str) -> None: """Gets the results from previously ran query. This methods differs from ``SnowflakeCursor.query_result`` in that it monitors the ``sfqid`` until it is no longer running, and then retrieves the results. @@ -1120,7 +1298,7 @@ async def query_result(self, qid: str) -> SnowflakeCursor: data = ret.get("data") await self._init_result_and_meta(data) else: - logger.info("failed") + logger.debug("failed") logger.debug(ret) err = ret["message"] code = ret.get("code", -1) diff --git a/src/snowflake/connector/aio/_description.py b/src/snowflake/connector/aio/_description.py index 9b5f175408..0095129906 100644 --- a/src/snowflake/connector/aio/_description.py +++ b/src/snowflake/connector/aio/_description.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - """Various constants.""" from __future__ import annotations diff --git a/src/snowflake/connector/aio/_direct_file_operation_utils.py b/src/snowflake/connector/aio/_direct_file_operation_utils.py new file mode 100644 index 0000000000..e63bd14d63 --- /dev/null +++ b/src/snowflake/connector/aio/_direct_file_operation_utils.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + + +class FileOperationParserBase(ABC): + """The interface of internal utility functions for file operation parsing.""" + + @abstractmethod + def __init__(self, connection): + pass + + @abstractmethod + async def parse_file_operation( + self, + stage_location, + local_file_name, + target_directory, + command_type, + options, + has_source_from_stream=False, + ): + """Converts the file operation details into a SQL and returns the SQL parsing result.""" + pass + + +class StreamDownloaderBase(ABC): + """The interface of internal utility functions for stream downloading of file.""" + + @abstractmethod + def __init__(self, connection): + pass + + @abstractmethod + async def download_as_stream(self, ret, decompress=False): + pass + + +class FileOperationParser(FileOperationParserBase): + def __init__(self, connection): + pass + + async def parse_file_operation( + self, + stage_location, + local_file_name, + target_directory, + command_type, + options, + has_source_from_stream=False, + ): + raise NotImplementedError("parse_file_operation is not yet supported") + + +class StreamDownloader(StreamDownloaderBase): + def __init__(self, connection): + pass + + async def download_as_stream(self, ret, decompress=False): + raise NotImplementedError("download_as_stream is not yet supported") diff --git a/src/snowflake/connector/aio/_file_transfer_agent.py b/src/snowflake/connector/aio/_file_transfer_agent.py index 80b4829bb5..a42c7cd879 100644 --- a/src/snowflake/connector/aio/_file_transfer_agent.py +++ b/src/snowflake/connector/aio/_file_transfer_agent.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import asyncio @@ -63,6 +59,7 @@ def __init__( source_from_stream: IO[bytes] | None = None, use_s3_regional_url: bool = False, unsafe_file_write: bool = False, + gcs_use_virtual_endpoints: bool = False, ) -> None: super().__init__( cursor=cursor, @@ -82,6 +79,7 @@ def __init__( source_from_stream=source_from_stream, use_s3_regional_url=use_s3_regional_url, unsafe_file_write=unsafe_file_write, + gcs_use_virtual_endpoints=gcs_use_virtual_endpoints, ) async def execute(self) -> None: @@ -193,7 +191,7 @@ def transfer_done_cb( ) -> None: # Note: chunk_id is 0 based while num_of_chunks is count logger.debug( - f"Chunk {chunk_id}/{done_client.num_of_chunks} of file {done_client.meta.name} reached callback" + f"Chunk(id: {chunk_id}) {chunk_id+1}/{done_client.num_of_chunks} of file {done_client.meta.name} reached callback" ) if task.exception(): done_client.failed_transfers += 1 @@ -281,7 +279,6 @@ async def _create_file_transfer_client( self._credentials, AZURE_CHUNK_SIZE, self._stage_info, - use_s3_regional_url=self._use_s3_regional_url, unsafe_file_write=self._unsafe_file_write, ) elif self._stage_location_type == S3_FS: @@ -303,8 +300,8 @@ async def _create_file_transfer_client( self._stage_info, self._cursor._connection, self._command, - use_s3_regional_url=self._use_s3_regional_url, unsafe_file_write=self._unsafe_file_write, + use_virtual_endpoints=self._gcs_use_virtual_endpoints, ) if client.security_token: logger.debug(f"len(GCS_ACCESS_TOKEN): {len(client.security_token)}") diff --git a/src/snowflake/connector/aio/_gcs_storage_client.py b/src/snowflake/connector/aio/_gcs_storage_client.py index 8683e7d4c3..22a360e44c 100644 --- a/src/snowflake/connector/aio/_gcs_storage_client.py +++ b/src/snowflake/connector/aio/_gcs_storage_client.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# + from __future__ import annotations @@ -27,6 +25,7 @@ GCS_METADATA_ENCRYPTIONDATAPROP, GCS_METADATA_MATDESC_KEY, GCS_METADATA_SFC_DIGEST, + GCS_REGION_ME_CENTRAL_2, ) @@ -38,8 +37,8 @@ def __init__( stage_info: dict[str, Any], cnx: SnowflakeConnection, command: str, - use_s3_regional_url: bool = False, unsafe_file_write: bool = False, + use_virtual_endpoints: bool = False, ) -> None: """Creates a client object with given stage credentials. @@ -65,6 +64,16 @@ def __init__( # presigned_url in meta is for downloading self.presigned_url: str = meta.presigned_url or stage_info.get("presignedUrl") self.security_token = credentials.creds.get("GCS_ACCESS_TOKEN") + self.use_regional_url = ( + "region" in stage_info + and stage_info["region"].lower() == GCS_REGION_ME_CENTRAL_2 + or "useRegionalUrl" in stage_info + and stage_info["useRegionalUrl"] + ) + self.endpoint: str | None = ( + None if "endPoint" not in stage_info else stage_info["endPoint"] + ) + self.use_virtual_endpoints: bool = use_virtual_endpoints async def _has_expired_token(self, response: aiohttp.ClientResponse) -> bool: return self.security_token and response.status == 401 @@ -73,7 +82,7 @@ async def _has_expired_presigned_url( self, response: aiohttp.ClientResponse ) -> bool: # Presigned urls can be generated for any xml-api operation - # offered by GCS. Hence the error codes expected are similar + # offered by GCS. Hence, the error codes expected are similar # to xml api. # https://cloud.google.com/storage/docs/xml-api/reference-status @@ -132,7 +141,16 @@ def generate_url_and_rest_args() -> ( ): if not self.presigned_url: upload_url = self.generate_file_url( - self.stage_info["location"], meta.dst_file_name.lstrip("/") + self.stage_info["location"], + meta.dst_file_name.lstrip("/"), + self.use_regional_url, + ( + None + if "region" not in self.stage_info + else self.stage_info["region"] + ), + self.endpoint, + self.use_virtual_endpoints, ) access_token = self.security_token else: @@ -162,7 +180,16 @@ def generate_url_and_rest_args() -> ( gcs_headers = {} if not self.presigned_url: download_url = self.generate_file_url( - self.stage_info["location"], meta.src_file_name.lstrip("/") + self.stage_info["location"], + meta.src_file_name.lstrip("/"), + self.use_regional_url, + ( + None + if "region" not in self.stage_info + else self.stage_info["region"] + ), + self.endpoint, + self.use_virtual_endpoints, ) access_token = self.security_token gcs_headers["Authorization"] = f"Bearer {access_token}" @@ -279,7 +306,16 @@ async def get_file_header(self, filename: str) -> FileHeader | None: def generate_url_and_authenticated_headers(): url = self.generate_file_url( - self.stage_info["location"], filename.lstrip("/") + self.stage_info["location"], + filename.lstrip("/"), + self.use_regional_url, + ( + None + if "region" not in self.stage_info + else self.stage_info["region"] + ), + self.endpoint, + self.use_virtual_endpoints, ) gcs_headers = {"Authorization": f"Bearer {self.security_token}"} rest_args = {"headers": gcs_headers} diff --git a/src/snowflake/connector/aio/_network.py b/src/snowflake/connector/aio/_network.py index 7ec0d1f003..194469a385 100644 --- a/src/snowflake/connector/aio/_network.py +++ b/src/snowflake/connector/aio/_network.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import asyncio @@ -71,7 +67,12 @@ ) from ..network import SessionPool as SessionPoolSync from ..network import SnowflakeRestful as SnowflakeRestfulSync -from ..network import get_http_retryable_error, is_login_request, is_retryable_http_code +from ..network import ( + SnowflakeRestfulJsonEncoder, + get_http_retryable_error, + is_login_request, + is_retryable_http_code, +) from ..secret_detector import SecretDetector from ..sqlstate import ( SQLSTATE_CONNECTION_NOT_EXISTS, @@ -236,7 +237,7 @@ async def request( return await self._post_request( url, headers, - json.dumps(body), + json.dumps(body, cls=SnowflakeRestfulJsonEncoder), token=self.token, _no_results=_no_results, timeout=timeout, @@ -298,7 +299,7 @@ async def _token_request(self, request_type): ret = await self._post_request( url, headers, - json.dumps(body), + json.dumps(body, cls=SnowflakeRestfulJsonEncoder), token=header_token, ) if ret.get("success") and ret.get("data", {}).get("sessionToken"): @@ -396,7 +397,7 @@ async def delete_session(self, retry: bool = False) -> None: ret = await self._post_request( url, headers, - json.dumps(body), + json.dumps(body, cls=SnowflakeRestfulJsonEncoder), token=self.token, timeout=5, no_retry=True, diff --git a/src/snowflake/connector/aio/_ocsp_asn1crypto.py b/src/snowflake/connector/aio/_ocsp_asn1crypto.py index 28622c5039..0428ce0040 100644 --- a/src/snowflake/connector/aio/_ocsp_asn1crypto.py +++ b/src/snowflake/connector/aio/_ocsp_asn1crypto.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import ssl diff --git a/src/snowflake/connector/aio/_ocsp_snowflake.py b/src/snowflake/connector/aio/_ocsp_snowflake.py index 8cff5d5d7d..d7fd8ff04a 100644 --- a/src/snowflake/connector/aio/_ocsp_snowflake.py +++ b/src/snowflake/connector/aio/_ocsp_snowflake.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import asyncio diff --git a/src/snowflake/connector/aio/_result_batch.py b/src/snowflake/connector/aio/_result_batch.py index 3bf9565ee7..d258593e03 100644 --- a/src/snowflake/connector/aio/_result_batch.py +++ b/src/snowflake/connector/aio/_result_batch.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import abc diff --git a/src/snowflake/connector/aio/_result_set.py b/src/snowflake/connector/aio/_result_set.py index 2ac9639947..1608e5a81a 100644 --- a/src/snowflake/connector/aio/_result_set.py +++ b/src/snowflake/connector/aio/_result_set.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# + from __future__ import annotations diff --git a/src/snowflake/connector/aio/_s3_storage_client.py b/src/snowflake/connector/aio/_s3_storage_client.py index 1f72166c68..fbeb54206f 100644 --- a/src/snowflake/connector/aio/_s3_storage_client.py +++ b/src/snowflake/connector/aio/_s3_storage_client.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import xml.etree.ElementTree as ET @@ -74,7 +70,13 @@ def __init__( self.stage_info["location"] ) ) - self.use_s3_regional_url = use_s3_regional_url + self.use_s3_regional_url = ( + use_s3_regional_url + or "useS3RegionalUrl" in stage_info + and stage_info["useS3RegionalUrl"] + or "useRegionalUrl" in stage_info + and stage_info["useRegionalUrl"] + ) self.location_type = stage_info.get("locationType") # if GS sends us an endpoint, it's likely for FIPS. Use it. @@ -121,6 +123,9 @@ def generate_authenticated_url_and_args_v4() -> tuple[str, dict[str, bytes]]: amzdate = t.strftime("%Y%m%dT%H%M%SZ") short_amzdate = amzdate[:8] x_amz_headers["x-amz-date"] = amzdate + x_amz_headers["x-amz-security-token"] = self.credentials.creds.get( + "AWS_TOKEN", "" + ) ( canonical_request, diff --git a/src/snowflake/connector/aio/_ssl_connector.py b/src/snowflake/connector/aio/_ssl_connector.py index b7ab50e6ec..2fae526b4d 100644 --- a/src/snowflake/connector/aio/_ssl_connector.py +++ b/src/snowflake/connector/aio/_ssl_connector.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging diff --git a/src/snowflake/connector/aio/_storage_client.py b/src/snowflake/connector/aio/_storage_client.py index 1e2265bba9..3d27222aab 100644 --- a/src/snowflake/connector/aio/_storage_client.py +++ b/src/snowflake/connector/aio/_storage_client.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import asyncio @@ -193,6 +189,7 @@ async def _send_request_with_retry( conn = self.meta.sfagent._cursor._connection while self.retry_count[retry_id] < self.max_retry: + logger.debug(f"retry #{self.retry_count[retry_id]}") cur_timestamp = self.credentials.timestamp url, rest_kwargs = get_request_args() # rest_kwargs["timeout"] = (REQUEST_CONNECTION_TIMEOUT, REQUEST_READ_TIMEOUT) @@ -208,10 +205,14 @@ async def _send_request_with_retry( ) if await self._has_expired_presigned_url(response): + logger.debug( + "presigned url expired. trying to update presigned url." + ) await self._update_presigned_url() else: self.last_err_is_presigned_url = False if response.status in self.TRANSIENT_HTTP_ERR: + logger.debug(f"transient error: {response.status}") await asyncio.sleep( min( # TODO should SLEEP_UNIT come from the parent @@ -222,7 +223,9 @@ async def _send_request_with_retry( ) self.retry_count[retry_id] += 1 elif await self._has_expired_token(response): + logger.debug("token is expired. trying to update token") self.credentials.update(cur_timestamp) + self.retry_count[retry_id] += 1 else: return response except self.TRANSIENT_ERRORS as e: diff --git a/src/snowflake/connector/aio/_telemetry.py b/src/snowflake/connector/aio/_telemetry.py index f5aa5d4254..b9b46f2301 100644 --- a/src/snowflake/connector/aio/_telemetry.py +++ b/src/snowflake/connector/aio/_telemetry.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# + from __future__ import annotations diff --git a/src/snowflake/connector/aio/_time_util.py b/src/snowflake/connector/aio/_time_util.py index c11f19728f..d21eae30bb 100644 --- a/src/snowflake/connector/aio/_time_util.py +++ b/src/snowflake/connector/aio/_time_util.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import asyncio diff --git a/src/snowflake/connector/aio/_wif_util.py b/src/snowflake/connector/aio/_wif_util.py new file mode 100644 index 0000000000..ebb74d48d8 --- /dev/null +++ b/src/snowflake/connector/aio/_wif_util.py @@ -0,0 +1,273 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import os +from base64 import b64encode + +import aioboto3 +import aiohttp +from aiobotocore.utils import AioInstanceMetadataRegionFetcher +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest + +from ..errorcode import ER_WIF_CREDENTIALS_NOT_FOUND +from ..errors import ProgrammingError +from ..wif_util import ( + DEFAULT_ENTRA_SNOWFLAKE_RESOURCE, + SNOWFLAKE_AUDIENCE, + AttestationProvider, + WorkloadIdentityAttestation, + create_oidc_attestation, + extract_iss_and_sub_without_signature_verification, +) + +logger = logging.getLogger(__name__) + + +async def try_metadata_service_call( + method: str, url: str, headers: dict, timeout_sec: int = 3 +) -> aiohttp.ClientResponse | None: + """Tries to make a HTTP request to the metadata service with the given URL, method, headers and timeout. + + If we receive an error response or any exceptions are raised, returns None. Otherwise returns the response. + """ + try: + timeout = aiohttp.ClientTimeout(total=timeout_sec) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.request( + method=method, url=url, headers=headers + ) as response: + if not response.ok: + return None + # Create a copy of the response data since the response will be closed + content = await response.read() + response._content = content + return response + except (aiohttp.ClientError, asyncio.TimeoutError): + return None + + +async def get_aws_region() -> str | None: + """Get the current AWS workload's region, if any.""" + if "AWS_REGION" in os.environ: # Lambda + return os.environ["AWS_REGION"] + else: # EC2 + return await AioInstanceMetadataRegionFetcher().retrieve_region() + + +async def get_aws_arn() -> str | None: + """Get the current AWS workload's ARN, if any.""" + session = aioboto3.Session() + async with session.client("sts") as client: + caller_identity = await client.get_caller_identity() + if not caller_identity or "Arn" not in caller_identity: + return None + return caller_identity["Arn"] + + +async def create_aws_attestation() -> WorkloadIdentityAttestation | None: + """Tries to create a workload identity attestation for AWS. + + If the application isn't running on AWS or no credentials were found, returns None. + """ + session = aioboto3.Session() + aws_creds = await session.get_credentials() + if not aws_creds: + logger.debug("No AWS credentials were found.") + return None + + region = await get_aws_region() + if not region: + logger.debug("No AWS region was found.") + return None + + arn = await get_aws_arn() + if not arn: + logger.debug("No AWS caller identity was found.") + return None + + sts_hostname = f"sts.{region}.amazonaws.com" + request = AWSRequest( + method="POST", + url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15", + headers={ + "Host": sts_hostname, + "X-Snowflake-Audience": SNOWFLAKE_AUDIENCE, + }, + ) + + SigV4Auth(aws_creds, "sts", region).add_auth(request) + + assertion_dict = { + "url": request.url, + "method": request.method, + "headers": dict(request.headers.items()), + } + credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode("utf-8") + return WorkloadIdentityAttestation( + AttestationProvider.AWS, credential, {"arn": arn} + ) + + +async def create_gcp_attestation() -> WorkloadIdentityAttestation | None: + """Tries to create a workload identity attestation for GCP. + + If the application isn't running on GCP or no credentials were found, returns None. + """ + res = await try_metadata_service_call( + method="GET", + url=f"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/identity?audience={SNOWFLAKE_AUDIENCE}", + headers={ + "Metadata-Flavor": "Google", + }, + ) + if res is None: + # Most likely we're just not running on GCP, which may be expected. + logger.debug("GCP metadata server request was not successful.") + return None + + jwt_str = res._content.decode("utf-8") + issuer, subject = extract_iss_and_sub_without_signature_verification(jwt_str) + if not issuer or not subject: + return None + if issuer != "https://accounts.google.com": + # This might happen if we're running on a different platform that responds to the same metadata request signature as GCP. + logger.debug("Unexpected GCP token issuer '%s'", issuer) + return None + + return WorkloadIdentityAttestation( + AttestationProvider.GCP, jwt_str, {"sub": subject} + ) + + +async def create_azure_attestation( + snowflake_entra_resource: str, +) -> WorkloadIdentityAttestation | None: + """Tries to create a workload identity attestation for Azure. + + If the application isn't running on Azure or no credentials were found, returns None. + """ + headers = {"Metadata": "True"} + url_without_query_string = "http://169.254.169.254/metadata/identity/oauth2/token" + query_params = f"api-version=2018-02-01&resource={snowflake_entra_resource}" + + # Check if running in Azure Functions environment + identity_endpoint = os.environ.get("IDENTITY_ENDPOINT") + identity_header = os.environ.get("IDENTITY_HEADER") + is_azure_functions = identity_endpoint is not None + + if is_azure_functions: + if not identity_header: + logger.warning("Managed identity is not enabled on this Azure function.") + return None + + # Azure Functions uses a different endpoint, headers and API version. + url_without_query_string = identity_endpoint + headers = {"X-IDENTITY-HEADER": identity_header} + query_params = f"api-version=2019-08-01&resource={snowflake_entra_resource}" + + # Some Azure Functions environments may require client_id in the URL + managed_identity_client_id = os.environ.get("MANAGED_IDENTITY_CLIENT_ID") + if managed_identity_client_id: + query_params += f"&client_id={managed_identity_client_id}" + + res = await try_metadata_service_call( + method="GET", + url=f"{url_without_query_string}?{query_params}", + headers=headers, + ) + if res is None: + # Most likely we're just not running on Azure, which may be expected. + logger.debug("Azure metadata server request was not successful.") + return None + + try: + response_text = res._content.decode("utf-8") + response_data = json.loads(response_text) + jwt_str = response_data.get("access_token") + if not jwt_str: + # Could be that Managed Identity is disabled. + logger.debug("No access token found in Azure response.") + return None + except (ValueError, KeyError) as e: + logger.debug(f"Error parsing Azure response: {e}") + return None + + issuer, subject = extract_iss_and_sub_without_signature_verification(jwt_str) + if not issuer or not subject: + return None + if not ( + issuer.startswith("https://sts.windows.net/") + or issuer.startswith("https://login.microsoftonline.com/") + ): + # This might happen if we're running on a different platform that responds to the same metadata request signature as Azure. + logger.debug("Unexpected Azure token issuer '%s'", issuer) + return None + + return WorkloadIdentityAttestation( + AttestationProvider.AZURE, jwt_str, {"iss": issuer, "sub": subject} + ) + + +async def create_autodetect_attestation( + entra_resource: str, token: str | None = None +) -> WorkloadIdentityAttestation | None: + """Tries to create an attestation using the auto-detected runtime environment. + + If no attestation can be found, returns None. + """ + attestation = create_oidc_attestation(token) + if attestation: + return attestation + + attestation = await create_aws_attestation() + if attestation: + return attestation + + attestation = await create_azure_attestation(entra_resource) + if attestation: + return attestation + + attestation = await create_gcp_attestation() + if attestation: + return attestation + + return None + + +async def create_attestation( + provider: AttestationProvider | None, + entra_resource: str | None = None, + token: str | None = None, +) -> WorkloadIdentityAttestation: + """Entry point to create an attestation using the given provider. + + If the provider is None, this will try to auto-detect a credential from the runtime environment. If the provider fails to detect a credential, + a ProgrammingError will be raised. + + If an explicit entra_resource was provided to the connector, this will be used. Otherwise, the default Snowflake Entra resource will be used. + """ + entra_resource = entra_resource or DEFAULT_ENTRA_SNOWFLAKE_RESOURCE + + attestation: WorkloadIdentityAttestation = None + if provider == AttestationProvider.AWS: + attestation = await create_aws_attestation() + elif provider == AttestationProvider.AZURE: + attestation = await create_azure_attestation(entra_resource) + elif provider == AttestationProvider.GCP: + attestation = await create_gcp_attestation() + elif provider == AttestationProvider.OIDC: + attestation = create_oidc_attestation(token) + elif provider is None: + attestation = await create_autodetect_attestation(entra_resource, token) + + if not attestation: + provider_str = "auto-detect" if provider is None else provider.value + raise ProgrammingError( + msg=f"No workload identity credential was found for '{provider_str}'.", + errno=ER_WIF_CREDENTIALS_NOT_FOUND, + ) + + return attestation diff --git a/src/snowflake/connector/aio/auth/__init__.py b/src/snowflake/connector/aio/auth/__init__.py index 97eecff7d6..4091bcf06b 100644 --- a/src/snowflake/connector/aio/auth/__init__.py +++ b/src/snowflake/connector/aio/auth/__init__.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from ...auth.by_plugin import AuthType @@ -16,6 +12,7 @@ from ._pat import AuthByPAT from ._usrpwdmfa import AuthByUsrPwdMfa from ._webbrowser import AuthByWebBrowser +from ._workload_identity import AuthByWorkloadIdentity FIRST_PARTY_AUTHENTICATORS = frozenset( ( @@ -27,6 +24,7 @@ AuthByWebBrowser, AuthByIdToken, AuthByPAT, + AuthByWorkloadIdentity, AuthNoAuth, ) ) @@ -40,6 +38,7 @@ "AuthByOkta", "AuthByUsrPwdMfa", "AuthByWebBrowser", + "AuthByWorkloadIdentity", "AuthNoAuth", "Auth", "AuthType", diff --git a/src/snowflake/connector/aio/auth/_auth.py b/src/snowflake/connector/aio/auth/_auth.py index edb270e49f..8dbb86f963 100644 --- a/src/snowflake/connector/aio/auth/_auth.py +++ b/src/snowflake/connector/aio/auth/_auth.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import asyncio diff --git a/src/snowflake/connector/aio/auth/_by_plugin.py b/src/snowflake/connector/aio/auth/_by_plugin.py index 818769a9f2..d69850f98e 100644 --- a/src/snowflake/connector/aio/auth/_by_plugin.py +++ b/src/snowflake/connector/aio/auth/_by_plugin.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import asyncio diff --git a/src/snowflake/connector/aio/auth/_default.py b/src/snowflake/connector/aio/auth/_default.py index 1466db4d7a..2988d70897 100644 --- a/src/snowflake/connector/aio/auth/_default.py +++ b/src/snowflake/connector/aio/auth/_default.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from logging import getLogger diff --git a/src/snowflake/connector/aio/auth/_idtoken.py b/src/snowflake/connector/aio/auth/_idtoken.py index 23bca2beaa..f88a647587 100644 --- a/src/snowflake/connector/aio/auth/_idtoken.py +++ b/src/snowflake/connector/aio/auth/_idtoken.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from typing import TYPE_CHECKING, Any diff --git a/src/snowflake/connector/aio/auth/_keypair.py b/src/snowflake/connector/aio/auth/_keypair.py index aff2f207f2..72da132319 100644 --- a/src/snowflake/connector/aio/auth/_keypair.py +++ b/src/snowflake/connector/aio/auth/_keypair.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# + from __future__ import annotations from logging import getLogger diff --git a/src/snowflake/connector/aio/auth/_no_auth.py b/src/snowflake/connector/aio/auth/_no_auth.py index 17a2d3e6d3..d315f612ff 100644 --- a/src/snowflake/connector/aio/auth/_no_auth.py +++ b/src/snowflake/connector/aio/auth/_no_auth.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# + from __future__ import annotations diff --git a/src/snowflake/connector/aio/auth/_oauth.py b/src/snowflake/connector/aio/auth/_oauth.py index 04cd44ba2c..ce63b099ab 100644 --- a/src/snowflake/connector/aio/auth/_oauth.py +++ b/src/snowflake/connector/aio/auth/_oauth.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# + from __future__ import annotations diff --git a/src/snowflake/connector/aio/auth/_okta.py b/src/snowflake/connector/aio/auth/_okta.py index d8cd216df5..9b40d8c2f3 100644 --- a/src/snowflake/connector/aio/auth/_okta.py +++ b/src/snowflake/connector/aio/auth/_okta.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# + from __future__ import annotations diff --git a/src/snowflake/connector/aio/auth/_pat.py b/src/snowflake/connector/aio/auth/_pat.py index 8c88944810..805159a86e 100644 --- a/src/snowflake/connector/aio/auth/_pat.py +++ b/src/snowflake/connector/aio/auth/_pat.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# + from __future__ import annotations diff --git a/src/snowflake/connector/aio/auth/_usrpwdmfa.py b/src/snowflake/connector/aio/auth/_usrpwdmfa.py index 4175bf5015..26ea212304 100644 --- a/src/snowflake/connector/aio/auth/_usrpwdmfa.py +++ b/src/snowflake/connector/aio/auth/_usrpwdmfa.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# + from __future__ import annotations diff --git a/src/snowflake/connector/aio/auth/_webbrowser.py b/src/snowflake/connector/aio/auth/_webbrowser.py index 97e9bbc1b6..c00e9a3293 100644 --- a/src/snowflake/connector/aio/auth/_webbrowser.py +++ b/src/snowflake/connector/aio/auth/_webbrowser.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# + from __future__ import annotations import asyncio diff --git a/src/snowflake/connector/aio/auth/_workload_identity.py b/src/snowflake/connector/aio/auth/_workload_identity.py new file mode 100644 index 0000000000..d1045f6aff --- /dev/null +++ b/src/snowflake/connector/aio/auth/_workload_identity.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Any + +from ...auth.workload_identity import ( + AuthByWorkloadIdentity as AuthByWorkloadIdentitySync, +) +from .._wif_util import AttestationProvider, create_attestation +from ._by_plugin import AuthByPlugin as AuthByPluginAsync + + +class AuthByWorkloadIdentity(AuthByPluginAsync, AuthByWorkloadIdentitySync): + """Plugin to authenticate via workload identity.""" + + def __init__( + self, + *, + provider: AttestationProvider | None = None, + token: str | None = None, + entra_resource: str | None = None, + **kwargs, + ) -> None: + """Initializes an instance with workload identity authentication.""" + AuthByWorkloadIdentitySync.__init__( + self, + provider=provider, + token=token, + entra_resource=entra_resource, + **kwargs, + ) + + async def reset_secrets(self) -> None: + AuthByWorkloadIdentitySync.reset_secrets(self) + + async def prepare(self, **kwargs: Any) -> None: + """Fetch the token using async wif_util.""" + self.attestation = await create_attestation( + self.provider, self.entra_resource, self.token + ) + + async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]: + """This is only relevant for AuthByIdToken, which uses a web-browser based flow. All other auth plugins just call authenticate() again.""" + return {"success": False} + + async def update_body(self, body: dict[Any, Any]) -> None: + AuthByWorkloadIdentitySync.update_body(self, body) diff --git a/src/snowflake/connector/arrow_context.py b/src/snowflake/connector/arrow_context.py index db5a465984..10dc9ea558 100644 --- a/src/snowflake/connector/arrow_context.py +++ b/src/snowflake/connector/arrow_context.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import decimal diff --git a/src/snowflake/connector/auth/__init__.py b/src/snowflake/connector/auth/__init__.py index 1884979239..0874b35ca7 100644 --- a/src/snowflake/connector/auth/__init__.py +++ b/src/snowflake/connector/auth/__init__.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from ._auth import Auth, get_public_key_fingerprint, get_token_from_private_key @@ -15,6 +11,7 @@ from .pat import AuthByPAT from .usrpwdmfa import AuthByUsrPwdMfa from .webbrowser import AuthByWebBrowser +from .workload_identity import AuthByWorkloadIdentity FIRST_PARTY_AUTHENTICATORS = frozenset( ( @@ -26,6 +23,7 @@ AuthByWebBrowser, AuthByIdToken, AuthByPAT, + AuthByWorkloadIdentity, AuthNoAuth, ) ) @@ -39,6 +37,7 @@ "AuthByOkta", "AuthByUsrPwdMfa", "AuthByWebBrowser", + "AuthByWorkloadIdentity", "AuthNoAuth", "Auth", "AuthType", diff --git a/src/snowflake/connector/auth/_auth.py b/src/snowflake/connector/auth/_auth.py index e3b18d42a5..cf3b6b6297 100644 --- a/src/snowflake/connector/auth/_auth.py +++ b/src/snowflake/connector/auth/_auth.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import copy diff --git a/src/snowflake/connector/auth/by_plugin.py b/src/snowflake/connector/auth/by_plugin.py index 3bffd61b81..9068a9ea44 100644 --- a/src/snowflake/connector/auth/by_plugin.py +++ b/src/snowflake/connector/auth/by_plugin.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations """This module implements the base class for authenticator classes. @@ -56,6 +52,7 @@ class AuthType(Enum): OKTA = "OKTA" PAT = "PROGRAMMATIC_ACCESS_TOKEN" NO_AUTH = "NO_AUTH" + WORKLOAD_IDENTITY = "WORKLOAD_IDENTITY" class AuthByPlugin(ABC): diff --git a/src/snowflake/connector/auth/default.py b/src/snowflake/connector/auth/default.py index 3b8c564669..0a7fd7be42 100644 --- a/src/snowflake/connector/auth/default.py +++ b/src/snowflake/connector/auth/default.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from typing import Any diff --git a/src/snowflake/connector/auth/idtoken.py b/src/snowflake/connector/auth/idtoken.py index 927138c960..9ca946230e 100644 --- a/src/snowflake/connector/auth/idtoken.py +++ b/src/snowflake/connector/auth/idtoken.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from typing import TYPE_CHECKING, Any diff --git a/src/snowflake/connector/auth/keypair.py b/src/snowflake/connector/auth/keypair.py index 3fa6b437f4..951e9e7dc5 100644 --- a/src/snowflake/connector/auth/keypair.py +++ b/src/snowflake/connector/auth/keypair.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import base64 diff --git a/src/snowflake/connector/auth/no_auth.py b/src/snowflake/connector/auth/no_auth.py index d7730b26ac..2f58edd916 100644 --- a/src/snowflake/connector/auth/no_auth.py +++ b/src/snowflake/connector/auth/no_auth.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from typing import Any diff --git a/src/snowflake/connector/auth/oauth.py b/src/snowflake/connector/auth/oauth.py index c497415d19..995ed95e4b 100644 --- a/src/snowflake/connector/auth/oauth.py +++ b/src/snowflake/connector/auth/oauth.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from typing import Any diff --git a/src/snowflake/connector/auth/okta.py b/src/snowflake/connector/auth/okta.py index 28452e313a..e0601d9516 100644 --- a/src/snowflake/connector/auth/okta.py +++ b/src/snowflake/connector/auth/okta.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import json diff --git a/src/snowflake/connector/auth/pat.py b/src/snowflake/connector/auth/pat.py index 3eb63fb462..cc61300bd4 100644 --- a/src/snowflake/connector/auth/pat.py +++ b/src/snowflake/connector/auth/pat.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import typing diff --git a/src/snowflake/connector/auth/usrpwdmfa.py b/src/snowflake/connector/auth/usrpwdmfa.py index 4c8f4aaf0a..a632f3a40a 100644 --- a/src/snowflake/connector/auth/usrpwdmfa.py +++ b/src/snowflake/connector/auth/usrpwdmfa.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging diff --git a/src/snowflake/connector/auth/webbrowser.py b/src/snowflake/connector/auth/webbrowser.py index b42fa9596d..2f77badf8c 100644 --- a/src/snowflake/connector/auth/webbrowser.py +++ b/src/snowflake/connector/auth/webbrowser.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import base64 diff --git a/src/snowflake/connector/auth/workload_identity.py b/src/snowflake/connector/auth/workload_identity.py new file mode 100644 index 0000000000..3c80c965e4 --- /dev/null +++ b/src/snowflake/connector/auth/workload_identity.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import json +import typing +from enum import Enum, unique + +from ..network import WORKLOAD_IDENTITY_AUTHENTICATOR +from ..wif_util import ( + AttestationProvider, + WorkloadIdentityAttestation, + create_attestation, +) +from .by_plugin import AuthByPlugin, AuthType + + +@unique +class ApiFederatedAuthenticationType(Enum): + """An API-specific enum of the WIF authentication type.""" + + AWS = "AWS" + AZURE = "AZURE" + GCP = "GCP" + OIDC = "OIDC" + + @staticmethod + def from_attestation( + attestation: WorkloadIdentityAttestation, + ) -> ApiFederatedAuthenticationType: + """Maps the internal / driver-specific attestation providers to API authenticator types. + + The AttestationProvider is related to how the driver fetches the credential, while the API authenticator + type is related to how the credential is verified. In most current cases these may be the same, though + in the future we could have, for example, multiple AttestationProviders that all fetch an OIDC ID token. + """ + if attestation.provider == AttestationProvider.AWS: + return ApiFederatedAuthenticationType.AWS + if attestation.provider == AttestationProvider.AZURE: + return ApiFederatedAuthenticationType.AZURE + if attestation.provider == AttestationProvider.GCP: + return ApiFederatedAuthenticationType.GCP + if attestation.provider == AttestationProvider.OIDC: + return ApiFederatedAuthenticationType.OIDC + raise ValueError(f"Unknown attestation provider '{attestation.provider}'") + + +class AuthByWorkloadIdentity(AuthByPlugin): + """Plugin to authenticate via workload identity.""" + + def __init__( + self, + *, + provider: AttestationProvider | None = None, + token: str | None = None, + entra_resource: str | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.provider = provider + self.token = token + self.entra_resource = entra_resource + + self.attestation: WorkloadIdentityAttestation | None = None + + def type_(self) -> AuthType: + return AuthType.WORKLOAD_IDENTITY + + def reset_secrets(self) -> None: + self.attestation = None + + def update_body(self, body: dict[typing.Any, typing.Any]) -> None: + body["data"]["AUTHENTICATOR"] = WORKLOAD_IDENTITY_AUTHENTICATOR + body["data"]["PROVIDER"] = ApiFederatedAuthenticationType.from_attestation( + self.attestation + ).value + body["data"]["TOKEN"] = self.attestation.credential + + def prepare(self, **kwargs: typing.Any) -> None: + """Fetch the token.""" + self.attestation = create_attestation( + self.provider, self.entra_resource, self.token + ) + + def reauthenticate(self, **kwargs: typing.Any) -> dict[str, bool]: + """This is only relevant for AuthByIdToken, which uses a web-browser based flow. All other auth plugins just call authenticate() again.""" + return {"success": False} + + @property + def assertion_content(self) -> str: + """Returns the CSP provider name and an identifier. Used for logging purposes.""" + if not self.attestation: + return "" + properties = self.attestation.user_identifier_components + properties["_provider"] = self.attestation.provider.value + return json.dumps(properties, sort_keys=True, separators=(",", ":")) diff --git a/src/snowflake/connector/azure_storage_client.py b/src/snowflake/connector/azure_storage_client.py index 6ac1c348e5..164dd41f42 100644 --- a/src/snowflake/connector/azure_storage_client.py +++ b/src/snowflake/connector/azure_storage_client.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import base64 @@ -9,7 +5,7 @@ import os import xml.etree.ElementTree as ET from datetime import datetime, timezone -from logging import Filter, getLogger +from logging import getLogger from random import choice from string import hexdigits from typing import TYPE_CHECKING, Any, NamedTuple @@ -41,22 +37,6 @@ class AzureLocation(NamedTuple): MATDESC = "x-ms-meta-matdesc" -class AzureCredentialFilter(Filter): - LEAKY_FMT = '%s://%s:%s "%s %s %s" %s %s' - - def filter(self, record): - if record.msg == AzureCredentialFilter.LEAKY_FMT and len(record.args) == 8: - record.args = ( - record.args[:4] + (record.args[4].split("?")[0],) + record.args[5:] - ) - return True - - -getLogger("snowflake.connector.vendored.urllib3.connectionpool").addFilter( - AzureCredentialFilter() -) - - class SnowflakeAzureRestClient(SnowflakeStorageClient): def __init__( self, @@ -64,7 +44,6 @@ def __init__( credentials: StorageCredential | None, chunk_size: int, stage_info: dict[str, Any], - use_s3_regional_url: bool = False, unsafe_file_write: bool = False, ) -> None: super().__init__( diff --git a/src/snowflake/connector/backoff_policies.py b/src/snowflake/connector/backoff_policies.py index 8813dc1adc..8e6b1010bd 100644 --- a/src/snowflake/connector/backoff_policies.py +++ b/src/snowflake/connector/backoff_policies.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import random diff --git a/src/snowflake/connector/bind_upload_agent.py b/src/snowflake/connector/bind_upload_agent.py index 694a85b827..b71920d0b4 100644 --- a/src/snowflake/connector/bind_upload_agent.py +++ b/src/snowflake/connector/bind_upload_agent.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import uuid diff --git a/src/snowflake/connector/cache.py b/src/snowflake/connector/cache.py index 5c47813049..86f6a3417c 100644 --- a/src/snowflake/connector/cache.py +++ b/src/snowflake/connector/cache.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import datetime diff --git a/src/snowflake/connector/compat.py b/src/snowflake/connector/compat.py index e138bdb2e0..3458ace0ef 100644 --- a/src/snowflake/connector/compat.py +++ b/src/snowflake/connector/compat.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import collections.abc diff --git a/src/snowflake/connector/config_manager.py b/src/snowflake/connector/config_manager.py index 6c3f7686f1..6e1ad51dfd 100644 --- a/src/snowflake/connector/config_manager.py +++ b/src/snowflake/connector/config_manager.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import itertools diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index f8ee1ba882..2a85965e6c 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import atexit @@ -44,6 +40,7 @@ AuthByPlugin, AuthByUsrPwdMfa, AuthByWebBrowser, + AuthByWorkloadIdentity, AuthNoAuth, ) from .auth.idtoken import AuthByIdToken @@ -55,6 +52,7 @@ from .constants import ( _CONNECTIVITY_ERR_MSG, _DOMAIN_NAME_MAP, + ENV_VAR_EXPERIMENTAL_AUTHENTICATION, ENV_VAR_PARTNER, PARAMETER_AUTOCOMMIT, PARAMETER_CLIENT_PREFETCH_THREADS, @@ -80,6 +78,7 @@ PYTHON_VERSION, SNOWFLAKE_CONNECTOR_VERSION, ) +from .direct_file_operation_utils import FileOperationParser, StreamDownloader from .errorcode import ( ER_CONNECTION_IS_CLOSED, ER_FAILED_PROCESSING_PYFORMAT, @@ -87,6 +86,7 @@ ER_FAILED_TO_CONNECT_TO_DB, ER_INVALID_BACKOFF_POLICY, ER_INVALID_VALUE, + ER_INVALID_WIF_SETTINGS, ER_NO_ACCOUNT_NAME, ER_NO_NUMPY, ER_NO_PASSWORD, @@ -104,6 +104,7 @@ PROGRAMMATIC_ACCESS_TOKEN, REQUEST_ID, USR_PWD_MFA_AUTHENTICATOR, + WORKLOAD_IDENTITY_AUTHENTICATOR, ReauthenticationRequest, SnowflakeRestful, ) @@ -112,9 +113,11 @@ from .time_util import HeartBeatTimer, get_time_millis from .url_util import extract_top_level_domain_from_hostname from .util_text import construct_hostname, parse_account, split_statements +from .wif_util import AttestationProvider DEFAULT_CLIENT_PREFETCH_THREADS = 4 MAX_CLIENT_PREFETCH_THREADS = 10 +MAX_CLIENT_FETCH_THREADS = 1024 DEFAULT_BACKOFF_POLICY = exponential_backoff() @@ -188,12 +191,14 @@ def _get_private_bytes_from_file( "private_key": (None, (type(None), bytes, str, RSAPrivateKey)), "private_key_file": (None, (type(None), str)), "private_key_file_pwd": (None, (type(None), str, bytes)), - "token": (None, (type(None), str)), # OAuth/JWT/PAT Token + "token": (None, (type(None), str)), # OAuth/JWT/PAT/OIDC Token "token_file_path": ( None, (type(None), str, bytes), - ), # OAuth/JWT/PAT Token file path + ), # OAuth/JWT/PAT/OIDC Token file path "authenticator": (DEFAULT_AUTHENTICATOR, (type(None), str)), + "workload_identity_provider": (None, (type(None), AttestationProvider)), + "workload_identity_entra_resource": (None, (type(None), str)), "mfa_callback": (None, (type(None), Callable)), "password_callback": (None, (type(None), Callable)), "auth_class": (None, (type(None), AuthByPlugin)), @@ -215,6 +220,7 @@ def _get_private_bytes_from_file( (type(None), int), ), # snowflake "client_prefetch_threads": (4, int), # snowflake + "client_fetch_threads": (None, (type(None), int)), "numpy": (False, bool), # snowflake "ocsp_response_cache_filename": (None, (type(None), str)), # snowflake internal "converter_class": (DefaultConverterClass(), SnowflakeConverter), @@ -305,6 +311,14 @@ def _get_private_bytes_from_file( None, (type(None), int), ), # SNOW-1817982: limit iobound TPE sizes when executing PUT/GET + "gcs_use_virtual_endpoints": ( + False, + bool, + ), # use https://{bucket}.storage.googleapis.com instead of https://storage.googleapis.com/{bucket} + "check_arrow_conversion_error_on_every_column": ( + True, + bool, + ), # SNOW-XXXXX: remove the check_arrow_conversion_error_on_every_column flag "unsafe_file_write": ( False, bool, @@ -369,6 +383,7 @@ class SnowflakeConnection: See the backoff_policies module for details and implementation examples. client_session_keep_alive_heartbeat_frequency: Heartbeat frequency to keep connection alive in seconds. client_prefetch_threads: Number of threads to download the result set. + client_fetch_threads: Number of threads to fetch staged query results. rest: Snowflake REST API object. Internal use only. Maybe removed in a later release. application: Application name to communicate with Snowflake as. By default, this is "PythonConnector". errorhandler: Handler used with errors. By default, an exception will be raised on error. @@ -388,6 +403,8 @@ class SnowflakeConnection: before the connector shuts down. Default value is false. token_file_path: The file path of the token file. If both token and token_file_path are provided, the token in token_file_path will be used. unsafe_file_write: When true, files downloaded by GET will be saved with 644 permissions. Otherwise, files will be saved with safe - owner-only permissions: 600. + gcs_use_virtual_endpoints: When true, the virtual endpoint url is used, see: https://cloud.google.com/storage/docs/request-endpoints#xml-api + check_arrow_conversion_error_on_every_column: When true, the error check after the conversion from arrow to python types will happen for every column in the row. This is a new behaviour which fixes the bug that caused the type errors to trigger silently when occurring at any place other than last column in a row. To revert the previous (faulty) behaviour, please set this flag to false. """ OCSP_ENV_LOCK = Lock() @@ -497,6 +514,10 @@ def __init__( # check SNOW-1218851 for long term improvement plan to refactor ocsp code atexit.register(self._close_at_exit) + # Set up the file operation parser and stream downloader. + self._file_operation_parser = FileOperationParser(self) + self._stream_downloader = StreamDownloader(self) + # Deprecated @property def insecure_mode(self) -> bool: @@ -627,6 +648,16 @@ def client_prefetch_threads(self, value) -> None: self._client_prefetch_threads = value self._validate_client_prefetch_threads() + @property + def client_fetch_threads(self) -> int | None: + return self._client_fetch_threads + + @client_fetch_threads.setter + def client_fetch_threads(self, value: None | int) -> None: + if value is not None: + value = min(max(1, value), MAX_CLIENT_FETCH_THREADS) + self._client_fetch_threads = value + @property def rest(self) -> SnowflakeRestful | None: return self._rest @@ -776,6 +807,22 @@ def unsafe_file_write(self) -> bool: def unsafe_file_write(self, value: bool) -> None: self._unsafe_file_write = value + @property + def gcs_use_virtual_endpoints(self) -> bool: + return self._gcs_use_virtual_endpoints + + @gcs_use_virtual_endpoints.setter + def gcs_use_virtual_endpoints(self, value: bool) -> None: + self._gcs_use_virtual_endpoints = value + + @property + def check_arrow_conversion_error_on_every_column(self) -> bool: + return self._check_arrow_conversion_error_on_every_column + + @check_arrow_conversion_error_on_every_column.setter + def check_arrow_conversion_error_on_every_column(self, value: bool) -> bool: + self._check_arrow_conversion_error_on_every_column = value + def connect(self, **kwargs) -> None: """Establishes connection to Snowflake.""" logger.debug("connect") @@ -836,16 +883,16 @@ def close(self, retry: bool = True) -> None: self._cancel_heartbeat() # close telemetry first, since it needs rest to send remaining data - logger.info("closed") + logger.debug("closed") self._telemetry.close(send_on_close=bool(retry and self.telemetry_enabled)) if ( self._all_async_queries_finished() and not self._server_session_keep_alive ): - logger.info("No async queries seem to be running, deleting session") + logger.debug("No async queries seem to be running, deleting session") self.rest.delete_session(retry=retry) else: - logger.info( + logger.debug( "There are {} async queries still running, not deleting session".format( len(self._async_sfqids) ) @@ -1141,9 +1188,30 @@ def __open_connection(self): backoff_generator=self._backoff_generator, ) elif self._authenticator == PROGRAMMATIC_ACCESS_TOKEN: - if not self._token and self._password: - self._token = self._password self.auth_class = AuthByPAT(self._token) + elif self._authenticator == WORKLOAD_IDENTITY_AUTHENTICATOR: + if ENV_VAR_EXPERIMENTAL_AUTHENTICATION not in os.environ: + Error.errorhandler_wrapper( + self, + None, + ProgrammingError, + { + "msg": f"Please set the '{ENV_VAR_EXPERIMENTAL_AUTHENTICATION}' environment variable to use the '{WORKLOAD_IDENTITY_AUTHENTICATOR}' authenticator.", + "errno": ER_INVALID_WIF_SETTINGS, + }, + ) + # Standardize the provider enum. + if self._workload_identity_provider and isinstance( + self._workload_identity_provider, str + ): + self._workload_identity_provider = AttestationProvider.from_string( + self._workload_identity_provider + ) + self.auth_class = AuthByWorkloadIdentity( + provider=self._workload_identity_provider, + token=self._token, + entra_resource=self._workload_identity_entra_resource, + ) else: # okta URL, e.g., https://.okta.com/ self.auth_class = AuthByOkta( @@ -1267,6 +1335,7 @@ def __config(self, **kwargs): KEY_PAIR_AUTHENTICATOR, OAUTH_AUTHENTICATOR, USR_PWD_MFA_AUTHENTICATOR, + WORKLOAD_IDENTITY_AUTHENTICATOR, ]: self._authenticator = auth_tmp @@ -1277,14 +1346,19 @@ def __config(self, **kwargs): self._token = f.read() # Set of authenticators allowing empty user. - empty_user_allowed_authenticators = {OAUTH_AUTHENTICATOR, NO_AUTH_AUTHENTICATOR} + empty_user_allowed_authenticators = { + OAUTH_AUTHENTICATOR, + NO_AUTH_AUTHENTICATOR, + WORKLOAD_IDENTITY_AUTHENTICATOR, + PROGRAMMATIC_ACCESS_TOKEN, + } if not (self._master_token and self._session_token): if ( not self.user and self._authenticator not in empty_user_allowed_authenticators ): - # OAuth and NoAuth Authentications does not require a username + # Some authenticators do not require a username Error.errorhandler_wrapper( self, None, @@ -1295,6 +1369,25 @@ def __config(self, **kwargs): if self._private_key or self._private_key_file: self._authenticator = KEY_PAIR_AUTHENTICATOR + workload_identity_dependent_options = [ + "workload_identity_provider", + "workload_identity_entra_resource", + ] + for dependent_option in workload_identity_dependent_options: + if ( + self.__getattribute__(f"_{dependent_option}") is not None + and self._authenticator != WORKLOAD_IDENTITY_AUTHENTICATOR + ): + Error.errorhandler_wrapper( + self, + None, + ProgrammingError, + { + "msg": f"{dependent_option} was set but authenticator was not set to {WORKLOAD_IDENTITY_AUTHENTICATOR}", + "errno": ER_INVALID_WIF_SETTINGS, + }, + ) + if ( self.auth_class is None and self._authenticator @@ -1303,6 +1396,7 @@ def __config(self, **kwargs): OAUTH_AUTHENTICATOR, KEY_PAIR_AUTHENTICATOR, PROGRAMMATIC_ACCESS_TOKEN, + WORKLOAD_IDENTITY_AUTHENTICATOR, ) and not self._password ): diff --git a/src/snowflake/connector/connection_diagnostic.py b/src/snowflake/connector/connection_diagnostic.py index 227d86015f..61edb99333 100644 --- a/src/snowflake/connector/connection_diagnostic.py +++ b/src/snowflake/connector/connection_diagnostic.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import base64 @@ -583,7 +579,7 @@ def __check_for_proxies(self) -> None: cert_reqs=cert_reqs, ) resp = http.request( - "GET", "https://ireallyshouldnotexistatallanywhere.com", timeout=10.0 + "GET", "https://nonexistentdomain.invalidtld", timeout=10.0 ) # squid does not throw exception. Check HTML diff --git a/src/snowflake/connector/constants.py b/src/snowflake/connector/constants.py index b78198f20f..085ec7a2b3 100644 --- a/src/snowflake/connector/constants.py +++ b/src/snowflake/connector/constants.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from collections import defaultdict @@ -430,6 +426,7 @@ class IterUnit(Enum): # TODO: all env variables definitions should be here ENV_VAR_PARTNER = "SF_PARTNER" ENV_VAR_TEST_MODE = "SNOWFLAKE_TEST_MODE" +ENV_VAR_EXPERIMENTAL_AUTHENTICATION = "SF_ENABLE_EXPERIMENTAL_AUTHENTICATION" # Needed to enable new strong auth features during the private preview. _DOMAIN_NAME_MAP = {_DEFAULT_HOSTNAME_TLD: "GLOBAL", _CHINA_HOSTNAME_TLD: "CHINA"} diff --git a/src/snowflake/connector/converter.py b/src/snowflake/connector/converter.py index ac42b12678..8202351990 100644 --- a/src/snowflake/connector/converter.py +++ b/src/snowflake/connector/converter.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import binascii diff --git a/src/snowflake/connector/converter_issue23517.py b/src/snowflake/connector/converter_issue23517.py index 729a65d5aa..e65bc77ead 100644 --- a/src/snowflake/connector/converter_issue23517.py +++ b/src/snowflake/connector/converter_issue23517.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from datetime import datetime, time, timedelta, timezone, tzinfo diff --git a/src/snowflake/connector/converter_null.py b/src/snowflake/connector/converter_null.py index 3d03b1e6da..53ac45b4b7 100644 --- a/src/snowflake/connector/converter_null.py +++ b/src/snowflake/connector/converter_null.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from typing import Any diff --git a/src/snowflake/connector/converter_snowsql.py b/src/snowflake/connector/converter_snowsql.py index 189cd3de71..4da4a5170f 100644 --- a/src/snowflake/connector/converter_snowsql.py +++ b/src/snowflake/connector/converter_snowsql.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import time diff --git a/src/snowflake/connector/cursor.py b/src/snowflake/connector/cursor.py index 2f5526aafe..e6c3dfdb53 100644 --- a/src/snowflake/connector/cursor.py +++ b/src/snowflake/connector/cursor.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import collections @@ -39,9 +35,15 @@ from . import compat from ._sql_util import get_file_transfer_type -from ._utils import _TrackedQueryCancellationTimer +from ._utils import ( + REQUEST_ID_STATEMENT_PARAM_NAME, + _TrackedQueryCancellationTimer, + is_uuid4, +) from .bind_upload_agent import BindUploadAgent, BindUploadError from .constants import ( + CMD_TYPE_DOWNLOAD, + CMD_TYPE_UPLOAD, FIELD_NAME_TO_ID, PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT, FileTransferType, @@ -639,7 +641,27 @@ def _execute_helper( ) self._sequence_counter = self._connection._next_sequence_counter() - self._request_id = uuid.uuid4() + + # If requestId is contained in statement parameters, use it to set request id. Verify here it is a valid uuid4 + # identifier. + if ( + statement_params is not None + and REQUEST_ID_STATEMENT_PARAM_NAME in statement_params + ): + request_id = statement_params[REQUEST_ID_STATEMENT_PARAM_NAME] + + if not is_uuid4(request_id): + # uuid.UUID will throw an error if invalid, but we explicitly check and throw here. + raise ValueError(f"requestId {request_id} is not a valid UUID4.") + self._request_id = uuid.UUID(str(request_id), version=4) + + # Create a (deep copy) and remove the statement param, there is no need to encode it as extra parameter + # one more time. + statement_params = statement_params.copy() + statement_params.pop(REQUEST_ID_STATEMENT_PARAM_NAME) + else: + # Generate UUID for query. + self._request_id = uuid.uuid4() logger.debug(f"Request id: {self._request_id}") @@ -888,8 +910,8 @@ def execute( _exec_async: Whether to execute this query asynchronously. _no_retry: Whether or not to retry on known errors. _do_reset: Whether or not the result set needs to be reset before executing query. - _put_callback: Function to which GET command should call back to. - _put_azure_callback: Function to which an Azure GET command should call back to. + _put_callback: Function to which PUT command should call back to. + _put_azure_callback: Function to which an Azure PUT command should call back to. _put_callback_output_stream: The output stream a PUT command's callback should report on. _get_callback: Function to which GET command should call back to. _get_azure_callback: Function to which an Azure GET command should call back to. @@ -1061,6 +1083,7 @@ def execute( use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1, iobound_tpe_limit=self._connection.iobound_tpe_limit, unsafe_file_write=self._connection.unsafe_file_write, + gcs_use_virtual_endpoints=self._connection.gcs_use_virtual_endpoints, ) sf_file_transfer_agent.execute() data = sf_file_transfer_agent.result() @@ -1083,7 +1106,15 @@ def execute( logger.debug(ret) err = ret["message"] code = ret.get("code", -1) - if self._timebomb and self._timebomb.executed: + if ( + self._timebomb + and self._timebomb.executed + and "SQL execution canceled" in err + ): + # Modify the error message only if the server error response indicates the query was canceled. + # If the error occurs before the cancellation request reaches the backend + # (e.g., due to a very short timeout), we retain the original error message + # as the query might have encountered an issue prior to cancellation. err = ( f"SQL execution was cancelled by the client due to a timeout. " f"Error message received from the server: {err}" @@ -1177,7 +1208,8 @@ def _init_result_and_meta(self, data: dict[Any, Any]) -> None: self._result_set = ResultSet( self, result_chunks, - self._connection.client_prefetch_threads, + self._connection.client_fetch_threads + or self._connection.client_prefetch_threads, ) self._rownumber = -1 self._result_state = ResultState.VALID @@ -1274,7 +1306,7 @@ def query_result(self, qid: str) -> SnowflakeCursor: data = ret.get("data") self._init_result_and_meta(data) else: - logger.info("failed") + logger.debug("failed") logger.debug(ret) err = ret["message"] code = ret.get("code", -1) @@ -1730,6 +1762,153 @@ def get_result_batches(self) -> list[ResultBatch] | None: ) return self._result_set.batches + def _download( + self, + stage_location: str, + target_directory: str, + options: dict[str, Any], + _do_reset: bool = True, + ) -> None: + """Downloads from the stage location to the target directory. + + Args: + stage_location (str): The location of the stage to download from. + target_directory (str): The destination directory to download into. + options (dict[str, Any]): The download options. + _do_reset (bool, optional): Whether to reset the cursor before + downloading, by default we will reset the cursor. + """ + from .file_transfer_agent import SnowflakeFileTransferAgent + + if _do_reset: + self.reset() + + # Interpret the file operation. + ret = self.connection._file_operation_parser.parse_file_operation( + stage_location=stage_location, + local_file_name=None, + target_directory=target_directory, + command_type=CMD_TYPE_DOWNLOAD, + options=options, + ) + + # Execute the file operation based on the interpretation above. + file_transfer_agent = SnowflakeFileTransferAgent( + self, + "", # empty command because it is triggered by directly calling this util not by a SQL query + ret, + ) + file_transfer_agent.execute() + self._init_result_and_meta(file_transfer_agent.result()) + + def _upload( + self, + local_file_name: str, + stage_location: str, + options: dict[str, Any], + _do_reset: bool = True, + ) -> None: + """Uploads the local file to the stage location. + + Args: + local_file_name (str): The local file to be uploaded. + stage_location (str): The stage location to upload the local file to. + options (dict[str, Any]): The upload options. + _do_reset (bool, optional): Whether to reset the cursor before + uploading, by default we will reset the cursor. + """ + from .file_transfer_agent import SnowflakeFileTransferAgent + + if _do_reset: + self.reset() + + # Interpret the file operation. + ret = self.connection._file_operation_parser.parse_file_operation( + stage_location=stage_location, + local_file_name=local_file_name, + target_directory=None, + command_type=CMD_TYPE_UPLOAD, + options=options, + ) + + # Execute the file operation based on the interpretation above. + file_transfer_agent = SnowflakeFileTransferAgent( + self, + "", # empty command because it is triggered by directly calling this util not by a SQL query + ret, + force_put_overwrite=False, # _upload should respect user decision on overwriting + ) + file_transfer_agent.execute() + self._init_result_and_meta(file_transfer_agent.result()) + + def _download_stream( + self, stage_location: str, decompress: bool = False + ) -> IO[bytes]: + """Downloads from the stage location as a stream. + + Args: + stage_location (str): The location of the stage to download from. + decompress (bool, optional): Whether to decompress the file, by + default we do not decompress. + + Returns: + IO[bytes]: A stream to read from. + """ + # Interpret the file operation. + ret = self.connection._file_operation_parser.parse_file_operation( + stage_location=stage_location, + local_file_name=None, + target_directory=None, + command_type=CMD_TYPE_DOWNLOAD, + options=None, + has_source_from_stream=True, + ) + + # Set up stream downloading based on the interpretation and return the stream for reading. + return self.connection._stream_downloader.download_as_stream(ret, decompress) + + def _upload_stream( + self, + input_stream: IO[bytes], + stage_location: str, + options: dict[str, Any], + _do_reset: bool = True, + ) -> None: + """Uploads content in the input stream to the stage location. + + Args: + input_stream (IO[bytes]): A stream to read from. + stage_location (str): The location of the stage to upload to. + options (dict[str, Any]): The upload options. + _do_reset (bool, optional): Whether to reset the cursor before + uploading, by default we will reset the cursor. + """ + from .file_transfer_agent import SnowflakeFileTransferAgent + + if _do_reset: + self.reset() + + # Interpret the file operation. + ret = self.connection._file_operation_parser.parse_file_operation( + stage_location=stage_location, + local_file_name=None, + target_directory=None, + command_type=CMD_TYPE_UPLOAD, + options=options, + has_source_from_stream=input_stream, + ) + + # Execute the file operation based on the interpretation above. + file_transfer_agent = SnowflakeFileTransferAgent( + self, + "", # empty command because it is triggered by directly calling this util not by a SQL query + ret, + source_from_stream=input_stream, + force_put_overwrite=False, # _upload_stream should respect user decision on overwriting + ) + file_transfer_agent.execute() + self._init_result_and_meta(file_transfer_agent.result()) + class DictCursor(SnowflakeCursor): """Cursor returning results in a dictionary.""" diff --git a/src/snowflake/connector/dbapi.py b/src/snowflake/connector/dbapi.py index fb9863fdc7..973878a001 100644 --- a/src/snowflake/connector/dbapi.py +++ b/src/snowflake/connector/dbapi.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - """This module implements some constructors and singletons as required by the DB API v2.0 (PEP-249).""" from __future__ import annotations diff --git a/src/snowflake/connector/description.py b/src/snowflake/connector/description.py index e3acbc32f0..a45250e785 100644 --- a/src/snowflake/connector/description.py +++ b/src/snowflake/connector/description.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - """Various constants.""" from __future__ import annotations diff --git a/src/snowflake/connector/direct_file_operation_utils.py b/src/snowflake/connector/direct_file_operation_utils.py new file mode 100644 index 0000000000..2290b8f1e2 --- /dev/null +++ b/src/snowflake/connector/direct_file_operation_utils.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + + +class FileOperationParserBase(ABC): + """The interface of internal utility functions for file operation parsing.""" + + @abstractmethod + def __init__(self, connection): + pass + + @abstractmethod + def parse_file_operation( + self, + stage_location, + local_file_name, + target_directory, + command_type, + options, + has_source_from_stream=False, + ): + """Converts the file operation details into a SQL and returns the SQL parsing result.""" + pass + + +class StreamDownloaderBase(ABC): + """The interface of internal utility functions for stream downloading of file.""" + + @abstractmethod + def __init__(self, connection): + pass + + @abstractmethod + def download_as_stream(self, ret, decompress=False): + pass + + +class FileOperationParser(FileOperationParserBase): + def __init__(self, connection): + pass + + def parse_file_operation( + self, + stage_location, + local_file_name, + target_directory, + command_type, + options, + has_source_from_stream=False, + ): + raise NotImplementedError("parse_file_operation is not yet supported") + + +class StreamDownloader(StreamDownloaderBase): + def __init__(self, connection): + pass + + def download_as_stream(self, ret, decompress=False): + raise NotImplementedError("download_as_stream is not yet supported") diff --git a/src/snowflake/connector/encryption_util.py b/src/snowflake/connector/encryption_util.py index 78d54497cf..a1efd040ee 100644 --- a/src/snowflake/connector/encryption_util.py +++ b/src/snowflake/connector/encryption_util.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import base64 diff --git a/src/snowflake/connector/errorcode.py b/src/snowflake/connector/errorcode.py index 513b9d408f..1bc9138df2 100644 --- a/src/snowflake/connector/errorcode.py +++ b/src/snowflake/connector/errorcode.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations # network @@ -31,6 +27,8 @@ ER_JWT_RETRY_EXPIRED = 251010 ER_CONNECTION_TIMEOUT = 251011 ER_RETRYABLE_CODE = 251012 +ER_INVALID_WIF_SETTINGS = 251013 +ER_WIF_CREDENTIALS_NOT_FOUND = 251014 # cursor ER_FAILED_TO_REWRITE_MULTI_ROW_INSERT = 252001 diff --git a/src/snowflake/connector/errors.py b/src/snowflake/connector/errors.py index e7355105fc..d7e8e8c985 100644 --- a/src/snowflake/connector/errors.py +++ b/src/snowflake/connector/errors.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging diff --git a/src/snowflake/connector/externals_utils/__init__.py b/src/snowflake/connector/externals_utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/snowflake/connector/externals_utils/externals_setup.py b/src/snowflake/connector/externals_utils/externals_setup.py new file mode 100644 index 0000000000..1b0147cee8 --- /dev/null +++ b/src/snowflake/connector/externals_utils/externals_setup.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from snowflake.connector.logging_utils.filters import ( + SecretMaskingFilter, + add_filter_to_logger_and_children, +) + +MODULES_TO_MASK_LOGS_NAMES = [ + "snowflake.connector.vendored.urllib3", + "botocore", + "boto3", +] +# TODO: after migration to the external urllib3 from the vendored one (SNOW-2041970), +# we should change filters here immediately to the below module's logger: +# MODULES_TO_MASK_LOGS_NAMES = [ "urllib3", ... ] + + +def add_filters_to_external_loggers(): + for module_name in MODULES_TO_MASK_LOGS_NAMES: + add_filter_to_logger_and_children(module_name, SecretMaskingFilter()) + + +def setup_external_libraries(): + """ + Assures proper setup and injections before any external libraries are used. + """ + add_filters_to_external_loggers() diff --git a/src/snowflake/connector/feature.py b/src/snowflake/connector/feature.py index 6cbdd11184..5056359c56 100644 --- a/src/snowflake/connector/feature.py +++ b/src/snowflake/connector/feature.py @@ -1,7 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# # Feature flags feature_use_pyopenssl = True # use pyopenssl API or openssl command diff --git a/src/snowflake/connector/file_compression_type.py b/src/snowflake/connector/file_compression_type.py index ca33b7117a..b936658f3c 100644 --- a/src/snowflake/connector/file_compression_type.py +++ b/src/snowflake/connector/file_compression_type.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from typing import NamedTuple diff --git a/src/snowflake/connector/file_transfer_agent.py b/src/snowflake/connector/file_transfer_agent.py index 2a7addb872..393d88c429 100644 --- a/src/snowflake/connector/file_transfer_agent.py +++ b/src/snowflake/connector/file_transfer_agent.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import binascii @@ -319,6 +315,9 @@ def __init__( def update(self, cur_timestamp) -> None: with self.lock: if cur_timestamp < self.timestamp: + logger.debug( + "Omitting renewal of storage token, as it already happened." + ) return logger.debug("Renewing expired storage token.") ret = self.connection.cursor()._execute_helper(self._command) @@ -356,6 +355,7 @@ def __init__( use_s3_regional_url: bool = False, iobound_tpe_limit: int | None = None, unsafe_file_write: bool = False, + gcs_use_virtual_endpoints: bool = False, ) -> None: self._cursor = cursor self._command = command @@ -388,6 +388,7 @@ def __init__( self._credentials: StorageCredential | None = None self._iobound_tpe_limit = iobound_tpe_limit self._unsafe_file_write = unsafe_file_write + self._gcs_use_virtual_endpoints = gcs_use_virtual_endpoints def execute(self) -> None: self._parse_command() @@ -538,7 +539,7 @@ def transfer_done_cb( ) -> None: # Note: chunk_id is 0 based while num_of_chunks is count logger.debug( - f"Chunk {chunk_id}/{done_client.num_of_chunks} of file {done_client.meta.name} reached callback" + f"Chunk(id: {chunk_id}) {chunk_id+1}/{done_client.num_of_chunks} of file {done_client.meta.name} reached callback" ) with cv_chunk_process: transfer_metadata.chunks_in_queue -= 1 @@ -683,7 +684,6 @@ def _create_file_transfer_client( self._credentials, AZURE_CHUNK_SIZE, self._stage_info, - use_s3_regional_url=self._use_s3_regional_url, unsafe_file_write=self._unsafe_file_write, ) elif self._stage_location_type == S3_FS: @@ -703,8 +703,8 @@ def _create_file_transfer_client( self._stage_info, self._cursor._connection, self._command, - use_s3_regional_url=self._use_s3_regional_url, unsafe_file_write=self._unsafe_file_write, + use_virtual_endpoints=self._gcs_use_virtual_endpoints, ) raise Exception(f"{self._stage_location_type} is an unknown stage type") diff --git a/src/snowflake/connector/file_util.py b/src/snowflake/connector/file_util.py index 04744f76e8..f1f336e1c8 100644 --- a/src/snowflake/connector/file_util.py +++ b/src/snowflake/connector/file_util.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import base64 diff --git a/src/snowflake/connector/gcs_storage_client.py b/src/snowflake/connector/gcs_storage_client.py index e7db2f423e..06c5bd9a87 100644 --- a/src/snowflake/connector/gcs_storage_client.py +++ b/src/snowflake/connector/gcs_storage_client.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import json @@ -36,6 +32,7 @@ GCS_FILE_HEADER_DIGEST = "gcs-file-header-digest" GCS_FILE_HEADER_CONTENT_LENGTH = "gcs-file-header-content-length" GCS_FILE_HEADER_ENCRYPTION_METADATA = "gcs-file-header-encryption-metadata" +GCS_REGION_ME_CENTRAL_2 = "me-central2" CONTENT_CHUNK_SIZE = 10 * kilobyte ACCESS_TOKEN = "GCS_ACCESS_TOKEN" @@ -43,6 +40,7 @@ class GcsLocation(NamedTuple): bucket_name: str path: str + endpoint: str = "https://storage.googleapis.com" class SnowflakeGCSRestClient(SnowflakeStorageClient): @@ -53,8 +51,8 @@ def __init__( stage_info: dict[str, Any], cnx: SnowflakeConnection, command: str, - use_s3_regional_url: bool = False, unsafe_file_write: bool = False, + use_virtual_endpoints: bool = False, ) -> None: """Creates a client object with given stage credentials. @@ -79,6 +77,16 @@ def __init__( # presigned_url in meta is for downloading self.presigned_url: str = meta.presigned_url or stage_info.get("presignedUrl") self.security_token = credentials.creds.get("GCS_ACCESS_TOKEN") + self.use_regional_url = ( + "region" in stage_info + and stage_info["region"].lower() == GCS_REGION_ME_CENTRAL_2 + or "useRegionalUrl" in stage_info + and stage_info["useRegionalUrl"] + ) + self.endpoint: str | None = ( + None if "endPoint" not in stage_info else stage_info["endPoint"] + ) + self.use_virtual_endpoints: bool = use_virtual_endpoints if self.security_token: logger.debug(f"len(GCS_ACCESS_TOKEN): {len(self.security_token)}") @@ -91,7 +99,7 @@ def _has_expired_token(self, response: requests.Response) -> bool: def _has_expired_presigned_url(self, response: requests.Response) -> bool: # Presigned urls can be generated for any xml-api operation - # offered by GCS. Hence the error codes expected are similar + # offered by GCS. Hence, the error codes expected are similar # to xml api. # https://cloud.google.com/storage/docs/xml-api/reference-status @@ -152,7 +160,16 @@ def generate_url_and_rest_args() -> ( ): if not self.presigned_url: upload_url = self.generate_file_url( - self.stage_info["location"], meta.dst_file_name.lstrip("/") + self.stage_info["location"], + meta.dst_file_name.lstrip("/"), + self.use_regional_url, + ( + None + if "region" not in self.stage_info + else self.stage_info["region"] + ), + self.endpoint, + self.use_virtual_endpoints, ) access_token = self.security_token else: @@ -182,7 +199,16 @@ def generate_url_and_rest_args() -> ( gcs_headers = {} if not self.presigned_url: download_url = self.generate_file_url( - self.stage_info["location"], meta.src_file_name.lstrip("/") + self.stage_info["location"], + meta.src_file_name.lstrip("/"), + self.use_regional_url, + ( + None + if "region" not in self.stage_info + else self.stage_info["region"] + ), + self.endpoint, + self.use_virtual_endpoints, ) access_token = self.security_token gcs_headers["Authorization"] = f"Bearer {access_token}" @@ -339,7 +365,16 @@ def get_file_header(self, filename: str) -> FileHeader | None: def generate_url_and_authenticated_headers(): url = self.generate_file_url( - self.stage_info["location"], filename.lstrip("/") + self.stage_info["location"], + filename.lstrip("/"), + self.use_regional_url, + ( + None + if "region" not in self.stage_info + else self.stage_info["region"] + ), + self.endpoint, + self.use_virtual_endpoints, ) gcs_headers = {"Authorization": f"Bearer {self.security_token}"} rest_args = {"headers": gcs_headers} @@ -383,7 +418,13 @@ def generate_url_and_authenticated_headers(): return None @staticmethod - def extract_bucket_name_and_path(stage_location: str) -> GcsLocation: + def get_location( + stage_location: str, + use_regional_url: str = False, + region: str = None, + endpoint: str = None, + use_virtual_endpoints: bool = False, + ) -> GcsLocation: container_name = stage_location path = "" @@ -393,13 +434,40 @@ def extract_bucket_name_and_path(stage_location: str) -> GcsLocation: path = stage_location[stage_location.index("/") + 1 :] if path and not path.endswith("/"): path += "/" - - return GcsLocation(bucket_name=container_name, path=path) + if endpoint: + if endpoint.endswith("/"): + endpoint = endpoint[:-1] + return GcsLocation(bucket_name=container_name, path=path, endpoint=endpoint) + elif use_virtual_endpoints: + return GcsLocation( + bucket_name=container_name, + path=path, + endpoint=f"https://{container_name}.storage.googleapis.com", + ) + elif use_regional_url: + return GcsLocation( + bucket_name=container_name, + path=path, + endpoint=f"https://storage.{region.lower()}.rep.googleapis.com", + ) + else: + return GcsLocation(bucket_name=container_name, path=path) @staticmethod - def generate_file_url(stage_location: str, filename: str) -> str: - gcs_location = SnowflakeGCSRestClient.extract_bucket_name_and_path( - stage_location + def generate_file_url( + stage_location: str, + filename: str, + use_regional_url: str = False, + region: str = None, + endpoint: str = None, + use_virtual_endpoints: bool = False, + ) -> str: + gcs_location = SnowflakeGCSRestClient.get_location( + stage_location, use_regional_url, region, endpoint ) full_file_path = f"{gcs_location.path}{filename}" - return f"https://storage.googleapis.com/{gcs_location.bucket_name}/{quote(full_file_path)}" + + if use_virtual_endpoints: + return f"{gcs_location.endpoint}/{quote(full_file_path)}" + else: + return f"{gcs_location.endpoint}/{gcs_location.bucket_name}/{quote(full_file_path)}" diff --git a/src/snowflake/connector/gzip_decoder.py b/src/snowflake/connector/gzip_decoder.py index 6c370bc6df..4a6cd7e0bc 100644 --- a/src/snowflake/connector/gzip_decoder.py +++ b/src/snowflake/connector/gzip_decoder.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import io diff --git a/src/snowflake/connector/local_storage_client.py b/src/snowflake/connector/local_storage_client.py index 2d5152831c..eae85f98c9 100644 --- a/src/snowflake/connector/local_storage_client.py +++ b/src/snowflake/connector/local_storage_client.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import os diff --git a/src/snowflake/connector/log_configuration.py b/src/snowflake/connector/log_configuration.py index 35a914c6bd..476ab89610 100644 --- a/src/snowflake/connector/log_configuration.py +++ b/src/snowflake/connector/log_configuration.py @@ -1,8 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - - from __future__ import annotations import logging diff --git a/src/snowflake/connector/logging_utils/__init__.py b/src/snowflake/connector/logging_utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/snowflake/connector/logging_utils/filters.py b/src/snowflake/connector/logging_utils/filters.py new file mode 100644 index 0000000000..3c6cf73568 --- /dev/null +++ b/src/snowflake/connector/logging_utils/filters.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import logging + +from snowflake.connector.secret_detector import SecretDetector + + +def add_filter_to_logger_and_children( + base_logger_name: str, filter_instance: logging.Filter +) -> None: + # Ensure the base logger exists and apply filter + base_logger = logging.getLogger(base_logger_name) + if filter_instance not in base_logger.filters: + base_logger.addFilter(filter_instance) + + all_loggers_pairs = logging.root.manager.loggerDict.items() + for name, obj in all_loggers_pairs: + if not name.startswith(base_logger_name + "."): + continue + + if not isinstance(obj, logging.Logger): + continue # Skip placeholders + + if filter_instance not in obj.filters: + obj.addFilter(filter_instance) + + +class SecretMaskingFilter(logging.Filter): + """ + A logging filter that masks sensitive information in log messages using the SecretDetector utility. + + This filter is designed for scenarios where you want to avoid applying SecretDetector globally + as a formatter on all logging handlers. Global masking can introduce unnecessary computational + overhead, particularly for internal logs where secrets are already handled explicitly. + It would be also easy to bypass unintentionally by simply adding a neighbouring handler to a logger + - without SecretDetector set as a formatter. + + On the other hand, libraries or submodules often do not have any handler attached, so formatting can't be + configured on those level, while attaching new handler for that can cause unintended log output or its duplication. + + âš  Important: + - Logging filters do **not** propagate down the logger hierarchy. + To apply this filter across a hierarchy, use the `add_filter_to_logger_and_children` utility. + - This filter causes **early formatting** of the log message (`record.getMessage()`), + meaning `record.args` are merged into `record.msg` prematurely. + If you rely on `record.args`, ensure this is the **last** filter in the chain. + + Notes: + - The filter directly modifies `record.msg` with the masked version of the message. + - It clears `record.args` to prevent re-formatting and ensure safe message output. + + Example: + logger.addFilter(SecretMaskingFilter()) + handler.addFilter(SecretMaskingFilter()) + """ + + def filter(self, record: logging.LogRecord) -> bool: + try: + # Format the message as it would be + message = record.getMessage() + + # Run masking on the whole message + masked_data = SecretDetector.mask_secrets(message) + record.msg = masked_data.masked_text + except Exception as ex: + record.msg = SecretDetector.create_formatting_error_log( + record, "EXCEPTION - " + str(ex) + ) + finally: + record.args = () # Avoid format re-application of formatting + + return True # allow all logs through diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/ArrayConverter.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/ArrayConverter.cpp index 86e633661f..0c2fd05edd 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/ArrayConverter.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/ArrayConverter.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "ArrayConverter.hpp" #include diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/ArrayConverter.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/ArrayConverter.hpp index b4c3712bf3..0df105dce1 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/ArrayConverter.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/ArrayConverter.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_ARRAYCONVERTER_HPP #define PC_ARRAYCONVERTER_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/BinaryConverter.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/BinaryConverter.cpp index 401420965c..79f89080dd 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/BinaryConverter.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/BinaryConverter.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "BinaryConverter.hpp" #include diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/BinaryConverter.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/BinaryConverter.hpp index 6d027677c8..9d6ce73e50 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/BinaryConverter.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/BinaryConverter.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_BINARYCONVERTER_HPP #define PC_BINARYCONVERTER_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/BooleanConverter.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/BooleanConverter.cpp index f9b832fe5b..44ef88e3d3 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/BooleanConverter.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/BooleanConverter.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "BooleanConverter.hpp" #include diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/BooleanConverter.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/BooleanConverter.hpp index 23dd53ec82..aacb629f0d 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/BooleanConverter.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/BooleanConverter.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_BOOLEANCONVERTER_HPP #define PC_BOOLEANCONVERTER_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.cpp index bdc4d9aada..95ac959c8a 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "CArrowChunkIterator.hpp" #include @@ -27,7 +23,8 @@ namespace sf { CArrowChunkIterator::CArrowChunkIterator(PyObject* context, char* arrow_bytes, int64_t arrow_bytes_size, - PyObject* use_numpy) + PyObject* use_numpy, + PyObject* check_error_on_every_column) : CArrowIterator(arrow_bytes, arrow_bytes_size), m_latestReturnedRow(nullptr), m_context(context) { @@ -39,6 +36,7 @@ CArrowChunkIterator::CArrowChunkIterator(PyObject* context, char* arrow_bytes, m_rowCountInBatch = 0; m_latestReturnedRow.reset(); m_useNumpy = PyObject_IsTrue(use_numpy); + m_checkErrorOnEveryColumn = PyObject_IsTrue(check_error_on_every_column); m_batchCount = m_ipcArrowArrayVec.size(); m_columnCount = m_batchCount > 0 ? m_ipcArrowSchema->n_children : 0; @@ -92,6 +90,9 @@ void CArrowChunkIterator::createRowPyObject() { PyTuple_SET_ITEM( m_latestReturnedRow.get(), i, m_currentBatchConverters[i]->toPyObject(m_rowIndexInBatch)); + if (m_checkErrorOnEveryColumn && py::checkPyError()) { + return; + } } return; } @@ -505,7 +506,8 @@ DictCArrowChunkIterator::DictCArrowChunkIterator(PyObject* context, char* arrow_bytes, int64_t arrow_bytes_size, PyObject* use_numpy) - : CArrowChunkIterator(context, arrow_bytes, arrow_bytes_size, use_numpy) {} + : CArrowChunkIterator(context, arrow_bytes, arrow_bytes_size, use_numpy, + Py_False) {} void DictCArrowChunkIterator::createRowPyObject() { m_latestReturnedRow.reset(PyDict_New()); diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.hpp index b4f0e4b62f..c8f770decf 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_ARROWCHUNKITERATOR_HPP #define PC_ARROWCHUNKITERATOR_HPP @@ -33,7 +29,8 @@ class CArrowChunkIterator : public CArrowIterator { * Constructor */ CArrowChunkIterator(PyObject* context, char* arrow_bytes, - int64_t arrow_bytes_size, PyObject* use_numpy); + int64_t arrow_bytes_size, PyObject* use_numpy, + PyObject* check_error_on_every_column); /** * Destructor @@ -78,6 +75,10 @@ class CArrowChunkIterator : public CArrowIterator { /** true if return numpy int64 float64 datetime*/ bool m_useNumpy; + /** a flag that ensures running py::checkPyError after each column processing + * in order to fail early on first python processing error */ + bool m_checkErrorOnEveryColumn; + void initColumnConverters(); }; diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowIterator.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowIterator.cpp index 4c33f1a7ba..9ba4499b97 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowIterator.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowIterator.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "CArrowIterator.hpp" #include diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowIterator.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowIterator.hpp index 977d1d60aa..d24304fe05 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowIterator.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowIterator.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_ARROWITERATOR_HPP #define PC_ARROWITERATOR_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.cpp index 2eb1b6ee46..09e495bb1e 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "CArrowTableIterator.hpp" #include diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.hpp index 900fb542c5..7615ed264d 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_ARROWTABLEITERATOR_HPP #define PC_ARROWTABLEITERATOR_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DateConverter.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DateConverter.cpp index 1e6c225f52..237b56da50 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DateConverter.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DateConverter.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "DateConverter.hpp" #include diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DateConverter.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DateConverter.hpp index d7fb463b26..2adc1aa632 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DateConverter.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DateConverter.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_DATECONVERTER_HPP #define PC_DATECONVERTER_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DecFloatConverter.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DecFloatConverter.cpp index 40f73c3f88..1f2eddf813 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DecFloatConverter.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DecFloatConverter.cpp @@ -1,8 +1,4 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "DecFloatConverter.hpp" #include diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DecFloatConverter.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DecFloatConverter.hpp index e0b738aa93..65a5b38ae3 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DecFloatConverter.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DecFloatConverter.hpp @@ -1,8 +1,4 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_DECFLOATCONVERTER_HPP #define PC_DECFLOATCONVERTER_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DecimalConverter.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DecimalConverter.cpp index ddb334bf8e..5619ecc303 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DecimalConverter.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DecimalConverter.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "DecimalConverter.hpp" #include diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DecimalConverter.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DecimalConverter.hpp index e48094b6b3..62cef9c4ad 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DecimalConverter.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/DecimalConverter.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_DECIMALCONVERTER_HPP #define PC_DECIMALCONVERTER_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/FixedSizeListConverter.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/FixedSizeListConverter.cpp index 8bfaa079e4..f9418166ef 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/FixedSizeListConverter.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/FixedSizeListConverter.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "FixedSizeListConverter.hpp" namespace sf { diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/FixedSizeListConverter.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/FixedSizeListConverter.hpp index 757fd63f1a..9242c77167 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/FixedSizeListConverter.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/FixedSizeListConverter.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_FIXEDSIZELISTCONVERTER_HPP #define PC_FIXEDSIZELISTCONVERTER_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/FloatConverter.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/FloatConverter.cpp index 7b8c53c26b..8166797dc9 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/FloatConverter.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/FloatConverter.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "FloatConverter.hpp" #include diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/FloatConverter.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/FloatConverter.hpp index 81dd3b9333..eb68b5e9b0 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/FloatConverter.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/FloatConverter.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_FLOATCONVERTER_HPP #define PC_FLOATCONVERTER_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/IColumnConverter.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/IColumnConverter.hpp index 1f32b9dc9c..b3fca27221 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/IColumnConverter.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/IColumnConverter.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_ICOLUMNCONVERTER_HPP #define PC_ICOLUMNCONVERTER_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/IntConverter.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/IntConverter.cpp index a405c289e7..2523727fbf 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/IntConverter.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/IntConverter.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "IntConverter.hpp" namespace sf { diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/IntConverter.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/IntConverter.hpp index b0f59e101d..69f6e1b681 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/IntConverter.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/IntConverter.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_INTCONVERTER_HPP #define PC_INTCONVERTER_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/MapConverter.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/MapConverter.cpp index da4e5ccdb8..8fae45c3df 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/MapConverter.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/MapConverter.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "MapConverter.hpp" #include diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/MapConverter.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/MapConverter.hpp index 995fe1aba6..6baf2dd19a 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/MapConverter.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/MapConverter.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_MAPCONVERTER_HPP #define PC_MAPCONVERTER_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/ObjectConverter.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/ObjectConverter.cpp index 683fffc9a1..bd412b1d10 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/ObjectConverter.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/ObjectConverter.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "ObjectConverter.hpp" #include diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/ObjectConverter.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/ObjectConverter.hpp index 5db0e0f2fd..e2ea788833 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/ObjectConverter.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/ObjectConverter.hpp @@ -1,6 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// #ifndef PC_OBJECTCONVERTER_HPP #define PC_OBJECTCONVERTER_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Python/Common.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Python/Common.cpp index be2d7e28f4..2f5d365dcd 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Python/Common.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Python/Common.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "Common.hpp" namespace sf { diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Python/Common.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Python/Common.hpp index ea0b1aa437..2f24d85cbb 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Python/Common.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Python/Common.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_PYTHON_COMMON_HPP #define PC_PYTHON_COMMON_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Python/Helpers.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Python/Helpers.cpp index b8fe7791b8..05231479a9 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Python/Helpers.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Python/Helpers.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "Helpers.hpp" #include diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Python/Helpers.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Python/Helpers.hpp index 1fcb497a31..5baec725ed 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Python/Helpers.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Python/Helpers.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_PYTHON_HELPERS_HPP #define PC_PYTHON_HELPERS_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/SnowflakeType.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/SnowflakeType.cpp index bc8286baa6..6361f97597 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/SnowflakeType.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/SnowflakeType.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "SnowflakeType.hpp" namespace sf { diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/SnowflakeType.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/SnowflakeType.hpp index 76ec4169ab..b01a152a95 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/SnowflakeType.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/SnowflakeType.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_SNOWFLAKETYPE_HPP #define PC_SNOWFLAKETYPE_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/StringConverter.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/StringConverter.cpp index ee220cb1be..5c0b7eab89 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/StringConverter.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/StringConverter.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "StringConverter.hpp" #include diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/StringConverter.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/StringConverter.hpp index 77d6c9723c..aaaa7233fb 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/StringConverter.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/StringConverter.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_STRINGCONVERTER_HPP #define PC_STRINGCONVERTER_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/TimeConverter.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/TimeConverter.cpp index 2d79e78372..6fa9e66f1b 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/TimeConverter.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/TimeConverter.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "TimeConverter.hpp" namespace sf { diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/TimeConverter.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/TimeConverter.hpp index 283ad2908d..a3c18f4d55 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/TimeConverter.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/TimeConverter.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_TIMECONVERTER_HPP #define PC_TIMECONVERTER_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/TimeStampConverter.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/TimeStampConverter.cpp index 2c3b82871a..1bc505b26b 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/TimeStampConverter.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/TimeStampConverter.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "TimeStampConverter.hpp" #include diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/TimeStampConverter.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/TimeStampConverter.hpp index 9e522b44c4..73f5e151b5 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/TimeStampConverter.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/TimeStampConverter.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_TIMESTAMPCONVERTER_HPP #define PC_TIMESTAMPCONVERTER_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Util/macros.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Util/macros.hpp index 5890364ed8..e93ad688ca 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Util/macros.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Util/macros.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_UTIL_MACROS_HPP #define PC_UTIL_MACROS_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Util/time.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Util/time.cpp index 883352577f..f81dbaab07 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Util/time.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Util/time.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "time.hpp" namespace sf { diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Util/time.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Util/time.hpp index ab276e8866..d08ccd86a1 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Util/time.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/Util/time.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_UTIL_TIME_HPP #define PC_UTIL_TIME_HPP diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/nanoarrow_arrow_iterator.pyx b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/nanoarrow_arrow_iterator.pyx index e2daa5ba1b..9113157761 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/nanoarrow_arrow_iterator.pyx +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/nanoarrow_arrow_iterator.pyx @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - # distutils: language = c++ # cython: language_level=3 @@ -50,6 +46,7 @@ cdef extern from "CArrowChunkIterator.hpp" namespace "sf": char* arrow_bytes, int64_t arrow_bytes_size, PyObject* use_numpy, + PyObject* check_error_on_every_column, ) except + cdef cppclass DictCArrowChunkIterator(CArrowChunkIterator): @@ -100,6 +97,7 @@ cdef class PyArrowIterator(EmptyPyArrowIterator): # still be converted into native python types. # https://docs.snowflake.com/en/user-guide/sqlalchemy.html#numpy-data-type-support cdef object use_numpy + cdef object check_error_on_every_column cdef object number_to_decimal cdef object pyarrow_table @@ -111,12 +109,14 @@ cdef class PyArrowIterator(EmptyPyArrowIterator): object use_dict_result, object numpy, object number_to_decimal, + object check_error_on_every_column ): self.context = arrow_context self.cIterator = NULL self.use_dict_result = use_dict_result self.cursor = cursor self.use_numpy = numpy + self.check_error_on_every_column = check_error_on_every_column self.number_to_decimal = number_to_decimal self.pyarrow_table = None self.table_returned = False @@ -139,8 +139,9 @@ cdef class PyArrowRowIterator(PyArrowIterator): object use_dict_result, object numpy, object number_to_decimal, + object check_error_on_every_column, ): - super().__init__(cursor, py_inputstream, arrow_context, use_dict_result, numpy, number_to_decimal) + super().__init__(cursor, py_inputstream, arrow_context, use_dict_result, numpy, number_to_decimal, check_error_on_every_column) if self.cIterator is not NULL: return @@ -155,7 +156,8 @@ cdef class PyArrowRowIterator(PyArrowIterator): self.context, self.arrow_bytes, self.arrow_bytes_size, - self.use_numpy + self.use_numpy, + self.check_error_on_every_column ) cdef ReturnVal cret = self.cIterator.checkInitializationStatus() if cret.exception: @@ -200,8 +202,9 @@ cdef class PyArrowTableIterator(PyArrowIterator): object use_dict_result, object numpy, object number_to_decimal, + object check_error_on_every_column ): - super().__init__(cursor, py_inputstream, arrow_context, use_dict_result, numpy, number_to_decimal) + super().__init__(cursor, py_inputstream, arrow_context, use_dict_result, numpy, number_to_decimal, check_error_on_every_column) if not INSTALLED_PYARROW: raise Error.errorhandler_make_exception( ProgrammingError, diff --git a/src/snowflake/connector/nanoarrow_cpp/Logging/logging.cpp b/src/snowflake/connector/nanoarrow_cpp/Logging/logging.cpp index f5c410cd13..bf48c05398 100644 --- a/src/snowflake/connector/nanoarrow_cpp/Logging/logging.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/Logging/logging.cpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #include "logging.hpp" #include diff --git a/src/snowflake/connector/nanoarrow_cpp/Logging/logging.hpp b/src/snowflake/connector/nanoarrow_cpp/Logging/logging.hpp index ac55bbcc8d..798b9a3e9e 100644 --- a/src/snowflake/connector/nanoarrow_cpp/Logging/logging.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/Logging/logging.hpp @@ -1,7 +1,3 @@ -// -// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -// - #ifndef PC_LOGGING_HPP #define PC_LOGGING_HPP diff --git a/src/snowflake/connector/network.py b/src/snowflake/connector/network.py index 22222d9a11..adffc4b6b9 100644 --- a/src/snowflake/connector/network.py +++ b/src/snowflake/connector/network.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import collections @@ -189,6 +185,7 @@ USR_PWD_MFA_AUTHENTICATOR = "USERNAME_PASSWORD_MFA" PROGRAMMATIC_ACCESS_TOKEN = "PROGRAMMATIC_ACCESS_TOKEN" NO_AUTH_AUTHENTICATOR = "NO_AUTH" +WORKLOAD_IDENTITY_AUTHENTICATOR = "WORKLOAD_IDENTITY" def is_retryable_http_code(code: int) -> bool: @@ -356,6 +353,15 @@ def close(self) -> None: self._idle_sessions.clear() +# Customizable JSONEncoder to support additional types. +class SnowflakeRestfulJsonEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, uuid.UUID): + return str(o) + + return super().default(o) + + class SnowflakeRestful: """Snowflake Restful class.""" @@ -502,7 +508,7 @@ def request( return self._post_request( url, headers, - json.dumps(body), + json.dumps(body, cls=SnowflakeRestfulJsonEncoder), token=self.token, _no_results=_no_results, timeout=timeout, @@ -564,7 +570,7 @@ def _token_request(self, request_type): ret = self._post_request( url, headers, - json.dumps(body), + json.dumps(body, cls=SnowflakeRestfulJsonEncoder), token=header_token, ) if ret.get("success") and ret.get("data", {}).get("sessionToken"): @@ -662,7 +668,7 @@ def delete_session(self, retry: bool = False) -> None: ret = self._post_request( url, headers, - json.dumps(body), + json.dumps(body, cls=SnowflakeRestfulJsonEncoder), token=self.token, timeout=5, no_retry=True, diff --git a/src/snowflake/connector/ocsp_asn1crypto.py b/src/snowflake/connector/ocsp_asn1crypto.py index a664cd8920..54004b5c59 100644 --- a/src/snowflake/connector/ocsp_asn1crypto.py +++ b/src/snowflake/connector/ocsp_asn1crypto.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import typing diff --git a/src/snowflake/connector/ocsp_snowflake.py b/src/snowflake/connector/ocsp_snowflake.py index 4244bda695..64b0482f46 100644 --- a/src/snowflake/connector/ocsp_snowflake.py +++ b/src/snowflake/connector/ocsp_snowflake.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import codecs diff --git a/src/snowflake/connector/options.py b/src/snowflake/connector/options.py index be9f73cc9c..8454ab1699 100644 --- a/src/snowflake/connector/options.py +++ b/src/snowflake/connector/options.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import importlib diff --git a/src/snowflake/connector/pandas_tools.py b/src/snowflake/connector/pandas_tools.py index f58bb2a982..5c1626954e 100644 --- a/src/snowflake/connector/pandas_tools.py +++ b/src/snowflake/connector/pandas_tools.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import collections.abc diff --git a/src/snowflake/connector/proxy.py b/src/snowflake/connector/proxy.py index 1729bf4131..6b54e29ee5 100644 --- a/src/snowflake/connector/proxy.py +++ b/src/snowflake/connector/proxy.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import os diff --git a/src/snowflake/connector/result_batch.py b/src/snowflake/connector/result_batch.py index d2efd52b7a..86de908a6d 100644 --- a/src/snowflake/connector/result_batch.py +++ b/src/snowflake/connector/result_batch.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import abc @@ -62,6 +58,7 @@ def _create_nanoarrow_iterator( numpy: bool, number_to_decimal: bool, row_unit: IterUnit, + check_error_on_every_column: bool = True, ): from .nanoarrow_arrow_iterator import PyArrowRowIterator, PyArrowTableIterator @@ -74,6 +71,7 @@ def _create_nanoarrow_iterator( use_dict_result, numpy, number_to_decimal, + check_error_on_every_column, ) if row_unit == IterUnit.ROW_UNIT else PyArrowTableIterator( @@ -83,6 +81,7 @@ def _create_nanoarrow_iterator( use_dict_result, numpy, number_to_decimal, + check_error_on_every_column, ) ) @@ -614,7 +613,7 @@ def _load( ) def _from_data( - self, data: str, iter_unit: IterUnit + self, data: str, iter_unit: IterUnit, check_error_on_every_column: bool = True ) -> Iterator[dict | Exception] | Iterator[tuple | Exception]: """Creates a ``PyArrowIterator`` files from a str. @@ -631,6 +630,7 @@ def _from_data( self._numpy, self._number_to_decimal, iter_unit, + check_error_on_every_column, ) @classmethod @@ -665,7 +665,15 @@ def _create_iter( """Create an iterator for the ResultBatch. Used by get_arrow_iter.""" if self._local: try: - return self._from_data(self._data, iter_unit) + return self._from_data( + self._data, + iter_unit, + ( + connection.check_arrow_conversion_error_on_every_column + if connection + else None + ), + ) except Exception: if connection and getattr(connection, "_debug_arrow_chunk", False): logger.debug(f"arrow data can not be parsed: {self._data}") diff --git a/src/snowflake/connector/result_set.py b/src/snowflake/connector/result_set.py index 25d3560bd0..b633b41a07 100644 --- a/src/snowflake/connector/result_set.py +++ b/src/snowflake/connector/result_set.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import inspect diff --git a/src/snowflake/connector/s3_storage_client.py b/src/snowflake/connector/s3_storage_client.py index 1103fd9697..d2e49389d1 100644 --- a/src/snowflake/connector/s3_storage_client.py +++ b/src/snowflake/connector/s3_storage_client.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import binascii @@ -86,7 +82,13 @@ def __init__( self.stage_info["location"] ) ) - self.use_s3_regional_url = use_s3_regional_url + self.use_s3_regional_url = ( + use_s3_regional_url + or "useS3RegionalUrl" in stage_info + and stage_info["useS3RegionalUrl"] + or "useRegionalUrl" in stage_info + and stage_info["useRegionalUrl"] + ) self.location_type = stage_info.get("locationType") # if GS sends us an endpoint, it's likely for FIPS. Use it. @@ -327,6 +329,9 @@ def generate_authenticated_url_and_args_v4() -> tuple[bytes, dict[str, bytes]]: amzdate = t.strftime("%Y%m%dT%H%M%SZ") short_amzdate = amzdate[:8] x_amz_headers["x-amz-date"] = amzdate + x_amz_headers["x-amz-security-token"] = self.credentials.creds.get( + "AWS_TOKEN", "" + ) ( canonical_request, diff --git a/src/snowflake/connector/secret_detector.py b/src/snowflake/connector/secret_detector.py index a9e3d8123e..643a7e8fb9 100644 --- a/src/snowflake/connector/secret_detector.py +++ b/src/snowflake/connector/secret_detector.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - """The secret detector detects sensitive information. It masks secrets that might be leaked from two potential avenues @@ -14,11 +10,18 @@ import logging import os import re +from typing import NamedTuple MIN_TOKEN_LEN = os.getenv("MIN_TOKEN_LEN", 32) MIN_PWD_LEN = os.getenv("MIN_PWD_LEN", 8) +class MaskedMessageData(NamedTuple): + is_masked: bool = False + masked_text: str | None = None + error_str: str | None = None + + class SecretDetector(logging.Formatter): AWS_KEY_PATTERN = re.compile( r"(aws_key_id|aws_secret_key|access_key_id|secret_access_key)\s*=\s*'([^']+)'", @@ -52,21 +55,31 @@ class SecretDetector(logging.Formatter): flags=re.IGNORECASE, ) + SECRET_STARRED_MASK_STR = "****" + @staticmethod def mask_connection_token(text: str) -> str: - return SecretDetector.CONNECTION_TOKEN_PATTERN.sub(r"\1\2****", text) + return SecretDetector.CONNECTION_TOKEN_PATTERN.sub( + r"\1\2" + f"{SecretDetector.SECRET_STARRED_MASK_STR}", text + ) @staticmethod def mask_password(text: str) -> str: - return SecretDetector.PASSWORD_PATTERN.sub(r"\1\2****", text) + return SecretDetector.PASSWORD_PATTERN.sub( + r"\1\2" + f"{SecretDetector.SECRET_STARRED_MASK_STR}", text + ) @staticmethod def mask_aws_keys(text: str) -> str: - return SecretDetector.AWS_KEY_PATTERN.sub(r"\1='****'", text) + return SecretDetector.AWS_KEY_PATTERN.sub( + r"\1=" + f"'{SecretDetector.SECRET_STARRED_MASK_STR}'", text + ) @staticmethod def mask_sas_tokens(text: str) -> str: - return SecretDetector.SAS_TOKEN_PATTERN.sub(r"\1=****", text) + return SecretDetector.SAS_TOKEN_PATTERN.sub( + r"\1=" + f"{SecretDetector.SECRET_STARRED_MASK_STR}", text + ) @staticmethod def mask_aws_tokens(text: str) -> str: @@ -85,17 +98,17 @@ def mask_private_key_data(text: str) -> str: ) @staticmethod - def mask_secrets(text: str) -> tuple[bool, str, str | None]: + def mask_secrets(text: str) -> MaskedMessageData: """Masks any secrets. This is the method that should be used by outside classes. Args: text: A string which may contain a secret. Returns: - The masked string. + The masked string data in MaskedMessageData. """ if text is None: - return (False, None, None) + return MaskedMessageData() masked = False err_str = None @@ -123,7 +136,20 @@ def mask_secrets(text: str) -> tuple[bool, str, str | None]: masked_text = str(ex) err_str = str(ex) - return masked, masked_text, err_str + return MaskedMessageData(masked, masked_text, err_str) + + @staticmethod + def create_formatting_error_log( + original_record: logging.LogRecord, error_message: str + ) -> str: + return "{} - {} {} - {} - {} - {}".format( + original_record.asctime, + original_record.threadName, + "secret_detector.py", + "sanitize_log_str", + original_record.levelname, + error_message, + ) def format(self, record: logging.LogRecord) -> str: """Wrapper around logging module's formatter. @@ -138,25 +164,18 @@ def format(self, record: logging.LogRecord) -> str: """ try: unsanitized_log = super().format(record) - masked, sanitized_log, err_str = SecretDetector.mask_secrets( + masked, optional_sanitized_log, err_str = SecretDetector.mask_secrets( unsanitized_log ) + # Added to comply with type hints (Optional[str] is not accepted for str) + sanitized_log = optional_sanitized_log or "" + if masked and err_str is not None: - sanitized_log = "{} - {} {} - {} - {} - {}".format( - record.asctime, - record.threadName, - "secret_detector.py", - "sanitize_log_str", - record.levelname, - err_str, - ) + sanitized_log = self.create_formatting_error_log(record, err_str) + except Exception as ex: - sanitized_log = "{} - {} {} - {} - {} - {}".format( - record.asctime, - record.threadName, - "secret_detector.py", - "sanitize_log_str", - record.levelname, - "EXCEPTION - " + str(ex), + sanitized_log = self.create_formatting_error_log( + record, "EXCEPTION - " + str(ex) ) + return sanitized_log diff --git a/src/snowflake/connector/sf_dirs.py b/src/snowflake/connector/sf_dirs.py index 09164affba..e8b035f7aa 100644 --- a/src/snowflake/connector/sf_dirs.py +++ b/src/snowflake/connector/sf_dirs.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import os diff --git a/src/snowflake/connector/sfbinaryformat.py b/src/snowflake/connector/sfbinaryformat.py index 006caeb927..1b03c843d3 100644 --- a/src/snowflake/connector/sfbinaryformat.py +++ b/src/snowflake/connector/sfbinaryformat.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from base64 import b16decode, b16encode, standard_b64encode diff --git a/src/snowflake/connector/sfdatetime.py b/src/snowflake/connector/sfdatetime.py index cc7e652874..c1f5a92da7 100644 --- a/src/snowflake/connector/sfdatetime.py +++ b/src/snowflake/connector/sfdatetime.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import time diff --git a/src/snowflake/connector/snow_logging.py b/src/snowflake/connector/snow_logging.py index 2e639f2c23..2ec115e2ba 100644 --- a/src/snowflake/connector/snow_logging.py +++ b/src/snowflake/connector/snow_logging.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging diff --git a/src/snowflake/connector/sqlstate.py b/src/snowflake/connector/sqlstate.py index 0746f1db3f..a4d9f123f3 100644 --- a/src/snowflake/connector/sqlstate.py +++ b/src/snowflake/connector/sqlstate.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED = "08001" SQLSTATE_CONNECTION_ALREADY_EXISTS = "08002" SQLSTATE_CONNECTION_NOT_EXISTS = "08003" diff --git a/src/snowflake/connector/ssd_internal_keys.py b/src/snowflake/connector/ssd_internal_keys.py index f8d9951c42..077b2c742a 100644 --- a/src/snowflake/connector/ssd_internal_keys.py +++ b/src/snowflake/connector/ssd_internal_keys.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from binascii import unhexlify diff --git a/src/snowflake/connector/ssl_wrap_socket.py b/src/snowflake/connector/ssl_wrap_socket.py index f6a2e96579..f1016dbce1 100644 --- a/src/snowflake/connector/ssl_wrap_socket.py +++ b/src/snowflake/connector/ssl_wrap_socket.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations # diff --git a/src/snowflake/connector/storage_client.py b/src/snowflake/connector/storage_client.py index 7b178bf740..7fc8b67dfa 100644 --- a/src/snowflake/connector/storage_client.py +++ b/src/snowflake/connector/storage_client.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import os @@ -286,6 +282,7 @@ def _send_request_with_retry( conn = self.meta.sfagent._cursor.connection while self.retry_count[retry_id] < self.max_retry: + logger.debug(f"retry #{self.retry_count[retry_id]}") cur_timestamp = self.credentials.timestamp url, rest_kwargs = get_request_args() rest_kwargs["timeout"] = (REQUEST_CONNECTION_TIMEOUT, REQUEST_READ_TIMEOUT) @@ -299,10 +296,14 @@ def _send_request_with_retry( response = rest_call(url, **rest_kwargs) if self._has_expired_presigned_url(response): + logger.debug( + "presigned url expired. trying to update presigned url." + ) self._update_presigned_url() else: self.last_err_is_presigned_url = False if response.status_code in self.TRANSIENT_HTTP_ERR: + logger.debug(f"transient error: {response.status_code}") time.sleep( min( # TODO should SLEEP_UNIT come from the parent @@ -313,7 +314,9 @@ def _send_request_with_retry( ) self.retry_count[retry_id] += 1 elif self._has_expired_token(response): + logger.debug("token is expired. trying to update token") self.credentials.update(cur_timestamp) + self.retry_count[retry_id] += 1 else: return response except self.TRANSIENT_ERRORS as e: diff --git a/src/snowflake/connector/telemetry.py b/src/snowflake/connector/telemetry.py index 933fc489ad..bec64bf72c 100644 --- a/src/snowflake/connector/telemetry.py +++ b/src/snowflake/connector/telemetry.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging diff --git a/src/snowflake/connector/telemetry_oob.py b/src/snowflake/connector/telemetry_oob.py index ddf33ffd32..1db611db75 100644 --- a/src/snowflake/connector/telemetry_oob.py +++ b/src/snowflake/connector/telemetry_oob.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import datetime diff --git a/src/snowflake/connector/test_util.py b/src/snowflake/connector/test_util.py index 5516093420..5af3b35a18 100644 --- a/src/snowflake/connector/test_util.py +++ b/src/snowflake/connector/test_util.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging diff --git a/src/snowflake/connector/time_util.py b/src/snowflake/connector/time_util.py index ee758c3683..3fb5372b5a 100644 --- a/src/snowflake/connector/time_util.py +++ b/src/snowflake/connector/time_util.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import time diff --git a/src/snowflake/connector/token_cache.py b/src/snowflake/connector/token_cache.py index 1c45aec007..40a55f9e8b 100644 --- a/src/snowflake/connector/token_cache.py +++ b/src/snowflake/connector/token_cache.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import codecs diff --git a/src/snowflake/connector/tool/__init__.py b/src/snowflake/connector/tool/__init__.py index ef416f64a0..e69de29bb2 100644 --- a/src/snowflake/connector/tool/__init__.py +++ b/src/snowflake/connector/tool/__init__.py @@ -1,3 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# diff --git a/src/snowflake/connector/tool/dump_certs.py b/src/snowflake/connector/tool/dump_certs.py index 1d715da54b..cffcad870e 100644 --- a/src/snowflake/connector/tool/dump_certs.py +++ b/src/snowflake/connector/tool/dump_certs.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import os diff --git a/src/snowflake/connector/tool/dump_ocsp_response.py b/src/snowflake/connector/tool/dump_ocsp_response.py index 8cb55c3a73..69357ebddb 100644 --- a/src/snowflake/connector/tool/dump_ocsp_response.py +++ b/src/snowflake/connector/tool/dump_ocsp_response.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging diff --git a/src/snowflake/connector/tool/dump_ocsp_response_cache.py b/src/snowflake/connector/tool/dump_ocsp_response_cache.py index 0c0d74cc29..2e195eb50b 100644 --- a/src/snowflake/connector/tool/dump_ocsp_response_cache.py +++ b/src/snowflake/connector/tool/dump_ocsp_response_cache.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import json diff --git a/src/snowflake/connector/tool/probe_connection.py b/src/snowflake/connector/tool/probe_connection.py index a38422393e..81546ce14f 100644 --- a/src/snowflake/connector/tool/probe_connection.py +++ b/src/snowflake/connector/tool/probe_connection.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from socket import gaierror, gethostbyname_ex diff --git a/src/snowflake/connector/url_util.py b/src/snowflake/connector/url_util.py index 36a5a24371..788a9d52ad 100644 --- a/src/snowflake/connector/url_util.py +++ b/src/snowflake/connector/url_util.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import re diff --git a/src/snowflake/connector/util_text.py b/src/snowflake/connector/util_text.py index 2c24ae577f..39762c2111 100644 --- a/src/snowflake/connector/util_text.py +++ b/src/snowflake/connector/util_text.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import base64 diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py new file mode 100644 index 0000000000..f735b00eb4 --- /dev/null +++ b/src/snowflake/connector/wif_util.py @@ -0,0 +1,327 @@ +from __future__ import annotations + +import json +import logging +import os +from base64 import b64encode +from dataclasses import dataclass +from enum import Enum, unique + +import boto3 +import jwt +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest +from botocore.utils import InstanceMetadataRegionFetcher + +from .errorcode import ER_WIF_CREDENTIALS_NOT_FOUND +from .errors import ProgrammingError +from .vendored import requests +from .vendored.requests import Response + +logger = logging.getLogger(__name__) +SNOWFLAKE_AUDIENCE = "snowflakecomputing.com" +DEFAULT_ENTRA_SNOWFLAKE_RESOURCE = "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad" + + +@unique +class AttestationProvider(Enum): + """A WIF provider implementation that can produce an attestation.""" + + AWS = "AWS" + """Provider that builds an encoded pre-signed GetCallerIdentity request using the current workload's IAM role.""" + AZURE = "AZURE" + """Provider that requests an OAuth access token for the workload's managed identity.""" + GCP = "GCP" + """Provider that requests an ID token for the workload's attached service account.""" + OIDC = "OIDC" + """Provider that looks for an OIDC ID token.""" + + @staticmethod + def from_string(provider: str) -> AttestationProvider: + """Converts a string to a strongly-typed enum value of AttestationProvider.""" + return AttestationProvider[provider.upper()] + + +@dataclass +class WorkloadIdentityAttestation: + provider: AttestationProvider + credential: str + user_identifier_components: dict + + +def try_metadata_service_call( + method: str, url: str, headers: dict, timeout_sec: int = 3 +) -> Response | None: + """Tries to make a HTTP request to the metadata service with the given URL, method, headers and timeout. + + If we receive an error response or any exceptions are raised, returns None. Otherwise returns the response. + """ + try: + res: Response = requests.request( + method=method, url=url, headers=headers, timeout=timeout_sec + ) + if not res.ok: + return None + except requests.RequestException: + return None + return res + + +def extract_iss_and_sub_without_signature_verification(jwt_str: str) -> tuple[str, str]: + """Extracts the 'iss' and 'sub' claims from the given JWT, without verifying the signature. + + Note: the real token verification (including signature verification) happens on the Snowflake side. The driver doesn't have + the keys to verify these JWTs, and in any case that's not where the security boundary is drawn. + + We only decode the JWT here to get some basic claims, which will be used for a) a quick smoke test to ensure we got the right + issuer, and b) to find the unique user being asserted and populate assertion_content. The latter may be used for logging + and possibly caching. + + If there are any errors in parsing the token or extracting iss and sub, this will return (None, None). + """ + try: + claims = jwt.decode(jwt_str, options={"verify_signature": False}) + except jwt.exceptions.InvalidTokenError: + logger.warning("Token is not a valid JWT.", exc_info=True) + return None, None + + if not ("iss" in claims and "sub" in claims): + logger.warning("Token is missing 'iss' or 'sub' claims.") + return None, None + + return claims["iss"], claims["sub"] + + +def get_aws_region() -> str | None: + """Get the current AWS workload's region, if any.""" + if "AWS_REGION" in os.environ: # Lambda + return os.environ["AWS_REGION"] + else: # EC2 + return InstanceMetadataRegionFetcher().retrieve_region() + + +def get_aws_arn() -> str | None: + """Get the current AWS workload's ARN, if any.""" + caller_identity = boto3.client("sts").get_caller_identity() + if not caller_identity or "Arn" not in caller_identity: + return None + return caller_identity["Arn"] + + +def create_aws_attestation() -> WorkloadIdentityAttestation | None: + """Tries to create a workload identity attestation for AWS. + + If the application isn't running on AWS or no credentials were found, returns None. + """ + aws_creds = boto3.session.Session().get_credentials() + if not aws_creds: + logger.debug("No AWS credentials were found.") + return None + region = get_aws_region() + if not region: + logger.debug("No AWS region was found.") + return None + arn = get_aws_arn() + if not arn: + logger.debug("No AWS caller identity was found.") + return None + + sts_hostname = f"sts.{region}.amazonaws.com" + request = AWSRequest( + method="POST", + url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15", + headers={ + "Host": sts_hostname, + "X-Snowflake-Audience": SNOWFLAKE_AUDIENCE, + }, + ) + + SigV4Auth(aws_creds, "sts", region).add_auth(request) + + assertion_dict = { + "url": request.url, + "method": request.method, + "headers": dict(request.headers.items()), + } + credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode("utf-8") + return WorkloadIdentityAttestation( + AttestationProvider.AWS, credential, {"arn": arn} + ) + + +def create_gcp_attestation() -> WorkloadIdentityAttestation | None: + """Tries to create a workload identity attestation for GCP. + + If the application isn't running on GCP or no credentials were found, returns None. + """ + res = try_metadata_service_call( + method="GET", + url=f"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/identity?audience={SNOWFLAKE_AUDIENCE}", + headers={ + "Metadata-Flavor": "Google", + }, + ) + if res is None: + # Most likely we're just not running on GCP, which may be expected. + logger.debug("GCP metadata server request was not successful.") + return None + + jwt_str = res.content.decode("utf-8") + issuer, subject = extract_iss_and_sub_without_signature_verification(jwt_str) + if not issuer or not subject: + return None + if issuer != "https://accounts.google.com": + # This might happen if we're running on a different platform that responds to the same metadata request signature as GCP. + logger.debug("Unexpected GCP token issuer '%s'", issuer) + return None + + return WorkloadIdentityAttestation( + AttestationProvider.GCP, jwt_str, {"sub": subject} + ) + + +def create_azure_attestation( + snowflake_entra_resource: str, +) -> WorkloadIdentityAttestation | None: + """Tries to create a workload identity attestation for Azure. + + If the application isn't running on Azure or no credentials were found, returns None. + """ + headers = {"Metadata": "True"} + url_without_query_string = "http://169.254.169.254/metadata/identity/oauth2/token" + query_params = f"api-version=2018-02-01&resource={snowflake_entra_resource}" + + # Check if running in Azure Functions environment + identity_endpoint = os.environ.get("IDENTITY_ENDPOINT") + identity_header = os.environ.get("IDENTITY_HEADER") + is_azure_functions = identity_endpoint is not None + + if is_azure_functions: + if not identity_header: + logger.warning("Managed identity is not enabled on this Azure function.") + return None + + # Azure Functions uses a different endpoint, headers and API version. + url_without_query_string = identity_endpoint + headers = {"X-IDENTITY-HEADER": identity_header} + query_params = f"api-version=2019-08-01&resource={snowflake_entra_resource}" + + # Some Azure Functions environments may require client_id in the URL + managed_identity_client_id = os.environ.get("MANAGED_IDENTITY_CLIENT_ID") + if managed_identity_client_id: + query_params += f"&client_id={managed_identity_client_id}" + + res = try_metadata_service_call( + method="GET", + url=f"{url_without_query_string}?{query_params}", + headers=headers, + ) + if res is None: + # Most likely we're just not running on Azure, which may be expected. + logger.debug("Azure metadata server request was not successful.") + return None + + try: + jwt_str = res.json().get("access_token") + if not jwt_str: + # Could be that Managed Identity is disabled. + logger.debug("No access token found in Azure response.") + return None + except (ValueError, KeyError) as e: + logger.debug(f"Error parsing Azure response: {e}") + return None + + issuer, subject = extract_iss_and_sub_without_signature_verification(jwt_str) + if not issuer or not subject: + return None + if not ( + issuer.startswith("https://sts.windows.net/") + or issuer.startswith("https://login.microsoftonline.com/") + ): + # This might happen if we're running on a different platform that responds to the same metadata request signature as Azure. + logger.debug("Unexpected Azure token issuer '%s'", issuer) + return None + + return WorkloadIdentityAttestation( + AttestationProvider.AZURE, jwt_str, {"iss": issuer, "sub": subject} + ) + + +def create_oidc_attestation(token: str | None) -> WorkloadIdentityAttestation | None: + """Tries to create an attestation using the given token. + + If this is not populated, returns None. + """ + if not token: + logger.debug("No OIDC token was specified.") + return None + + issuer, subject = extract_iss_and_sub_without_signature_verification(token) + if not issuer or not subject: + return None + + return WorkloadIdentityAttestation( + AttestationProvider.OIDC, token, {"iss": issuer, "sub": subject} + ) + + +def create_autodetect_attestation( + entra_resource: str, token: str | None = None +) -> WorkloadIdentityAttestation | None: + """Tries to create an attestation using the auto-detected runtime environment. + + If no attestation can be found, returns None. + """ + attestation = create_oidc_attestation(token) + if attestation: + return attestation + + attestation = create_aws_attestation() + if attestation: + return attestation + + attestation = create_azure_attestation(entra_resource) + if attestation: + return attestation + + attestation = create_gcp_attestation() + if attestation: + return attestation + + return None + + +def create_attestation( + provider: AttestationProvider | None, + entra_resource: str | None = None, + token: str | None = None, +) -> WorkloadIdentityAttestation: + """Entry point to create an attestation using the given provider. + + If the provider is None, this will try to auto-detect a credential from the runtime environment. If the provider fails to detect a credential, + a ProgrammingError will be raised. + + If an explicit entra_resource was provided to the connector, this will be used. Otherwise, the default Snowflake Entra resource will be used. + """ + entra_resource = entra_resource or DEFAULT_ENTRA_SNOWFLAKE_RESOURCE + + attestation: WorkloadIdentityAttestation = None + if provider == AttestationProvider.AWS: + attestation = create_aws_attestation() + elif provider == AttestationProvider.AZURE: + attestation = create_azure_attestation(entra_resource) + elif provider == AttestationProvider.GCP: + attestation = create_gcp_attestation() + elif provider == AttestationProvider.OIDC: + attestation = create_oidc_attestation(token) + elif provider is None: + attestation = create_autodetect_attestation(entra_resource, token) + + if not attestation: + provider_str = "auto-detect" if provider is None else provider.value + raise ProgrammingError( + msg=f"No workload identity credential was found for '{provider_str}'.", + errno=ER_WIF_CREDENTIALS_NOT_FOUND, + ) + + return attestation diff --git a/test/__init__.py b/test/__init__.py index 49c0cb56ad..976bb38cd6 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations # This file houses functions and constants shared by both integration and unit tests diff --git a/test/conftest.py b/test/conftest.py index 59b46690b8..41c7c33bce 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import os diff --git a/test/csp_helpers.py b/test/csp_helpers.py new file mode 100644 index 0000000000..b793215359 --- /dev/null +++ b/test/csp_helpers.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python +import datetime +import json +import logging +import os +from abc import ABC, abstractmethod +from time import time +from unittest import mock +from urllib.parse import parse_qs, urlparse + +import jwt +from botocore.awsrequest import AWSRequest +from botocore.credentials import Credentials + +from snowflake.connector.vendored.requests.exceptions import ConnectTimeout, HTTPError +from snowflake.connector.vendored.requests.models import Response + +logger = logging.getLogger(__name__) + + +def gen_dummy_id_token( + sub="test-subject", iss="test-issuer", aud="snowflakecomputing.com" +) -> str: + """Generates a dummy ID token using the given subject and issuer.""" + now = int(time()) + key = "secret" + payload = { + "sub": sub, + "iss": iss, + "aud": aud, + "iat": now, + "exp": now + 60 * 60, + } + logger.debug(f"Generating dummy token with the following claims:\n{str(payload)}") + return jwt.encode( + payload=payload, + key=key, + algorithm="HS256", + ) + + +def build_response(content: bytes, status_code: int = 200) -> Response: + """Builds a requests.Response object with the given status code and content.""" + response = Response() + response.status_code = status_code + response._content = content + return response + + +class FakeMetadataService(ABC): + """Base class for fake metadata service implementations.""" + + def __init__(self): + self.reset_defaults() + + @abstractmethod + def reset_defaults(self): + """Resets any default values for test parameters. + + This is called in the constructor and when entering as a context manager. + """ + pass + + @property + @abstractmethod + def expected_hostname(self): + """Hostname at which this metadata service is listening. + + Used to raise a ConnectTimeout for requests not targeted to this hostname. + """ + pass + + @abstractmethod + def handle_request(self, method, parsed_url, headers, timeout): + """Main business logic for handling this request. Should return a Response object.""" + pass + + def __call__(self, method, url, headers, timeout): + """Entry point for the requests mock.""" + logger.debug(f"Received request: {method} {url} {str(headers)}") + parsed_url = urlparse(url) + + if not parsed_url.hostname == self.expected_hostname: + logger.debug( + f"Received request to unexpected hostname {parsed_url.hostname}" + ) + raise ConnectTimeout() + + return self.handle_request(method, parsed_url, headers, timeout) + + def __enter__(self): + """Patches the relevant HTTP calls when entering as a context manager.""" + self.reset_defaults() + self.patchers = [] + # requests.request is used by the direct metadata service API calls from our code. This is the main + # thing being faked here. + self.patchers.append( + mock.patch( + "snowflake.connector.vendored.requests.request", side_effect=self + ) + ) + + # HTTPConnection.request is used by the AWS boto libraries. We're not mocking those calls here, so we + # simply raise a ConnectTimeout to avoid making real network calls. + self.patchers.append( + mock.patch( + "urllib3.connection.HTTPConnection.request", + side_effect=ConnectTimeout(), + ) + ) + for patcher in self.patchers: + patcher.__enter__() + return self + + def __exit__(self, *args, **kwargs): + for patcher in self.patchers: + patcher.__exit__(*args, **kwargs) + + +class NoMetadataService(FakeMetadataService): + """Emulates an environment without any metadata service.""" + + def reset_defaults(self): + pass + + @property + def expected_hostname(self): + return None # Always raise a ConnectTimeout. + + def handle_request(self, method, parsed_url, headers, timeout): + # This should never be called because we always raise a ConnectTimeout. + pass + + +class FakeAzureVmMetadataService(FakeMetadataService): + """Emulates an environment with the Azure VM metadata service.""" + + def reset_defaults(self): + # Defaults used for generating an Entra ID token. Can be overriden in individual tests. + self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" + self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" + + @property + def expected_hostname(self): + return "169.254.169.254" + + def handle_request(self, method, parsed_url, headers, timeout): + query_string = parse_qs(parsed_url.query) + + # Reject malformed requests. + if not ( + method == "GET" + and parsed_url.path == "/metadata/identity/oauth2/token" + and headers.get("Metadata") == "True" + and query_string["resource"] + ): + raise HTTPError() + + logger.debug("Received request for Azure VM metadata service") + + resource = query_string["resource"][0] + self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=resource) + return build_response(json.dumps({"access_token": self.token}).encode("utf-8")) + + +class FakeAzureFunctionMetadataService(FakeMetadataService): + """Emulates an environment with the Azure Function metadata service.""" + + def reset_defaults(self): + # Defaults used for generating an Entra ID token. Can be overriden in individual tests. + self.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" + self.iss = "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" + + self.identity_endpoint = "http://169.254.255.2:8081/msi/token" + self.identity_header = "FD80F6DA783A4881BE9FAFA365F58E7A" + self.parsed_identity_endpoint = urlparse(self.identity_endpoint) + + @property + def expected_hostname(self): + return self.parsed_identity_endpoint.hostname + + def handle_request(self, method, parsed_url, headers, timeout): + query_string = parse_qs(parsed_url.query) + + # Reject malformed requests. + if not ( + method == "GET" + and parsed_url.path == self.parsed_identity_endpoint.path + and headers.get("X-IDENTITY-HEADER") == self.identity_header + and query_string["resource"] + ): + logger.warning( + f"Received malformed request: {method} {parsed_url.path} {str(headers)} {str(query_string)}" + ) + raise HTTPError() + + logger.debug("Received request for Azure Functions metadata service") + + resource = query_string["resource"][0] + self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=resource) + return build_response(json.dumps({"access_token": self.token}).encode("utf-8")) + + def __enter__(self): + # In addition to the normal patching, we need to set the environment variables that Azure Functions would set. + os.environ["IDENTITY_ENDPOINT"] = self.identity_endpoint + os.environ["IDENTITY_HEADER"] = self.identity_header + return super().__enter__() + + def __exit__(self, *args, **kwargs): + os.environ.pop("IDENTITY_ENDPOINT") + os.environ.pop("IDENTITY_HEADER") + return super().__exit__(*args, **kwargs) + + +class FakeGceMetadataService(FakeMetadataService): + """Emulates an environment with the GCE metadata service.""" + + def reset_defaults(self): + # Defaults used for generating a token. Can be overriden in individual tests. + self.sub = "123" + self.iss = "https://accounts.google.com" + + @property + def expected_hostname(self): + return "169.254.169.254" + + def handle_request(self, method, parsed_url, headers, timeout): + query_string = parse_qs(parsed_url.query) + + # Reject malformed requests. + if not ( + method == "GET" + and parsed_url.path + == "/computeMetadata/v1/instance/service-accounts/default/identity" + and headers.get("Metadata-Flavor") == "Google" + and query_string["audience"] + ): + raise HTTPError() + + logger.debug("Received request for GCE metadata service") + + audience = query_string["audience"][0] + self.token = gen_dummy_id_token(sub=self.sub, iss=self.iss, aud=audience) + return build_response(self.token.encode("utf-8")) + + +class FakeAwsEnvironment: + """Emulates the AWS environment-specific functions used in wif_util.py. + + Unlike the other metadata services, the HTTP calls made by AWS are deep within boto libaries, so + emulating them here would be complex and fragile. Instead, we emulate the higher-level functions + called by the connector code. + """ + + def __init__(self): + # Defaults used for generating a token. Can be overriden in individual tests. + self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-34afe100cad287fab" + self.region = "us-east-1" + self.credentials = Credentials(access_key="ak", secret_key="sk") + + def get_region(self): + return self.region + + def get_arn(self): + return self.arn + + def get_credentials(self): + return self.credentials + + def sign_request(self, request: AWSRequest): + request.headers.add_header( + "X-Amz-Date", datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ") + ) + request.headers.add_header("X-Amz-Security-Token", "") + request.headers.add_header( + "Authorization", + f"AWS4-HMAC-SHA256 Credential=, SignedHeaders={';'.join(request.headers.keys())}, Signature=", + ) + + def __enter__(self): + # Patch the relevant functions to do what we want. + self.patchers = [] + + # Patch sync boto3 calls + self.patchers.append( + mock.patch( + "boto3.session.Session.get_credentials", + side_effect=self.get_credentials, + ) + ) + self.patchers.append( + mock.patch( + "botocore.auth.SigV4Auth.add_auth", side_effect=self.sign_request + ) + ) + self.patchers.append( + mock.patch( + "snowflake.connector.wif_util.get_aws_region", + side_effect=self.get_region, + ) + ) + self.patchers.append( + mock.patch( + "snowflake.connector.wif_util.get_aws_arn", side_effect=self.get_arn + ) + ) + + for patcher in self.patchers: + patcher.__enter__() + return self + + def __exit__(self, *args, **kwargs): + for patcher in self.patchers: + patcher.__exit__(*args, **kwargs) diff --git a/test/data/wiremock/mappings/auth/pat/invalid_token.json b/test/data/wiremock/mappings/auth/pat/invalid_token.json index 5014a2b170..ca6f9329fb 100644 --- a/test/data/wiremock/mappings/auth/pat/invalid_token.json +++ b/test/data/wiremock/mappings/auth/pat/invalid_token.json @@ -11,7 +11,6 @@ { "equalToJson": { "data": { - "LOGIN_NAME": "testUser", "AUTHENTICATOR": "PROGRAMMATIC_ACCESS_TOKEN", "TOKEN": "some PAT" } diff --git a/test/data/wiremock/mappings/auth/pat/successful_flow.json b/test/data/wiremock/mappings/auth/pat/successful_flow.json index 10b138f078..323057f330 100644 --- a/test/data/wiremock/mappings/auth/pat/successful_flow.json +++ b/test/data/wiremock/mappings/auth/pat/successful_flow.json @@ -11,7 +11,6 @@ { "equalToJson": { "data": { - "LOGIN_NAME": "testUser", "AUTHENTICATOR": "PROGRAMMATIC_ACCESS_TOKEN", "TOKEN": "some PAT" } diff --git a/test/extras/__init__.py b/test/extras/__init__.py index ef416f64a0..e69de29bb2 100644 --- a/test/extras/__init__.py +++ b/test/extras/__init__.py @@ -1,3 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# diff --git a/test/extras/run.py b/test/extras/run.py index 1dab55162f..e29bfecc75 100644 --- a/test/extras/run.py +++ b/test/extras/run.py @@ -1,6 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# import os import pathlib import platform diff --git a/test/extras/simple_select1.py b/test/extras/simple_select1.py index 957cf88ed6..b4c7856c82 100644 --- a/test/extras/simple_select1.py +++ b/test/extras/simple_select1.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from snowflake.connector import connect from ..parameters import CONNECTION_PARAMETERS diff --git a/test/generate_test_files.py b/test/generate_test_files.py index 38e46a0b9b..4f4fb4472d 100644 --- a/test/generate_test_files.py +++ b/test/generate_test_files.py @@ -1,8 +1,4 @@ #!/usr/bin/env python3 -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import argparse diff --git a/test/helpers.py b/test/helpers.py index 98f1db898a..6562aa83f4 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import asyncio @@ -175,6 +171,7 @@ def create_nanoarrow_pyarrow_iterator(input_data, use_table_iterator): False, False, False, + True, ) if not use_table_iterator else NanoarrowPyArrowTableIterator( @@ -186,6 +183,7 @@ def create_nanoarrow_pyarrow_iterator(input_data, use_table_iterator): False, False, False, + False, ) ) diff --git a/test/integ/__init__.py b/test/integ/__init__.py index ef416f64a0..e69de29bb2 100644 --- a/test/integ/__init__.py +++ b/test/integ/__init__.py @@ -1,3 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# diff --git a/test/integ/aio/test_connection_async.py b/test/integ/aio/test_connection_async.py index bb2a852b5d..c8d7ea6a4d 100644 --- a/test/integ/aio/test_connection_async.py +++ b/test/integ/aio/test_connection_async.py @@ -1686,3 +1686,34 @@ async def test_no_auth_connection_negative_case(): await conn.execute_string("select 1") await conn.close() + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize( + "value", + [ + True, + False, + ], +) +async def test_gcs_use_virtual_endpoints(value): + with mock.patch( + "snowflake.connector.aio._network.SnowflakeRestful.fetch", + return_value={"data": {"token": None, "masterToken": None}, "success": True}, + ): + cnx = snowflake.connector.aio.SnowflakeConnection( + user="test-user", + password="test-password", + host="test-host", + port="443", + account="test-account", + gcs_use_virtual_endpoints=value, + ) + try: + await cnx.connect() + cnx.commit = cnx.rollback = ( + lambda: None + ) # Skip tear down, there's only a mocked rest api + assert cnx.gcs_use_virtual_endpoints == value + finally: + await cnx.close() diff --git a/test/integ/aio/test_cursor_async.py b/test/integ/aio/test_cursor_async.py index e437d942d2..878d2d4085 100644 --- a/test/integ/aio/test_cursor_async.py +++ b/test/integ/aio/test_cursor_async.py @@ -12,9 +12,11 @@ import os import pickle import time +import uuid from datetime import date, datetime, timezone from typing import NamedTuple from unittest import mock +from unittest.mock import MagicMock import pytest import pytz @@ -792,6 +794,7 @@ async def test_invalid_bind_data_type(conn_cnx): await cnx.cursor().execute("select 1 from dual where 1=%s", ([1, 2, 3],)) +@pytest.mark.skipolddriver async def test_timeout_query(conn_cnx): async with conn_cnx() as cnx: async with cnx.cursor() as c: @@ -802,8 +805,30 @@ async def test_timeout_query(conn_cnx): ) assert err.value.errno == 604, ( "Invalid error code" - and "SQL execution was cancelled by the client due to a timeout" + and "SQL execution was cancelled by the client due to a timeout. Error message received from the server: SQL execution canceled" + in err.value.msg + ) + + with pytest.raises(errors.ProgrammingError) as err: + # we can not precisely control the timing to send cancel query request right after server + # executes the query but before returning the results back to client + # it depends on python scheduling and server processing speed, so we mock here + mock_timebomb = MagicMock() + mock_timebomb.result.return_value = True + + with mock.patch.object(c, "_timebomb", mock_timebomb): + await c.execute( + "select 123'", + timeout=0.1, + ) + assert ( + mock_timebomb.result.return_value is True and err.value.errno == 1003 + ), ( + "Invalid error code" + and "SQL compilation error:\nsyntax error line 1 at position 10 unexpected '''." in err.value.msg + and "SQL execution was cancelled by the client due to a timeout" + not in err.value.msg ) @@ -1725,6 +1750,24 @@ async def test_out_of_range_year(conn_cnx, result_format, cursor_type, fetch_met await fetch_next_fn() +@pytest.mark.skipolddriver +@pytest.mark.parametrize("result_format", ("json", "arrow")) +async def test_out_of_range_year_followed_by_correct_year(conn_cnx, result_format): + """Tests whether the year 10000 is out of range exception is raised as expected.""" + async with conn_cnx( + session_parameters={ + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: result_format + } + ) as con: + async with con.cursor() as cur: + await cur.execute("select TO_DATE('10000-01-01'), TO_DATE('9999-01-01')") + with pytest.raises( + InterfaceError, + match="out of range", + ): + await cur.fetchall() + + async def test_describe(conn_cnx): async with conn_cnx() as con: async with con.cursor() as cur: @@ -1879,3 +1922,39 @@ async def test_fetch_download_timeout_setting(conn_cnx): sql = "SELECT seq4(), uniform(1, 10, RANDOM(12)) FROM TABLE(GENERATOR(ROWCOUNT => 100000)) v" async with conn_cnx() as con, con.cursor() as cur: assert len(await (await cur.execute(sql)).fetchall()) == 100000 + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize( + "request_id", + [ + "THIS IS NOT VALID", + uuid.uuid1(), + uuid.uuid3(uuid.NAMESPACE_URL, "www.snowflake.com"), + uuid.uuid5(uuid.NAMESPACE_URL, "www.snowflake.com"), + ], +) +async def test_custom_request_id_negative(request_id, conn_cnx): + + # Ensure that invalid request_ids (non uuid4) do not compromise interface. + with pytest.raises(ValueError, match="requestId"): + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute( + "select seq4() as foo from table(generator(rowcount=>5))", + _statement_params={"requestId": request_id}, + ) + + +@pytest.mark.skipolddriver +async def test_custom_request_id(conn_cnx): + request_id = uuid.uuid4() + + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute( + "select seq4() as foo from table(generator(rowcount=>5))", + _statement_params={"requestId": request_id}, + ) + + assert cur._sfqid is not None, "Query must execute successfully." diff --git a/test/integ/conftest.py b/test/integ/conftest.py index 8658549568..9dbc930c5c 100644 --- a/test/integ/conftest.py +++ b/test/integ/conftest.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import os @@ -173,16 +169,22 @@ def init_test_schema(db_parameters) -> Generator[None]: This is automatically called per test session. """ - ret = db_parameters - with snowflake.connector.connect( - user=ret["user"], - password=ret["password"], - host=ret["host"], - port=ret["port"], - database=ret["database"], - account=ret["account"], - protocol=ret["protocol"], - ) as con: + connection_params = { + "user": db_parameters["user"], + "password": db_parameters["password"], + "host": db_parameters["host"], + "port": db_parameters["port"], + "database": db_parameters["database"], + "account": db_parameters["account"], + "protocol": db_parameters["protocol"], + } + + # Role may be needed when running on preprod, but is not present on Jenkins jobs + optional_role = db_parameters.get("role") + if optional_role is not None: + connection_params.update(role=optional_role) + + with snowflake.connector.connect(**connection_params) as con: con.cursor().execute(f"CREATE SCHEMA IF NOT EXISTS {TEST_SCHEMA}") yield con.cursor().execute(f"DROP SCHEMA IF EXISTS {TEST_SCHEMA}") diff --git a/test/integ/lambda/__init__.py b/test/integ/lambda/__init__.py index ef416f64a0..e69de29bb2 100644 --- a/test/integ/lambda/__init__.py +++ b/test/integ/lambda/__init__.py @@ -1,3 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# diff --git a/test/integ/lambda/test_basic_query.py b/test/integ/lambda/test_basic_query.py index 83236554e0..e3964641a0 100644 --- a/test/integ/lambda/test_basic_query.py +++ b/test/integ/lambda/test_basic_query.py @@ -1,9 +1,5 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - def test_connection(conn_cnx): """Test basic connection.""" diff --git a/test/integ/pandas/__init__.py b/test/integ/pandas/__init__.py index ef416f64a0..e69de29bb2 100644 --- a/test/integ/pandas/__init__.py +++ b/test/integ/pandas/__init__.py @@ -1,3 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# diff --git a/test/integ/pandas/test_arrow_chunk_iterator.py b/test/integ/pandas/test_arrow_chunk_iterator.py index 090f4d152a..d19fd5644c 100644 --- a/test/integ/pandas/test_arrow_chunk_iterator.py +++ b/test/integ/pandas/test_arrow_chunk_iterator.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - import datetime import random from typing import Callable diff --git a/test/integ/pandas/test_arrow_pandas.py b/test/integ/pandas/test_arrow_pandas.py index 3d10bb2a7c..2bb41e8af4 100644 --- a/test/integ/pandas/test_arrow_pandas.py +++ b/test/integ/pandas/test_arrow_pandas.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import decimal diff --git a/test/integ/pandas/test_error_arrow_pandas_stream.py b/test/integ/pandas/test_error_arrow_pandas_stream.py index f89b8ee37f..777f9f483c 100644 --- a/test/integ/pandas/test_error_arrow_pandas_stream.py +++ b/test/integ/pandas/test_error_arrow_pandas_stream.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - import pytest from ...helpers import ( diff --git a/test/integ/pandas/test_logging.py b/test/integ/pandas/test_logging.py index b7e8d81a25..19e79c2cf5 100644 --- a/test/integ/pandas/test_logging.py +++ b/test/integ/pandas/test_logging.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging diff --git a/test/integ/pandas/test_pandas_tools.py b/test/integ/pandas/test_pandas_tools.py index e53afc5335..8d69fd1a9f 100644 --- a/test/integ/pandas/test_pandas_tools.py +++ b/test/integ/pandas/test_pandas_tools.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import math diff --git a/test/integ/pandas/test_unit_arrow_chunk_iterator.py b/test/integ/pandas/test_unit_arrow_chunk_iterator.py index 9f7a836e4a..73e4dfa540 100644 --- a/test/integ/pandas/test_unit_arrow_chunk_iterator.py +++ b/test/integ/pandas/test_unit_arrow_chunk_iterator.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import datetime @@ -430,7 +426,9 @@ def iterate_over_test_chunk( stream.seek(0) context = ArrowConverterContext() - it = NanoarrowPyArrowRowIterator(None, stream.read(), context, False, False, False) + it = NanoarrowPyArrowRowIterator( + None, stream.read(), context, False, False, False, True + ) count = 0 while True: diff --git a/test/integ/pandas/test_unit_options.py b/test/integ/pandas/test_unit_options.py index e992b2cb2f..9038e98d7c 100644 --- a/test/integ/pandas/test_unit_options.py +++ b/test/integ/pandas/test_unit_options.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging diff --git a/test/integ/sso/__init__.py b/test/integ/sso/__init__.py index ef416f64a0..e69de29bb2 100644 --- a/test/integ/sso/__init__.py +++ b/test/integ/sso/__init__.py @@ -1,3 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# diff --git a/test/integ/sso/test_connection_manual.py b/test/integ/sso/test_connection_manual.py index 55bd750079..2808b759c8 100644 --- a/test/integ/sso/test_connection_manual.py +++ b/test/integ/sso/test_connection_manual.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations # This test requires the SSO and Snowflake admin connection parameters. diff --git a/test/integ/sso/test_unit_mfa_cache.py b/test/integ/sso/test_unit_mfa_cache.py index 03f302fe64..15c13029a5 100644 --- a/test/integ/sso/test_unit_mfa_cache.py +++ b/test/integ/sso/test_unit_mfa_cache.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import json diff --git a/test/integ/sso/test_unit_sso_connection.py b/test/integ/sso/test_unit_sso_connection.py index 5c57d70b7d..4c02499d2a 100644 --- a/test/integ/sso/test_unit_sso_connection.py +++ b/test/integ/sso/test_unit_sso_connection.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import os diff --git a/test/integ/test_arrow_result.py b/test/integ/test_arrow_result.py index 5cdd3bb341..02f11ccbc4 100644 --- a/test/integ/test_arrow_result.py +++ b/test/integ/test_arrow_result.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import base64 diff --git a/test/integ/test_async.py b/test/integ/test_async.py index 4ad2726a1d..41047b5f35 100644 --- a/test/integ/test_async.py +++ b/test/integ/test_async.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging diff --git a/test/integ/test_autocommit.py b/test/integ/test_autocommit.py index 94baf0ad22..a182f243f3 100644 --- a/test/integ/test_autocommit.py +++ b/test/integ/test_autocommit.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import snowflake.connector diff --git a/test/integ/test_bindings.py b/test/integ/test_bindings.py index b9ca1870a6..e5820c199b 100644 --- a/test/integ/test_bindings.py +++ b/test/integ/test_bindings.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import calendar diff --git a/test/integ/test_boolean.py b/test/integ/test_boolean.py index 6d72753358..887c0ca012 100644 --- a/test/integ/test_boolean.py +++ b/test/integ/test_boolean.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations diff --git a/test/integ/test_client_session_keep_alive.py b/test/integ/test_client_session_keep_alive.py index 027d364bc0..0037742729 100644 --- a/test/integ/test_client_session_keep_alive.py +++ b/test/integ/test_client_session_keep_alive.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import time diff --git a/test/integ/test_concurrent_create_objects.py b/test/integ/test_concurrent_create_objects.py index 0434829149..305c10bc45 100644 --- a/test/integ/test_concurrent_create_objects.py +++ b/test/integ/test_concurrent_create_objects.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from concurrent.futures.thread import ThreadPoolExecutor diff --git a/test/integ/test_concurrent_insert.py b/test/integ/test_concurrent_insert.py index e66999ac99..094c7f5e25 100644 --- a/test/integ/test_concurrent_insert.py +++ b/test/integ/test_concurrent_insert.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from concurrent.futures.thread import ThreadPoolExecutor diff --git a/test/integ/test_connection.py b/test/integ/test_connection.py index 8a4f833158..0df386afca 100644 --- a/test/integ/test_connection.py +++ b/test/integ/test_connection.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import gc @@ -1144,6 +1140,15 @@ def test_client_prefetch_threads_setting(conn_cnx): assert conn.client_prefetch_threads == new_thread_count +@pytest.mark.skipolddriver +def test_client_fetch_threads_setting(conn_cnx): + """Tests whether client_fetch_threads is None by default and setting the parameter has effect.""" + with conn_cnx() as conn: + assert conn.client_fetch_threads is None + conn.client_fetch_threads = 32 + assert conn.client_fetch_threads == 32 + + @pytest.mark.external def test_client_failover_connection_url(conn_cnx): with conn_cnx("client_failover") as conn: @@ -1369,6 +1374,34 @@ def test_server_session_keep_alive(conn_cnx): mock_delete_session.assert_called_once() +@pytest.mark.skipolddriver +@pytest.mark.parametrize( + "value", + [ + True, + False, + ], +) +def test_gcs_use_virtual_endpoints(conn_cnx, value): + with mock.patch( + "snowflake.connector.network.SnowflakeRestful.fetch", + return_value={"data": {"token": None, "masterToken": None}, "success": True}, + ): + with snowflake.connector.connect( + user="test-user", + password="test-password", + host="test-host", + port="443", + account="test-account", + gcs_use_virtual_endpoints=value, + ) as cnx: + assert cnx + cnx.commit = cnx.rollback = ( + lambda: None + ) # Skip tear down, there's only a mocked rest api + assert cnx.gcs_use_virtual_endpoints == value + + @pytest.mark.skipolddriver def test_ocsp_mode_disable_ocsp_checks( conn_cnx, is_public_test, is_local_dev_setup, caplog @@ -1591,3 +1624,12 @@ def test_no_auth_connection_negative_case(): # connection is not able to run any query with pytest.raises(DatabaseError, match="Connection is closed"): conn.execute_string("select 1") + + +# _file_operation_parser and _stream_downloader are newly introduced and +# therefore should not be tested on old drivers. +@pytest.mark.skipolddriver +def test_file_utils_sanity_check(): + conn = create_connection("default") + assert hasattr(conn._file_operation_parser, "parse_file_operation") + assert hasattr(conn._stream_downloader, "download_as_stream") diff --git a/test/integ/test_converter.py b/test/integ/test_converter.py index 10628e102a..c944eea01a 100644 --- a/test/integ/test_converter.py +++ b/test/integ/test_converter.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from datetime import time, timedelta diff --git a/test/integ/test_converter_more_timestamp.py b/test/integ/test_converter_more_timestamp.py index c70ed5e139..2ef975bd92 100644 --- a/test/integ/test_converter_more_timestamp.py +++ b/test/integ/test_converter_more_timestamp.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from datetime import datetime, timedelta diff --git a/test/integ/test_converter_null.py b/test/integ/test_converter_null.py index 0297c625b5..671656dccf 100644 --- a/test/integ/test_converter_null.py +++ b/test/integ/test_converter_null.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import re diff --git a/test/integ/test_cursor.py b/test/integ/test_cursor.py index 85362ce829..9d7d4e6c55 100644 --- a/test/integ/test_cursor.py +++ b/test/integ/test_cursor.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import decimal @@ -11,6 +7,7 @@ import os import pickle import time +import uuid from datetime import date, datetime, timezone from typing import TYPE_CHECKING, NamedTuple from unittest import mock @@ -826,6 +823,7 @@ def test_invalid_bind_data_type(conn_cnx): cnx.cursor().execute("select 1 from dual where 1=%s", ([1, 2, 3],)) +@pytest.mark.skipolddriver def test_timeout_query(conn_cnx): with conn_cnx() as cnx: with cnx.cursor() as c: @@ -836,10 +834,31 @@ def test_timeout_query(conn_cnx): ) assert err.value.errno == 604, ( "Invalid error code" - and "SQL execution was cancelled by the client due to a timeout" + and "SQL execution was cancelled by the client due to a timeout. Error message received from the server: SQL execution canceled" in err.value.msg ) + with pytest.raises(errors.ProgrammingError) as err: + # we can not precisely control the timing to send cancel query request right after server + # executes the query but before returning the results back to client + # it depends on python scheduling and server processing speed, so we mock here + with mock.patch( + "snowflake.connector.cursor._TrackedQueryCancellationTimer", + autospec=True, + ) as mock_timebomb: + mock_timebomb.return_value.executed = True + c.execute( + "select 123'", + timeout=0.1, + ) + assert c._timebomb.executed is True and err.value.errno == 1003, ( + "Invalid error code" + and "SQL compilation error:\nsyntax error line 1 at position 10 unexpected '''." + in err.value.msg + and "SQL execution was cancelled by the client due to a timeout" + not in err.value.msg + ) + def test_executemany(conn, db_parameters): """Executes many statements. Client binding is supported by either dict, or list data types. @@ -1768,6 +1787,24 @@ def test_out_of_range_year(conn_cnx, result_format, cursor_type, fetch_method): fetch_next_fn() +@pytest.mark.skipolddriver +@pytest.mark.parametrize("result_format", ("json", "arrow")) +def test_out_of_range_year_followed_by_correct_year(conn_cnx, result_format): + """Tests whether the year 10000 is out of range exception is raised as expected.""" + with conn_cnx( + session_parameters={ + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: result_format + } + ) as con: + with con.cursor() as cur: + cur.execute("select TO_DATE('10000-01-01'), TO_DATE('9999-01-01')") + with pytest.raises( + InterfaceError, + match="out of range", + ): + cur.fetchall() + + @pytest.mark.skipolddriver def test_describe(conn_cnx): with conn_cnx() as con: @@ -1928,3 +1965,39 @@ def test_nanoarrow_usage_deprecation(): and "snowflake.connector.cursor.NanoarrowUsage has been deprecated" in str(record[2].message) ) + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize( + "request_id", + [ + "THIS IS NOT VALID", + uuid.uuid1(), + uuid.uuid3(uuid.NAMESPACE_URL, "www.snowflake.com"), + uuid.uuid5(uuid.NAMESPACE_URL, "www.snowflake.com"), + ], +) +def test_custom_request_id_negative(request_id, conn_cnx): + + # Ensure that invalid request_ids (non uuid4) do not compromise interface. + with pytest.raises(ValueError, match="requestId"): + with conn_cnx() as con: + with con.cursor() as cur: + cur.execute( + "select seq4() as foo from table(generator(rowcount=>5))", + _statement_params={"requestId": request_id}, + ) + + +@pytest.mark.skipolddriver +def test_custom_request_id(conn_cnx): + request_id = uuid.uuid4() + + with conn_cnx() as con: + with con.cursor() as cur: + cur.execute( + "select seq4() as foo from table(generator(rowcount=>5))", + _statement_params={"requestId": request_id}, + ) + + assert cur._sfqid is not None, "Query must execute successfully." diff --git a/test/integ/test_cursor_binding.py b/test/integ/test_cursor_binding.py index eb0f55aa0c..15ace863e2 100644 --- a/test/integ/test_cursor_binding.py +++ b/test/integ/test_cursor_binding.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import pytest diff --git a/test/integ/test_cursor_context_manager.py b/test/integ/test_cursor_context_manager.py index 2d288fb2f9..f9ee44d56d 100644 --- a/test/integ/test_cursor_context_manager.py +++ b/test/integ/test_cursor_context_manager.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from logging import getLogger diff --git a/test/integ/test_dataintegrity.py b/test/integ/test_dataintegrity.py index 0964d8ead6..4cca91f303 100644 --- a/test/integ/test_dataintegrity.py +++ b/test/integ/test_dataintegrity.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -O -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - """Script to test database capabilities and the DB-API interface. It tests for functionality and data integrity for some of the basic data types. Adapted from a script diff --git a/test/integ/test_daylight_savings.py b/test/integ/test_daylight_savings.py index 45ec281dc5..6f8862bdde 100644 --- a/test/integ/test_daylight_savings.py +++ b/test/integ/test_daylight_savings.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from datetime import datetime diff --git a/test/integ/test_dbapi.py b/test/integ/test_dbapi.py index 97d3c6e47f..f27ea39d02 100644 --- a/test/integ/test_dbapi.py +++ b/test/integ/test_dbapi.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - """Script to test database capabilities and the DB-API interface for functionality and data integrity. Adapted from a script by M-A Lemburg and taken from the MySQL python driver. diff --git a/test/integ/test_decfloat.py b/test/integ/test_decfloat.py index 1a9224d920..b776dc007b 100644 --- a/test/integ/test_decfloat.py +++ b/test/integ/test_decfloat.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import decimal diff --git a/test/integ/test_easy_logging.py b/test/integ/test_easy_logging.py index ce89177699..b035ca278c 100644 --- a/test/integ/test_easy_logging.py +++ b/test/integ/test_easy_logging.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - import stat from test.integ.conftest import create_connection diff --git a/test/integ/test_errors.py b/test/integ/test_errors.py index f4e8a699bc..9ec63e7802 100644 --- a/test/integ/test_errors.py +++ b/test/integ/test_errors.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import traceback diff --git a/test/integ/test_execute_multi_statements.py b/test/integ/test_execute_multi_statements.py index 5b143313b2..fb70045610 100644 --- a/test/integ/test_execute_multi_statements.py +++ b/test/integ/test_execute_multi_statements.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import codecs diff --git a/test/integ/test_key_pair_authentication.py b/test/integ/test_key_pair_authentication.py index c3ebb4b448..1273ee0036 100644 --- a/test/integ/test_key_pair_authentication.py +++ b/test/integ/test_key_pair_authentication.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import base64 diff --git a/test/integ/test_large_put.py b/test/integ/test_large_put.py index e27c784b8e..e9687fc5c8 100644 --- a/test/integ/test_large_put.py +++ b/test/integ/test_large_put.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import os diff --git a/test/integ/test_large_result_set.py b/test/integ/test_large_result_set.py index 481c7220c9..cbffcac107 100644 --- a/test/integ/test_large_result_set.py +++ b/test/integ/test_large_result_set.py @@ -1,14 +1,12 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations +import logging from unittest.mock import Mock import pytest +from snowflake.connector.secret_detector import SecretDetector from snowflake.connector.telemetry import TelemetryField NUMBER_OF_ROWS = 50000 @@ -115,8 +113,9 @@ def test_query_large_result_set_n_threads( @pytest.mark.aws @pytest.mark.skipolddriver -def test_query_large_result_set(conn_cnx, db_parameters, ingest_data): +def test_query_large_result_set(conn_cnx, db_parameters, ingest_data, caplog): """[s3] Gets Large Result set.""" + caplog.set_level(logging.DEBUG) sql = "select * from {name} order by 1".format(name=db_parameters["name"]) with conn_cnx() as cnx: telemetry_data = [] @@ -165,3 +164,19 @@ def test_query_large_result_set(conn_cnx, db_parameters, ingest_data): "Expected three telemetry logs (one per query) " "for log type {}".format(field.value) ) + + aws_request_present = False + expected_token_prefix = "X-Amz-Signature=" + for line in caplog.text.splitlines(): + if expected_token_prefix in line: + aws_request_present = True + # getattr is used to stay compatible with old driver - before SECRET_STARRED_MASK_STR was added + assert ( + expected_token_prefix + + getattr(SecretDetector, "SECRET_STARRED_MASK_STR", "****") + in line + ), "connectionpool logger is leaking sensitive information" + + assert ( + aws_request_present + ), "AWS URL was not found in logs, so it can't be assumed that no leaks happened in it" diff --git a/test/integ/test_load_unload.py b/test/integ/test_load_unload.py index cdbb063145..afcfa8ceef 100644 --- a/test/integ/test_load_unload.py +++ b/test/integ/test_load_unload.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import os diff --git a/test/integ/test_multi_statement.py b/test/integ/test_multi_statement.py index 4b461325fe..3fd80485d1 100644 --- a/test/integ/test_multi_statement.py +++ b/test/integ/test_multi_statement.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - import pytest from snowflake.connector.version import VERSION diff --git a/test/integ/test_network.py b/test/integ/test_network.py index bf4ab44ac9..4f2f550eb5 100644 --- a/test/integ/test_network.py +++ b/test/integ/test_network.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging diff --git a/test/integ/test_numpy_binding.py b/test/integ/test_numpy_binding.py index 5ccd65e6cd..f210d9eec2 100644 --- a/test/integ/test_numpy_binding.py +++ b/test/integ/test_numpy_binding.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import datetime diff --git a/test/integ/test_pickle_timestamp_tz.py b/test/integ/test_pickle_timestamp_tz.py index 2c0332aacf..b6ceb239f9 100644 --- a/test/integ/test_pickle_timestamp_tz.py +++ b/test/integ/test_pickle_timestamp_tz.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import os diff --git a/test/integ/test_put_get.py b/test/integ/test_put_get.py index 74138bc606..67d020508a 100644 --- a/test/integ/test_put_get.py +++ b/test/integ/test_put_get.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import filecmp diff --git a/test/integ/test_put_get_compress_enc.py b/test/integ/test_put_get_compress_enc.py index 9caab8f231..efe8c209b5 100644 --- a/test/integ/test_put_get_compress_enc.py +++ b/test/integ/test_put_get_compress_enc.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import filecmp diff --git a/test/integ/test_put_get_medium.py b/test/integ/test_put_get_medium.py index fcc9becdb6..2b6c2ee6c0 100644 --- a/test/integ/test_put_get_medium.py +++ b/test/integ/test_put_get_medium.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import datetime diff --git a/test/integ/test_put_get_snow_4525.py b/test/integ/test_put_get_snow_4525.py index 9d8f38d98e..5c21b4f138 100644 --- a/test/integ/test_put_get_snow_4525.py +++ b/test/integ/test_put_get_snow_4525.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import os diff --git a/test/integ/test_put_get_user_stage.py b/test/integ/test_put_get_user_stage.py index 8cf41e77b1..b10a5d73c2 100644 --- a/test/integ/test_put_get_user_stage.py +++ b/test/integ/test_put_get_user_stage.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import mimetypes diff --git a/test/integ/test_put_get_with_aws_token.py b/test/integ/test_put_get_with_aws_token.py index 6dc3f63509..7b9a64e87a 100644 --- a/test/integ/test_put_get_with_aws_token.py +++ b/test/integ/test_put_get_with_aws_token.py @@ -1,17 +1,16 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import glob import gzip import os +from logging import DEBUG import pytest from snowflake.connector.constants import UTF8 +from snowflake.connector.file_transfer_agent import SnowflakeS3ProgressPercentage +from snowflake.connector.secret_detector import SecretDetector try: # pragma: no cover from snowflake.connector.vendored import requests @@ -42,9 +41,10 @@ @pytest.mark.parametrize( "from_path", [True, pytest.param(False, marks=pytest.mark.skipolddriver)] ) -def test_put_get_with_aws(tmpdir, conn_cnx, from_path): +def test_put_get_with_aws(tmpdir, conn_cnx, from_path, caplog): """[s3] Puts and Gets a small text using AWS S3.""" # create a data file + caplog.set_level(DEBUG) fname = str(tmpdir.join("test_put_get_with_aws_token.txt.gz")) original_contents = "123,test1\n456,test2\n" with gzip.open(fname, "wb") as f: @@ -54,8 +54,8 @@ def test_put_get_with_aws(tmpdir, conn_cnx, from_path): with conn_cnx() as cnx: with cnx.cursor() as csr: + csr.execute(f"create or replace table {table_name} (a int, b string)") try: - csr.execute(f"create or replace table {table_name} (a int, b string)") file_stream = None if from_path else open(fname, "rb") put( csr, @@ -63,6 +63,8 @@ def test_put_get_with_aws(tmpdir, conn_cnx, from_path): f"%{table_name}", from_path, sql_options=" auto_compress=true parallel=30", + _put_callback=SnowflakeS3ProgressPercentage, + _get_callback=SnowflakeS3ProgressPercentage, file_stream=file_stream, ) rec = csr.fetchone() @@ -74,17 +76,38 @@ def test_put_get_with_aws(tmpdir, conn_cnx, from_path): f"copy into @%{table_name} from {table_name} " "file_format=(type=csv compression='gzip')" ) - csr.execute(f"get @%{table_name} file://{tmp_dir}") + csr.execute( + f"get @%{table_name} file://{tmp_dir}", + _put_callback=SnowflakeS3ProgressPercentage, + _get_callback=SnowflakeS3ProgressPercentage, + ) rec = csr.fetchone() assert rec[0].startswith("data_"), "A file downloaded by GET" assert rec[1] == 36, "Return right file size" assert rec[2] == "DOWNLOADED", "Return DOWNLOADED status" assert rec[3] == "", "Return no error message" finally: - csr.execute(f"drop table {table_name}") + csr.execute(f"drop table if exists {table_name}") if file_stream: file_stream.close() + aws_request_present = False + expected_token_prefix = "X-Amz-Signature=" + for line in caplog.text.splitlines(): + if ".amazonaws." in line: + aws_request_present = True + # getattr is used to stay compatible with old driver - before SECRET_STARRED_MASK_STR was added + assert ( + expected_token_prefix + + getattr(SecretDetector, "SECRET_STARRED_MASK_STR", "****") + in line + or expected_token_prefix not in line + ), "connectionpool logger is leaking sensitive information" + + assert ( + aws_request_present + ), "AWS URL was not found in logs, so it can't be assumed that no leaks happened in it" + files = glob.glob(os.path.join(tmp_dir, "data_*")) with gzip.open(files[0], "rb") as fd: contents = fd.read().decode(UTF8) diff --git a/test/integ/test_put_get_with_azure_token.py b/test/integ/test_put_get_with_azure_token.py index c3a8957b3e..7e2e011c72 100644 --- a/test/integ/test_put_get_with_azure_token.py +++ b/test/integ/test_put_get_with_azure_token.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import glob @@ -19,6 +15,7 @@ SnowflakeAzureProgressPercentage, SnowflakeProgressPercentage, ) +from snowflake.connector.secret_detector import SecretDetector try: from snowflake.connector.util_text import random_string @@ -84,14 +81,24 @@ def test_put_get_with_azure(tmpdir, conn_cnx, from_path, caplog): finally: if file_stream: file_stream.close() - csr.execute(f"drop table {table_name}") + csr.execute(f"drop table if exists {table_name}") + azure_request_present = False + expected_token_prefix = "sig=" for line in caplog.text.splitlines(): - if "blob.core.windows.net" in line: + if "blob.core.windows.net" in line and expected_token_prefix in line: + azure_request_present = True + # getattr is used to stay compatible with old driver - before SECRET_STARRED_MASK_STR was added assert ( - "sig=" not in line + expected_token_prefix + + getattr(SecretDetector, "SECRET_STARRED_MASK_STR", "****") + in line ), "connectionpool logger is leaking sensitive information" + assert ( + azure_request_present + ), "Azure URL was not found in logs, so it can't be assumed that no leaks happened in it" + files = glob.glob(os.path.join(tmp_dir, "data_*")) with gzip.open(files[0], "rb") as fd: contents = fd.read().decode(UTF8) diff --git a/test/integ/test_put_get_with_gcp_account.py b/test/integ/test_put_get_with_gcp_account.py index d02643db43..06a77bc371 100644 --- a/test/integ/test_put_get_with_gcp_account.py +++ b/test/integ/test_put_get_with_gcp_account.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import glob diff --git a/test/integ/test_put_windows_path.py b/test/integ/test_put_windows_path.py index 2785ab14c6..f12b0d1a3c 100644 --- a/test/integ/test_put_windows_path.py +++ b/test/integ/test_put_windows_path.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import os diff --git a/test/integ/test_qmark.py b/test/integ/test_qmark.py index 9459e5062d..861a1795d3 100644 --- a/test/integ/test_qmark.py +++ b/test/integ/test_qmark.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import pytest diff --git a/test/integ/test_query_cancelling.py b/test/integ/test_query_cancelling.py index 77f28c5073..dbab9aefdd 100644 --- a/test/integ/test_query_cancelling.py +++ b/test/integ/test_query_cancelling.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging diff --git a/test/integ/test_results.py b/test/integ/test_results.py index 3ce3dcddd6..3f3e63edb9 100644 --- a/test/integ/test_results.py +++ b/test/integ/test_results.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import pytest diff --git a/test/integ/test_reuse_cursor.py b/test/integ/test_reuse_cursor.py index c550deeb5c..1c5d359df6 100644 --- a/test/integ/test_reuse_cursor.py +++ b/test/integ/test_reuse_cursor.py @@ -1,9 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - - def test_reuse_cursor(conn_cnx, db_parameters): """Ensures only the last executed command/query's result sets are returned.""" with conn_cnx() as cnx: diff --git a/test/integ/test_session_parameters.py b/test/integ/test_session_parameters.py index 73ae5fa650..10502f8585 100644 --- a/test/integ/test_session_parameters.py +++ b/test/integ/test_session_parameters.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import pytest diff --git a/test/integ/test_snowsql_timestamp_format.py b/test/integ/test_snowsql_timestamp_format.py index 6681069818..9f1d0257d7 100644 --- a/test/integ/test_snowsql_timestamp_format.py +++ b/test/integ/test_snowsql_timestamp_format.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import pytest diff --git a/test/integ/test_statement_parameter_binding.py b/test/integ/test_statement_parameter_binding.py index 63e325aa76..4c553fe60d 100644 --- a/test/integ/test_statement_parameter_binding.py +++ b/test/integ/test_statement_parameter_binding.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from datetime import datetime diff --git a/test/integ/test_structured_types.py b/test/integ/test_structured_types.py index 1efa72164b..8b32bb0898 100644 --- a/test/integ/test_structured_types.py +++ b/test/integ/test_structured_types.py @@ -1,7 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# from __future__ import annotations from textwrap import dedent diff --git a/test/integ/test_transaction.py b/test/integ/test_transaction.py index c36b2a0419..49196c570d 100644 --- a/test/integ/test_transaction.py +++ b/test/integ/test_transaction.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import snowflake.connector diff --git a/test/integ/test_vendored_urllib.py b/test/integ/test_vendored_urllib.py index bf178b214b..ec83e62f3e 100644 --- a/test/integ/test_vendored_urllib.py +++ b/test/integ/test_vendored_urllib.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - import pytest try: diff --git a/test/integ_helpers.py b/test/integ_helpers.py index d4e32a4e50..0f0d20d5dc 100644 --- a/test/integ_helpers.py +++ b/test/integ_helpers.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import os diff --git a/test/lazy_var.py b/test/lazy_var.py index 44897d5abc..a0439c8074 100644 --- a/test/lazy_var.py +++ b/test/lazy_var.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from typing import Callable, Generic, TypeVar diff --git a/test/randomize.py b/test/randomize.py index 59b259be44..963317d6c5 100644 --- a/test/randomize.py +++ b/test/randomize.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - """ This module was added back to the repository for compatibility with the old driver tests that rely on random_string from this file for functionality. diff --git a/test/stress/__init__.py b/test/stress/__init__.py index ef416f64a0..e69de29bb2 100644 --- a/test/stress/__init__.py +++ b/test/stress/__init__.py @@ -1,3 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# diff --git a/test/stress/e2e_iterator.py b/test/stress/e2e_iterator.py index 662ac0aa15..0829598317 100644 --- a/test/stress/e2e_iterator.py +++ b/test/stress/e2e_iterator.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - """ This script is used for end-to-end performance test. It tracks the processing time from cursor fetching data till all data are converted to python objects. diff --git a/test/stress/local_iterator.py b/test/stress/local_iterator.py index 31efa5bfe3..8bba1adf5a 100644 --- a/test/stress/local_iterator.py +++ b/test/stress/local_iterator.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - """ This script is used for PyArrowIterator performance test. It tracks the processing time of PyArrowIterator converting data to python objects. diff --git a/test/stress/util.py b/test/stress/util.py index 8f7d2c88db..f4bf8cebf2 100644 --- a/test/stress/util.py +++ b/test/stress/util.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - import time import psutil diff --git a/test/unit/__init__.py b/test/unit/__init__.py index ef416f64a0..e69de29bb2 100644 --- a/test/unit/__init__.py +++ b/test/unit/__init__.py @@ -1,3 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# diff --git a/test/unit/aio/conftest.py b/test/unit/aio/conftest.py new file mode 100644 index 0000000000..ee2b3dd0ba --- /dev/null +++ b/test/unit/aio/conftest.py @@ -0,0 +1,45 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import pytest + +from .csp_helpers_async import ( + FakeAwsEnvironmentAsync, + FakeAzureFunctionMetadataServiceAsync, + FakeAzureVmMetadataServiceAsync, + FakeGceMetadataServiceAsync, + NoMetadataServiceAsync, +) + + +@pytest.fixture +def no_metadata_service(): + """Emulates an environment without any metadata service.""" + with NoMetadataServiceAsync() as server: + yield server + + +@pytest.fixture +def fake_aws_environment(): + with FakeAwsEnvironmentAsync() as env: + yield env + + +@pytest.fixture( + params=[FakeAzureFunctionMetadataServiceAsync(), FakeAzureVmMetadataServiceAsync()], + ids=["azure_function", "azure_vm"], +) +def fake_azure_metadata_service(request): + """Parameterized fixture that emulates both the Azure VM and Azure Functions metadata services.""" + with request.param as server: + yield server + + +@pytest.fixture +def fake_gce_metadata_service(): + """Emulates the GCE metadata service, returning a dummy token.""" + with FakeGceMetadataServiceAsync() as server: + yield server diff --git a/test/unit/aio/csp_helpers_async.py b/test/unit/aio/csp_helpers_async.py new file mode 100644 index 0000000000..5e50dae72d --- /dev/null +++ b/test/unit/aio/csp_helpers_async.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import logging +import os +from unittest import mock +from urllib.parse import urlparse + +from snowflake.connector.vendored.requests.exceptions import ConnectTimeout, HTTPError + +logger = logging.getLogger(__name__) + + +# Import shared functions +from ...csp_helpers import ( + FakeAwsEnvironment, + FakeAzureFunctionMetadataService, + FakeAzureVmMetadataService, + FakeGceMetadataService, + FakeMetadataService, + NoMetadataService, +) + + +def build_response(content: bytes, status_code: int = 200): + """Builds an aiohttp-compatible response object with the given status code and content.""" + + class AsyncResponse: + def __init__(self, content, status_code): + self.ok = status_code < 400 + self.status = status_code + self._content = content + + async def read(self): + return self._content + + return AsyncResponse(content, status_code) + + +class FakeMetadataServiceAsync(FakeMetadataService): + def _async_request(self, method, url, headers=None, timeout=None): + """Entry point for the aiohttp mock.""" + logger.debug(f"Received async request: {method} {url} {str(headers)}") + parsed_url = urlparse(url) + + # Create async context manager for aiohttp response + class AsyncResponseContextManager: + def __init__(self, response): + self.response = response + + async def __aenter__(self): + return self.response + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + # Create aiohttp-compatible response mock + class AsyncResponse: + def __init__(self, requests_response): + self.ok = requests_response.ok + self.status = requests_response.status_code + self._content = requests_response.content + + async def read(self): + return self._content + + if not parsed_url.hostname == self.expected_hostname: + logger.debug( + f"Received async request to unexpected hostname {parsed_url.hostname}" + ) + import aiohttp + + raise aiohttp.ClientError() + + # Get the response from the subclass handler, catch exceptions and convert them + try: + sync_response = self.handle_request(method, parsed_url, headers, timeout) + async_response = AsyncResponse(sync_response) + return AsyncResponseContextManager(async_response) + except (HTTPError, ConnectTimeout) as e: + import aiohttp + + # Convert requests exceptions to aiohttp exceptions so they get caught properly + raise aiohttp.ClientError() from e + + def __enter__(self): + self.reset_defaults() + self.patchers = [] + # Mock aiohttp for async requests + self.patchers.append( + mock.patch("aiohttp.ClientSession.request", side_effect=self._async_request) + ) + for patcher in self.patchers: + patcher.__enter__() + return self + + +class NoMetadataServiceAsync(FakeMetadataServiceAsync, NoMetadataService): + pass + + +class FakeAzureVmMetadataServiceAsync( + FakeMetadataServiceAsync, FakeAzureVmMetadataService +): + pass + + +class FakeAzureFunctionMetadataServiceAsync( + FakeMetadataServiceAsync, FakeAzureFunctionMetadataService +): + def __enter__(self): + # Set environment variables first (like Azure Function service) + os.environ["IDENTITY_ENDPOINT"] = self.identity_endpoint + os.environ["IDENTITY_HEADER"] = self.identity_header + + # Then set up the metadata service mocks + FakeMetadataServiceAsync.__enter__(self) + return self + + def __exit__(self, *args, **kwargs): + # Clean up async mocks first + FakeMetadataServiceAsync.__exit__(self, *args, **kwargs) + + # Then clean up environment variables + os.environ.pop("IDENTITY_ENDPOINT", None) + os.environ.pop("IDENTITY_HEADER", None) + + +class FakeGceMetadataServiceAsync(FakeMetadataServiceAsync, FakeGceMetadataService): + pass + + +class FakeAwsEnvironmentAsync(FakeAwsEnvironment): + """Emulates the AWS environment-specific functions used in async wif_util.py. + + Unlike the other metadata services, the HTTP calls made by AWS are deep within boto libaries, so + emulating them here would be complex and fragile. Instead, we emulate the higher-level functions + called by the connector code. + """ + + async def get_region(self): + return self.region + + async def get_arn(self): + return self.arn + + async def get_credentials(self): + return self.credentials + + def __enter__(self): + # First call the parent's __enter__ to get base functionality + super().__enter__() + + # Then add async-specific patches + async def async_get_credentials(): + return self.credentials + + async def async_get_caller_identity(): + return {"Arn": self.arn} + + async def async_get_region(): + return await self.get_region() + + async def async_get_arn(): + return await self.get_arn() + + # Mock aioboto3.Session.get_credentials (IS async) + self.patchers.append( + mock.patch( + "snowflake.connector.aio._wif_util.aioboto3.Session.get_credentials", + side_effect=async_get_credentials, + ) + ) + + # Mock the async AWS region and ARN functions + self.patchers.append( + mock.patch( + "snowflake.connector.aio._wif_util.get_aws_region", + side_effect=async_get_region, + ) + ) + + self.patchers.append( + mock.patch( + "snowflake.connector.aio._wif_util.get_aws_arn", + side_effect=async_get_arn, + ) + ) + + # Mock the async STS client for direct aioboto3 usage + class MockStsClient: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def get_caller_identity(self): + return await async_get_caller_identity() + + def mock_session_client(service_name): + if service_name == "sts": + return MockStsClient() + return None + + self.patchers.append( + mock.patch( + "snowflake.connector.aio._wif_util.aioboto3.Session.client", + side_effect=mock_session_client, + ) + ) + + # Start the additional async patches + for patcher in self.patchers[-4:]: # Only start the new patches we just added + patcher.__enter__() + return self + + def __exit__(self, *args, **kwargs): + # Call parent's exit to clean up base patches + super().__exit__(*args, **kwargs) diff --git a/test/unit/aio/test_auth_async.py b/test/unit/aio/test_auth_async.py index b36a64d0eb..ca871d3cb5 100644 --- a/test/unit/aio/test_auth_async.py +++ b/test/unit/aio/test_auth_async.py @@ -330,3 +330,13 @@ async def test_authbyplugin_abc_api(): 'password': , \ 'kwargs': })""" ) + + +def test_mro(): + """Ensure that methods from AuthByPluginAsync override those from AuthByPlugin.""" + from snowflake.connector.aio.auth import AuthByPlugin as AuthByPluginAsync + from snowflake.connector.auth import AuthByPlugin as AuthByPluginSync + + assert AuthByDefault.mro().index(AuthByPluginAsync) < AuthByDefault.mro().index( + AuthByPluginSync + ) diff --git a/test/unit/aio/test_auth_keypair_async.py b/test/unit/aio/test_auth_keypair_async.py index 2b7cd6df67..866b8bed1e 100644 --- a/test/unit/aio/test_auth_keypair_async.py +++ b/test/unit/aio/test_auth_keypair_async.py @@ -130,6 +130,16 @@ async def test_renew_token(mockPrepare): assert mockPrepare.called +def test_mro(): + """Ensure that methods from AuthByPluginAsync override those from AuthByPlugin.""" + from snowflake.connector.aio.auth import AuthByPlugin as AuthByPluginAsync + from snowflake.connector.auth import AuthByPlugin as AuthByPluginSync + + assert AuthByKeyPair.mro().index(AuthByPluginAsync) < AuthByKeyPair.mro().index( + AuthByPluginSync + ) + + def _init_rest(application, post_requset): connection = mock_connection() connection.errorhandler = Mock(return_value=None) diff --git a/test/unit/aio/test_auth_no_auth_async.py b/test/unit/aio/test_auth_no_auth_async.py index 0c5585281b..cc2bb5d530 100644 --- a/test/unit/aio/test_auth_no_auth_async.py +++ b/test/unit/aio/test_auth_no_auth_async.py @@ -39,3 +39,14 @@ async def test_auth_no_auth(): assert ( reauth_response == expected_reauth_response ), f'reauthenticate(foo="bar") is expected to return {expected_reauth_response}, but returns {reauth_response}' + + +def test_mro(): + """Ensure that methods from AuthByPluginAsync override those from AuthByPlugin.""" + from snowflake.connector.aio.auth import AuthByPlugin as AuthByPluginAsync + from snowflake.connector.aio.auth._no_auth import AuthNoAuth + from snowflake.connector.auth import AuthByPlugin as AuthByPluginSync + + assert AuthNoAuth.mro().index(AuthByPluginAsync) < AuthNoAuth.mro().index( + AuthByPluginSync + ) diff --git a/test/unit/aio/test_auth_oauth_async.py b/test/unit/aio/test_auth_oauth_async.py index 1c99c1f123..fc353224db 100644 --- a/test/unit/aio/test_auth_oauth_async.py +++ b/test/unit/aio/test_auth_oauth_async.py @@ -16,3 +16,13 @@ async def test_auth_oauth(): await auth.update_body(body) assert body["data"]["TOKEN"] == token, body assert body["data"]["AUTHENTICATOR"] == "OAUTH", body + + +def test_mro(): + """Ensure that methods from AuthByPluginAsync override those from AuthByPlugin.""" + from snowflake.connector.aio.auth import AuthByPlugin as AuthByPluginAsync + from snowflake.connector.auth import AuthByPlugin as AuthByPluginSync + + assert AuthByOAuth.mro().index(AuthByPluginAsync) < AuthByOAuth.mro().index( + AuthByPluginSync + ) diff --git a/test/unit/aio/test_auth_okta_async.py b/test/unit/aio/test_auth_okta_async.py index c2ceee78d3..0b20f0ec33 100644 --- a/test/unit/aio/test_auth_okta_async.py +++ b/test/unit/aio/test_auth_okta_async.py @@ -346,3 +346,13 @@ async def post_request(url, headers, body, **kwargs): connection._rest = rest rest._post_request = post_request return rest + + +def test_mro(): + """Ensure that methods from AuthByPluginAsync override those from AuthByPlugin.""" + from snowflake.connector.aio.auth import AuthByPlugin as AuthByPluginAsync + from snowflake.connector.auth import AuthByPlugin as AuthByPluginSync + + assert AuthByOkta.mro().index(AuthByPluginAsync) < AuthByOkta.mro().index( + AuthByPluginSync + ) diff --git a/test/unit/aio/test_auth_pat_async.py b/test/unit/aio/test_auth_pat_async.py index 08c785500c..6927d52290 100644 --- a/test/unit/aio/test_auth_pat_async.py +++ b/test/unit/aio/test_auth_pat_async.py @@ -70,3 +70,13 @@ async def mock_post_request(request, url, headers, json_body, **kwargs): assert isinstance(conn.auth_class, AuthByPAT) await conn.close() + + +def test_mro(): + """Ensure that methods from AuthByPluginAsync override those from AuthByPlugin.""" + from snowflake.connector.aio.auth import AuthByPlugin as AuthByPluginAsync + from snowflake.connector.auth import AuthByPlugin as AuthByPluginSync + + assert AuthByPAT.mro().index(AuthByPluginAsync) < AuthByPAT.mro().index( + AuthByPluginSync + ) diff --git a/test/unit/aio/test_auth_usrpwdmfa_async.py b/test/unit/aio/test_auth_usrpwdmfa_async.py new file mode 100644 index 0000000000..5c5ba5dea9 --- /dev/null +++ b/test/unit/aio/test_auth_usrpwdmfa_async.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from snowflake.connector.aio.auth._usrpwdmfa import AuthByUsrPwdMfa + + +def test_mro(): + """Ensure that methods from AuthByPluginAsync override those from AuthByPlugin.""" + from snowflake.connector.aio.auth import AuthByPlugin as AuthByPluginAsync + from snowflake.connector.auth import AuthByPlugin as AuthByPluginSync + + assert AuthByUsrPwdMfa.mro().index(AuthByPluginAsync) < AuthByUsrPwdMfa.mro().index( + AuthByPluginSync + ) diff --git a/test/unit/aio/test_auth_webbrowser_async.py b/test/unit/aio/test_auth_webbrowser_async.py index 758529137f..d93aad0b0c 100644 --- a/test/unit/aio/test_auth_webbrowser_async.py +++ b/test/unit/aio/test_auth_webbrowser_async.py @@ -871,3 +871,17 @@ async def test_auth_webbrowser_socket_reuseport_option_not_set_with_no_flag( assert not rest._connection.errorhandler.called # no error assert auth.assertion_content == ref_token + + +def test_mro(): + """Ensure that methods from AuthByPluginAsync override those from AuthByPlugin.""" + from snowflake.connector.aio.auth import AuthByPlugin as AuthByPluginAsync + from snowflake.connector.auth import AuthByPlugin as AuthByPluginSync + + assert AuthByWebBrowser.mro().index( + AuthByPluginAsync + ) < AuthByWebBrowser.mro().index(AuthByPluginSync) + + assert AuthByIdToken.mro().index(AuthByPluginAsync) < AuthByIdToken.mro().index( + AuthByPluginSync + ) diff --git a/test/unit/aio/test_auth_workload_identity_async.py b/test/unit/aio/test_auth_workload_identity_async.py new file mode 100644 index 0000000000..13c073d3be --- /dev/null +++ b/test/unit/aio/test_auth_workload_identity_async.py @@ -0,0 +1,433 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import asyncio +import json +import logging +from base64 import b64decode +from unittest import mock +from urllib.parse import parse_qs, urlparse + +import aiohttp +import jwt +import pytest + +from snowflake.connector.aio._wif_util import AttestationProvider +from snowflake.connector.aio.auth import AuthByWorkloadIdentity +from snowflake.connector.errors import ProgrammingError + +from ...csp_helpers import gen_dummy_id_token +from .csp_helpers_async import FakeAwsEnvironmentAsync, FakeGceMetadataServiceAsync + +logger = logging.getLogger(__name__) + + +async def extract_api_data(auth_class: AuthByWorkloadIdentity): + """Extracts the 'data' portion of the request body populated by the given auth class.""" + req_body = {"data": {}} + await auth_class.update_body(req_body) + return req_body["data"] + + +def verify_aws_token(token: str, region: str): + """Performs some basic checks on a 'token' produced for AWS, to ensure it includes the expected fields.""" + decoded_token = json.loads(b64decode(token)) + + parsed_url = urlparse(decoded_token["url"]) + assert parsed_url.scheme == "https" + assert parsed_url.hostname == f"sts.{region}.amazonaws.com" + query_string = parse_qs(parsed_url.query) + assert query_string.get("Action")[0] == "GetCallerIdentity" + assert query_string.get("Version")[0] == "2011-06-15" + + assert decoded_token["method"] == "POST" + + headers = decoded_token["headers"] + assert set(headers.keys()) == { + "Host", + "X-Snowflake-Audience", + "X-Amz-Date", + "X-Amz-Security-Token", + "Authorization", + } + assert headers["Host"] == f"sts.{region}.amazonaws.com" + assert headers["X-Snowflake-Audience"] == "snowflakecomputing.com" + + +# -- OIDC Tests -- + + +async def test_explicit_oidc_valid_inline_token_plumbed_to_api(): + dummy_token = gen_dummy_id_token(sub="service-1", iss="issuer-1") + auth_class = AuthByWorkloadIdentity( + provider=AttestationProvider.OIDC, token=dummy_token + ) + await auth_class.prepare() + + assert await extract_api_data(auth_class) == { + "AUTHENTICATOR": "WORKLOAD_IDENTITY", + "PROVIDER": "OIDC", + "TOKEN": dummy_token, + } + + +async def test_explicit_oidc_valid_inline_token_generates_unique_assertion_content(): + dummy_token = gen_dummy_id_token(sub="service-1", iss="issuer-1") + auth_class = AuthByWorkloadIdentity( + provider=AttestationProvider.OIDC, token=dummy_token + ) + await auth_class.prepare() + assert ( + auth_class.assertion_content + == '{"_provider":"OIDC","iss":"issuer-1","sub":"service-1"}' + ) + + +async def test_explicit_oidc_invalid_inline_token_raises_error(): + invalid_token = "not-a-jwt" + auth_class = AuthByWorkloadIdentity( + provider=AttestationProvider.OIDC, token=invalid_token + ) + with pytest.raises(ProgrammingError) as excinfo: + await auth_class.prepare() + assert "No workload identity credential was found for 'OIDC'" in str(excinfo.value) + + +async def test_explicit_oidc_no_token_raises_error(): + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.OIDC, token=None) + with pytest.raises(ProgrammingError) as excinfo: + await auth_class.prepare() + assert "No workload identity credential was found for 'OIDC'" in str(excinfo.value) + + +# -- AWS Tests -- + + +async def test_explicit_aws_no_auth_raises_error( + fake_aws_environment: FakeAwsEnvironmentAsync, +): + fake_aws_environment.credentials = None + + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) + with pytest.raises(ProgrammingError) as excinfo: + await auth_class.prepare() + assert "No workload identity credential was found for 'AWS'" in str(excinfo.value) + + +async def test_explicit_aws_encodes_audience_host_signature_to_api( + fake_aws_environment: FakeAwsEnvironmentAsync, +): + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) + await auth_class.prepare() + + data = await extract_api_data(auth_class) + assert data["AUTHENTICATOR"] == "WORKLOAD_IDENTITY" + assert data["PROVIDER"] == "AWS" + verify_aws_token(data["TOKEN"], fake_aws_environment.region) + + +async def test_explicit_aws_uses_regional_hostname( + fake_aws_environment: FakeAwsEnvironmentAsync, +): + fake_aws_environment.region = "antarctica-northeast-3" + + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) + await auth_class.prepare() + + data = await extract_api_data(auth_class) + decoded_token = json.loads(b64decode(data["TOKEN"])) + hostname_from_url = urlparse(decoded_token["url"]).hostname + hostname_from_header = decoded_token["headers"]["Host"] + + expected_hostname = "sts.antarctica-northeast-3.amazonaws.com" + assert expected_hostname == hostname_from_url + assert expected_hostname == hostname_from_header + + +async def test_explicit_aws_generates_unique_assertion_content( + fake_aws_environment: FakeAwsEnvironmentAsync, +): + fake_aws_environment.arn = ( + "arn:aws:sts::123456789:assumed-role/A-Different-Role/i-34afe100cad287fab" + ) + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) + await auth_class.prepare() + + assert ( + '{"_provider":"AWS","arn":"arn:aws:sts::123456789:assumed-role/A-Different-Role/i-34afe100cad287fab"}' + == auth_class.assertion_content + ) + + +# -- GCP Tests -- + + +def _mock_aiohttp_exception(exception): + class MockResponse: + def __init__(self, exception): + self.exception = exception + + async def __aenter__(self): + raise self.exception + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + def mock_request(*args, **kwargs): + return MockResponse(exception) + + return mock_request + + +@pytest.mark.parametrize( + "exception", + [ + aiohttp.ClientError(), + aiohttp.ConnectionTimeoutError(), + asyncio.TimeoutError(), + ], +) +async def test_explicit_gcp_metadata_server_error_raises_auth_error(exception): + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.GCP) + + mock_request = _mock_aiohttp_exception(exception) + + with mock.patch("aiohttp.ClientSession.request", side_effect=mock_request): + with pytest.raises(ProgrammingError) as excinfo: + await auth_class.prepare() + assert "No workload identity credential was found for 'GCP'" in str( + excinfo.value + ) + + +async def test_explicit_gcp_wrong_issuer_raises_error( + fake_gce_metadata_service: FakeGceMetadataServiceAsync, +): + fake_gce_metadata_service.iss = "not-google" + + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.GCP) + with pytest.raises(ProgrammingError) as excinfo: + await auth_class.prepare() + assert "No workload identity credential was found for 'GCP'" in str(excinfo.value) + + +async def test_explicit_gcp_plumbs_token_to_api( + fake_gce_metadata_service: FakeGceMetadataServiceAsync, +): + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.GCP) + await auth_class.prepare() + + assert await extract_api_data(auth_class) == { + "AUTHENTICATOR": "WORKLOAD_IDENTITY", + "PROVIDER": "GCP", + "TOKEN": fake_gce_metadata_service.token, + } + + +async def test_explicit_gcp_generates_unique_assertion_content( + fake_gce_metadata_service: FakeGceMetadataServiceAsync, +): + fake_gce_metadata_service.sub = "123456" + + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.GCP) + await auth_class.prepare() + + assert auth_class.assertion_content == '{"_provider":"GCP","sub":"123456"}' + + +# -- Azure Tests -- + + +@pytest.mark.parametrize( + "exception", + [ + aiohttp.ClientError(), + asyncio.TimeoutError(), + aiohttp.ConnectionTimeoutError(), + ], +) +async def test_explicit_azure_metadata_server_error_raises_auth_error(exception): + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) + + mock_request = _mock_aiohttp_exception(exception) + + with mock.patch("aiohttp.ClientSession.request", side_effect=mock_request): + with pytest.raises(ProgrammingError) as excinfo: + await auth_class.prepare() + assert "No workload identity credential was found for 'AZURE'" in str( + excinfo.value + ) + + +async def test_explicit_azure_wrong_issuer_raises_error(fake_azure_metadata_service): + fake_azure_metadata_service.iss = "https://notazure.com" + + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) + with pytest.raises(ProgrammingError) as excinfo: + await auth_class.prepare() + assert "No workload identity credential was found for 'AZURE'" in str(excinfo.value) + + +@pytest.mark.parametrize( + "issuer", + [ + "https://sts.windows.net/067802cd-8f92-4c7c-bceb-ea8f15d31cc5", + "https://login.microsoftonline.com/067802cd-8f92-4c7c-bceb-ea8f15d31cc5", + "https://login.microsoftonline.com/067802cd-8f92-4c7c-bceb-ea8f15d31cc5/v2.0", + ], + ids=["v1", "v2_without_suffix", "v2_with_suffix"], +) +async def test_explicit_azure_v1_and_v2_issuers_accepted( + fake_azure_metadata_service, issuer +): + fake_azure_metadata_service.iss = issuer + + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) + await auth_class.prepare() + + assert issuer == json.loads(auth_class.assertion_content)["iss"] + + +async def test_explicit_azure_plumbs_token_to_api(fake_azure_metadata_service): + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) + await auth_class.prepare() + + assert await extract_api_data(auth_class) == { + "AUTHENTICATOR": "WORKLOAD_IDENTITY", + "PROVIDER": "AZURE", + "TOKEN": fake_azure_metadata_service.token, + } + + +async def test_explicit_azure_generates_unique_assertion_content( + fake_azure_metadata_service, +): + fake_azure_metadata_service.iss = ( + "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" + ) + fake_azure_metadata_service.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" + + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) + await auth_class.prepare() + + assert ( + '{"_provider":"AZURE","iss":"https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd","sub":"611ab25b-2e81-4e18-92a7-b21f2bebb269"}' + == auth_class.assertion_content + ) + + +async def test_explicit_azure_uses_default_entra_resource_if_unspecified( + fake_azure_metadata_service, +): + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) + await auth_class.prepare() + + token = fake_azure_metadata_service.token + parsed = jwt.decode(token, options={"verify_signature": False}) + assert ( + parsed["aud"] == "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad" + ) # the default entra resource defined in wif_util.py. + + +async def test_explicit_azure_uses_explicit_entra_resource(fake_azure_metadata_service): + auth_class = AuthByWorkloadIdentity( + provider=AttestationProvider.AZURE, entra_resource="api://non-standard" + ) + await auth_class.prepare() + + token = fake_azure_metadata_service.token + parsed = jwt.decode(token, options={"verify_signature": False}) + assert parsed["aud"] == "api://non-standard" + + +# -- Auto-detect Tests -- + + +async def test_autodetect_aws_present( + no_metadata_service, fake_aws_environment: FakeAwsEnvironmentAsync +): + auth_class = AuthByWorkloadIdentity(provider=None) + await auth_class.prepare() + + data = await extract_api_data(auth_class) + assert data["AUTHENTICATOR"] == "WORKLOAD_IDENTITY" + assert data["PROVIDER"] == "AWS" + verify_aws_token(data["TOKEN"], fake_aws_environment.region) + + +@mock.patch("snowflake.connector.aio._wif_util.AioInstanceMetadataRegionFetcher") +async def test_autodetect_gcp_present( + mock_fetcher, + fake_gce_metadata_service: FakeGceMetadataServiceAsync, +): + # Mock AioInstanceMetadataRegionFetcher to return None properly as an async function + async def mock_retrieve_region(): + return None + + mock_fetcher.return_value.retrieve_region.side_effect = mock_retrieve_region + + auth_class = AuthByWorkloadIdentity(provider=None) + await auth_class.prepare() + + assert await extract_api_data(auth_class) == { + "AUTHENTICATOR": "WORKLOAD_IDENTITY", + "PROVIDER": "GCP", + "TOKEN": fake_gce_metadata_service.token, + } + + +@mock.patch("snowflake.connector.aio._wif_util.AioInstanceMetadataRegionFetcher") +async def test_autodetect_azure_present(mock_fetcher, fake_azure_metadata_service): + # Mock AioInstanceMetadataRegionFetcher to return None properly as an async function + async def mock_retrieve_region(): + return None + + mock_fetcher.return_value.retrieve_region.side_effect = mock_retrieve_region + + auth_class = AuthByWorkloadIdentity(provider=None) + await auth_class.prepare() + + assert await extract_api_data(auth_class) == { + "AUTHENTICATOR": "WORKLOAD_IDENTITY", + "PROVIDER": "AZURE", + "TOKEN": fake_azure_metadata_service.token, + } + + +async def test_autodetect_oidc_present(no_metadata_service): + dummy_token = gen_dummy_id_token(sub="service-1", iss="issuer-1") + auth_class = AuthByWorkloadIdentity(provider=None, token=dummy_token) + await auth_class.prepare() + + assert await extract_api_data(auth_class) == { + "AUTHENTICATOR": "WORKLOAD_IDENTITY", + "PROVIDER": "OIDC", + "TOKEN": dummy_token, + } + + +@mock.patch("snowflake.connector.aio._wif_util.AioInstanceMetadataRegionFetcher") +async def test_autodetect_no_provider_raises_error(mock_fetcher, no_metadata_service): + # Mock AioInstanceMetadataRegionFetcher to return None properly as an async function + async def mock_retrieve_region(): + return None + + mock_fetcher.return_value.retrieve_region.side_effect = mock_retrieve_region + + auth_class = AuthByWorkloadIdentity(provider=None, token=None) + with pytest.raises(ProgrammingError) as excinfo: + await auth_class.prepare() + assert "No workload identity credential was found for 'auto-detect" in str( + excinfo.value + ) + + +def test_mro(): + """Ensure that methods from AuthByPluginAsync override those from AuthByPlugin.""" + from snowflake.connector.aio.auth import AuthByPlugin as AuthByPluginAsync + from snowflake.connector.auth import AuthByPlugin as AuthByPluginSync + + assert AuthByWorkloadIdentity.mro().index( + AuthByPluginAsync + ) < AuthByWorkloadIdentity.mro().index(AuthByPluginSync) diff --git a/test/unit/aio/test_connection_async_unit.py b/test/unit/aio/test_connection_async_unit.py index f04ec8aacd..43a6c63324 100644 --- a/test/unit/aio/test_connection_async_unit.py +++ b/test/unit/aio/test_connection_async_unit.py @@ -48,6 +48,7 @@ OperationalError, ProgrammingError, ) +from snowflake.connector.wif_util import AttestationProvider def fake_connector(**kwargs) -> snowflake.connector.aio.SnowflakeConnection: @@ -61,6 +62,13 @@ def fake_connector(**kwargs) -> snowflake.connector.aio.SnowflakeConnection: ) +def write_temp_file(file_path: Path, contents: str) -> Path: + """Write the given string text to the given path, chmods it to be accessible, and returns the same path.""" + file_path.write_text(contents) + file_path.chmod(stat.S_IRUSR | stat.S_IWUSR) + return file_path + + @asynccontextmanager async def fake_db_conn(**kwargs): conn = fake_connector(**kwargs) @@ -567,3 +575,111 @@ async def test_otel_error_message_async(caplog, mock_post_requests): ] assert len(important_records) == 1 assert important_records[0].exc_text is not None + + +@pytest.mark.parametrize( + "dependent_param,value", + [ + ("workload_identity_provider", "AWS"), + ( + "workload_identity_entra_resource", + "api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b", + ), + ], +) +async def test_cannot_set_dependent_params_without_wlid_authenticator( + mock_post_requests, dependent_param, value +): + with pytest.raises(ProgrammingError) as excinfo: + await snowflake.connector.aio.connect( + user="user", + account="account", + password="password", + **{dependent_param: value}, + ) + assert ( + f"{dependent_param} was set but authenticator was not set to WORKLOAD_IDENTITY" + in str(excinfo.value) + ) + + +async def test_cannot_set_wlid_authenticator_without_env_variable(mock_post_requests): + with pytest.raises(ProgrammingError) as excinfo: + await snowflake.connector.aio.connect( + account="account", authenticator="WORKLOAD_IDENTITY" + ) + assert ( + "Please set the 'SF_ENABLE_EXPERIMENTAL_AUTHENTICATION' environment variable to use the 'WORKLOAD_IDENTITY' authenticator" + in str(excinfo.value) + ) + + +async def test_connection_params_are_plumbed_into_authbyworkloadidentity(monkeypatch): + async def mock_authenticate(*_): + pass + + with monkeypatch.context() as m: + m.setattr( + "snowflake.connector.aio._connection.SnowflakeConnection._authenticate", + mock_authenticate, + ) + m.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "") # Can be set to anything. + + conn = await snowflake.connector.aio.connect( + account="my_account_1", + workload_identity_provider=AttestationProvider.AWS, + workload_identity_entra_resource="api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b", + token="my_token", + authenticator="WORKLOAD_IDENTITY", + ) + assert conn.auth_class.provider == AttestationProvider.AWS + assert ( + conn.auth_class.entra_resource + == "api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b" + ) + assert conn.auth_class.token == "my_token" + + +async def test_toml_connection_params_are_plumbed_into_authbyworkloadidentity( + monkeypatch, tmp_path +): + token_file = write_temp_file(tmp_path / "token.txt", contents="my_token") + # On Windows, this path includes backslashes which will result in errors while parsing the TOML. + # Escape the backslashes to ensure it parses correctly. + token_file_path_escaped = str(token_file).replace("\\", "\\\\") + connections_file = write_temp_file( + tmp_path / "connections.toml", + contents=dedent( + f"""\ + [default] + account = "my_account_1" + authenticator = "WORKLOAD_IDENTITY" + workload_identity_provider = "OIDC" + workload_identity_entra_resource = "api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b" + token_file_path = "{token_file_path_escaped}" + """ + ), + ) + + async def mock_authenticate(*_): + pass + + with monkeypatch.context() as m: + m.delenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", raising=False) + m.delenv("SNOWFLAKE_CONNECTIONS", raising=False) + m.setattr(CONFIG_MANAGER, "conf_file_cache", None) + m.setattr( + "snowflake.connector.aio._connection.SnowflakeConnection._authenticate", + mock_authenticate, + ) + m.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "") + + conn = await snowflake.connector.aio.connect( + connections_file_path=connections_file + ) + assert conn.auth_class.provider == AttestationProvider.OIDC + assert ( + conn.auth_class.entra_resource + == "api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b" + ) + assert conn.auth_class.token == "my_token" diff --git a/test/unit/aio/test_cursor_async_unit.py b/test/unit/aio/test_cursor_async_unit.py index 3cf5e687a6..95a431c907 100644 --- a/test/unit/aio/test_cursor_async_unit.py +++ b/test/unit/aio/test_cursor_async_unit.py @@ -6,7 +6,8 @@ import asyncio import unittest.mock -from unittest.mock import MagicMock, patch +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -99,3 +100,97 @@ async def mock_cmd_query(*args, **kwargs): # query cancel request should be sent upon timeout assert mockCancelQuery.called + + +# The _upload/_download/_upload_stream/_download_stream are newly introduced +# and therefore should not be tested in old drivers. +@pytest.mark.skipolddriver +class TestUploadDownloadMethods(IsolatedAsyncioTestCase): + """Test the _upload/_download/_upload_stream/_download_stream methods.""" + + @patch("snowflake.connector.aio._file_transfer_agent.SnowflakeFileTransferAgent") + async def test_download(self, MockFileTransferAgent): + cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks( + MockFileTransferAgent + ) + + # Call _download method + await cursor._download("@st", "/tmp/test.txt", {}) + + # In the process of _download execution, we expect these methods to be called + # - parse_file_operation in connection._file_operation_parser + # - execute in SnowflakeFileTransferAgent + # And we do not expect this method to be involved + # - download_as_stream of connection._stream_downloader + fake_conn._file_operation_parser.parse_file_operation.assert_called_once() + fake_conn._stream_downloader.download_as_stream.assert_not_called() + mock_file_transfer_agent_instance.execute.assert_called_once() + + @patch("snowflake.connector.aio._file_transfer_agent.SnowflakeFileTransferAgent") + async def test_upload(self, MockFileTransferAgent): + cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks( + MockFileTransferAgent + ) + + # Call _upload method + await cursor._upload("/tmp/test.txt", "@st", {}) + + # In the process of _upload execution, we expect these methods to be called + # - parse_file_operation in connection._file_operation_parser + # - execute in SnowflakeFileTransferAgent + # And we do not expect this method to be involved + # - download_as_stream of connection._stream_downloader + fake_conn._file_operation_parser.parse_file_operation.assert_called_once() + fake_conn._stream_downloader.download_as_stream.assert_not_called() + mock_file_transfer_agent_instance.execute.assert_called_once() + + @patch("snowflake.connector.aio._file_transfer_agent.SnowflakeFileTransferAgent") + async def test_download_stream(self, MockFileTransferAgent): + cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks( + MockFileTransferAgent + ) + + # Call _download_stream method + await cursor._download_stream("@st/test.txt", decompress=True) + + # In the process of _download_stream execution, we expect these methods to be called + # - parse_file_operation in connection._file_operation_parser + # - download_as_stream of connection._stream_downloader + # And we do not expect this method to be involved + # - execute in SnowflakeFileTransferAgent + fake_conn._file_operation_parser.parse_file_operation.assert_called_once() + fake_conn._stream_downloader.download_as_stream.assert_called_once() + mock_file_transfer_agent_instance.execute.assert_not_called() + + @patch("snowflake.connector.aio._file_transfer_agent.SnowflakeFileTransferAgent") + async def test_upload_stream(self, MockFileTransferAgent): + cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks( + MockFileTransferAgent + ) + + # Call _upload_stream method + fd = MagicMock() + await cursor._upload_stream(fd, "@st/test.txt", {}) + + # In the process of _upload_stream execution, we expect these methods to be called + # - parse_file_operation in connection._file_operation_parser + # - execute in SnowflakeFileTransferAgent + # And we do not expect this method to be involved + # - download_as_stream of connection._stream_downloader + fake_conn._file_operation_parser.parse_file_operation.assert_called_once() + fake_conn._stream_downloader.download_as_stream.assert_not_called() + mock_file_transfer_agent_instance.execute.assert_called_once() + + def _setup_mocks(self, MockFileTransferAgent): + mock_file_transfer_agent_instance = MockFileTransferAgent.return_value + mock_file_transfer_agent_instance.execute = AsyncMock(return_value=None) + + fake_conn = FakeConnection() + fake_conn._file_operation_parser = MagicMock() + fake_conn._stream_downloader = MagicMock() + fake_conn._stream_downloader.download_as_stream = AsyncMock() + + cursor = SnowflakeCursor(fake_conn) + cursor.reset = MagicMock() + cursor._init_result_and_meta = AsyncMock() + return cursor, fake_conn, mock_file_transfer_agent_instance diff --git a/test/unit/aio/test_gcs_client_async.py b/test/unit/aio/test_gcs_client_async.py index 4ff648e620..483674238a 100644 --- a/test/unit/aio/test_gcs_client_async.py +++ b/test/unit/aio/test_gcs_client_async.py @@ -330,7 +330,7 @@ async def test_get_file_header_none_with_presigned_url(tmp_path): ) storage_credentials = Mock() storage_credentials.creds = {} - stage_info = Mock() + stage_info: dict[str, any] = dict() connection = Mock() client = SnowflakeGCSRestClient( meta, storage_credentials, stage_info, connection, "" @@ -339,3 +339,102 @@ async def test_get_file_header_none_with_presigned_url(tmp_path): await client._update_presigned_url() file_header = await client.get_file_header(meta.name) assert file_header is None + + +@pytest.mark.parametrize( + "region,return_url,use_regional_url,endpoint,gcs_use_virtual_endpoints", + [ + ( + "US-CENTRAL1", + "https://storage.us-central1.rep.googleapis.com", + True, + None, + False, + ), + ( + "ME-CENTRAL2", + "https://storage.me-central2.rep.googleapis.com", + True, + None, + False, + ), + ("US-CENTRAL1", "https://storage.googleapis.com", False, None, False), + ("US-CENTRAL1", "https://storage.googleapis.com", False, None, False), + ("US-CENTRAL1", "https://location.storage.googleapis.com", False, None, True), + ("US-CENTRAL1", "https://location.storage.googleapis.com", True, None, True), + ( + "US-CENTRAL1", + "https://overriddenurl.com", + False, + "https://overriddenurl.com", + False, + ), + ( + "US-CENTRAL1", + "https://overriddenurl.com", + True, + "https://overriddenurl.com", + False, + ), + ( + "US-CENTRAL1", + "https://overriddenurl.com", + True, + "https://overriddenurl.com", + True, + ), + ( + "US-CENTRAL1", + "https://overriddenurl.com", + False, + "https://overriddenurl.com", + False, + ), + ( + "US-CENTRAL1", + "https://overriddenurl.com", + False, + "https://overriddenurl.com", + True, + ), + ], +) +def test_url(region, return_url, use_regional_url, endpoint, gcs_use_virtual_endpoints): + gcs_location = SnowflakeGCSRestClient.get_location( + stage_location="location", + use_regional_url=use_regional_url, + region=region, + endpoint=endpoint, + use_virtual_endpoints=gcs_use_virtual_endpoints, + ) + assert gcs_location.endpoint == return_url + + +@pytest.mark.parametrize( + "region,use_regional_url,return_value", + [ + ("ME-CENTRAL2", False, True), + ("ME-CENTRAL2", True, True), + ("US-CENTRAL1", False, False), + ("US-CENTRAL1", True, True), + ], +) +def test_use_regional_url(region, use_regional_url, return_value): + meta = SnowflakeFileMeta( + name="path/some_file", + src_file_name="path/some_file", + stage_location_type="GCS", + presigned_url="www.example.com", + ) + storage_credentials = Mock() + storage_credentials.creds = {} + stage_info: dict[str, any] = dict() + stage_info["region"] = region + stage_info["useRegionalUrl"] = use_regional_url + connection = Mock() + + client = SnowflakeGCSRestClient( + meta, storage_credentials, stage_info, connection, "" + ) + + assert client.use_regional_url == return_value diff --git a/test/unit/aio/test_ocsp.py b/test/unit/aio/test_ocsp.py index d200e863aa..afc88c60b8 100644 --- a/test/unit/aio/test_ocsp.py +++ b/test/unit/aio/test_ocsp.py @@ -233,7 +233,6 @@ async def test_ocsp_bad_validity(): del environ["SF_TEST_OCSP_FORCE_BAD_RESPONSE_VALIDITY"] -@pytest.mark.flaky(reruns=3) async def test_ocsp_single_endpoint(): environ["SF_OCSP_ACTIVATE_NEW_ENDPOINT"] = "True" SnowflakeOCSP.clear_cache() @@ -257,7 +256,6 @@ async def test_ocsp_by_post_method(): assert await ocsp.validate(url, connection), f"Failed to validate: {url}" -@pytest.mark.flaky(reruns=3) async def test_ocsp_with_file_cache(tmpdir): """OCSP tests and the cache server and file.""" tmp_dir = str(tmpdir.mkdir("ocsp_response_cache")) @@ -271,7 +269,6 @@ async def test_ocsp_with_file_cache(tmpdir): assert await ocsp.validate(url, connection), f"Failed to validate: {url}" -@pytest.mark.flaky(reruns=3) async def test_ocsp_with_bogus_cache_files( tmpdir, random_ocsp_response_validation_cache ): @@ -312,7 +309,6 @@ async def test_ocsp_with_bogus_cache_files( ), f"Failed to validate: {hostname}" -@pytest.mark.flaky(reruns=3) async def test_ocsp_with_outdated_cache(tmpdir, random_ocsp_response_validation_cache): with mock.patch( "snowflake.connector.ocsp_snowflake.OCSP_RESPONSE_VALIDATION_CACHE", @@ -372,7 +368,6 @@ async def _store_cache_in_file(tmpdir, target_hosts=None): return filename, target_hosts -@pytest.mark.flaky(reruns=3) async def test_ocsp_with_invalid_cache_file(): """OCSP tests with an invalid cache file.""" SnowflakeOCSP.clear_cache() # reset the memory cache @@ -382,7 +377,6 @@ async def test_ocsp_with_invalid_cache_file(): assert await ocsp.validate(url, connection), f"Failed to validate: {url}" -@pytest.mark.flaky(reruns=3) @mock.patch( "snowflake.connector.aio._ocsp_snowflake.SnowflakeOCSP._fetch_ocsp_response", new_callable=mock.AsyncMock, @@ -406,7 +400,6 @@ async def test_ocsp_cache_when_server_is_down( assert not cache_data, "no cache should present because of broken pipe" -@pytest.mark.flaky(reruns=3) async def test_concurrent_ocsp_requests(tmpdir): """Run OCSP revocation checks in parallel. The memory and file caches are deleted randomly.""" cache_file_name = path.join(str(tmpdir), "cache_file.txt") diff --git a/test/unit/aio/test_programmatic_access_token_async.py b/test/unit/aio/test_programmatic_access_token_async.py index 4d4e14f088..a663a55b76 100644 --- a/test/unit/aio/test_programmatic_access_token_async.py +++ b/test/unit/aio/test_programmatic_access_token_async.py @@ -27,7 +27,6 @@ def wiremock_client() -> Generator[WiremockClient | Any, Any, None]: @pytest.mark.skipolddriver -@pytest.mark.asyncio async def test_valid_pat_async(wiremock_client: WiremockClient) -> None: wiremock_data_dir = ( pathlib.Path(__file__).parent.parent.parent @@ -65,7 +64,6 @@ async def test_valid_pat_async(wiremock_client: WiremockClient) -> None: @pytest.mark.skipolddriver -@pytest.mark.asyncio async def test_invalid_pat_async(wiremock_client: WiremockClient) -> None: wiremock_data_dir = ( pathlib.Path(__file__).parent.parent.parent @@ -90,42 +88,3 @@ async def test_invalid_pat_async(wiremock_client: WiremockClient) -> None: await connection.connect() assert str(execinfo.value).endswith("Programmatic access token is invalid.") - - -@pytest.mark.skipolddriver -@pytest.mark.asyncio -async def test_pat_as_password_async(wiremock_client: WiremockClient) -> None: - wiremock_data_dir = ( - pathlib.Path(__file__).parent.parent.parent - / "data" - / "wiremock" - / "mappings" - / "auth" - / "pat" - ) - - wiremock_generic_data_dir = ( - pathlib.Path(__file__).parent.parent.parent - / "data" - / "wiremock" - / "mappings" - / "generic" - ) - - wiremock_client.import_mapping(wiremock_data_dir / "successful_flow.json") - wiremock_client.add_mapping( - wiremock_generic_data_dir / "snowflake_disconnect_successful.json" - ) - - connection = SnowflakeConnection( - user="testUser", - authenticator=PROGRAMMATIC_ACCESS_TOKEN, - token=None, - password="some PAT", - account="testAccount", - protocol="http", - host=wiremock_client.wiremock_host, - port=wiremock_client.wiremock_http_port, - ) - await connection.connect() - await connection.close() diff --git a/test/unit/aio/test_s3_util_async.py b/test/unit/aio/test_s3_util_async.py index 821246aafb..7c3c299d4c 100644 --- a/test/unit/aio/test_s3_util_async.py +++ b/test/unit/aio/test_s3_util_async.py @@ -29,14 +29,11 @@ SnowflakeFileMeta, StorageCredential, ) - from snowflake.connector.s3_storage_client import ERRORNO_WSAECONNABORTED from snowflake.connector.vendored.requests import HTTPError except ImportError: # Compatibility for olddriver tests from requests import HTTPError - from snowflake.connector.s3_util import ERRORNO_WSAECONNABORTED # NOQA - SnowflakeFileMeta = dict SnowflakeS3RestClient = None RequestExceedMaxRetryError = None @@ -500,3 +497,46 @@ async def test_accelerate_in_china_endpoint(): 8 * megabyte, ) assert not await rest_client.transfer_accelerate_config() + + +@pytest.mark.parametrize( + "use_s3_regional_url,stage_info_flags,expected", + [ + (False, {}, False), + (True, {}, True), + (False, {"useS3RegionalUrl": True}, True), + (False, {"useRegionalUrl": True}, True), + (True, {"useS3RegionalUrl": False}, True), + (False, {"useS3RegionalUrl": True, "useRegionalUrl": False}, True), + (False, {"useS3RegionalUrl": False, "useRegionalUrl": True}, True), + (False, {"useS3RegionalUrl": False, "useRegionalUrl": False}, False), + ], +) +def test_s3_regional_url_logic_async(use_s3_regional_url, stage_info_flags, expected): + """Tests that the async S3 storage client correctly handles regional URL flags from stage_info.""" + if SnowflakeS3RestClient is None: + pytest.skip("S3 storage client not available") + + meta = SnowflakeFileMeta( + name="path/some_file", + src_file_name="path/some_file", + stage_location_type="S3", + ) + storage_credentials = StorageCredential({}, mock.Mock(), "test") + + stage_info = { + "region": "us-west-2", + "location": "test-bucket", + "endPoint": None, + } + stage_info.update(stage_info_flags) + + client = SnowflakeS3RestClient( + meta=meta, + credentials=storage_credentials, + stage_info=stage_info, + chunk_size=1024, + use_s3_regional_url=use_s3_regional_url, + ) + + assert client.use_s3_regional_url == expected diff --git a/test/unit/conftest.py b/test/unit/conftest.py index 6a72f8b57e..65c2fb02f6 100644 --- a/test/unit/conftest.py +++ b/test/unit/conftest.py @@ -1,13 +1,17 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import pytest from snowflake.connector.telemetry_oob import TelemetryService +from ..csp_helpers import ( + FakeAwsEnvironment, + FakeAzureFunctionMetadataService, + FakeAzureVmMetadataService, + FakeGceMetadataService, + NoMetadataService, +) + @pytest.fixture(autouse=True, scope="session") def disable_oob_telemetry(): @@ -17,3 +21,34 @@ def disable_oob_telemetry(): yield None if original_state: oob_telemetry_service.enable() + + +@pytest.fixture +def no_metadata_service(): + """Emulates an environment without any metadata service.""" + with NoMetadataService() as server: + yield server + + +@pytest.fixture +def fake_aws_environment(): + """Emulates the AWS environment, returning dummy credentials.""" + with FakeAwsEnvironment() as env: + yield env + + +@pytest.fixture( + params=[FakeAzureFunctionMetadataService(), FakeAzureVmMetadataService()], + ids=["azure_function", "azure_vm"], +) +def fake_azure_metadata_service(request): + """Parameterized fixture that emulates both the Azure VM and Azure Functions metadata services.""" + with request.param as server: + yield server + + +@pytest.fixture +def fake_gce_metadata_service(): + """Emulates the GCE metadata service, returning a dummy token.""" + with FakeGceMetadataService() as server: + yield server diff --git a/test/unit/mock_utils.py b/test/unit/mock_utils.py index b6e27d514d..ef4d6de264 100644 --- a/test/unit/mock_utils.py +++ b/test/unit/mock_utils.py @@ -1,6 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# import time from unittest.mock import MagicMock diff --git a/test/unit/test_auth.py b/test/unit/test_auth.py index efd1b43a22..aeef815115 100644 --- a/test/unit/test_auth.py +++ b/test/unit/test_auth.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import inspect diff --git a/test/unit/test_auth_keypair.py b/test/unit/test_auth_keypair.py index 4d7974adbd..8824e822de 100644 --- a/test/unit/test_auth_keypair.py +++ b/test/unit/test_auth_keypair.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from unittest.mock import Mock, PropertyMock, patch diff --git a/test/unit/test_auth_mfa.py b/test/unit/test_auth_mfa.py index 8c7026e553..0deb724b84 100644 --- a/test/unit/test_auth_mfa.py +++ b/test/unit/test_auth_mfa.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from unittest import mock from snowflake.connector import connect diff --git a/test/unit/test_auth_no_auth.py b/test/unit/test_auth_no_auth.py index b63406376b..e89b6b72c5 100644 --- a/test/unit/test_auth_no_auth.py +++ b/test/unit/test_auth_no_auth.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import pytest diff --git a/test/unit/test_auth_oauth.py b/test/unit/test_auth_oauth.py index e10f87cd20..443753ac74 100644 --- a/test/unit/test_auth_oauth.py +++ b/test/unit/test_auth_oauth.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations try: # pragma: no cover diff --git a/test/unit/test_auth_okta.py b/test/unit/test_auth_okta.py index 9066476ba1..efbecfd9eb 100644 --- a/test/unit/test_auth_okta.py +++ b/test/unit/test_auth_okta.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging diff --git a/test/unit/test_auth_webbrowser.py b/test/unit/test_auth_webbrowser.py index 8a138d8f98..d9dfe47a27 100644 --- a/test/unit/test_auth_webbrowser.py +++ b/test/unit/test_auth_webbrowser.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import base64 diff --git a/test/unit/test_auth_workload_identity.py b/test/unit/test_auth_workload_identity.py new file mode 100644 index 0000000000..b5b0f39881 --- /dev/null +++ b/test/unit/test_auth_workload_identity.py @@ -0,0 +1,370 @@ +import json +import logging +from base64 import b64decode +from unittest import mock +from urllib.parse import parse_qs, urlparse + +import jwt +import pytest + +from snowflake.connector.auth import AuthByWorkloadIdentity +from snowflake.connector.errors import ProgrammingError +from snowflake.connector.vendored.requests.exceptions import ( + ConnectTimeout, + HTTPError, + Timeout, +) +from snowflake.connector.wif_util import AttestationProvider + +from ..csp_helpers import FakeAwsEnvironment, FakeGceMetadataService, gen_dummy_id_token + +logger = logging.getLogger(__name__) + + +def extract_api_data(auth_class: AuthByWorkloadIdentity): + """Extracts the 'data' portion of the request body populated by the given auth class.""" + req_body = {"data": {}} + auth_class.update_body(req_body) + return req_body["data"] + + +def verify_aws_token(token: str, region: str): + """Performs some basic checks on a 'token' produced for AWS, to ensure it includes the expected fields.""" + decoded_token = json.loads(b64decode(token)) + + parsed_url = urlparse(decoded_token["url"]) + assert parsed_url.scheme == "https" + assert parsed_url.hostname == f"sts.{region}.amazonaws.com" + query_string = parse_qs(parsed_url.query) + assert query_string.get("Action")[0] == "GetCallerIdentity" + assert query_string.get("Version")[0] == "2011-06-15" + + assert decoded_token["method"] == "POST" + + headers = decoded_token["headers"] + assert set(headers.keys()) == { + "Host", + "X-Snowflake-Audience", + "X-Amz-Date", + "X-Amz-Security-Token", + "Authorization", + } + assert headers["Host"] == f"sts.{region}.amazonaws.com" + assert headers["X-Snowflake-Audience"] == "snowflakecomputing.com" + + +# -- OIDC Tests -- + + +def test_explicit_oidc_valid_inline_token_plumbed_to_api(): + dummy_token = gen_dummy_id_token(sub="service-1", iss="issuer-1") + auth_class = AuthByWorkloadIdentity( + provider=AttestationProvider.OIDC, token=dummy_token + ) + auth_class.prepare() + + assert extract_api_data(auth_class) == { + "AUTHENTICATOR": "WORKLOAD_IDENTITY", + "PROVIDER": "OIDC", + "TOKEN": dummy_token, + } + + +def test_explicit_oidc_valid_inline_token_generates_unique_assertion_content(): + dummy_token = gen_dummy_id_token(sub="service-1", iss="issuer-1") + auth_class = AuthByWorkloadIdentity( + provider=AttestationProvider.OIDC, token=dummy_token + ) + auth_class.prepare() + assert ( + auth_class.assertion_content + == '{"_provider":"OIDC","iss":"issuer-1","sub":"service-1"}' + ) + + +def test_explicit_oidc_invalid_inline_token_raises_error(): + invalid_token = "not-a-jwt" + auth_class = AuthByWorkloadIdentity( + provider=AttestationProvider.OIDC, token=invalid_token + ) + with pytest.raises(ProgrammingError) as excinfo: + auth_class.prepare() + assert "No workload identity credential was found for 'OIDC'" in str(excinfo.value) + + +def test_explicit_oidc_no_token_raises_error(): + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.OIDC, token=None) + with pytest.raises(ProgrammingError) as excinfo: + auth_class.prepare() + assert "No workload identity credential was found for 'OIDC'" in str(excinfo.value) + + +# -- AWS Tests -- + + +def test_explicit_aws_no_auth_raises_error(fake_aws_environment: FakeAwsEnvironment): + fake_aws_environment.credentials = None + + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) + with pytest.raises(ProgrammingError) as excinfo: + auth_class.prepare() + assert "No workload identity credential was found for 'AWS'" in str(excinfo.value) + + +def test_explicit_aws_encodes_audience_host_signature_to_api( + fake_aws_environment: FakeAwsEnvironment, +): + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) + auth_class.prepare() + + data = extract_api_data(auth_class) + assert data["AUTHENTICATOR"] == "WORKLOAD_IDENTITY" + assert data["PROVIDER"] == "AWS" + verify_aws_token(data["TOKEN"], fake_aws_environment.region) + + +def test_explicit_aws_uses_regional_hostname(fake_aws_environment: FakeAwsEnvironment): + fake_aws_environment.region = "antarctica-northeast-3" + + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) + auth_class.prepare() + + data = extract_api_data(auth_class) + decoded_token = json.loads(b64decode(data["TOKEN"])) + hostname_from_url = urlparse(decoded_token["url"]).hostname + hostname_from_header = decoded_token["headers"]["Host"] + + expected_hostname = "sts.antarctica-northeast-3.amazonaws.com" + assert expected_hostname == hostname_from_url + assert expected_hostname == hostname_from_header + + +def test_explicit_aws_generates_unique_assertion_content( + fake_aws_environment: FakeAwsEnvironment, +): + fake_aws_environment.arn = ( + "arn:aws:sts::123456789:assumed-role/A-Different-Role/i-34afe100cad287fab" + ) + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS) + auth_class.prepare() + + assert ( + '{"_provider":"AWS","arn":"arn:aws:sts::123456789:assumed-role/A-Different-Role/i-34afe100cad287fab"}' + == auth_class.assertion_content + ) + + +# -- GCP Tests -- + + +@pytest.mark.parametrize( + "exception", + [ + HTTPError(), + Timeout(), + ConnectTimeout(), + ], +) +def test_explicit_gcp_metadata_server_error_raises_auth_error(exception): + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.GCP) + with mock.patch( + "snowflake.connector.vendored.requests.request", side_effect=exception + ): + with pytest.raises(ProgrammingError) as excinfo: + auth_class.prepare() + assert "No workload identity credential was found for 'GCP'" in str( + excinfo.value + ) + + +def test_explicit_gcp_wrong_issuer_raises_error( + fake_gce_metadata_service: FakeGceMetadataService, +): + fake_gce_metadata_service.iss = "not-google" + + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.GCP) + with pytest.raises(ProgrammingError) as excinfo: + auth_class.prepare() + assert "No workload identity credential was found for 'GCP'" in str(excinfo.value) + + +def test_explicit_gcp_plumbs_token_to_api( + fake_gce_metadata_service: FakeGceMetadataService, +): + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.GCP) + auth_class.prepare() + + assert extract_api_data(auth_class) == { + "AUTHENTICATOR": "WORKLOAD_IDENTITY", + "PROVIDER": "GCP", + "TOKEN": fake_gce_metadata_service.token, + } + + +def test_explicit_gcp_generates_unique_assertion_content( + fake_gce_metadata_service: FakeGceMetadataService, +): + fake_gce_metadata_service.sub = "123456" + + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.GCP) + auth_class.prepare() + + assert auth_class.assertion_content == '{"_provider":"GCP","sub":"123456"}' + + +# -- Azure Tests -- + + +@pytest.mark.parametrize( + "exception", + [ + HTTPError(), + Timeout(), + ConnectTimeout(), + ], +) +def test_explicit_azure_metadata_server_error_raises_auth_error(exception): + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) + with mock.patch( + "snowflake.connector.vendored.requests.request", side_effect=exception + ): + with pytest.raises(ProgrammingError) as excinfo: + auth_class.prepare() + assert "No workload identity credential was found for 'AZURE'" in str( + excinfo.value + ) + + +def test_explicit_azure_wrong_issuer_raises_error(fake_azure_metadata_service): + fake_azure_metadata_service.iss = "https://notazure.com" + + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) + with pytest.raises(ProgrammingError) as excinfo: + auth_class.prepare() + assert "No workload identity credential was found for 'AZURE'" in str(excinfo.value) + + +@pytest.mark.parametrize( + "issuer", + [ + "https://sts.windows.net/067802cd-8f92-4c7c-bceb-ea8f15d31cc5", + "https://login.microsoftonline.com/067802cd-8f92-4c7c-bceb-ea8f15d31cc5", + "https://login.microsoftonline.com/067802cd-8f92-4c7c-bceb-ea8f15d31cc5/v2.0", + ], + ids=["v1", "v2_without_suffix", "v2_with_suffix"], +) +def test_explicit_azure_v1_and_v2_issuers_accepted(fake_azure_metadata_service, issuer): + fake_azure_metadata_service.iss = issuer + + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) + auth_class.prepare() + + assert issuer == json.loads(auth_class.assertion_content)["iss"] + + +def test_explicit_azure_plumbs_token_to_api(fake_azure_metadata_service): + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) + auth_class.prepare() + + assert extract_api_data(auth_class) == { + "AUTHENTICATOR": "WORKLOAD_IDENTITY", + "PROVIDER": "AZURE", + "TOKEN": fake_azure_metadata_service.token, + } + + +def test_explicit_azure_generates_unique_assertion_content(fake_azure_metadata_service): + fake_azure_metadata_service.iss = ( + "https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd" + ) + fake_azure_metadata_service.sub = "611ab25b-2e81-4e18-92a7-b21f2bebb269" + + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) + auth_class.prepare() + + assert ( + '{"_provider":"AZURE","iss":"https://sts.windows.net/2c0183ed-cf17-480d-b3f7-df91bc0a97cd","sub":"611ab25b-2e81-4e18-92a7-b21f2bebb269"}' + == auth_class.assertion_content + ) + + +def test_explicit_azure_uses_default_entra_resource_if_unspecified( + fake_azure_metadata_service, +): + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) + auth_class.prepare() + + token = fake_azure_metadata_service.token + parsed = jwt.decode(token, options={"verify_signature": False}) + assert ( + parsed["aud"] == "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad" + ) # the default entra resource defined in wif_util.py. + + +def test_explicit_azure_uses_explicit_entra_resource(fake_azure_metadata_service): + auth_class = AuthByWorkloadIdentity( + provider=AttestationProvider.AZURE, entra_resource="api://non-standard" + ) + auth_class.prepare() + + token = fake_azure_metadata_service.token + parsed = jwt.decode(token, options={"verify_signature": False}) + assert parsed["aud"] == "api://non-standard" + + +# -- Auto-detect Tests -- + + +def test_autodetect_aws_present( + no_metadata_service, fake_aws_environment: FakeAwsEnvironment +): + auth_class = AuthByWorkloadIdentity(provider=None) + auth_class.prepare() + + data = extract_api_data(auth_class) + assert data["AUTHENTICATOR"] == "WORKLOAD_IDENTITY" + assert data["PROVIDER"] == "AWS" + verify_aws_token(data["TOKEN"], fake_aws_environment.region) + + +def test_autodetect_gcp_present(fake_gce_metadata_service: FakeGceMetadataService): + auth_class = AuthByWorkloadIdentity(provider=None) + auth_class.prepare() + + assert extract_api_data(auth_class) == { + "AUTHENTICATOR": "WORKLOAD_IDENTITY", + "PROVIDER": "GCP", + "TOKEN": fake_gce_metadata_service.token, + } + + +def test_autodetect_azure_present(fake_azure_metadata_service): + auth_class = AuthByWorkloadIdentity(provider=None) + auth_class.prepare() + + assert extract_api_data(auth_class) == { + "AUTHENTICATOR": "WORKLOAD_IDENTITY", + "PROVIDER": "AZURE", + "TOKEN": fake_azure_metadata_service.token, + } + + +def test_autodetect_oidc_present(no_metadata_service): + dummy_token = gen_dummy_id_token(sub="service-1", iss="issuer-1") + auth_class = AuthByWorkloadIdentity(provider=None, token=dummy_token) + auth_class.prepare() + + assert extract_api_data(auth_class) == { + "AUTHENTICATOR": "WORKLOAD_IDENTITY", + "PROVIDER": "OIDC", + "TOKEN": dummy_token, + } + + +def test_autodetect_no_provider_raises_error(no_metadata_service): + auth_class = AuthByWorkloadIdentity(provider=None, token=None) + with pytest.raises(ProgrammingError) as excinfo: + auth_class.prepare() + assert "No workload identity credential was found for 'auto-detect" in str( + excinfo.value + ) diff --git a/test/unit/test_backoff_policies.py b/test/unit/test_backoff_policies.py index ed4fea9e04..064cce145e 100644 --- a/test/unit/test_backoff_policies.py +++ b/test/unit/test_backoff_policies.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - import pytest try: diff --git a/test/unit/test_binaryformat.py b/test/unit/test_binaryformat.py index 02ee884ab8..2150301d10 100644 --- a/test/unit/test_binaryformat.py +++ b/test/unit/test_binaryformat.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from snowflake.connector.sfbinaryformat import ( diff --git a/test/unit/test_bind_upload_agent.py b/test/unit/test_bind_upload_agent.py index 7110d36d18..6f9ed64740 100644 --- a/test/unit/test_bind_upload_agent.py +++ b/test/unit/test_bind_upload_agent.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from unittest import mock diff --git a/test/unit/test_cache.py b/test/unit/test_cache.py index 11d01f7c90..9cd4b0bb92 100644 --- a/test/unit/test_cache.py +++ b/test/unit/test_cache.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - import datetime import logging import os diff --git a/test/unit/test_compute_chunk_size.py b/test/unit/test_compute_chunk_size.py index b7d07d5c48..afd68bf8ad 100644 --- a/test/unit/test_compute_chunk_size.py +++ b/test/unit/test_compute_chunk_size.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - import pytest pytestmark = pytest.mark.skipolddriver diff --git a/test/unit/test_configmanager.py b/test/unit/test_configmanager.py index c1bfce2bbb..cdb45379b3 100644 --- a/test/unit/test_configmanager.py +++ b/test/unit/test_configmanager.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging diff --git a/test/unit/test_connection.py b/test/unit/test_connection.py index d3c0c3259e..5fa43a4224 100644 --- a/test/unit/test_connection.py +++ b/test/unit/test_connection.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import json @@ -30,6 +26,7 @@ ProgrammingError, ) from snowflake.connector.network import SnowflakeRestful +from snowflake.connector.wif_util import AttestationProvider from ..randomize import random_string from .mock_utils import mock_request_with_action, zero_backoff @@ -97,6 +94,13 @@ def mock_post_request(request, url, headers, json_body, **kwargs): return request_body +def write_temp_file(file_path: Path, contents: str) -> Path: + """Write the given string text to the given path, chmods it to be accessible, and returns the same path.""" + file_path.write_text(contents) + file_path.chmod(stat.S_IRUSR | stat.S_IWUSR) + return file_path + + def test_connect_with_service_name(mock_post_requests): assert fake_connector().service_name == "FAKE_SERVICE_NAME" @@ -588,3 +592,98 @@ def test_otel_error_message(caplog, mock_post_requests): ] assert len(important_records) == 1 assert important_records[0].exc_text is not None + + +@pytest.mark.parametrize( + "dependent_param,value", + [ + ("workload_identity_provider", "AWS"), + ( + "workload_identity_entra_resource", + "api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b", + ), + ], +) +def test_cannot_set_dependent_params_without_wlid_authenticator( + mock_post_requests, dependent_param, value +): + with pytest.raises(ProgrammingError) as excinfo: + snowflake.connector.connect( + user="user", + account="account", + password="password", + **{dependent_param: value}, + ) + assert ( + f"{dependent_param} was set but authenticator was not set to WORKLOAD_IDENTITY" + in str(excinfo.value) + ) + + +def test_cannot_set_wlid_authenticator_without_env_variable(mock_post_requests): + with pytest.raises(ProgrammingError) as excinfo: + snowflake.connector.connect( + account="account", authenticator="WORKLOAD_IDENTITY" + ) + assert ( + "Please set the 'SF_ENABLE_EXPERIMENTAL_AUTHENTICATION' environment variable to use the 'WORKLOAD_IDENTITY' authenticator" + in str(excinfo.value) + ) + + +def test_connection_params_are_plumbed_into_authbyworkloadidentity(monkeypatch): + with monkeypatch.context() as m: + m.setattr( + "snowflake.connector.SnowflakeConnection._authenticate", lambda *_: None + ) + m.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "") # Can be set to anything. + + conn = snowflake.connector.connect( + account="my_account_1", + workload_identity_provider=AttestationProvider.AWS, + workload_identity_entra_resource="api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b", + token="my_token", + authenticator="WORKLOAD_IDENTITY", + ) + assert conn.auth_class.provider == AttestationProvider.AWS + assert ( + conn.auth_class.entra_resource + == "api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b" + ) + assert conn.auth_class.token == "my_token" + + +def test_toml_connection_params_are_plumbed_into_authbyworkloadidentity( + monkeypatch, tmp_path +): + token_file = write_temp_file(tmp_path / "token.txt", contents="my_token") + # On Windows, this path includes backslashes which will result in errors while parsing the TOML. + # Escape the backslashes to ensure it parses correctly. + token_file_path_escaped = str(token_file).replace("\\", "\\\\") + connections_file = write_temp_file( + tmp_path / "connections.toml", + contents=dedent( + f"""\ + [default] + account = "my_account_1" + authenticator = "WORKLOAD_IDENTITY" + workload_identity_provider = "OIDC" + workload_identity_entra_resource = "api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b" + token_file_path = "{token_file_path_escaped}" + """ + ), + ) + + with monkeypatch.context() as m: + m.setattr( + "snowflake.connector.SnowflakeConnection._authenticate", lambda *_: None + ) + m.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "") + + conn = snowflake.connector.connect(connections_file_path=connections_file) + assert conn.auth_class.provider == AttestationProvider.OIDC + assert ( + conn.auth_class.entra_resource + == "api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b" + ) + assert conn.auth_class.token == "my_token" diff --git a/test/unit/test_connection_diagnostic.py b/test/unit/test_connection_diagnostic.py index ffe4015b73..99f7419cb3 100644 --- a/test/unit/test_connection_diagnostic.py +++ b/test/unit/test_connection_diagnostic.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging diff --git a/test/unit/test_construct_hostname.py b/test/unit/test_construct_hostname.py index 973ef06c6b..86239d841e 100644 --- a/test/unit/test_construct_hostname.py +++ b/test/unit/test_construct_hostname.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from snowflake.connector.util_text import construct_hostname diff --git a/test/unit/test_converter.py b/test/unit/test_converter.py index aa9243bb9c..d1b143a6cd 100644 --- a/test/unit/test_converter.py +++ b/test/unit/test_converter.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from decimal import Decimal diff --git a/test/unit/test_cursor.py b/test/unit/test_cursor.py index 7b04c43e50..80ace1be33 100644 --- a/test/unit/test_cursor.py +++ b/test/unit/test_cursor.py @@ -1,10 +1,7 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import time +from unittest import TestCase from unittest.mock import MagicMock, patch import pytest @@ -99,3 +96,96 @@ def mock_cmd_query(*args, **kwargs): # query cancel request should be sent upon timeout assert mockCancelQuery.called + + +# The _upload/_download/_upload_stream/_download_stream are newly introduced +# and therefore should not be tested in old drivers. +@pytest.mark.skipolddriver +class TestUploadDownloadMethods(TestCase): + """Test the _upload/_download/_upload_stream/_download_stream methods.""" + + @patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent") + def test_download(self, MockFileTransferAgent): + cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks( + MockFileTransferAgent + ) + + # Call _download method + cursor._download("@st", "/tmp/test.txt", {}) + + # In the process of _download execution, we expect these methods to be called + # - parse_file_operation in connection._file_operation_parser + # - execute in SnowflakeFileTransferAgent + # And we do not expect this method to be involved + # - download_as_stream of connection._stream_downloader + fake_conn._file_operation_parser.parse_file_operation.assert_called_once() + fake_conn._stream_downloader.download_as_stream.assert_not_called() + mock_file_transfer_agent_instance.execute.assert_called_once() + + @patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent") + def test_upload(self, MockFileTransferAgent): + cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks( + MockFileTransferAgent + ) + + # Call _upload method + cursor._upload("/tmp/test.txt", "@st", {}) + + # In the process of _upload execution, we expect these methods to be called + # - parse_file_operation in connection._file_operation_parser + # - execute in SnowflakeFileTransferAgent + # And we do not expect this method to be involved + # - download_as_stream of connection._stream_downloader + fake_conn._file_operation_parser.parse_file_operation.assert_called_once() + fake_conn._stream_downloader.download_as_stream.assert_not_called() + mock_file_transfer_agent_instance.execute.assert_called_once() + + @patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent") + def test_download_stream(self, MockFileTransferAgent): + cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks( + MockFileTransferAgent + ) + + # Call _download_stream method + cursor._download_stream("@st/test.txt", decompress=True) + + # In the process of _download_stream execution, we expect these methods to be called + # - parse_file_operation in connection._file_operation_parser + # - download_as_stream of connection._stream_downloader + # And we do not expect this method to be involved + # - execute in SnowflakeFileTransferAgent + fake_conn._file_operation_parser.parse_file_operation.assert_called_once() + fake_conn._stream_downloader.download_as_stream.assert_called_once() + mock_file_transfer_agent_instance.execute.assert_not_called() + + @patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent") + def test_upload_stream(self, MockFileTransferAgent): + cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks( + MockFileTransferAgent + ) + + # Call _upload_stream method + fd = MagicMock() + cursor._upload_stream(fd, "@st/test.txt", {}) + + # In the process of _upload_stream execution, we expect these methods to be called + # - parse_file_operation in connection._file_operation_parser + # - execute in SnowflakeFileTransferAgent + # And we do not expect this method to be involved + # - download_as_stream of connection._stream_downloader + fake_conn._file_operation_parser.parse_file_operation.assert_called_once() + fake_conn._stream_downloader.download_as_stream.assert_not_called() + mock_file_transfer_agent_instance.execute.assert_called_once() + + def _setup_mocks(self, MockFileTransferAgent): + mock_file_transfer_agent_instance = MockFileTransferAgent.return_value + mock_file_transfer_agent_instance.execute.return_value = None + + fake_conn = FakeConnection() + fake_conn._file_operation_parser = MagicMock() + fake_conn._stream_downloader = MagicMock() + + cursor = SnowflakeCursor(fake_conn) + cursor.reset = MagicMock() + cursor._init_result_and_meta = MagicMock() + return cursor, fake_conn, mock_file_transfer_agent_instance diff --git a/test/unit/test_datetime.py b/test/unit/test_datetime.py index d006fc0df9..8351090076 100644 --- a/test/unit/test_datetime.py +++ b/test/unit/test_datetime.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import time diff --git a/test/unit/test_dbapi.py b/test/unit/test_dbapi.py index cf383aa908..ff2a38c1bd 100644 --- a/test/unit/test_dbapi.py +++ b/test/unit/test_dbapi.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from snowflake.connector.dbapi import Binary diff --git a/test/unit/test_dependencies.py b/test/unit/test_dependencies.py index fb0c192073..8bc0a246ec 100644 --- a/test/unit/test_dependencies.py +++ b/test/unit/test_dependencies.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - import warnings import cryptography.utils diff --git a/test/unit/test_easy_logging.py b/test/unit/test_easy_logging.py index 5eba47eaba..92f62c3a36 100644 --- a/test/unit/test_easy_logging.py +++ b/test/unit/test_easy_logging.py @@ -1,6 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# import stat import pytest diff --git a/test/unit/test_encryption_util.py b/test/unit/test_encryption_util.py index d1c08ab8c9..a35f99fd90 100644 --- a/test/unit/test_encryption_util.py +++ b/test/unit/test_encryption_util.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import codecs diff --git a/test/unit/test_error_arrow_stream.py b/test/unit/test_error_arrow_stream.py index 62f3f70470..14b8a208bb 100644 --- a/test/unit/test_error_arrow_stream.py +++ b/test/unit/test_error_arrow_stream.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - import pytest from ..helpers import ( diff --git a/test/unit/test_errors.py b/test/unit/test_errors.py index 052d53debe..a09bca727b 100644 --- a/test/unit/test_errors.py +++ b/test/unit/test_errors.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import re diff --git a/test/unit/test_gcs_client.py b/test/unit/test_gcs_client.py index 963d20d579..c08b5f7c3f 100644 --- a/test/unit/test_gcs_client.py +++ b/test/unit/test_gcs_client.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging @@ -344,10 +340,109 @@ def test_get_file_header_none_with_presigned_url(tmp_path): ) storage_credentials = Mock() storage_credentials.creds = {} - stage_info = Mock() + stage_info: dict[str, any] = dict() connection = Mock() client = SnowflakeGCSRestClient( meta, storage_credentials, stage_info, connection, "" ) file_header = client.get_file_header(meta.name) assert file_header is None + + +@pytest.mark.parametrize( + "region,return_url,use_regional_url,endpoint,gcs_use_virtual_endpoints", + [ + ( + "US-CENTRAL1", + "https://storage.us-central1.rep.googleapis.com", + True, + None, + False, + ), + ( + "ME-CENTRAL2", + "https://storage.me-central2.rep.googleapis.com", + True, + None, + False, + ), + ("US-CENTRAL1", "https://storage.googleapis.com", False, None, False), + ("US-CENTRAL1", "https://storage.googleapis.com", False, None, False), + ("US-CENTRAL1", "https://location.storage.googleapis.com", False, None, True), + ("US-CENTRAL1", "https://location.storage.googleapis.com", True, None, True), + ( + "US-CENTRAL1", + "https://overriddenurl.com", + False, + "https://overriddenurl.com", + False, + ), + ( + "US-CENTRAL1", + "https://overriddenurl.com", + True, + "https://overriddenurl.com", + False, + ), + ( + "US-CENTRAL1", + "https://overriddenurl.com", + True, + "https://overriddenurl.com", + True, + ), + ( + "US-CENTRAL1", + "https://overriddenurl.com", + False, + "https://overriddenurl.com", + False, + ), + ( + "US-CENTRAL1", + "https://overriddenurl.com", + False, + "https://overriddenurl.com", + True, + ), + ], +) +def test_url(region, return_url, use_regional_url, endpoint, gcs_use_virtual_endpoints): + gcs_location = SnowflakeGCSRestClient.get_location( + stage_location="location", + use_regional_url=use_regional_url, + region=region, + endpoint=endpoint, + use_virtual_endpoints=gcs_use_virtual_endpoints, + ) + assert gcs_location.endpoint == return_url + + +@pytest.mark.parametrize( + "region,use_regional_url,return_value", + [ + ("ME-CENTRAL2", False, True), + ("ME-CENTRAL2", True, True), + ("US-CENTRAL1", False, False), + ("US-CENTRAL1", True, True), + ], +) +def test_use_regional_url(region, use_regional_url, return_value): + meta = SnowflakeFileMeta( + name="path/some_file", + src_file_name="path/some_file", + stage_location_type="GCS", + presigned_url="www.example.com", + ) + storage_credentials = Mock() + storage_credentials.creds = {} + stage_info: dict[str, any] = dict() + stage_info["region"] = region + stage_info["useRegionalUrl"] = use_regional_url + connection = Mock() + + client = SnowflakeGCSRestClient( + meta, storage_credentials, stage_info, connection, "" + ) + + assert client.use_regional_url == return_value diff --git a/test/unit/test_linux_local_file_cache.py b/test/unit/test_linux_local_file_cache.py index 9c5ac10667..51617f6094 100644 --- a/test/unit/test_linux_local_file_cache.py +++ b/test/unit/test_linux_local_file_cache.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import os diff --git a/test/unit/test_local_storage_client.py b/test/unit/test_local_storage_client.py index cbea8de7c1..49479f1ede 100644 --- a/test/unit/test_local_storage_client.py +++ b/test/unit/test_local_storage_client.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - import random import string import tempfile diff --git a/test/unit/test_log_secret_detector.py b/test/unit/test_log_secret_detector.py index a6e62cb189..cbdbd91f80 100644 --- a/test/unit/test_log_secret_detector.py +++ b/test/unit/test_log_secret_detector.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging diff --git a/test/unit/test_mfa_no_cache.py b/test/unit/test_mfa_no_cache.py index 44e0080500..00436e60fc 100644 --- a/test/unit/test_mfa_no_cache.py +++ b/test/unit/test_mfa_no_cache.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import json diff --git a/test/unit/test_network.py b/test/unit/test_network.py index 9139a767c1..fdf493d776 100644 --- a/test/unit/test_network.py +++ b/test/unit/test_network.py @@ -1,14 +1,14 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - import io +import json import unittest.mock +import uuid from test.unit.mock_utils import mock_connection import pytest +from src.snowflake.connector.network import SnowflakeRestfulJsonEncoder + try: from snowflake.connector import Error, InterfaceError from snowflake.connector.network import SnowflakeRestful @@ -67,3 +67,20 @@ def test_fetch(): # if no retry is set to False, the function raises an InterfaceError with pytest.raises(InterfaceError) as exc: assert rest.fetch(**default_parameters, no_retry=False) + + +@pytest.mark.parametrize( + "u", + [ + uuid.uuid1(), + uuid.uuid3(uuid.NAMESPACE_URL, "www.snowflake.com"), + uuid.uuid4(), + uuid.uuid5(uuid.NAMESPACE_URL, "www.snowflake.com"), + ], +) +def test_json_serialize_uuid(u): + expected = f'{{"u": "{u}", "a": 42}}' + + assert (json.dumps(u, cls=SnowflakeRestfulJsonEncoder)) == f'"{u}"' + + assert json.dumps({"u": u, "a": 42}, cls=SnowflakeRestfulJsonEncoder) == expected diff --git a/test/unit/test_ocsp.py b/test/unit/test_ocsp.py index 526a083e66..45bbfaa4f3 100644 --- a/test/unit/test_ocsp.py +++ b/test/unit/test_ocsp.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import copy @@ -261,7 +257,6 @@ def test_ocsp_bad_validity(): del environ["SF_TEST_OCSP_FORCE_BAD_RESPONSE_VALIDITY"] -@pytest.mark.flaky(reruns=3) def test_ocsp_single_endpoint(): environ["SF_OCSP_ACTIVATE_NEW_ENDPOINT"] = "True" SnowflakeOCSP.clear_cache() @@ -285,7 +280,6 @@ def test_ocsp_by_post_method(): assert ocsp.validate(url, connection), f"Failed to validate: {url}" -@pytest.mark.flaky(reruns=3) def test_ocsp_with_file_cache(tmpdir): """OCSP tests and the cache server and file.""" tmp_dir = str(tmpdir.mkdir("ocsp_response_cache")) @@ -299,7 +293,6 @@ def test_ocsp_with_file_cache(tmpdir): assert ocsp.validate(url, connection), f"Failed to validate: {url}" -@pytest.mark.flaky(reruns=3) @pytest.mark.skipolddriver def test_ocsp_with_bogus_cache_files(tmpdir, random_ocsp_response_validation_cache): with mock.patch( @@ -339,7 +332,6 @@ def test_ocsp_with_bogus_cache_files(tmpdir, random_ocsp_response_validation_cac ) -@pytest.mark.flaky(reruns=3) @pytest.mark.skipolddriver def test_ocsp_with_outdated_cache(tmpdir, random_ocsp_response_validation_cache): with mock.patch( @@ -400,7 +392,6 @@ def _store_cache_in_file(tmpdir, target_hosts=None): return filename, target_hosts -@pytest.mark.flaky(reruns=3) def test_ocsp_with_invalid_cache_file(): """OCSP tests with an invalid cache file.""" SnowflakeOCSP.clear_cache() # reset the memory cache @@ -410,7 +401,6 @@ def test_ocsp_with_invalid_cache_file(): assert ocsp.validate(url, connection), f"Failed to validate: {url}" -@pytest.mark.flaky(reruns=3) @mock.patch( "snowflake.connector.ocsp_snowflake.SnowflakeOCSP._fetch_ocsp_response", side_effect=BrokenPipeError("fake error"), @@ -433,7 +423,6 @@ def test_ocsp_cache_when_server_is_down( assert not cache_data, "no cache should present because of broken pipe" -@pytest.mark.flaky(reruns=3) def test_concurrent_ocsp_requests(tmpdir): """Run OCSP revocation checks in parallel. The memory and file caches are deleted randomly.""" cache_file_name = path.join(str(tmpdir), "cache_file.txt") @@ -478,7 +467,6 @@ def test_ocsp_revoked_certificate(): assert ex.value.errno == ex.value.errno == ER_OCSP_RESPONSE_CERT_STATUS_REVOKED -@pytest.mark.flaky(reruns=3) def test_ocsp_incomplete_chain(): """Tests incomplete chained certificate.""" incomplete_chain_cert = path.join( diff --git a/test/unit/test_oob_secret_detector.py b/test/unit/test_oob_secret_detector.py index 48414bf19d..3481c40788 100644 --- a/test/unit/test_oob_secret_detector.py +++ b/test/unit/test_oob_secret_detector.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import random diff --git a/test/unit/test_parse_account.py b/test/unit/test_parse_account.py index e123ec7077..c07dd46c05 100644 --- a/test/unit/test_parse_account.py +++ b/test/unit/test_parse_account.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from snowflake.connector.util_text import parse_account diff --git a/test/unit/test_programmatic_access_token.py b/test/unit/test_programmatic_access_token.py index 1113be1501..7d6ecb175e 100644 --- a/test/unit/test_programmatic_access_token.py +++ b/test/unit/test_programmatic_access_token.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - import pathlib from typing import Any, Generator, Union @@ -47,7 +43,6 @@ def test_valid_pat(wiremock_client: WiremockClient) -> None: ) cnx = snowflake.connector.connect( - user="testUser", authenticator=PROGRAMMATIC_ACCESS_TOKEN, token="some PAT", account="testAccount", @@ -74,7 +69,6 @@ def test_invalid_pat(wiremock_client: WiremockClient) -> None: with pytest.raises(snowflake.connector.errors.DatabaseError) as execinfo: snowflake.connector.connect( - user="testUser", authenticator=PROGRAMMATIC_ACCESS_TOKEN, token="some PAT", account="testAccount", @@ -84,42 +78,3 @@ def test_invalid_pat(wiremock_client: WiremockClient) -> None: ) assert str(execinfo.value).endswith("Programmatic access token is invalid.") - - -@pytest.mark.skipolddriver -def test_pat_as_password(wiremock_client: WiremockClient) -> None: - wiremock_data_dir = ( - pathlib.Path(__file__).parent.parent - / "data" - / "wiremock" - / "mappings" - / "auth" - / "pat" - ) - - wiremock_generic_data_dir = ( - pathlib.Path(__file__).parent.parent - / "data" - / "wiremock" - / "mappings" - / "generic" - ) - - wiremock_client.import_mapping(wiremock_data_dir / "successful_flow.json") - wiremock_client.add_mapping( - wiremock_generic_data_dir / "snowflake_disconnect_successful.json" - ) - - cnx = snowflake.connector.connect( - user="testUser", - authenticator=PROGRAMMATIC_ACCESS_TOKEN, - token=None, - password="some PAT", - account="testAccount", - protocol="http", - host=wiremock_client.wiremock_host, - port=wiremock_client.wiremock_http_port, - ) - - assert cnx, "invalid cnx" - cnx.close() diff --git a/test/unit/test_proxies.py b/test/unit/test_proxies.py index 55aff685ef..8835695aa2 100644 --- a/test/unit/test_proxies.py +++ b/test/unit/test_proxies.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging diff --git a/test/unit/test_put_get.py b/test/unit/test_put_get.py index 87d9fb46e3..a8cd43839b 100644 --- a/test/unit/test_put_get.py +++ b/test/unit/test_put_get.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from os import chmod, path diff --git a/test/unit/test_query_context_cache.py b/test/unit/test_query_context_cache.py index cd887fe749..bb4c2408e6 100644 --- a/test/unit/test_query_context_cache.py +++ b/test/unit/test_query_context_cache.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - import json from random import shuffle diff --git a/test/unit/test_renew_session.py b/test/unit/test_renew_session.py index 0b2361b0a7..bfc5bf6245 100644 --- a/test/unit/test_renew_session.py +++ b/test/unit/test_renew_session.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging diff --git a/test/unit/test_result_batch.py b/test/unit/test_result_batch.py index 7206136f87..e2de635886 100644 --- a/test/unit/test_result_batch.py +++ b/test/unit/test_result_batch.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from collections import namedtuple diff --git a/test/unit/test_retry_network.py b/test/unit/test_retry_network.py index d83bc08224..3f8e2cee81 100644 --- a/test/unit/test_retry_network.py +++ b/test/unit/test_retry_network.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import errno diff --git a/test/unit/test_s3_util.py b/test/unit/test_s3_util.py index 6bd6dda8f6..9fece987eb 100644 --- a/test/unit/test_s3_util.py +++ b/test/unit/test_s3_util.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import logging diff --git a/test/unit/test_session_manager.py b/test/unit/test_session_manager.py index 73487c5881..8ca3044b6b 100644 --- a/test/unit/test_session_manager.py +++ b/test/unit/test_session_manager.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from enum import Enum diff --git a/test/unit/test_split_statement.py b/test/unit/test_split_statement.py index 971b600524..917c8a6ace 100644 --- a/test/unit/test_split_statement.py +++ b/test/unit/test_split_statement.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from io import StringIO diff --git a/test/unit/test_storage_client.py b/test/unit/test_storage_client.py index 9a14d186f9..6f925749ea 100644 --- a/test/unit/test_storage_client.py +++ b/test/unit/test_storage_client.py @@ -1,6 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# from os import path from unittest.mock import MagicMock diff --git a/test/unit/test_telemetry.py b/test/unit/test_telemetry.py index e5d536cee3..06646ec7b5 100644 --- a/test/unit/test_telemetry.py +++ b/test/unit/test_telemetry.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations from unittest.mock import Mock diff --git a/test/unit/test_telemetry_oob.py b/test/unit/test_telemetry_oob.py index a39d8b8b65..13c4524dc2 100644 --- a/test/unit/test_telemetry_oob.py +++ b/test/unit/test_telemetry_oob.py @@ -1,8 +1,4 @@ #!/usr/bin/env python -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from __future__ import annotations import time diff --git a/test/unit/test_text_util.py b/test/unit/test_text_util.py index 69895b0191..f07ea1751a 100644 --- a/test/unit/test_text_util.py +++ b/test/unit/test_text_util.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - import concurrent.futures import random diff --git a/test/unit/test_url_util.py b/test/unit/test_url_util.py index b373e93de7..2c4f236631 100644 --- a/test/unit/test_url_util.py +++ b/test/unit/test_url_util.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - try: from snowflake.connector.url_util import ( extract_top_level_domain_from_hostname, diff --git a/test/unit/test_util.py b/test/unit/test_util.py index 482bd4d34b..b2862f4660 100644 --- a/test/unit/test_util.py +++ b/test/unit/test_util.py @@ -1,6 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# import pytest try: diff --git a/test/unit/test_wiremock_client.py b/test/unit/test_wiremock_client.py index 3e670227b9..df4cacd2da 100644 --- a/test/unit/test_wiremock_client.py +++ b/test/unit/test_wiremock_client.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - from typing import Any, Generator import pytest diff --git a/test/wiremock/__init__.py b/test/wiremock/__init__.py index ef416f64a0..e69de29bb2 100644 --- a/test/wiremock/__init__.py +++ b/test/wiremock/__init__.py @@ -1,3 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# diff --git a/test/wiremock/wiremock_utils.py b/test/wiremock/wiremock_utils.py index 6fe2f138b9..95b7374c1e 100644 --- a/test/wiremock/wiremock_utils.py +++ b/test/wiremock/wiremock_utils.py @@ -1,7 +1,3 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - import json import logging import pathlib diff --git a/tested_requirements/requirements_310.reqs b/tested_requirements/requirements_310.reqs index 9ecb96bd18..c40c82708c 100644 --- a/tested_requirements/requirements_310.reqs +++ b/tested_requirements/requirements_310.reqs @@ -1,20 +1,26 @@ -# Generated on: Python 3.10.16 +# Generated on: Python 3.10.17 asn1crypto==1.5.1 +boto3==1.37.38 +botocore==1.37.38 certifi==2025.1.31 cffi==1.17.1 charset-normalizer==3.4.1 cryptography==44.0.2 -filelock==3.17.0 +filelock==3.18.0 idna==3.10 -packaging==24.2 -platformdirs==4.3.6 +jmespath==1.0.1 +packaging==25.0 +platformdirs==4.3.7 pycparser==2.22 PyJWT==2.10.1 pyOpenSSL==25.0.0 -pytz==2025.1 +python-dateutil==2.9.0.post0 +pytz==2025.2 requests==2.32.3 +s3transfer==0.11.5 +six==1.17.0 sortedcontainers==2.4.0 tomlkit==0.13.2 -typing_extensions==4.12.2 -urllib3==2.3.0 -snowflake-connector-python==3.14.0 +typing_extensions==4.13.2 +urllib3==2.4.0 +snowflake-connector-python==3.14.1 diff --git a/tested_requirements/requirements_311.reqs b/tested_requirements/requirements_311.reqs index 7839ec674d..62f67fd30e 100644 --- a/tested_requirements/requirements_311.reqs +++ b/tested_requirements/requirements_311.reqs @@ -1,20 +1,26 @@ -# Generated on: Python 3.11.11 +# Generated on: Python 3.11.12 asn1crypto==1.5.1 +boto3==1.37.38 +botocore==1.37.38 certifi==2025.1.31 cffi==1.17.1 charset-normalizer==3.4.1 cryptography==44.0.2 -filelock==3.17.0 +filelock==3.18.0 idna==3.10 -packaging==24.2 -platformdirs==4.3.6 +jmespath==1.0.1 +packaging==25.0 +platformdirs==4.3.7 pycparser==2.22 PyJWT==2.10.1 pyOpenSSL==25.0.0 -pytz==2025.1 +python-dateutil==2.9.0.post0 +pytz==2025.2 requests==2.32.3 +s3transfer==0.11.5 +six==1.17.0 sortedcontainers==2.4.0 tomlkit==0.13.2 -typing_extensions==4.12.2 -urllib3==2.3.0 -snowflake-connector-python==3.14.0 +typing_extensions==4.13.2 +urllib3==2.4.0 +snowflake-connector-python==3.14.1 diff --git a/tested_requirements/requirements_312.reqs b/tested_requirements/requirements_312.reqs index a9ae4f8386..232359acd6 100644 --- a/tested_requirements/requirements_312.reqs +++ b/tested_requirements/requirements_312.reqs @@ -1,22 +1,28 @@ -# Generated on: Python 3.12.9 +# Generated on: Python 3.12.10 asn1crypto==1.5.1 +boto3==1.37.38 +botocore==1.37.38 certifi==2025.1.31 cffi==1.17.1 charset-normalizer==3.4.1 cryptography==44.0.2 -filelock==3.17.0 +filelock==3.18.0 idna==3.10 -packaging==24.2 -platformdirs==4.3.6 +jmespath==1.0.1 +packaging==25.0 +platformdirs==4.3.7 pycparser==2.22 PyJWT==2.10.1 pyOpenSSL==25.0.0 -pytz==2025.1 +python-dateutil==2.9.0.post0 +pytz==2025.2 requests==2.32.3 -setuptools==75.8.2 +s3transfer==0.11.5 +setuptools==79.0.0 +six==1.17.0 sortedcontainers==2.4.0 tomlkit==0.13.2 -typing_extensions==4.12.2 -urllib3==2.3.0 +typing_extensions==4.13.2 +urllib3==2.4.0 wheel==0.45.1 -snowflake-connector-python==3.14.0 +snowflake-connector-python==3.14.1 diff --git a/tested_requirements/requirements_313.reqs b/tested_requirements/requirements_313.reqs new file mode 100644 index 0000000000..d206c77c50 --- /dev/null +++ b/tested_requirements/requirements_313.reqs @@ -0,0 +1,28 @@ +# Generated on: Python 3.13.3 +asn1crypto==1.5.1 +boto3==1.37.38 +botocore==1.37.38 +certifi==2025.1.31 +cffi==1.17.1 +charset-normalizer==3.4.1 +cryptography==44.0.2 +filelock==3.18.0 +idna==3.10 +jmespath==1.0.1 +packaging==25.0 +platformdirs==4.3.7 +pycparser==2.22 +PyJWT==2.10.1 +pyOpenSSL==25.0.0 +python-dateutil==2.9.0.post0 +pytz==2025.2 +requests==2.32.3 +s3transfer==0.11.5 +setuptools==79.0.0 +six==1.17.0 +sortedcontainers==2.4.0 +tomlkit==0.13.2 +typing_extensions==4.13.2 +urllib3==2.4.0 +wheel==0.45.1 +snowflake-connector-python==3.14.1 diff --git a/tested_requirements/requirements_39.reqs b/tested_requirements/requirements_39.reqs index 8d3ba20f37..25e17ca852 100644 --- a/tested_requirements/requirements_39.reqs +++ b/tested_requirements/requirements_39.reqs @@ -1,20 +1,26 @@ -# Generated on: Python 3.9.21 +# Generated on: Python 3.9.22 asn1crypto==1.5.1 +boto3==1.37.38 +botocore==1.37.38 certifi==2025.1.31 cffi==1.17.1 charset-normalizer==3.4.1 cryptography==44.0.2 -filelock==3.17.0 +filelock==3.18.0 idna==3.10 -packaging==24.2 -platformdirs==4.3.6 +jmespath==1.0.1 +packaging==25.0 +platformdirs==4.3.7 pycparser==2.22 PyJWT==2.10.1 pyOpenSSL==25.0.0 -pytz==2025.1 +python-dateutil==2.9.0.post0 +pytz==2025.2 requests==2.32.3 +s3transfer==0.11.5 +six==1.17.0 sortedcontainers==2.4.0 tomlkit==0.13.2 -typing_extensions==4.12.2 +typing_extensions==4.13.2 urllib3==1.26.20 -snowflake-connector-python==3.14.0 +snowflake-connector-python==3.14.1 diff --git a/tox.ini b/tox.ini index ba68dc88af..25bef2ffe7 100644 --- a/tox.ini +++ b/tox.ini @@ -97,6 +97,7 @@ commands = # Unit and pandas tests are already skipped for the old driver (see test/conftest.py). Avoid walking those # directories entirely to avoid loading any potentially incompatible subdirectories' own conftest.py files. {env:SNOWFLAKE_PYTEST_CMD_IGNORE_AIO} --ignore=test/unit --ignore=test/pandas -m "not skipolddriver" -vvv {posargs:} test + {env:SNOWFLAKE_PYTEST_CMD} --ignore=test/unit --ignore=test/pandas -m "not skipolddriver" -vvv {posargs:} test [testenv:noarrowextension] basepython = python3.9