diff --git a/.github/workflows/parameters/private/parameters_aws_auth_tests.json.gpg b/.github/workflows/parameters/private/parameters_aws_auth_tests.json.gpg
new file mode 100644
index 0000000000..4cdd2a880e
Binary files /dev/null and b/.github/workflows/parameters/private/parameters_aws_auth_tests.json.gpg differ
diff --git a/.github/workflows/parameters/private/rsa_keys/rsa_key.p8.gpg b/.github/workflows/parameters/private/rsa_keys/rsa_key.p8.gpg
new file mode 100644
index 0000000000..e90253cd3a
Binary files /dev/null and b/.github/workflows/parameters/private/rsa_keys/rsa_key.p8.gpg differ
diff --git a/.github/workflows/parameters/private/rsa_keys/rsa_key_invalid.p8.gpg b/.github/workflows/parameters/private/rsa_keys/rsa_key_invalid.p8.gpg
new file mode 100644
index 0000000000..3d2442a7c8
Binary files /dev/null and b/.github/workflows/parameters/private/rsa_keys/rsa_key_invalid.p8.gpg differ
diff --git a/DESCRIPTION.md b/DESCRIPTION.md
index 3f8686eea4..916812e99c 100644
--- a/DESCRIPTION.md
+++ b/DESCRIPTION.md
@@ -13,6 +13,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
- Dropped support for Python 3.8.
- Basic decimal floating-point type support.
- Added handling of PAT provided in `password` field.
+ - Added experimental support for OAuth authorization code and client credentials flows.
- Improved error message for client-side query cancellations due to timeouts.
- Added support of GCS regional endpoints.
- Added `gcs_use_virtual_endpoints` connection property that forces the usage of the virtual GCS usage. Thanks to this it should be possible to set up private DNS entry for the GCS endpoint. See more: https://cloud.google.com/storage/docs/request-endpoints#xml-api
diff --git a/Jenkinsfile b/Jenkinsfile
index 699a514970..00374eaf9a 100644
--- a/Jenkinsfile
+++ b/Jenkinsfile
@@ -35,29 +35,46 @@ timestamps {
string(name: 'parent_job', value: env.JOB_NAME),
string(name: 'parent_build_number', value: env.BUILD_NUMBER)
]
- stage('Test') {
- try {
- def commit_hash = "main" // default which we want to override
- def bptp_tag = "bptp-stable"
- def response = authenticatedGithubCall("https://api.github.com/repos/snowflakedb/snowflake/git/ref/tags/${bptp_tag}")
- commit_hash = response.object.sha
- // Append the bptp-stable commit sha to params
- params += [string(name: 'svn_revision', value: commit_hash)]
- } catch(Exception e) {
- println("Exception computing commit hash from: ${response}")
+ parallel(
+ 'Test': {
+ stage('Test') {
+ try {
+ def commit_hash = "main" // default which we want to override
+ def bptp_tag = "bptp-stable"
+ def response = authenticatedGithubCall("https://api.github.com/repos/snowflakedb/snowflake/git/ref/tags/${bptp_tag}")
+ commit_hash = response.object.sha
+ // Append the bptp-stable commit sha to params
+ params += [string(name: 'svn_revision', value: commit_hash)]
+ } catch(Exception e) {
+ println("Exception computing commit hash from: ${response}")
+ }
+ parallel (
+ 'Test Python 39': { build job: 'RT-PyConnector39-PC',parameters: params},
+ 'Test Python 310': { build job: 'RT-PyConnector310-PC',parameters: params},
+ 'Test Python 311': { build job: 'RT-PyConnector311-PC',parameters: params},
+ 'Test Python 312': { build job: 'RT-PyConnector312-PC',parameters: params},
+ 'Test Python 313': { build job: 'RT-PyConnector313-PC',parameters: params},
+ 'Test Python 39 OldDriver': { build job: 'RT-PyConnector39-OldDriver-PC',parameters: params},
+ 'Test Python 39 FIPS': { build job: 'RT-FIPS-PyConnector39',parameters: params},
+ )
+ }
+ },
+ 'Test Authentication': {
+ stage('Test Authentication') {
+ withCredentials([
+ string(credentialsId: 'a791118f-a1ea-46cd-b876-56da1b9bc71c', variable: 'NEXUS_PASSWORD'),
+ string(credentialsId: 'sfctest0-parameters-secret', variable: 'PARAMETERS_SECRET')
+ ]) {
+ sh '''\
+ |#!/bin/bash -e
+ |$WORKSPACE/ci/test_authentication.sh
+ '''.stripMargin()
}
- parallel (
- 'Test Python 39': { build job: 'RT-PyConnector39-PC',parameters: params},
- 'Test Python 310': { build job: 'RT-PyConnector310-PC',parameters: params},
- 'Test Python 311': { build job: 'RT-PyConnector311-PC',parameters: params},
- 'Test Python 312': { build job: 'RT-PyConnector312-PC',parameters: params},
- 'Test Python 313': { build job: 'RT-PyConnector313-PC',parameters: params},
- 'Test Python 39 OldDriver': { build job: 'RT-PyConnector39-OldDriver-PC',parameters: params},
- 'Test Python 39 FIPS': { build job: 'RT-FIPS-PyConnector39',parameters: params},
- )
}
}
- }
+ )
+ }
+}
pipeline {
diff --git a/ci/container/test_authentication.sh b/ci/container/test_authentication.sh
new file mode 100755
index 0000000000..d65c7627eb
--- /dev/null
+++ b/ci/container/test_authentication.sh
@@ -0,0 +1,24 @@
+#!/bin/bash -e
+
+set -o pipefail
+
+
+export WORKSPACE=${WORKSPACE:-/mnt/workspace}
+export SOURCE_ROOT=${SOURCE_ROOT:-/mnt/host}
+
+MVNW_EXE=$SOURCE_ROOT/mvnw
+AUTH_PARAMETER_FILE=./.github/workflows/parameters/private/parameters_aws_auth_tests.json
+eval $(jq -r '.authtestparams | to_entries | map("export \(.key)=\(.value|tostring)")|.[]' $AUTH_PARAMETER_FILE)
+
+export SNOWFLAKE_AUTH_TEST_PRIVATE_KEY_PATH=./.github/workflows/parameters/private/rsa_keys/rsa_key.p8
+export SNOWFLAKE_AUTH_TEST_INVALID_PRIVATE_KEY_PATH=./.github/workflows/parameters/private/rsa_keys/rsa_key_invalid.p8
+
+export SF_OCSP_TEST_MODE=true
+export SF_ENABLE_EXPERIMENTAL_AUTHENTICATION=true
+export RUN_AUTH_TESTS=true
+export AUTHENTICATION_TESTS_ENV="docker"
+export PYTHONPATH=$SOURCE_ROOT
+
+python3 -m pip install --break-system-packages -e .
+
+python3 -m pytest test/auth/*
diff --git a/ci/test_authentication.sh b/ci/test_authentication.sh
new file mode 100755
index 0000000000..dbf78c83e8
--- /dev/null
+++ b/ci/test_authentication.sh
@@ -0,0 +1,27 @@
+#!/bin/bash -e
+
+set -o pipefail
+
+
+export THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
+export WORKSPACE=${WORKSPACE:-/tmp}
+
+CI_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
+if [[ -n "$JENKINS_HOME" ]]; then
+ ROOT_DIR="$(cd "${CI_DIR}/.." && pwd)"
+ export WORKSPACE=${WORKSPACE:-/tmp}
+ echo "Use /sbin/ip"
+ IP_ADDR=$(/sbin/ip -4 addr show scope global dev eth0 | grep inet | awk '{print $2}' | cut -d / -f 1)
+
+fi
+
+gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" --output $THIS_DIR/../.github/workflows/parameters/private/parameters_aws_auth_tests.json "$THIS_DIR/../.github/workflows/parameters/private/parameters_aws_auth_tests.json.gpg"
+gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" --output $THIS_DIR/../.github/workflows/parameters/private/rsa_keys/rsa_key.p8 "$THIS_DIR/../.github/workflows/parameters/private/rsa_keys/rsa_key.p8.gpg"
+gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" --output $THIS_DIR/../.github/workflows/parameters/private/rsa_keys/rsa_key_invalid.p8 "$THIS_DIR/../.github/workflows/parameters/private/rsa_keys/rsa_key_invalid.p8.gpg"
+
+docker run \
+ -v $(cd $THIS_DIR/.. && pwd):/mnt/host \
+ -v $WORKSPACE:/mnt/workspace \
+ --rm \
+ nexus.int.snowflakecomputing.com:8086/docker/snowdrivers-test-external-browser-python:1 \
+ "/mnt/host/ci/container/test_authentication.sh"
diff --git a/src/snowflake/connector/aio/_connection.py b/src/snowflake/connector/aio/_connection.py
index c7a2add13d..ce0ddd8220 100644
--- a/src/snowflake/connector/aio/_connection.py
+++ b/src/snowflake/connector/aio/_connection.py
@@ -32,6 +32,7 @@
from ..connection import _get_private_bytes_from_file
from ..constants import (
_CONNECTIVITY_ERR_MSG,
+ _OAUTH_DEFAULT_SCOPE,
ENV_VAR_EXPERIMENTAL_AUTHENTICATION,
ENV_VAR_PARTNER,
PARAMETER_AUTOCOMMIT,
@@ -51,15 +52,19 @@
from ..description import PLATFORM, PYTHON_VERSION, SNOWFLAKE_CONNECTOR_VERSION
from ..errorcode import (
ER_CONNECTION_IS_CLOSED,
+ ER_EXPERIMENTAL_AUTHENTICATION_NOT_SUPPORTED,
ER_FAILED_TO_CONNECT_TO_DB,
ER_INVALID_VALUE,
ER_INVALID_WIF_SETTINGS,
+ ER_NO_CLIENT_ID,
)
from ..network import (
DEFAULT_AUTHENTICATOR,
EXTERNAL_BROWSER_AUTHENTICATOR,
KEY_PAIR_AUTHENTICATOR,
OAUTH_AUTHENTICATOR,
+ OAUTH_AUTHORIZATION_CODE,
+ OAUTH_CLIENT_CREDENTIALS,
PROGRAMMATIC_ACCESS_TOKEN,
REQUEST_ID,
USR_PWD_MFA_AUTHENTICATOR,
@@ -84,6 +89,8 @@
AuthByIdToken,
AuthByKeyPair,
AuthByOAuth,
+ AuthByOauthCode,
+ AuthByOauthCredentials,
AuthByOkta,
AuthByPAT,
AuthByPlugin,
@@ -307,6 +314,56 @@ async def __open_connection(self):
timeout=self.login_timeout,
backoff_generator=self._backoff_generator,
)
+ elif self._authenticator == OAUTH_AUTHORIZATION_CODE:
+ self._check_experimental_authentication_flag()
+ self._check_oauth_required_parameters()
+ features = self.oauth_security_features
+ if self._role and (self._oauth_scope == ""):
+ # if role is known then let's inject it into scope
+ self._oauth_scope = _OAUTH_DEFAULT_SCOPE.format(role=self._role)
+ self.auth_class = AuthByOauthCode(
+ application=self.application,
+ client_id=self._oauth_client_id,
+ client_secret=self._oauth_client_secret,
+ authentication_url=self._oauth_authorization_url.format(
+ host=self.host, port=self.port
+ ),
+ token_request_url=self._oauth_token_request_url.format(
+ host=self.host, port=self.port
+ ),
+ redirect_uri=self._oauth_redirect_uri,
+ scope=self._oauth_scope,
+ pkce_enabled=features.pkce_enabled,
+ token_cache=(
+ auth.get_token_cache()
+ if self._client_store_temporary_credential
+ else None
+ ),
+ refresh_token_enabled=features.refresh_token_enabled,
+ external_browser_timeout=self._external_browser_timeout,
+ )
+ elif self._authenticator == OAUTH_CLIENT_CREDENTIALS:
+ self._check_experimental_authentication_flag()
+ self._check_oauth_required_parameters()
+ features = self.oauth_security_features
+ if self._role and (self._oauth_scope == ""):
+ # if role is known then let's inject it into scope
+ self._oauth_scope = _OAUTH_DEFAULT_SCOPE.format(role=self._role)
+ self.auth_class = AuthByOauthCredentials(
+ application=self.application,
+ client_id=self._oauth_client_id,
+ client_secret=self._oauth_client_secret,
+ token_request_url=self._oauth_token_request_url.format(
+ host=self.host, port=self.port
+ ),
+ scope=self._oauth_scope,
+ token_cache=(
+ auth.get_token_cache()
+ if self._client_store_temporary_credential
+ else None
+ ),
+ refresh_token_enabled=features.refresh_token_enabled,
+ )
elif self._authenticator == PROGRAMMATIC_ACCESS_TOKEN:
self.auth_class = AuthByPAT(self._token)
elif self._authenticator == USR_PWD_MFA_AUTHENTICATOR:
@@ -747,7 +804,11 @@ async def authenticate_with_retry(self, auth_instance) -> None:
# SSO if it has expired
await self._reauthenticate()
else:
- await self._authenticate(auth_instance)
+ # TODO pczajka: check if this is correct
+ # For OAuth and other auth types, call their reauthenticate method
+ await auth_instance.reauthenticate(conn=self)
+ # The reauthenticate method will call authenticate_with_retry internally,
+ # so we don't need to call _authenticate again here
async def autocommit(self, mode) -> None:
"""Sets autocommit mode to True, or False. Defaults to True."""
@@ -1052,3 +1113,37 @@ async def is_valid(self) -> bool:
except Exception as e:
logger.debug("session could not be validated due to exception: %s", e)
return False
+
+ def _check_experimental_authentication_flag(self) -> None:
+ if os.getenv(ENV_VAR_EXPERIMENTAL_AUTHENTICATION, "false").lower() != "true":
+ Error.errorhandler_wrapper(
+ self,
+ None,
+ ProgrammingError,
+ {
+ "msg": f"Please set the '{ENV_VAR_EXPERIMENTAL_AUTHENTICATION}' environment variable true to use the '{self._authenticator}' authenticator.",
+ "errno": ER_EXPERIMENTAL_AUTHENTICATION_NOT_SUPPORTED,
+ },
+ )
+
+ def _check_oauth_required_parameters(self) -> None:
+ if self._oauth_client_id is None:
+ Error.errorhandler_wrapper(
+ self,
+ None,
+ ProgrammingError,
+ {
+ "msg": "Oauth code flow requirement 'client_id' is empty",
+ "errno": ER_NO_CLIENT_ID,
+ },
+ )
+ if self._oauth_client_secret is None:
+ Error.errorhandler_wrapper(
+ self,
+ None,
+ ProgrammingError,
+ {
+ "msg": "Oauth code flow requirement 'client_secret' is empty",
+ "errno": ER_NO_CLIENT_ID,
+ },
+ )
diff --git a/src/snowflake/connector/aio/auth/__init__.py b/src/snowflake/connector/aio/auth/__init__.py
index 4091bcf06b..3caf65c6a7 100644
--- a/src/snowflake/connector/aio/auth/__init__.py
+++ b/src/snowflake/connector/aio/auth/__init__.py
@@ -8,6 +8,8 @@
from ._keypair import AuthByKeyPair
from ._no_auth import AuthNoAuth
from ._oauth import AuthByOAuth
+from ._oauth_code import AuthByOauthCode
+from ._oauth_credentials import AuthByOauthCredentials
from ._okta import AuthByOkta
from ._pat import AuthByPAT
from ._usrpwdmfa import AuthByUsrPwdMfa
@@ -19,6 +21,8 @@
AuthByDefault,
AuthByKeyPair,
AuthByOAuth,
+ AuthByOauthCode,
+ AuthByOauthCredentials,
AuthByOkta,
AuthByUsrPwdMfa,
AuthByWebBrowser,
@@ -35,6 +39,8 @@
"AuthByKeyPair",
"AuthByPAT",
"AuthByOAuth",
+ "AuthByOauthCode",
+ "AuthByOauthCredentials",
"AuthByOkta",
"AuthByUsrPwdMfa",
"AuthByWebBrowser",
diff --git a/src/snowflake/connector/aio/auth/_auth.py b/src/snowflake/connector/aio/auth/_auth.py
index 8dbb86f963..462e107ae1 100644
--- a/src/snowflake/connector/aio/auth/_auth.py
+++ b/src/snowflake/connector/aio/auth/_auth.py
@@ -30,6 +30,7 @@
ACCEPT_TYPE_APPLICATION_SNOWFLAKE,
CONTENT_TYPE_APPLICATION_JSON,
ID_TOKEN_INVALID_LOGIN_REQUEST_GS_CODE,
+ OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE,
PYTHON_CONNECTOR_USER_AGENT,
ReauthenticationRequest,
)
@@ -282,6 +283,15 @@ async def post_request_wrapper(self, url, headers, body) -> None:
sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
)
)
+ elif errno == OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE:
+ raise ReauthenticationRequest(
+ ProgrammingError(
+ msg=ret["message"],
+ errno=int(errno),
+ sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
+ )
+ )
+
from . import AuthByKeyPair
if isinstance(auth_instance, AuthByKeyPair):
diff --git a/src/snowflake/connector/aio/auth/_oauth_code.py b/src/snowflake/connector/aio/auth/_oauth_code.py
new file mode 100644
index 0000000000..a4b3f35ae7
--- /dev/null
+++ b/src/snowflake/connector/aio/auth/_oauth_code.py
@@ -0,0 +1,106 @@
+#!/usr/bin/env python
+
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING, Any
+
+from ...auth.oauth_code import AuthByOauthCode as AuthByOauthCodeSync
+from ...token_cache import TokenCache
+from ._by_plugin import AuthByPlugin as AuthByPluginAsync
+
+if TYPE_CHECKING:
+ from .. import SnowflakeConnection
+
+logger = logging.getLogger(__name__)
+
+
+class AuthByOauthCode(AuthByPluginAsync, AuthByOauthCodeSync):
+ """Async version of OAuth authorization code authenticator."""
+
+ def __init__(
+ self,
+ application: str,
+ client_id: str,
+ client_secret: str,
+ authentication_url: str,
+ token_request_url: str,
+ redirect_uri: str,
+ scope: str,
+ pkce_enabled: bool = True,
+ token_cache: TokenCache | None = None,
+ refresh_token_enabled: bool = False,
+ external_browser_timeout: int | None = None,
+ **kwargs,
+ ) -> None:
+ """Initializes an instance with OAuth authorization code parameters."""
+ logger.debug(
+ "OAuth authentication is not supported in async version - falling back to sync implementation"
+ )
+ AuthByOauthCodeSync.__init__(
+ self,
+ application=application,
+ client_id=client_id,
+ client_secret=client_secret,
+ authentication_url=authentication_url,
+ token_request_url=token_request_url,
+ redirect_uri=redirect_uri,
+ scope=scope,
+ pkce_enabled=pkce_enabled,
+ token_cache=token_cache,
+ refresh_token_enabled=refresh_token_enabled,
+ external_browser_timeout=external_browser_timeout,
+ **kwargs,
+ )
+
+ async def reset_secrets(self) -> None:
+ AuthByOauthCodeSync.reset_secrets(self)
+
+ async def prepare(self, **kwargs: Any) -> None:
+ AuthByOauthCodeSync.prepare(self, **kwargs)
+
+ async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]:
+ """Override to use async connection properly."""
+ # TODO pczajka: check if this is correct
+
+ # Call the sync reset logic but handle the connection retry ourselves
+ self._reset_access_token()
+ if self._pop_cached_refresh_token():
+ logger.debug(
+ "OAuth refresh token is available, try to use it and get a new access token"
+ )
+ self._do_refresh_token(conn=kwargs.get("conn"))
+ # Use async authenticate_with_retry
+ if "conn" in kwargs:
+ await kwargs["conn"].authenticate_with_retry(self)
+ return {"success": True}
+
+ async def update_body(self, body: dict[Any, Any]) -> None:
+ AuthByOauthCodeSync.update_body(self, body)
+
+ def _handle_failure(
+ self,
+ *,
+ conn: SnowflakeConnection,
+ ret: dict[Any, Any],
+ **kwargs: Any,
+ ) -> None:
+ """Override to ensure proper error handling in async context."""
+ # Use sync error handling directly to avoid async/sync mismatch
+ from ...errors import DatabaseError, Error
+ from ...sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
+
+ Error.errorhandler_wrapper(
+ conn,
+ None,
+ DatabaseError,
+ {
+ "msg": "Failed to connect to DB: {host}:{port}, {message}".format(
+ host=conn._rest._host,
+ port=conn._rest._port,
+ message=ret["message"],
+ ),
+ "errno": int(ret.get("code", -1)),
+ "sqlstate": SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
+ },
+ )
diff --git a/src/snowflake/connector/aio/auth/_oauth_credentials.py b/src/snowflake/connector/aio/auth/_oauth_credentials.py
new file mode 100644
index 0000000000..855296e372
--- /dev/null
+++ b/src/snowflake/connector/aio/auth/_oauth_credentials.py
@@ -0,0 +1,100 @@
+#!/usr/bin/env python
+
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING, Any
+
+from ...auth.oauth_credentials import (
+ AuthByOauthCredentials as AuthByOauthCredentialsSync,
+)
+from ...token_cache import TokenCache
+from ._by_plugin import AuthByPlugin as AuthByPluginAsync
+
+if TYPE_CHECKING:
+ from .. import SnowflakeConnection
+
+logger = logging.getLogger(__name__)
+
+
+class AuthByOauthCredentials(AuthByPluginAsync, AuthByOauthCredentialsSync):
+ """Async version of OAuth client credentials authenticator."""
+
+ def __init__(
+ self,
+ application: str,
+ client_id: str,
+ client_secret: str,
+ token_request_url: str,
+ scope: str,
+ token_cache: TokenCache | None = None,
+ refresh_token_enabled: bool = False,
+ **kwargs,
+ ) -> None:
+ """Initializes an instance with OAuth client credentials parameters."""
+ logger.debug(
+ "OAuth authentication is not supported in async version - falling back to sync implementation"
+ )
+ AuthByOauthCredentialsSync.__init__(
+ self,
+ application=application,
+ client_id=client_id,
+ client_secret=client_secret,
+ token_request_url=token_request_url,
+ scope=scope,
+ token_cache=token_cache,
+ refresh_token_enabled=refresh_token_enabled,
+ **kwargs,
+ )
+
+ async def reset_secrets(self) -> None:
+ AuthByOauthCredentialsSync.reset_secrets(self)
+
+ async def prepare(self, **kwargs: Any) -> None:
+ AuthByOauthCredentialsSync.prepare(self, **kwargs)
+
+ async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]:
+ """Override to use async connection properly."""
+ # TODO pczajka: check if this is correct
+
+ # Call the sync reset logic but handle the connection retry ourselves
+ self._reset_access_token()
+ if self._pop_cached_refresh_token():
+ logger.debug(
+ "OAuth refresh token is available, try to use it and get a new access token"
+ )
+ self._do_refresh_token(conn=kwargs.get("conn"))
+ # Use async authenticate_with_retry
+ if "conn" in kwargs:
+ await kwargs["conn"].authenticate_with_retry(self)
+ return {"success": True}
+
+ async def update_body(self, body: dict[Any, Any]) -> None:
+ AuthByOauthCredentialsSync.update_body(self, body)
+
+ def _handle_failure(
+ self,
+ *,
+ conn: SnowflakeConnection,
+ ret: dict[Any, Any],
+ **kwargs: Any,
+ ) -> None:
+ """Override to ensure proper error handling in async context."""
+ # Use sync error handling directly to avoid async/sync mismatch
+ from ...errors import DatabaseError, Error
+ from ...sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
+
+ Error.errorhandler_wrapper(
+ conn,
+ None,
+ DatabaseError,
+ {
+ "msg": "Failed to connect to DB: {host}:{port}, {message}".format(
+ host=conn._rest._host,
+ port=conn._rest._port,
+ message=ret["message"],
+ ),
+ "errno": int(ret.get("code", -1)),
+ "sqlstate": SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
+ },
+ )
diff --git a/src/snowflake/connector/auth/__init__.py b/src/snowflake/connector/auth/__init__.py
index 0874b35ca7..cb25f7d364 100644
--- a/src/snowflake/connector/auth/__init__.py
+++ b/src/snowflake/connector/auth/__init__.py
@@ -7,6 +7,8 @@
from .keypair import AuthByKeyPair
from .no_auth import AuthNoAuth
from .oauth import AuthByOAuth
+from .oauth_code import AuthByOauthCode
+from .oauth_credentials import AuthByOauthCredentials
from .okta import AuthByOkta
from .pat import AuthByPAT
from .usrpwdmfa import AuthByUsrPwdMfa
@@ -18,6 +20,8 @@
AuthByDefault,
AuthByKeyPair,
AuthByOAuth,
+ AuthByOauthCode,
+ AuthByOauthCredentials,
AuthByOkta,
AuthByUsrPwdMfa,
AuthByWebBrowser,
@@ -34,6 +38,8 @@
"AuthByKeyPair",
"AuthByPAT",
"AuthByOAuth",
+ "AuthByOauthCode",
+ "AuthByOauthCredentials",
"AuthByOkta",
"AuthByUsrPwdMfa",
"AuthByWebBrowser",
diff --git a/src/snowflake/connector/auth/_auth.py b/src/snowflake/connector/auth/_auth.py
index cf3b6b6297..527bd5cf9b 100644
--- a/src/snowflake/connector/auth/_auth.py
+++ b/src/snowflake/connector/auth/_auth.py
@@ -47,6 +47,7 @@
ACCEPT_TYPE_APPLICATION_SNOWFLAKE,
CONTENT_TYPE_APPLICATION_JSON,
ID_TOKEN_INVALID_LOGIN_REQUEST_GS_CODE,
+ OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE,
PYTHON_CONNECTOR_USER_AGENT,
ReauthenticationRequest,
)
@@ -86,7 +87,7 @@ class Auth:
def __init__(self, rest) -> None:
self._rest = rest
- self.token_cache = TokenCache.make()
+ self._token_cache: TokenCache | None = None
@staticmethod
def base_auth_data(
@@ -350,7 +351,7 @@ def post_request_wrapper(self, url, headers, body) -> None:
# clear stored id_token if failed to connect because of id_token
# raise an exception for reauth without id_token
self._rest.id_token = None
- self.delete_temporary_credential(
+ self._delete_temporary_credential(
self._rest._host, user, TokenType.ID_TOKEN
)
raise ReauthenticationRequest(
@@ -360,6 +361,14 @@ def post_request_wrapper(self, url, headers, body) -> None:
sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
)
)
+ elif errno == OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE:
+ raise ReauthenticationRequest(
+ ProgrammingError(
+ msg=ret["message"],
+ errno=int(errno),
+ sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
+ )
+ )
from . import AuthByKeyPair
@@ -374,7 +383,7 @@ def post_request_wrapper(self, url, headers, body) -> None:
from . import AuthByUsrPwdMfa
if isinstance(auth_instance, AuthByUsrPwdMfa):
- self.delete_temporary_credential(
+ self._delete_temporary_credential(
self._rest._host, user, TokenType.MFA_TOKEN
)
Error.errorhandler_wrapper(
@@ -466,7 +475,7 @@ def _read_temporary_credential(
user: str,
cred_type: TokenType,
) -> str | None:
- return self.token_cache.retrieve(TokenKey(host, user, cred_type))
+ return self.get_token_cache().retrieve(TokenKey(host, user, cred_type))
def read_temporary_credentials(
self,
@@ -500,7 +509,7 @@ def _write_temporary_credential(
"no credential is given when try to store temporary credential"
)
return
- self.token_cache.store(TokenKey(host, user, cred_type), cred)
+ self.get_token_cache().store(TokenKey(host, user, cred_type), cred)
def write_temporary_credentials(
self,
@@ -524,10 +533,15 @@ def write_temporary_credentials(
host, user, TokenType.MFA_TOKEN, response["data"].get("mfaToken")
)
- def delete_temporary_credential(
+ def _delete_temporary_credential(
self, host: str, user: str, cred_type: TokenType
) -> None:
- self.token_cache.remove(TokenKey(host, user, cred_type))
+ self.get_token_cache().remove(TokenKey(host, user, cred_type))
+
+ def get_token_cache(self) -> TokenCache:
+ if self._token_cache is None:
+ self._token_cache = TokenCache.make()
+ return self._token_cache
def get_token_from_private_key(
diff --git a/src/snowflake/connector/auth/_http_server.py b/src/snowflake/connector/auth/_http_server.py
new file mode 100644
index 0000000000..a11662f25b
--- /dev/null
+++ b/src/snowflake/connector/auth/_http_server.py
@@ -0,0 +1,220 @@
+#
+# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
+#
+
+from __future__ import annotations
+
+import logging
+import os
+import select
+import socket
+import time
+import urllib.parse
+from collections.abc import Callable
+from types import TracebackType
+
+from typing_extensions import Self
+
+from ..compat import IS_WINDOWS
+
+logger = logging.getLogger(__name__)
+
+
+def _use_msg_dont_wait() -> bool:
+ if os.getenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", "false").lower() != "true":
+ return False
+ if IS_WINDOWS:
+ logger.warning(
+ "Configuration SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT is not available in Windows. Ignoring."
+ )
+ return False
+ return True
+
+
+def _wrap_socket_recv() -> Callable[[socket.socket, int], bytes]:
+ dont_wait = _use_msg_dont_wait()
+ if dont_wait:
+ # WSL containerized environment sometimes causes socket_client.recv to hang indefinetly
+ # To avoid this, passing the socket.MSG_DONTWAIT flag which raises BlockingIOError if
+ # operation would block
+ logger.debug(
+ "Will call socket.recv with MSG_DONTWAIT flag due to SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT env var"
+ )
+ socket_recv = (
+ (lambda sock, buf_size: socket.socket.recv(sock, buf_size, socket.MSG_DONTWAIT))
+ if dont_wait
+ else (lambda sock, buf_size: socket.socket.recv(sock, buf_size))
+ )
+
+ def socket_recv_checked(sock: socket.socket, buf_size: int) -> bytes:
+ raw = socket_recv(sock, buf_size)
+ # when running in a containerized environment, socket_client.recv occasionally returns an empty byte array
+ # an immediate successive call to socket_client.recv gets the actual data
+ if len(raw) == 0:
+ raw = socket_recv(sock, buf_size)
+ return raw
+
+ return socket_recv_checked
+
+
+class AuthHttpServer:
+ """Simple HTTP server to receive callbacks through for auth purposes."""
+
+ DEFAULT_MAX_ATTEMPTS = 15
+ DEFAULT_TIMEOUT = 30.0
+
+ PORT_BIND_MAX_ATTEMPTS = 10
+ PORT_BIND_TIMEOUT = 20.0
+
+ def __init__(
+ self,
+ uri: str,
+ buf_size: int = 16384,
+ ) -> None:
+ parsed_uri = urllib.parse.urlparse(uri)
+ self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.buf_size = buf_size
+ if os.getenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "False").lower() == "true":
+ if IS_WINDOWS:
+ logger.warning(
+ "Configuration SNOWFLAKE_AUTH_SOCKET_REUSE_PORT is not available in Windows. Ignoring."
+ )
+ else:
+ self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+
+ port = parsed_uri.port or 0
+ for attempt in range(1, self.DEFAULT_MAX_ATTEMPTS + 1):
+ try:
+ self._socket.bind(
+ (
+ parsed_uri.hostname,
+ port,
+ )
+ )
+ break
+ except socket.gaierror as ex:
+ logger.error(
+ f"Failed to bind authorization callback server to port {port}: {ex}"
+ )
+ raise
+ except OSError as ex:
+ if attempt == self.DEFAULT_MAX_ATTEMPTS:
+ logger.error(
+ f"Failed to bind authorization callback server to port {port}: {ex}"
+ )
+ raise
+ logger.warning(
+ f"Attempt {attempt}/{self.DEFAULT_MAX_ATTEMPTS}. "
+ f"Failed to bind authorization callback server to port {port}: {ex}"
+ )
+ time.sleep(self.PORT_BIND_TIMEOUT / self.PORT_BIND_MAX_ATTEMPTS)
+ try:
+ self._socket.listen(0) # no backlog
+ except Exception as ex:
+ logger.error(f"Failed to start listening for auth callback: {ex}")
+ self.close()
+ raise
+ port = self._socket.getsockname()[1]
+ self._uri = urllib.parse.ParseResult(
+ scheme=parsed_uri.scheme,
+ netloc=parsed_uri.hostname + ":" + str(port),
+ path=parsed_uri.path,
+ params=parsed_uri.params,
+ query=parsed_uri.query,
+ fragment=parsed_uri.fragment,
+ )
+
+ @property
+ def url(self) -> str:
+ return self._uri.geturl()
+
+ @property
+ def port(self) -> int:
+ return self._uri.port
+
+ @property
+ def hostname(self) -> str:
+ return self._uri.hostname
+
+ def _try_poll(
+ self, attempts: int, attempt_timeout: float | None
+ ) -> (socket.socket | None, int):
+ for attempt in range(attempts):
+ read_sockets = select.select([self._socket], [], [], attempt_timeout)[0]
+ if read_sockets and read_sockets[0] is not None:
+ return self._socket.accept()[0], attempt
+ return None, attempts
+
+ def _try_receive_block(
+ self, client_socket: socket.socket, attempts: int, attempt_timeout: float | None
+ ) -> bytes | None:
+ if attempt_timeout is not None:
+ client_socket.settimeout(attempt_timeout)
+ recv = _wrap_socket_recv()
+ for attempt in range(attempts):
+ try:
+ return recv(client_socket, self.buf_size)
+ except BlockingIOError:
+ if attempt < attempts - 1:
+ cooldown = min(attempt_timeout, 0.25) if attempt_timeout else 0.25
+ logger.debug(
+ f"BlockingIOError raised from socket.recv on {1 + attempt}/{attempts} attempt."
+ f"Waiting for {cooldown} seconds before trying again"
+ )
+ time.sleep(cooldown)
+ except socket.timeout:
+ logger.debug(
+ f"socket.recv timed out on {1 + attempt}/{attempts} attempt."
+ )
+ return None
+
+ def receive_block(
+ self,
+ max_attempts: int = None,
+ timeout: float | int | None = None,
+ ) -> (list[str] | None, socket.socket | None):
+ if max_attempts is None:
+ max_attempts = self.DEFAULT_MAX_ATTEMPTS
+ if timeout is None:
+ timeout = self.DEFAULT_TIMEOUT
+ """Receive a message with a maximum attempt count and a timeout in seconds, blocking."""
+ if not self._socket:
+ raise RuntimeError(
+ "Operation is not supported, server was already shut down."
+ )
+ attempt_timeout = timeout / max_attempts if timeout else None
+ client_socket, poll_attempts = self._try_poll(max_attempts, attempt_timeout)
+ if client_socket is None:
+ return None, None
+ raw_block = self._try_receive_block(
+ client_socket, max_attempts - poll_attempts, attempt_timeout
+ )
+ if raw_block:
+ return raw_block.decode("utf-8").split("\r\n"), client_socket
+ try:
+ client_socket.shutdown(socket.SHUT_RDWR)
+ except OSError:
+ pass
+ client_socket.close()
+ return None, None
+
+ def close(self) -> None:
+ """Closes the underlying socket.
+ After having close() being called the server object cannot be reused.
+ """
+ if self._socket:
+ self._socket.close()
+ self._socket = None
+
+ def __enter__(self) -> Self:
+ """Context manager."""
+ return self
+
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> None:
+ """Context manager with disposing underlying networking objects."""
+ self.close()
diff --git a/src/snowflake/connector/auth/_oauth_base.py b/src/snowflake/connector/auth/_oauth_base.py
new file mode 100644
index 0000000000..ec77b22735
--- /dev/null
+++ b/src/snowflake/connector/auth/_oauth_base.py
@@ -0,0 +1,367 @@
+#
+# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
+#
+
+from __future__ import annotations
+
+import base64
+import json
+import logging
+import urllib.parse
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any
+from urllib.error import HTTPError, URLError
+
+from ..errorcode import ER_FAILED_TO_REQUEST, ER_IDP_CONNECTION_ERROR
+from ..network import OAUTH_AUTHENTICATOR
+from ..secret_detector import SecretDetector
+from ..token_cache import TokenCache, TokenKey, TokenType
+from ..vendored import urllib3
+from .by_plugin import AuthByPlugin, AuthType
+
+if TYPE_CHECKING:
+ from .. import SnowflakeConnection
+
+logger = logging.getLogger(__name__)
+
+
+class _OAuthTokensMixin:
+ def __init__(
+ self,
+ token_cache: TokenCache | None,
+ refresh_token_enabled: bool,
+ idp_host: str,
+ ) -> None:
+ self._access_token = None
+ self._refresh_token_enabled = refresh_token_enabled
+ if self._refresh_token_enabled:
+ self._refresh_token = None
+ self._token_cache = token_cache
+ if self._token_cache:
+ logger.debug("token cache is going to be used if needed")
+ self._idp_host = idp_host
+ self._access_token_key: TokenKey | None = None
+ if self._refresh_token_enabled:
+ self._refresh_token_key: TokenKey | None = None
+
+ def _update_cache_keys(self, user: str) -> None:
+ if self._token_cache:
+ self._user = user
+
+ def _get_access_token_cache_key(self) -> TokenKey | None:
+ return (
+ TokenKey(self._user, self._idp_host, TokenType.OAUTH_ACCESS_TOKEN)
+ if self._token_cache and self._user
+ else None
+ )
+
+ def _get_refresh_token_cache_key(self) -> TokenKey | None:
+ return (
+ TokenKey(self._user, self._idp_host, TokenType.OAUTH_REFRESH_TOKEN)
+ if self._refresh_token_enabled and self._token_cache and self._user
+ else None
+ )
+
+ def _pop_cached_token(self, key: TokenKey | None) -> str | None:
+ if self._token_cache is None or key is None:
+ return None
+ return self._token_cache.retrieve(key)
+
+ def _pop_cached_access_token(self) -> bool:
+ """Retrieves OAuth access token from the token cache if enabled"""
+ self._access_token = self._pop_cached_token(self._get_access_token_cache_key())
+ return self._access_token is not None
+
+ def _pop_cached_refresh_token(self) -> bool:
+ """Retrieves OAuth refresh token from the token cache if enabled"""
+ if self._refresh_token_enabled:
+ self._refresh_token = self._pop_cached_token(
+ self._get_refresh_token_cache_key()
+ )
+ return self._refresh_token is not None
+ return False
+
+ def _reset_cached_token(self, key: TokenKey | None, token: str | None) -> None:
+ if self._token_cache is None or key is None:
+ return
+ if token:
+ self._token_cache.store(key, token)
+ else:
+ self._token_cache.remove(key)
+
+ def _reset_access_token(self, access_token: str | None = None) -> None:
+ """Updates OAuth access token both in memory and in the token cache if enabled"""
+ logger.debug(
+ "resetting access token to %s",
+ "*" * len(access_token) if access_token else None,
+ )
+ self._access_token = access_token
+ self._reset_cached_token(self._get_access_token_cache_key(), self._access_token)
+
+ def _reset_refresh_token(self, refresh_token: str | None = None) -> None:
+ """Updates OAuth refresh token both in memory and in the token cache if necessary"""
+ if self._refresh_token_enabled:
+ logger.debug(
+ "resetting refresh token to %s",
+ "*" * len(refresh_token) if refresh_token else None,
+ )
+ self._refresh_token = refresh_token
+ self._reset_cached_token(
+ self._get_refresh_token_cache_key(), self._refresh_token
+ )
+
+ def _reset_temporary_state(self) -> None:
+ self._access_token = None
+ if self._refresh_token_enabled:
+ self._refresh_token = None
+ if self._token_cache:
+ self._user = None
+
+
+class AuthByOAuthBase(AuthByPlugin, _OAuthTokensMixin, ABC):
+ """A base abstract class for OAuth authenticators"""
+
+ def __init__(
+ self,
+ client_id: str,
+ client_secret: str,
+ token_request_url: str,
+ scope: str,
+ token_cache: TokenCache | None,
+ refresh_token_enabled: bool,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ _OAuthTokensMixin.__init__(
+ self,
+ token_cache=token_cache,
+ refresh_token_enabled=refresh_token_enabled,
+ idp_host=urllib.parse.urlparse(token_request_url).hostname,
+ )
+ self._client_id = client_id
+ self._client_secret = client_secret
+ self._token_request_url = token_request_url
+ self._scope = scope
+ if refresh_token_enabled:
+ logger.debug("oauth refresh token is going to be used if needed")
+ self._scope += (" " if self._scope else "") + "offline_access"
+
+ @abstractmethod
+ def _request_tokens(
+ self,
+ *,
+ conn: SnowflakeConnection,
+ authenticator: str,
+ service_name: str | None,
+ account: str,
+ user: str,
+ password: str | None,
+ **kwargs: Any,
+ ) -> (str | None, str | None):
+ """Request new access and optionally refresh tokens from IdP.
+
+ This function should implement specific tokens querying flow.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def _get_oauth_type_id(self) -> str:
+ """Get OAuth specific authenticator id to be passed to Snowflake.
+
+ This function should return a unique OAuth authenticator id.
+ """
+ raise NotImplementedError
+
+ def reset_secrets(self) -> None:
+ logger.debug("resetting secrets")
+ self._reset_temporary_state()
+
+ @property
+ def type_(self) -> AuthType:
+ return AuthType.OAUTH
+
+ @property
+ def assertion_content(self) -> str:
+ """Returns the token."""
+ return self._access_token or ""
+
+ def reauthenticate(
+ self,
+ *,
+ conn: SnowflakeConnection,
+ **kwargs: Any,
+ ) -> dict[str, bool]:
+ self._reset_access_token()
+ if self._pop_cached_refresh_token():
+ logger.debug(
+ "OAuth refresh token is available, try to use it and get a new access token"
+ )
+ self._do_refresh_token(conn=conn)
+ conn.authenticate_with_retry(self)
+ return {"success": True}
+
+ def prepare(
+ self,
+ *,
+ conn: SnowflakeConnection,
+ authenticator: str,
+ service_name: str | None,
+ account: str,
+ user: str,
+ **kwargs: Any,
+ ) -> None:
+ """Web Browser based Authentication."""
+ logger.debug("authenticating with OAuth authorization code flow")
+ self._update_cache_keys(user=user)
+ if self._pop_cached_access_token():
+ logger.info(
+ "OAuth access token is already available in cache, no need to authenticate."
+ )
+ return
+ access_token, refresh_token = self._request_tokens(
+ conn=conn,
+ authenticator=authenticator,
+ service_name=service_name,
+ account=account,
+ user=user,
+ **kwargs,
+ )
+ self._reset_access_token(access_token)
+ self._reset_refresh_token(refresh_token)
+
+ def update_body(self, body: dict[Any, Any]) -> None:
+ """Used by Auth to update the request that gets sent to /v1/login-request.
+
+ Args:
+ body: existing request dictionary
+ """
+ body["data"]["AUTHENTICATOR"] = OAUTH_AUTHENTICATOR
+ body["data"]["TOKEN"] = self._access_token
+ body["data"]["OAUTH_TYPE"] = self._get_oauth_type_id()
+
+ def _do_refresh_token(self, conn: SnowflakeConnection) -> None:
+ """If a refresh token is available exchanges it with a new access token.
+ Updates self as a side-effect. Needs at lest self._refresh_token and client_id set.
+ """
+ if not self._refresh_token_enabled:
+ logger.debug("refresh_token feature is disabled")
+ return
+
+ resp = self._get_refresh_token_response(conn)
+ if not resp:
+ logger.info(
+ "failed to exchange the refresh token on a new OAuth access token"
+ )
+ self._reset_refresh_token()
+ return
+
+ try:
+ json_resp = json.loads(resp.data.decode())
+ self._reset_access_token(json_resp["access_token"])
+ if "refresh_token" in json_resp:
+ self._reset_refresh_token(json_resp["refresh_token"])
+ except (
+ json.JSONDecodeError,
+ KeyError,
+ ):
+ logger.error(
+ "refresh token exchange response did not contain 'access_token'"
+ )
+ logger.debug(
+ "received the following response body when exchanging refresh token: %s",
+ SecretDetector.mask_secrets(str(resp.data)),
+ )
+ self._reset_refresh_token()
+
+ def _get_refresh_token_response(
+ self, conn: SnowflakeConnection
+ ) -> urllib3.BaseHTTPResponse | None:
+ fields = {
+ "grant_type": "refresh_token",
+ "refresh_token": self._refresh_token,
+ }
+ if self._scope:
+ fields["scope"] = self._scope
+ try:
+ return urllib3.PoolManager().request_encode_body(
+ # TODO: use network pool to gain use of proxy settings and so on
+ "POST",
+ self._token_request_url,
+ encode_multipart=False,
+ headers=self._create_token_request_headers(),
+ fields=fields,
+ )
+ except HTTPError as e:
+ self._handle_failure(
+ conn=conn,
+ ret={
+ "code": ER_FAILED_TO_REQUEST,
+ "message": f"Failed to request new OAuth access token with a refresh token,"
+ f" url={e.url}, code={e.code}, reason={e.reason}",
+ },
+ )
+ except URLError as e:
+ self._handle_failure(
+ conn=conn,
+ ret={
+ "code": ER_FAILED_TO_REQUEST,
+ "message": f"Failed to request new OAuth access token with a refresh token, reason: {e.reason}",
+ },
+ )
+ except Exception:
+ self._handle_failure(
+ conn=conn,
+ ret={
+ "code": ER_FAILED_TO_REQUEST,
+ "message": "Failed to request new OAuth access token with a refresh token by unknown reason",
+ },
+ )
+ return None
+
+ def _get_request_token_response(
+ self,
+ connection: SnowflakeConnection,
+ fields: dict[str, str],
+ ) -> (str | None, str | None):
+ resp = urllib3.PoolManager().request_encode_body(
+ # TODO: use network pool to gain use of proxy settings and so on
+ "POST",
+ self._token_request_url,
+ headers=self._create_token_request_headers(),
+ encode_multipart=False,
+ fields=fields,
+ )
+ try:
+ logger.debug("OAuth IdP response received, try to parse it")
+ json_resp: dict = json.loads(resp.data)
+ access_token = json_resp["access_token"]
+ refresh_token = json_resp.get("refresh_token")
+ return access_token, refresh_token
+ except (
+ json.JSONDecodeError,
+ KeyError,
+ ):
+ logger.error("oauth response invalid, does not contain 'access_token'")
+ logger.debug(
+ "received the following response body when requesting oauth token: %s",
+ SecretDetector.mask_secrets(str(resp.data)),
+ )
+ self._handle_failure(
+ conn=connection,
+ ret={
+ "code": ER_IDP_CONNECTION_ERROR,
+ "message": "Invalid HTTP request from web browser. Idp "
+ "authentication could have failed.",
+ },
+ )
+ return None, None
+
+ def _create_token_request_headers(self) -> dict[str, str]:
+ return {
+ "Authorization": "Basic "
+ + base64.b64encode(
+ f"{self._client_id}:{self._client_secret}".encode()
+ ).decode(),
+ "Accept": "application/json",
+ "Content-Type": "application/x-www-form-urlencoded; charset=UTF-8",
+ }
diff --git a/src/snowflake/connector/auth/oauth_code.py b/src/snowflake/connector/auth/oauth_code.py
new file mode 100644
index 0000000000..f93562bc3b
--- /dev/null
+++ b/src/snowflake/connector/auth/oauth_code.py
@@ -0,0 +1,383 @@
+#
+# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
+#
+
+from __future__ import annotations
+
+import base64
+import hashlib
+import json
+import logging
+import secrets
+import socket
+import time
+import urllib.parse
+import webbrowser
+from typing import TYPE_CHECKING, Any
+
+from ..compat import parse_qs, urlparse, urlsplit
+from ..constants import OAUTH_TYPE_AUTHORIZATION_CODE
+from ..errorcode import (
+ ER_OAUTH_CALLBACK_ERROR,
+ ER_OAUTH_SERVER_TIMEOUT,
+ ER_OAUTH_STATE_CHANGED,
+ ER_UNABLE_TO_OPEN_BROWSER,
+)
+from ..token_cache import TokenCache
+from ._http_server import AuthHttpServer
+from ._oauth_base import AuthByOAuthBase
+
+if TYPE_CHECKING:
+ from .. import SnowflakeConnection
+
+logger = logging.getLogger(__name__)
+
+BUF_SIZE = 16384
+
+
+def _get_query_params(
+ url: str,
+) -> dict[str, list[str]]:
+ parsed = parse_qs(urlparse(url).query)
+ return parsed
+
+
+class AuthByOauthCode(AuthByOAuthBase):
+ """Authenticates user by OAuth code flow."""
+
+ def __init__(
+ self,
+ application: str,
+ client_id: str,
+ client_secret: str,
+ authentication_url: str,
+ token_request_url: str,
+ redirect_uri: str,
+ scope: str,
+ pkce_enabled: bool = True,
+ token_cache: TokenCache | None = None,
+ refresh_token_enabled: bool = False,
+ external_browser_timeout: int | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ client_id=client_id,
+ client_secret=client_secret,
+ token_request_url=token_request_url,
+ scope=scope,
+ token_cache=token_cache,
+ refresh_token_enabled=refresh_token_enabled,
+ **kwargs,
+ )
+ self._application = application
+ self._origin: str | None = None
+ self._authentication_url = authentication_url
+ self._redirect_uri = redirect_uri
+ self._state = secrets.token_urlsafe(43)
+ logger.debug("chose oauth state: %s", "".join("*" for _ in self._state))
+ self._protocol = "http"
+ self._pkce_enabled = pkce_enabled
+ if pkce_enabled:
+ logger.debug("oauth pkce is going to be used")
+ self._verifier: str | None = None
+ self._external_browser_timeout = external_browser_timeout
+
+ def _get_oauth_type_id(self) -> str:
+ return OAUTH_TYPE_AUTHORIZATION_CODE
+
+ def _request_tokens(
+ self,
+ *,
+ conn: SnowflakeConnection,
+ authenticator: str,
+ service_name: str | None,
+ account: str,
+ user: str,
+ **kwargs: Any,
+ ) -> (str | None, str | None):
+ """Web Browser based Authentication."""
+ logger.debug("authenticating with OAuth authorization code flow")
+ with AuthHttpServer(self._redirect_uri) as callback_server:
+ code = self._do_authorization_request(callback_server, conn)
+ return self._do_token_request(code, callback_server, conn)
+
+ def _check_post_requested(
+ self, data: list[str]
+ ) -> tuple[str, str] | tuple[None, None]:
+ request_line = None
+ header_line = None
+ origin_line = None
+ for line in data:
+ if line.startswith("Access-Control-Request-Method:"):
+ request_line = line
+ elif line.startswith("Access-Control-Request-Headers:"):
+ header_line = line
+ elif line.startswith("Origin:"):
+ origin_line = line
+
+ if (
+ not request_line
+ or not header_line
+ or not origin_line
+ or request_line.split(":")[1].strip() != "POST"
+ ):
+ return (None, None)
+
+ return (
+ header_line.split(":")[1].strip(),
+ ":".join(origin_line.split(":")[1:]).strip(),
+ )
+
+ def _process_options(
+ self, data: list[str], socket_client: socket.socket, hostname: str, port: int
+ ) -> bool:
+ """Allows JS Ajax access to this endpoint."""
+ for line in data:
+ if line.startswith("OPTIONS "):
+ break
+ else:
+ return False
+ requested_headers, requested_origin = self._check_post_requested(data)
+ if requested_headers is None or requested_origin is None:
+ return False
+
+ if not self._validate_origin(requested_origin, hostname, port):
+ # validate Origin and fail if not match with the server.
+ return False
+
+ self._origin = requested_origin
+ content = [
+ "HTTP/1.1 200 OK",
+ "Date: {}".format(
+ time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime())
+ ),
+ "Access-Control-Allow-Methods: POST, GET",
+ f"Access-Control-Allow-Headers: {requested_headers}",
+ "Access-Control-Max-Age: 86400",
+ f"Access-Control-Allow-Origin: {self._origin}",
+ "",
+ "",
+ ]
+ socket_client.sendall("\r\n".join(content).encode("utf-8"))
+ return True
+
+ def _validate_origin(self, requested_origin: str, hostname: str, port: int) -> bool:
+ ret = urlsplit(requested_origin)
+ netloc = ret.netloc.split(":")
+ host_got = netloc[0]
+ port_got = (
+ netloc[1] if len(netloc) > 1 else (443 if self._protocol == "https" else 80)
+ )
+
+ return (
+ ret.scheme == self._protocol and host_got == hostname and port_got == port
+ )
+
+ def _send_response(self, data: list[str], socket_client: socket.socket) -> None:
+ if not self._is_request_get(data):
+ return # error
+
+ response = [
+ "HTTP/1.1 200 OK",
+ "Content-Type: text/html",
+ ]
+ if self._origin:
+ msg = json.dumps({"consent": self.consent_cache_id_token})
+ response.append(f"Access-Control-Allow-Origin: {self._origin}")
+ response.append("Vary: Accept-Encoding, Origin")
+ else:
+ msg = f"""
+
+
+OAuth Response for Snowflake
+
+Your identity was confirmed and propagated to Snowflake {self._application}.
+You can close this window now and go back where you started from.
+"""
+ response.append(f"Content-Length: {len(msg)}")
+ response.append("")
+ response.append(msg)
+
+ socket_client.sendall("\r\n".join(response).encode("utf-8"))
+
+ @staticmethod
+ def _has_code(url: str) -> bool:
+ return "code" in parse_qs(urlparse(url).query)
+
+ @staticmethod
+ def _is_request_get(data: list[str]) -> bool:
+ """Whether an HTTP request is a GET."""
+ return any(line.startswith("GET ") for line in data)
+
+ def _construct_authorization_request(self, redirect_uri: str) -> str:
+ params = {
+ "response_type": "code",
+ "client_id": self._client_id,
+ "redirect_uri": redirect_uri,
+ "state": self._state,
+ }
+ if self._scope:
+ params["scope"] = self._scope
+ if self._pkce_enabled:
+ self._verifier = secrets.token_urlsafe(43)
+ # calculate challenge and verifier
+ challenge = (
+ base64.urlsafe_b64encode(
+ hashlib.sha256(self._verifier.encode("utf-8")).digest()
+ )
+ .decode("utf-8")
+ .rstrip("=")
+ )
+ params["code_challenge"] = challenge
+ params["code_challenge_method"] = "S256"
+ url_params = urllib.parse.urlencode(params)
+ url = f"{self._authentication_url}?{url_params}"
+ return url
+
+ def _do_authorization_request(
+ self,
+ callback_server: AuthHttpServer,
+ connection: SnowflakeConnection,
+ ) -> str | None:
+ authorization_request = self._construct_authorization_request(
+ callback_server.url
+ )
+ logger.debug("step 1: going to open authorization URL")
+ print(
+ "Initiating login request with your identity provider. A "
+ "browser window should have opened for you to complete the "
+ "login. If you can't see it, check existing browser windows, "
+ "or your OS settings. Press CTRL+C to abort and try again..."
+ )
+ code, state = (
+ self._receive_authorization_callback(callback_server, connection)
+ if webbrowser.open(authorization_request)
+ else self._ask_authorization_callback_from_user(
+ authorization_request, connection
+ )
+ )
+ if not code:
+ self._handle_failure(
+ conn=connection,
+ ret={
+ "code": ER_UNABLE_TO_OPEN_BROWSER,
+ "message": (
+ "Unable to open a browser in this environment and "
+ "OAuth URL contained no authorization code."
+ ),
+ },
+ )
+ return None
+ if state != self._state:
+ self._handle_failure(
+ conn=connection,
+ ret={
+ "code": ER_OAUTH_STATE_CHANGED,
+ "message": "State changed during OAuth process.",
+ },
+ )
+ logger.debug(
+ "received oauth code: %s and state: %s",
+ "*" * len(code),
+ "*" * len(state),
+ )
+ return None
+ return code
+
+ def _do_token_request(
+ self,
+ code: str,
+ callback_server: AuthHttpServer,
+ connection: SnowflakeConnection,
+ ) -> (str | None, str | None):
+ logger.debug("step 2: received OAuth callback, requesting token")
+ fields = {
+ "grant_type": "authorization_code",
+ "code": code,
+ "redirect_uri": callback_server.url,
+ }
+ if self._pkce_enabled:
+ assert self._verifier is not None
+ fields["code_verifier"] = self._verifier
+ return self._get_request_token_response(connection, fields)
+
+ def _receive_authorization_callback(
+ self,
+ http_server: AuthHttpServer,
+ connection: SnowflakeConnection,
+ ) -> (str | None, str | None):
+ logger.debug("trying to receive authorization redirected uri")
+ data, socket_connection = http_server.receive_block(
+ timeout=self._external_browser_timeout
+ )
+ if socket_connection is None:
+ self._handle_failure(
+ conn=connection,
+ ret={
+ "code": ER_OAUTH_SERVER_TIMEOUT,
+ "message": "Unable to receive the OAuth message within a given timeout. Please check the redirect URI and try again.",
+ },
+ )
+ return None, None
+ try:
+ if not self._process_options(
+ data, socket_connection, http_server.hostname, http_server.port
+ ):
+ self._send_response(data, socket_connection)
+ socket_connection.shutdown(socket.SHUT_RDWR)
+ except OSError:
+ pass
+ finally:
+ socket_connection.close()
+ return self._parse_authorization_redirected_request(
+ data[0].split(maxsplit=2)[1],
+ connection,
+ )
+
+ def _ask_authorization_callback_from_user(
+ self,
+ authorization_request: str,
+ connection: SnowflakeConnection,
+ ) -> (str | None, str | None):
+ logger.debug("requesting authorization redirected url from user")
+ print(
+ "We were unable to open a browser window for you, "
+ "please open the URL manually then paste the "
+ "URL you are redirected to into the terminal:\n"
+ f"{authorization_request}"
+ )
+ received_redirected_request = input(
+ "Enter the URL the OAuth flow redirected you to: "
+ )
+ code, state = self._parse_authorization_redirected_request(
+ received_redirected_request,
+ connection,
+ )
+ if not code:
+ self._handle_failure(
+ conn=connection,
+ ret={
+ "code": ER_UNABLE_TO_OPEN_BROWSER,
+ "message": (
+ "Unable to open a browser in this environment and "
+ "OAuth URL contained no code"
+ ),
+ },
+ )
+ return code, state
+
+ def _parse_authorization_redirected_request(
+ self,
+ url: str,
+ conn: SnowflakeConnection,
+ ) -> (str | None, str | None):
+ parsed = parse_qs(urlparse(url).query)
+ if "error" in parsed:
+ self._handle_failure(
+ conn=conn,
+ ret={
+ "code": ER_OAUTH_CALLBACK_ERROR,
+ "message": f"Oauth callback returned an {parsed['error'][0]} error{': ' + parsed['error_description'][0] if 'error_description' in parsed else '.'}",
+ },
+ )
+ return parsed.get("code", [None])[0], parsed.get("state", [None])[0]
diff --git a/src/snowflake/connector/auth/oauth_credentials.py b/src/snowflake/connector/auth/oauth_credentials.py
new file mode 100644
index 0000000000..6061ead023
--- /dev/null
+++ b/src/snowflake/connector/auth/oauth_credentials.py
@@ -0,0 +1,64 @@
+#
+# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
+#
+
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING, Any
+
+from ..constants import OAUTH_TYPE_CLIENT_CREDENTIALS
+from ..token_cache import TokenCache
+from ._oauth_base import AuthByOAuthBase
+
+if TYPE_CHECKING:
+ from .. import SnowflakeConnection
+
+logger = logging.getLogger(__name__)
+
+
+class AuthByOauthCredentials(AuthByOAuthBase):
+ """Authenticates user by OAuth credentials - a client_id/client_secret pair."""
+
+ def __init__(
+ self,
+ application: str,
+ client_id: str,
+ client_secret: str,
+ token_request_url: str,
+ scope: str,
+ token_cache: TokenCache | None = None,
+ refresh_token_enabled: bool = False,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ client_id=client_id,
+ client_secret=client_secret,
+ token_request_url=token_request_url,
+ scope=scope,
+ token_cache=token_cache,
+ refresh_token_enabled=refresh_token_enabled,
+ **kwargs,
+ )
+ self._application = application
+ self._origin: str | None = None
+
+ def _get_oauth_type_id(self) -> str:
+ return OAUTH_TYPE_CLIENT_CREDENTIALS
+
+ def _request_tokens(
+ self,
+ *,
+ conn: SnowflakeConnection,
+ authenticator: str,
+ service_name: str | None,
+ account: str,
+ user: str,
+ **kwargs: Any,
+ ) -> (str | None, str | None):
+ logger.debug("authenticating with OAuth client credentials flow")
+ fields = {
+ "grant_type": "client_credentials",
+ "scope": self._scope,
+ }
+ return self._get_request_token_response(conn, fields)
diff --git a/src/snowflake/connector/auth/webbrowser.py b/src/snowflake/connector/auth/webbrowser.py
index 20b92efb52..f5bddd4fcc 100644
--- a/src/snowflake/connector/auth/webbrowser.py
+++ b/src/snowflake/connector/auth/webbrowser.py
@@ -112,6 +112,7 @@ def prepare(
"""Web Browser based Authentication."""
logger.debug("authenticating by Web Browser")
+ # TODO: switch to the new AuthHttpServer class instead of doing this manually
socket_connection = self._socket(socket.AF_INET, socket.SOCK_STREAM)
if os.getenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "False").lower() == "true":
diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py
index 84e0052a62..9103710f7a 100644
--- a/src/snowflake/connector/connection.py
+++ b/src/snowflake/connector/connection.py
@@ -2,6 +2,7 @@
from __future__ import annotations
import atexit
+import collections.abc
import logging
import os
import pathlib
@@ -35,6 +36,8 @@
AuthByDefault,
AuthByKeyPair,
AuthByOAuth,
+ AuthByOauthCode,
+ AuthByOauthCredentials,
AuthByOkta,
AuthByPAT,
AuthByPlugin,
@@ -52,6 +55,7 @@
from .constants import (
_CONNECTIVITY_ERR_MSG,
_DOMAIN_NAME_MAP,
+ _OAUTH_DEFAULT_SCOPE,
ENV_VAR_EXPERIMENTAL_AUTHENTICATION,
ENV_VAR_PARTNER,
PARAMETER_AUTOCOMMIT,
@@ -81,6 +85,7 @@
from .direct_file_operation_utils import FileOperationParser, StreamDownloader
from .errorcode import (
ER_CONNECTION_IS_CLOSED,
+ ER_EXPERIMENTAL_AUTHENTICATION_NOT_SUPPORTED,
ER_FAILED_PROCESSING_PYFORMAT,
ER_FAILED_PROCESSING_QMARK,
ER_FAILED_TO_CONNECT_TO_DB,
@@ -88,6 +93,7 @@
ER_INVALID_VALUE,
ER_INVALID_WIF_SETTINGS,
ER_NO_ACCOUNT_NAME,
+ ER_NO_CLIENT_ID,
ER_NO_NUMPY,
ER_NO_PASSWORD,
ER_NO_USER,
@@ -101,6 +107,8 @@
KEY_PAIR_AUTHENTICATOR,
NO_AUTH_AUTHENTICATOR,
OAUTH_AUTHENTICATOR,
+ OAUTH_AUTHORIZATION_CODE,
+ OAUTH_CLIENT_CREDENTIALS,
PROGRAMMATIC_ACCESS_TOKEN,
REQUEST_ID,
USR_PWD_MFA_AUTHENTICATOR,
@@ -166,13 +174,13 @@ def _get_private_bytes_from_file(
"user": ("", str), # standard
"password": ("", str), # standard
"host": ("127.0.0.1", str), # standard
- "port": (8080, (int, str)), # standard
+ "port": (443, (int, str)), # standard
"database": (None, (type(None), str)), # standard
"proxy_host": (None, (type(None), str)), # snowflake
"proxy_port": (None, (type(None), str)), # snowflake
"proxy_user": (None, (type(None), str)), # snowflake
"proxy_password": (None, (type(None), str)), # snowflake
- "protocol": ("http", str), # snowflake
+ "protocol": ("https", str), # snowflake
"warehouse": (None, (type(None), str)), # snowflake
"region": (None, (type(None), str)), # snowflake
"account": (None, (type(None), str)), # snowflake
@@ -185,6 +193,7 @@ def _get_private_bytes_from_file(
(type(None), int),
), # network timeout (infinite by default)
"socket_timeout": (None, (type(None), int)),
+ "external_browser_timeout": (120, int),
"backoff_policy": (DEFAULT_BACKOFF_POLICY, Callable),
"passcode_in_password": (False, bool), # Snowflake MFA
"passcode": (None, (type(None), str)), # Snowflake MFA
@@ -315,6 +324,37 @@ def _get_private_bytes_from_file(
False,
bool,
), # use https://{bucket}.storage.googleapis.com instead of https://storage.googleapis.com/{bucket}
+ "oauth_client_id": (
+ None,
+ (type(None), str),
+ # SNOW-1825621: OAUTH implementation
+ ),
+ "oauth_client_secret": (
+ None,
+ (type(None), str),
+ # SNOW-1825621: OAUTH implementation
+ ),
+ "oauth_authorization_url": (
+ "https://{host}:{port}/oauth/authorize",
+ str,
+ # SNOW-1825621: OAUTH implementation
+ ),
+ "oauth_token_request_url": (
+ "https://{host}:{port}/oauth/token-request",
+ str,
+ # SNOW-1825621: OAUTH implementation
+ ),
+ "oauth_redirect_uri": ("http://127.0.0.1/", str),
+ "oauth_scope": (
+ "",
+ str,
+ # SNOW-1825621: OAUTH implementation
+ ),
+ "oauth_security_features": (
+ ("pkce",),
+ collections.abc.Iterable, # of strings
+ # SNOW-1825621: OAUTH PKCE
+ ),
"check_arrow_conversion_error_on_every_column": (
True,
bool,
@@ -552,8 +592,8 @@ def host(self) -> str:
return self._host
@property
- def port(self) -> int | str: # TODO: shouldn't be a string
- return self._port
+ def port(self) -> int:
+ return int(self._port)
@property
def region(self) -> str | None:
@@ -806,6 +846,21 @@ def unsafe_file_write(self) -> bool:
def unsafe_file_write(self, value: bool) -> None:
self._unsafe_file_write = value
+ class _OAuthSecurityFeatures(NamedTuple):
+ pkce_enabled: bool
+ refresh_token_enabled: bool
+
+ @property
+ def oauth_security_features(self) -> _OAuthSecurityFeatures:
+ features = self._oauth_security_features
+ if isinstance(features, str):
+ features = features.split(" ")
+ features = [feat.lower() for feat in features]
+ return self._OAuthSecurityFeatures(
+ pkce_enabled="pkce" in features,
+ refresh_token_enabled="refresh_token" in features,
+ )
+
@property
def gcs_use_virtual_endpoints(self) -> bool:
return self._gcs_use_virtual_endpoints
@@ -1134,7 +1189,7 @@ def __open_connection(self):
self.auth_class = AuthByWebBrowser(
application=self.application,
protocol=self._protocol,
- host=self.host,
+ host=self.host, # TODO: delete this?
port=self.port,
timeout=self.login_timeout,
backoff_generator=self._backoff_generator,
@@ -1170,6 +1225,56 @@ def __open_connection(self):
timeout=self.login_timeout,
backoff_generator=self._backoff_generator,
)
+ elif self._authenticator == OAUTH_AUTHORIZATION_CODE:
+ self._check_experimental_authentication_flag()
+ self._check_oauth_required_parameters()
+ features = self.oauth_security_features
+ if self._role and (self._oauth_scope == ""):
+ # if role is known then let's inject it into scope
+ self._oauth_scope = _OAUTH_DEFAULT_SCOPE.format(role=self._role)
+ self.auth_class = AuthByOauthCode(
+ application=self.application,
+ client_id=self._oauth_client_id,
+ client_secret=self._oauth_client_secret,
+ authentication_url=self._oauth_authorization_url.format(
+ host=self.host, port=self.port
+ ),
+ token_request_url=self._oauth_token_request_url.format(
+ host=self.host, port=self.port
+ ),
+ redirect_uri=self._oauth_redirect_uri,
+ scope=self._oauth_scope,
+ pkce_enabled=features.pkce_enabled,
+ token_cache=(
+ auth.get_token_cache()
+ if self._client_store_temporary_credential
+ else None
+ ),
+ refresh_token_enabled=features.refresh_token_enabled,
+ external_browser_timeout=self._external_browser_timeout,
+ )
+ elif self._authenticator == OAUTH_CLIENT_CREDENTIALS:
+ self._check_experimental_authentication_flag()
+ self._check_oauth_required_parameters()
+ features = self.oauth_security_features
+ if self._role and (self._oauth_scope == ""):
+ # if role is known then let's inject it into scope
+ self._oauth_scope = _OAUTH_DEFAULT_SCOPE.format(role=self._role)
+ self.auth_class = AuthByOauthCredentials(
+ application=self.application,
+ client_id=self._oauth_client_id,
+ client_secret=self._oauth_client_secret,
+ token_request_url=self._oauth_token_request_url.format(
+ host=self.host, port=self.port
+ ),
+ scope=self._oauth_scope,
+ token_cache=(
+ auth.get_token_cache()
+ if self._client_store_temporary_credential
+ else None
+ ),
+ refresh_token_enabled=features.refresh_token_enabled,
+ )
elif self._authenticator == USR_PWD_MFA_AUTHENTICATOR:
self._session_parameters[PARAMETER_CLIENT_REQUEST_MFA_TOKEN] = (
self._client_request_mfa_token if IS_LINUX else True
@@ -1189,16 +1294,7 @@ def __open_connection(self):
elif self._authenticator == PROGRAMMATIC_ACCESS_TOKEN:
self.auth_class = AuthByPAT(self._token)
elif self._authenticator == WORKLOAD_IDENTITY_AUTHENTICATOR:
- if ENV_VAR_EXPERIMENTAL_AUTHENTICATION not in os.environ:
- Error.errorhandler_wrapper(
- self,
- None,
- ProgrammingError,
- {
- "msg": f"Please set the '{ENV_VAR_EXPERIMENTAL_AUTHENTICATION}' environment variable to use the '{WORKLOAD_IDENTITY_AUTHENTICATOR}' authenticator.",
- "errno": ER_INVALID_WIF_SETTINGS,
- },
- )
+ self._check_experimental_authentication_flag()
# Standardize the provider enum.
if self._workload_identity_provider and isinstance(
self._workload_identity_provider, str
@@ -1311,10 +1407,6 @@ def __config(self, **kwargs):
if "account" in kwargs:
if "host" not in kwargs:
self._host = construct_hostname(kwargs.get("region"), self._account)
- if "port" not in kwargs:
- self._port = "443"
- if "protocol" not in kwargs:
- self._protocol = "https"
logger.info(
f"Connecting to {_DOMAIN_NAME_MAP.get(extract_top_level_domain_from_hostname(self._host), 'GLOBAL')} Snowflake domain"
@@ -1393,6 +1485,8 @@ def __config(self, **kwargs):
not in (
EXTERNAL_BROWSER_AUTHENTICATOR,
OAUTH_AUTHENTICATOR,
+ OAUTH_AUTHORIZATION_CODE,
+ OAUTH_CLIENT_CREDENTIALS,
KEY_PAIR_AUTHENTICATOR,
PROGRAMMATIC_ACCESS_TOKEN,
WORKLOAD_IDENTITY_AUTHENTICATOR,
@@ -1542,9 +1636,13 @@ def authenticate_with_retry(self, auth_instance) -> None:
except ReauthenticationRequest as ex:
# cached id_token expiration error, we have cleaned id_token and try to authenticate again
logger.debug("ID token expired. Reauthenticating...: %s", ex)
- if isinstance(auth_instance, AuthByIdToken):
- # Note: SNOW-733835 IDToken auth needs to authenticate through
- # SSO if it has expired
+ if type(auth_instance) in (
+ AuthByIdToken,
+ AuthByOauthCode,
+ AuthByOauthCredentials,
+ ):
+ # IDToken and OAuth auth need to authenticate through
+ # SSO if its credential has expired
self._reauthenticate()
else:
self._authenticate(auth_instance)
@@ -2146,6 +2244,40 @@ def is_valid(self) -> bool:
logger.debug("session could not be validated due to exception: %s", e)
return False
+ def _check_experimental_authentication_flag(self) -> None:
+ if os.getenv(ENV_VAR_EXPERIMENTAL_AUTHENTICATION, "false").lower() != "true":
+ Error.errorhandler_wrapper(
+ self,
+ None,
+ ProgrammingError,
+ {
+ "msg": f"Please set the '{ENV_VAR_EXPERIMENTAL_AUTHENTICATION}' environment variable true to use the '{self._authenticator}' authenticator.",
+ "errno": ER_EXPERIMENTAL_AUTHENTICATION_NOT_SUPPORTED,
+ },
+ )
+
+ def _check_oauth_required_parameters(self) -> None:
+ if self._oauth_client_id is None:
+ Error.errorhandler_wrapper(
+ self,
+ None,
+ ProgrammingError,
+ {
+ "msg": "Oauth code flow requirement 'client_id' is empty",
+ "errno": ER_NO_CLIENT_ID,
+ },
+ )
+ if self._oauth_client_secret is None:
+ Error.errorhandler_wrapper(
+ self,
+ None,
+ ProgrammingError,
+ {
+ "msg": "Oauth code flow requirement 'client_secret' is empty",
+ "errno": ER_NO_CLIENT_ID,
+ },
+ )
+
@staticmethod
def _detect_application() -> None | str:
if ENV_VAR_PARTNER in os.environ.keys():
diff --git a/src/snowflake/connector/constants.py b/src/snowflake/connector/constants.py
index 085ec7a2b3..739fcd3fcc 100644
--- a/src/snowflake/connector/constants.py
+++ b/src/snowflake/connector/constants.py
@@ -321,7 +321,7 @@ class FileHeader(NamedTuple):
PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL = "CLIENT_STORE_TEMPORARY_CREDENTIAL"
PARAMETER_CLIENT_REQUEST_MFA_TOKEN = "CLIENT_REQUEST_MFA_TOKEN"
PARAMETER_CLIENT_USE_SECURE_STORAGE_FOR_TEMPORARY_CREDENTIAL = (
- "CLIENT_USE_SECURE_STORAGE_FOR_TEMPORARY_CREDENTAIL"
+ "CLIENT_USE_SECURE_STORAGE_FOR_TEMPORARY_CREDENTIAL"
)
PARAMETER_QUERY_CONTEXT_CACHE_SIZE = "QUERY_CONTEXT_CACHE_SIZE"
PARAMETER_TIMEZONE = "TIMEZONE"
@@ -436,3 +436,7 @@ class IterUnit(Enum):
"\nTo further troubleshoot your connection you may reference the following article: "
"https://docs.snowflake.com/en/user-guide/client-connectivity-troubleshooting/overview."
)
+
+_OAUTH_DEFAULT_SCOPE = "session:role:{role}"
+OAUTH_TYPE_AUTHORIZATION_CODE = "authorization_code"
+OAUTH_TYPE_CLIENT_CREDENTIALS = "client_credentials"
diff --git a/src/snowflake/connector/errorcode.py b/src/snowflake/connector/errorcode.py
index 1bc9138df2..0a0dbe0a45 100644
--- a/src/snowflake/connector/errorcode.py
+++ b/src/snowflake/connector/errorcode.py
@@ -27,8 +27,13 @@
ER_JWT_RETRY_EXPIRED = 251010
ER_CONNECTION_TIMEOUT = 251011
ER_RETRYABLE_CODE = 251012
-ER_INVALID_WIF_SETTINGS = 251013
-ER_WIF_CREDENTIALS_NOT_FOUND = 251014
+ER_NO_CLIENT_ID = 251013
+ER_OAUTH_STATE_CHANGED = 251014
+ER_OAUTH_CALLBACK_ERROR = 251015
+ER_OAUTH_SERVER_TIMEOUT = 251016
+ER_INVALID_WIF_SETTINGS = 251017
+ER_WIF_CREDENTIALS_NOT_FOUND = 251018
+ER_EXPERIMENTAL_AUTHENTICATION_NOT_SUPPORTED = 251019
# cursor
ER_FAILED_TO_REWRITE_MULTI_ROW_INSERT = 252001
diff --git a/src/snowflake/connector/file_lock.py b/src/snowflake/connector/file_lock.py
new file mode 100644
index 0000000000..dd3bc85ab9
--- /dev/null
+++ b/src/snowflake/connector/file_lock.py
@@ -0,0 +1,72 @@
+from __future__ import annotations
+
+import logging
+import time
+from os import stat_result
+from pathlib import Path
+from time import sleep
+
+MAX_RETRIES = 5
+INITIAL_BACKOFF_SECONDS = 0.025
+STALE_LOCK_AGE_SECONDS = 1
+
+
+class FileLockError(Exception):
+ pass
+
+
+class FileLock:
+ def __init__(self, path: Path) -> None:
+ self.path: Path = path
+ self.locked = False
+ self.logger = logging.getLogger(__name__)
+
+ def __enter__(self):
+ statinfo: stat_result | None = None
+ try:
+ statinfo = self.path.stat()
+ except FileNotFoundError:
+ pass
+ except OSError as e:
+ raise FileLockError(f"Failed to stat lock file {self.path} due to {e=}")
+
+ if statinfo and statinfo.st_ctime < time.time() - STALE_LOCK_AGE_SECONDS:
+ self.logger.debug("Removing stale file lock")
+ try:
+ self.path.rmdir()
+ except FileNotFoundError:
+ pass
+ except OSError as e:
+ raise FileLockError(
+ f"Failed to remove stale lock file {self.path} due to {e=}"
+ )
+
+ backoff_seconds = INITIAL_BACKOFF_SECONDS
+ for attempt in range(MAX_RETRIES):
+ self.logger.debug(
+ f"Trying to acquire file lock after {backoff_seconds} seconds in attempt number {attempt}.",
+ )
+ backoff_seconds = backoff_seconds * 2
+ try:
+ self.path.mkdir(mode=0o700)
+ self.locked = True
+ break
+ except FileExistsError:
+ sleep(backoff_seconds)
+ continue
+ except OSError as e:
+ raise FileLockError(
+ f"Failed to acquire lock file {self.path} due to {e=}"
+ )
+
+ if not self.locked:
+ raise FileLockError(
+ f"Failed to acquire file lock, after {MAX_RETRIES} attempts."
+ )
+
+ def __exit__(self, exc_type, exc_val, exc_tbc):
+ try:
+ self.path.rmdir()
+ except FileNotFoundError:
+ pass
+ self.locked = False
diff --git a/src/snowflake/connector/network.py b/src/snowflake/connector/network.py
index adffc4b6b9..acfe14c589 100644
--- a/src/snowflake/connector/network.py
+++ b/src/snowflake/connector/network.py
@@ -138,6 +138,7 @@
MASTER_TOKEN_INVALD_GS_CODE = "390115"
ID_TOKEN_INVALID_LOGIN_REQUEST_GS_CODE = "390195"
BAD_REQUEST_GS_CODE = "390400"
+OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE = "390318"
# other constants
CONTENT_TYPE_APPLICATION_JSON = "application/json"
@@ -181,6 +182,8 @@
EXTERNAL_BROWSER_AUTHENTICATOR = "EXTERNALBROWSER"
KEY_PAIR_AUTHENTICATOR = "SNOWFLAKE_JWT"
OAUTH_AUTHENTICATOR = "OAUTH"
+OAUTH_AUTHORIZATION_CODE = "OAUTH_AUTHORIZATION_CODE"
+OAUTH_CLIENT_CREDENTIALS = "OAUTH_CLIENT_CREDENTIALS"
ID_TOKEN_AUTHENTICATOR = "ID_TOKEN"
USR_PWD_MFA_AUTHENTICATOR = "USERNAME_PASSWORD_MFA"
PROGRAMMATIC_ACCESS_TOKEN = "PROGRAMMATIC_ACCESS_TOKEN"
diff --git a/src/snowflake/connector/token_cache.py b/src/snowflake/connector/token_cache.py
index 40a55f9e8b..a5ace1f6a8 100644
--- a/src/snowflake/connector/token_cache.py
+++ b/src/snowflake/connector/token_cache.py
@@ -1,22 +1,24 @@
from __future__ import annotations
import codecs
+import hashlib
import json
import logging
-import tempfile
-import time
+import os
+import stat
+import sys
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
-from os import getenv, makedirs, mkdir, path, remove, removedirs, rmdir
-from os.path import expanduser
-from threading import Lock
+from pathlib import Path
+from typing import Any, TypeVar
from .compat import IS_LINUX, IS_MACOS, IS_WINDOWS
-from .file_util import owner_rw_opener
+from .file_lock import FileLock, FileLockError
from .options import installed_keyring, keyring
-KEYRING_DRIVER_NAME = "SNOWFLAKE-PYTHON-DRIVER"
+logger = logging.getLogger(__name__)
+T = TypeVar("T")
class TokenType(Enum):
@@ -26,46 +28,65 @@ class TokenType(Enum):
OAUTH_REFRESH_TOKEN = "OAUTH_REFRESH_TOKEN"
+class _InvalidTokenKeyError(Exception):
+ pass
+
+
@dataclass
class TokenKey:
user: str
host: str
tokenType: TokenType
+ def string_key(self) -> str:
+ if len(self.host) == 0:
+ raise _InvalidTokenKeyError("Invalid key, host is empty")
+ if len(self.user) == 0:
+ raise _InvalidTokenKeyError("Invalid key, user is empty")
+ return f"{self.host.upper()}:{self.user.upper()}:{self.tokenType.value}"
-class TokenCache(ABC):
- def build_temporary_credential_name(
- self, host: str, user: str, cred_type: TokenType
- ) -> str:
- return "{host}:{user}:{driver}:{cred}".format(
- host=host.upper(),
- user=user.upper(),
- driver=KEYRING_DRIVER_NAME,
- cred=cred_type.value,
- )
+ def hash_key(self) -> str:
+ m = hashlib.sha256()
+ m.update(self.string_key().encode(encoding="utf-8"))
+ return m.hexdigest()
+
+
+def _warn(warning: str) -> None:
+ logger.warning(warning)
+ print("Warning: " + warning, file=sys.stderr)
+
+class TokenCache(ABC):
@staticmethod
def make() -> TokenCache:
if IS_MACOS or IS_WINDOWS:
if not installed_keyring:
- logging.getLogger(__name__).debug(
+ _warn(
"Dependency 'keyring' is not installed, cannot cache id token. You might experience "
- "multiple authentication pop ups while using ExternalBrowser Authenticator. To avoid "
- "this please install keyring module using the following command : pip install "
- "snowflake-connector-python[secure-local-storage]"
+ "multiple authentication pop ups while using ExternalBrowser/OAuth/MFA Authenticator. To avoid "
+ "this please install keyring module using the following command:\n"
+ " pip install snowflake-connector-python[secure-local-storage]"
)
return NoopTokenCache()
return KeyringTokenCache()
if IS_LINUX:
- return FileTokenCache()
+ cache = FileTokenCache.make()
+ if cache:
+ return cache
+ else:
+ _warn(
+ "Failed to initialize file based token cache. You might experience "
+ "multiple authentication pop ups while using ExternalBrowser/OAuth/MFA Authenticator."
+ )
+ return NoopTokenCache()
@abstractmethod
def store(self, key: TokenKey, token: str) -> None:
pass
@abstractmethod
- def retrieve(self, key: TokenKey) -> str:
+ def retrieve(self, key: TokenKey) -> str | None:
pass
@abstractmethod
@@ -73,196 +94,255 @@ def remove(self, key: TokenKey) -> None:
pass
+class _FileTokenCacheError(Exception):
+ pass
+
+
+class _OwnershipError(_FileTokenCacheError):
+ pass
+
+
+class _PermissionsTooWideError(_FileTokenCacheError):
+ pass
+
+
+class _CacheDirNotFoundError(_FileTokenCacheError):
+ pass
+
+
+class _InvalidCacheDirError(_FileTokenCacheError):
+ pass
+
+
+class _MalformedCacheFileError(_FileTokenCacheError):
+ pass
+
+
+class _CacheFileReadError(_FileTokenCacheError):
+ pass
+
+
+class _CacheFileWriteError(_FileTokenCacheError):
+ pass
+
+
class FileTokenCache(TokenCache):
+ @staticmethod
+ def make() -> FileTokenCache | None:
+ cache_dir = FileTokenCache.find_cache_dir()
+ if cache_dir is None:
+ logging.getLogger(__name__).debug(
+ "Failed to find suitable cache directory for token cache. File based token cache initialization failed."
+ )
+ return None
+ else:
+ return FileTokenCache(cache_dir)
- def __init__(self):
+ def __init__(self, cache_dir: Path) -> None:
self.logger = logging.getLogger(__name__)
- self.CACHE_ROOT_DIR = (
- getenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR")
- or expanduser("~")
- or tempfile.gettempdir()
- )
- self.CACHE_DIR = path.join(self.CACHE_ROOT_DIR, ".cache", "snowflake")
-
- if not path.exists(self.CACHE_DIR):
- try:
- makedirs(self.CACHE_DIR, mode=0o700)
- except Exception as ex:
- self.logger.debug(
- "cannot create a cache directory: [%s], err=[%s]",
- self.CACHE_DIR,
- ex,
- )
- self.CACHE_DIR = None
- self.logger.debug("cache directory: %s", self.CACHE_DIR)
-
- # temporary credential cache
- self.TEMPORARY_CREDENTIAL: dict[str, dict[str, str | None]] = {}
-
- self.TEMPORARY_CREDENTIAL_LOCK = Lock()
-
- # temporary credential cache file name
- self.TEMPORARY_CREDENTIAL_FILE = "temporary_credential.json"
- self.TEMPORARY_CREDENTIAL_FILE = (
- path.join(self.CACHE_DIR, self.TEMPORARY_CREDENTIAL_FILE)
- if self.CACHE_DIR
- else ""
- )
-
- # temporary credential cache lock directory name
- self.TEMPORARY_CREDENTIAL_FILE_LOCK = self.TEMPORARY_CREDENTIAL_FILE + ".lck"
-
- def flush_temporary_credentials(self) -> None:
- """Flush temporary credentials in memory into disk. Need to hold TEMPORARY_CREDENTIAL_LOCK."""
- for _ in range(10):
- if self.lock_temporary_credential_file():
- break
- time.sleep(1)
- else:
- self.logger.debug(
- "The lock file still persists after the maximum wait time."
- "Will ignore it and write temporary credential file: %s",
- self.TEMPORARY_CREDENTIAL_FILE,
- )
+ self.cache_dir: Path = cache_dir
+
+ def store(self, key: TokenKey, token: str) -> None:
try:
- with open(
- self.TEMPORARY_CREDENTIAL_FILE,
- "w",
- encoding="utf-8",
- errors="ignore",
- opener=owner_rw_opener,
- ) as f:
- json.dump(self.TEMPORARY_CREDENTIAL, f)
- except Exception as ex:
- self.logger.debug(
- "Failed to write a credential file: " "file=[%s], err=[%s]",
- self.TEMPORARY_CREDENTIAL_FILE,
- ex,
+ FileTokenCache.validate_cache_dir(self.cache_dir)
+ with FileLock(self.lock_file()):
+ cache = self._read_cache_file()
+ cache["tokens"][key.hash_key()] = token
+ self._write_cache_file(cache)
+ except _FileTokenCacheError as e:
+ self.logger.error(f"Failed to store token: {e=}")
+ except FileLockError as e:
+ self.logger.error(f"Unable to lock file lock: {e=}")
+ except _InvalidTokenKeyError as e:
+ self.logger.error(f"Failed to produce token key {e=}")
+
+ def retrieve(self, key: TokenKey) -> str | None:
+ try:
+ FileTokenCache.validate_cache_dir(self.cache_dir)
+ with FileLock(self.lock_file()):
+ cache = self._read_cache_file()
+ token = cache["tokens"].get(key.hash_key(), None)
+ if isinstance(token, str):
+ return token
+ else:
+ return None
+ except _FileTokenCacheError as e:
+ self.logger.error(f"Failed to retrieve token: {e=}")
+ return None
+ except FileLockError as e:
+ self.logger.error(f"Unable to lock file lock: {e=}")
+ return None
+ except _InvalidTokenKeyError as e:
+ self.logger.error(f"Failed to produce token key {e=}")
+ return None
+
+ def remove(self, key: TokenKey) -> None:
+ try:
+ FileTokenCache.validate_cache_dir(self.cache_dir)
+ with FileLock(self.lock_file()):
+ cache = self._read_cache_file()
+ cache["tokens"].pop(key.hash_key(), None)
+ self._write_cache_file(cache)
+ except _FileTokenCacheError as e:
+ self.logger.error(f"Failed to remove token: {e=}")
+ except FileLockError as e:
+ self.logger.error(f"Unable to lock file lock: {e=}")
+ except _InvalidTokenKeyError as e:
+ self.logger.error(f"Failed to produce token key {e=}")
+
+ def cache_file(self) -> Path:
+ return self.cache_dir / "credential_cache_v1.json"
+
+ def lock_file(self) -> Path:
+ return self.cache_dir / "credential_cache_v1.json.lck"
+
+ def _read_cache_file(self) -> dict[str, dict[str, Any]]:
+ fd = -1
+ json_data = {"tokens": {}}
+ try:
+ fd = os.open(self.cache_file(), os.O_RDONLY)
+ self._ensure_permissions(fd, 0o600)
+ size = os.lseek(fd, 0, os.SEEK_END)
+ os.lseek(fd, 0, os.SEEK_SET)
+ data = os.read(fd, size)
+ json_data = json.loads(codecs.decode(data, "utf-8"))
+ except FileNotFoundError:
+ self.logger.debug(f"{self.cache_file()} not found")
+ except json.decoder.JSONDecodeError as e:
+ self.logger.warning(
+ f"Failed to decode json read from cache file {self.cache_file()}: {e.__class__.__name__}"
+ )
+ except UnicodeError as e:
+ self.logger.warning(
+ f"Failed to decode utf-8 read from cache file {self.cache_file()}: {e.__class__.__name__}"
)
+ except OSError as e:
+ self.logger.warning(f"Failed to read cache file {self.cache_file()}: {e}")
finally:
- self.unlock_temporary_credential_file()
+ if fd > 0:
+ os.close(fd)
- def lock_temporary_credential_file(self) -> bool:
+ if "tokens" not in json_data or not isinstance(json_data["tokens"], dict):
+ json_data["tokens"] = {}
+
+ return json_data
+
+ def _write_cache_file(self, json_data: dict):
+ fd = -1
+ self.logger.debug(f"Writing cache file {self.cache_file()}")
try:
- mkdir(self.TEMPORARY_CREDENTIAL_FILE_LOCK)
- return True
- except OSError:
- self.logger.debug(
- "Temporary cache file lock already exists. Other "
- "process may be updating the temporary "
+ fd = os.open(
+ self.cache_file(), os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600
)
- return False
+ self._ensure_permissions(fd, 0o600)
+ os.write(fd, codecs.encode(json.dumps(json_data), "utf-8"))
+ return json_data
+ except OSError as e:
+ raise _CacheFileWriteError("Failed to write cache file", e)
+ finally:
+ if fd > 0:
+ os.close(fd)
- def unlock_temporary_credential_file(self) -> bool:
- try:
- rmdir(self.TEMPORARY_CREDENTIAL_FILE_LOCK)
- return True
- except OSError:
- self.logger.debug("Temporary cache file lock no longer exists.")
- return False
-
- def write_temporary_credential_file(
- self, host: str, cred_name: str, cred: str
- ) -> None:
- """Writes temporary credential file when OS is Linux."""
- if not self.CACHE_DIR:
- # no cache is enabled
- return
- with self.TEMPORARY_CREDENTIAL_LOCK:
- # update the cache
- host_data = self.TEMPORARY_CREDENTIAL.get(host.upper(), {})
- host_data[cred_name.upper()] = cred
- self.TEMPORARY_CREDENTIAL[host.upper()] = host_data
- self.flush_temporary_credentials()
-
- def read_temporary_credential_file(self):
- """Reads temporary credential file when OS is Linux."""
- if not self.CACHE_DIR:
- # no cache is enabled
- return
-
- with self.TEMPORARY_CREDENTIAL_LOCK:
- for _ in range(10):
- if self.lock_temporary_credential_file():
- break
- time.sleep(1)
- else:
- self.logger.debug(
- "The lock file still persists. Will ignore and "
- "write the temporary credential file: %s",
- self.TEMPORARY_CREDENTIAL_FILE,
+ @staticmethod
+ def find_cache_dir() -> Path | None:
+ def lookup_env_dir(env_var: str, subpath_segments: list[str]) -> Path | None:
+ env_val = os.getenv(env_var)
+ if env_val is None:
+ logger.debug(
+ f"Environment variable {env_var} not set. Skipping it in cache directory lookup."
)
+ return None
+
+ directory = Path(env_val)
+
+ if len(subpath_segments) > 0:
+ if not directory.exists():
+ logger.debug(
+ f"Path {str(directory)} does not exist. Skipping it in cache directory lookup."
+ )
+ return None
+
+ if not directory.is_dir():
+ logger.debug(
+ f"Path {str(directory)} is not a directory. Skipping it in cache directory lookup."
+ )
+ return None
+
+ for subpath in subpath_segments[:-1]:
+ directory = directory / subpath
+ directory.mkdir(exist_ok=True, mode=0o755)
+
+ directory = directory / subpath_segments[-1]
+ directory.mkdir(exist_ok=True, mode=0o700)
+
try:
- with codecs.open(
- self.TEMPORARY_CREDENTIAL_FILE,
- "r",
- encoding="utf-8",
- errors="ignore",
- ) as f:
- self.TEMPORARY_CREDENTIAL = json.load(f)
- return self.TEMPORARY_CREDENTIAL
- except Exception as ex:
- self.logger.debug(
- "Failed to read a credential file. The file may not"
- "exists: file=[%s], err=[%s]",
- self.TEMPORARY_CREDENTIAL_FILE,
- ex,
+ FileTokenCache.validate_cache_dir(directory)
+ return directory
+ except _FileTokenCacheError as e:
+ logger.debug(
+ f"Cache directory validation failed for {str(directory)} due to error '{e}'. Skipping it in cache directory lookup."
)
- finally:
- self.unlock_temporary_credential_file()
-
- def temporary_credential_file_delete_password(
- self, host: str, user: str, cred_type: TokenType
- ) -> None:
- """Remove credential from temporary credential file when OS is Linux."""
- if not self.CACHE_DIR:
- # no cache is enabled
- return
- with self.TEMPORARY_CREDENTIAL_LOCK:
- # update the cache
- host_data = self.TEMPORARY_CREDENTIAL.get(host.upper(), {})
- host_data.pop(
- self.build_temporary_credential_name(host, user, cred_type), None
- )
- if not host_data:
- self.TEMPORARY_CREDENTIAL.pop(host.upper(), None)
- else:
- self.TEMPORARY_CREDENTIAL[host.upper()] = host_data
- self.flush_temporary_credentials()
+ return None
+
+ lookup_functions = [
+ lambda: lookup_env_dir("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", []),
+ lambda: lookup_env_dir("XDG_CACHE_HOME", ["snowflake"]),
+ lambda: lookup_env_dir("HOME", [".cache", "snowflake"]),
+ ]
- def delete_temporary_credential_file(self) -> None:
- """Deletes temporary credential file and its lock file."""
+ for lf in lookup_functions:
+ cache_dir = lf()
+ if cache_dir:
+ return cache_dir
+
+ return None
+
+ @staticmethod
+ def validate_cache_dir(cache_dir: Path | None) -> None:
try:
- remove(self.TEMPORARY_CREDENTIAL_FILE)
- except Exception as ex:
- self.logger.debug(
- "Failed to delete a credential file: " "file=[%s], err=[%s]",
- self.TEMPORARY_CREDENTIAL_FILE,
- ex,
+ statinfo = cache_dir.stat()
+
+ if cache_dir is None:
+ raise _CacheDirNotFoundError("Cache dir was not found")
+
+ if not stat.S_ISDIR(statinfo.st_mode):
+ raise _InvalidCacheDirError(f"Cache dir {cache_dir} is not a directory")
+
+ permissions = stat.S_IMODE(statinfo.st_mode)
+ if permissions != 0o700:
+ raise _PermissionsTooWideError(
+ f"Cache dir {cache_dir} has incorrect permissions. {permissions:o} != 0700"
+ )
+
+ euid = os.geteuid()
+ if statinfo.st_uid != euid:
+ raise _OwnershipError(
+ f"Cache dir {cache_dir} has incorrect owner. {euid} != {statinfo.st_uid}"
+ )
+
+ except FileNotFoundError:
+ raise _CacheDirNotFoundError(
+ f"Cache dir {cache_dir} was not found. Failed to stat."
)
+
+ def _ensure_permissions(self, fd: int, permissions: int) -> None:
try:
- removedirs(self.TEMPORARY_CREDENTIAL_FILE_LOCK)
- except Exception as ex:
- self.logger.debug("Failed to delete credential lock file: err=[%s]", ex)
+ statinfo = os.fstat(fd)
+ actual_permissions = stat.S_IMODE(statinfo.st_mode)
- def store(self, key: TokenKey, token: str) -> None:
- return self.write_temporary_credential_file(
- key.host,
- self.build_temporary_credential_name(key.host, key.user, key.tokenType),
- token,
- )
-
- def retrieve(self, key: TokenKey) -> str:
- self.read_temporary_credential_file()
- token = self.TEMPORARY_CREDENTIAL.get(key.host.upper(), {}).get(
- self.build_temporary_credential_name(key.host, key.user, key.tokenType)
- )
- return token
+ if actual_permissions != permissions:
+ raise _PermissionsTooWideError(
+ f"Cache file {self.cache_file()} has incorrect permissions. {permissions:o} != {actual_permissions:o}"
+ )
- def remove(self, key: TokenKey) -> None:
- return self.temporary_credential_file_delete_password(
- key.host, key.user, key.tokenType
- )
+ euid = os.geteuid()
+ if statinfo.st_uid != euid:
+ raise _OwnershipError(
+ f"Cache file {self.cache_file()} has incorrect owner. {euid} != {statinfo.st_uid}"
+ )
+
+ except FileNotFoundError:
+ pass
class KeyringTokenCache(TokenCache):
@@ -272,17 +352,19 @@ def __init__(self) -> None:
def store(self, key: TokenKey, token: str) -> None:
try:
keyring.set_password(
- self.build_temporary_credential_name(key.host, key.user, key.tokenType),
+ key.string_key(),
key.user.upper(),
token,
)
+ except _InvalidTokenKeyError as e:
+ self.logger.error(f"Could not retrieve {key.tokenType} from keyring, {e=}")
except keyring.errors.KeyringError as ke:
self.logger.error("Could not store id_token to keyring, %s", str(ke))
- def retrieve(self, key: TokenKey) -> str:
+ def retrieve(self, key: TokenKey) -> str | None:
try:
return keyring.get_password(
- self.build_temporary_credential_name(key.host, key.user, key.tokenType),
+ key.string_key(),
key.user.upper(),
)
except keyring.errors.KeyringError as ke:
@@ -291,13 +373,17 @@ def retrieve(self, key: TokenKey) -> str:
key.tokenType.value, str(ke)
)
)
+ except _InvalidTokenKeyError as e:
+ self.logger.error(f"Could not retrieve {key.tokenType} from keyring, {e=}")
def remove(self, key: TokenKey) -> None:
try:
keyring.delete_password(
- self.build_temporary_credential_name(key.host, key.user, key.tokenType),
+ key.string_key(),
key.user.upper(),
)
+ except _InvalidTokenKeyError as e:
+ self.logger.error(f"Could not retrieve {key.tokenType} from keyring, {e=}")
except Exception as ex:
self.logger.error(
"Failed to delete credential in the keyring: err=[%s]", ex
diff --git a/src/snowflake/connector/vendored/requests/__init__.py b/src/snowflake/connector/vendored/requests/__init__.py
index 03c3f69d31..f3d57da6de 100644
--- a/src/snowflake/connector/vendored/requests/__init__.py
+++ b/src/snowflake/connector/vendored/requests/__init__.py
@@ -41,7 +41,6 @@
import warnings
from .. import urllib3
-
from .exceptions import RequestsDependencyWarning
try:
diff --git a/src/snowflake/connector/vendored/requests/adapters.py b/src/snowflake/connector/vendored/requests/adapters.py
index ab92194fb5..0c14ac32fd 100644
--- a/src/snowflake/connector/vendored/requests/adapters.py
+++ b/src/snowflake/connector/vendored/requests/adapters.py
@@ -25,7 +25,6 @@
from ..urllib3.util import Timeout as TimeoutSauce
from ..urllib3.util import parse_url
from ..urllib3.util.retry import Retry
-
from .auth import _basic_auth_str
from .compat import basestring, urlparse
from .cookies import extract_cookies_to_jar
diff --git a/src/snowflake/connector/vendored/requests/exceptions.py b/src/snowflake/connector/vendored/requests/exceptions.py
index 5efb9c99e1..2ee5d1cfcd 100644
--- a/src/snowflake/connector/vendored/requests/exceptions.py
+++ b/src/snowflake/connector/vendored/requests/exceptions.py
@@ -5,7 +5,6 @@
This module contains the set of Requests' exceptions.
"""
from ..urllib3.exceptions import HTTPError as BaseHTTPError
-
from .compat import JSONDecodeError as CompatJSONDecodeError
diff --git a/src/snowflake/connector/vendored/requests/help.py b/src/snowflake/connector/vendored/requests/help.py
index fc3d1daef5..85f091e3b0 100644
--- a/src/snowflake/connector/vendored/requests/help.py
+++ b/src/snowflake/connector/vendored/requests/help.py
@@ -6,8 +6,8 @@
import sys
import idna
-from .. import urllib3
+from .. import urllib3
from . import __version__ as requests_version
try:
diff --git a/src/snowflake/connector/vendored/requests/models.py b/src/snowflake/connector/vendored/requests/models.py
index bc73aabc52..e88d2a1904 100644
--- a/src/snowflake/connector/vendored/requests/models.py
+++ b/src/snowflake/connector/vendored/requests/models.py
@@ -23,7 +23,6 @@
from ..urllib3.fields import RequestField
from ..urllib3.filepost import encode_multipart_formdata
from ..urllib3.util import parse_url
-
from ._internal_utils import to_native_string, unicode_is_ascii
from .auth import HTTPBasicAuth
from .compat import (
diff --git a/src/snowflake/connector/vendored/requests/utils.py b/src/snowflake/connector/vendored/requests/utils.py
index 1da5e1c34a..e90f96cc81 100644
--- a/src/snowflake/connector/vendored/requests/utils.py
+++ b/src/snowflake/connector/vendored/requests/utils.py
@@ -20,7 +20,6 @@
from collections import OrderedDict
from ..urllib3.util import make_headers, parse_url
-
from . import certs
from .__version__ import __version__
diff --git a/test/auth/__init__.py b/test/auth/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/test/auth/authorization_parameters.py b/test/auth/authorization_parameters.py
new file mode 100644
index 0000000000..fe33ee8ea5
--- /dev/null
+++ b/test/auth/authorization_parameters.py
@@ -0,0 +1,218 @@
+import os
+import sys
+from typing import Union
+
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import serialization
+
+sys.path.append(os.path.abspath(os.path.dirname(__file__)))
+
+
+def get_oauth_token_parameters() -> dict[str, str]:
+ return {
+ "auth_url": _get_env_variable("SNOWFLAKE_AUTH_TEST_OAUTH_URL"),
+ "oauth_client_id": _get_env_variable("SNOWFLAKE_AUTH_TEST_OAUTH_CLIENT_ID"),
+ "oauth_client_secret": _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_OAUTH_CLIENT_SECRET"
+ ),
+ "okta_user": _get_env_variable("SNOWFLAKE_AUTH_TEST_OKTA_USER"),
+ "okta_pass": _get_env_variable("SNOWFLAKE_AUTH_TEST_OKTA_PASS"),
+ "role": (_get_env_variable("SNOWFLAKE_AUTH_TEST_ROLE")).lower(),
+ }
+
+
+def _get_env_variable(name: str, required: bool = True) -> str:
+ value = os.getenv(name)
+ if required and value is None:
+ raise OSError(f"Environment variable {name} is not set")
+ return value
+
+
+def get_okta_login_credentials() -> dict[str, str]:
+ return {
+ "login": _get_env_variable("SNOWFLAKE_AUTH_TEST_OKTA_USER"),
+ "password": _get_env_variable("SNOWFLAKE_AUTH_TEST_OKTA_PASS"),
+ }
+
+
+def get_soteria_okta_login_credentials() -> dict[str, str]:
+ return {
+ "login": _get_env_variable("SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID"),
+ "password": _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_USER_PASSWORD"
+ ),
+ }
+
+
+def get_rsa_private_key_for_key_pair(
+ key_path: str,
+) -> serialization.load_pem_private_key:
+ with open(_get_env_variable(key_path), "rb") as key_file:
+ private_key = serialization.load_pem_private_key(
+ key_file.read(), password=None, backend=default_backend()
+ )
+ return private_key
+
+
+def get_pat_setup_command_variables() -> dict[str, Union[str, bool, int]]:
+ return {
+ "snowflake_user": _get_env_variable("SNOWFLAKE_AUTH_TEST_SNOWFLAKE_USER"),
+ "role": _get_env_variable("SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_ROLE"),
+ }
+
+
+class AuthConnectionParameters:
+ def __init__(self):
+ self.basic_config = {
+ "host": _get_env_variable("SNOWFLAKE_AUTH_TEST_HOST"),
+ "port": _get_env_variable("SNOWFLAKE_AUTH_TEST_PORT"),
+ "role": _get_env_variable("SNOWFLAKE_AUTH_TEST_ROLE"),
+ "account": _get_env_variable("SNOWFLAKE_AUTH_TEST_ACCOUNT"),
+ "db": _get_env_variable("SNOWFLAKE_AUTH_TEST_DATABASE"),
+ "schema": _get_env_variable("SNOWFLAKE_AUTH_TEST_SCHEMA"),
+ "warehouse": _get_env_variable("SNOWFLAKE_AUTH_TEST_WAREHOUSE"),
+ "CLIENT_STORE_TEMPORARY_CREDENTIAL": False,
+ }
+
+ def get_base_connection_parameters(self) -> dict[str, Union[str, bool, int]]:
+ return self.basic_config
+
+ def get_key_pair_connection_parameters(self):
+ config = self.basic_config.copy()
+ config["authenticator"] = "KEY_PAIR_AUTHENTICATOR"
+ config["user"] = _get_env_variable("SNOWFLAKE_AUTH_TEST_BROWSER_USER")
+
+ return config
+
+ def get_external_browser_connection_parameters(self) -> dict[str, str]:
+ config = self.basic_config.copy()
+
+ config["user"] = _get_env_variable("SNOWFLAKE_AUTH_TEST_BROWSER_USER")
+ config["authenticator"] = "externalbrowser"
+
+ return config
+
+ def get_store_id_token_connection_parameters(self) -> dict[str, str]:
+ config = self.get_external_browser_connection_parameters()
+
+ config["CLIENT_STORE_TEMPORARY_CREDENTIAL"] = _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_STORE_ID_TOKEN_USER"
+ )
+
+ return config
+
+ def get_okta_connection_parameters(self) -> dict[str, str]:
+ config = self.basic_config.copy()
+
+ config["user"] = _get_env_variable("SNOWFLAKE_AUTH_TEST_BROWSER_USER")
+ config["password"] = _get_env_variable("SNOWFLAKE_AUTH_TEST_OKTA_PASS")
+ config["authenticator"] = _get_env_variable("SNOWFLAKE_AUTH_TEST_OAUTH_URL")
+
+ return config
+
+ def get_oauth_connection_parameters(self, token: str) -> dict[str, str]:
+ config = self.basic_config.copy()
+
+ config["user"] = _get_env_variable("SNOWFLAKE_AUTH_TEST_BROWSER_USER")
+ config["authenticator"] = "OAUTH"
+ config["token"] = token
+ return config
+
+ def get_oauth_external_authorization_code_connection_parameters(
+ self,
+ ) -> dict[str, Union[str, bool, int]]:
+ config = self.basic_config.copy()
+
+ config["authenticator"] = "OAUTH_AUTHORIZATION_CODE"
+ config["oauth_client_id"] = _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID"
+ )
+ config["oauth_client_secret"] = _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_SECRET"
+ )
+ config["oauth_redirect_uri"] = _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_REDIRECT_URI"
+ )
+ config["oauth_authorization_url"] = _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_AUTH_URL"
+ )
+ config["oauth_token_request_url"] = _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_TOKEN"
+ )
+ config["user"] = _get_env_variable("SNOWFLAKE_AUTH_TEST_BROWSER_USER")
+
+ return config
+
+ def get_snowflake_authorization_code_connection_parameters(
+ self,
+ ) -> dict[str, Union[str, bool, int]]:
+ config = self.basic_config.copy()
+
+ config["authenticator"] = "OAUTH_AUTHORIZATION_CODE"
+ config["oauth_client_id"] = _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_CLIENT_ID"
+ )
+ config["oauth_client_secret"] = _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_CLIENT_SECRET"
+ )
+ config["oauth_redirect_uri"] = _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_REDIRECT_URI"
+ )
+ config["role"] = _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_ROLE"
+ )
+ config["user"] = _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID"
+ )
+
+ return config
+
+ def get_snowflake_wildcard_external_authorization_code_connection_parameters(
+ self,
+ ) -> dict[str, Union[str, bool, int]]:
+ config = self.basic_config.copy()
+
+ config["authenticator"] = "OAUTH_AUTHORIZATION_CODE"
+ config["oauth_client_id"] = _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_WILDCARDS_CLIENT_ID"
+ )
+ config["oauth_client_secret"] = _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_WILDCARDS_CLIENT_SECRET"
+ )
+ config["role"] = _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_INTERNAL_OAUTH_SNOWFLAKE_ROLE"
+ )
+ config["user"] = _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID"
+ )
+
+ return config
+
+ def get_oauth_external_client_credential_connection_parameters(
+ self,
+ ) -> dict[str, str]:
+ config = self.basic_config.copy()
+
+ config["authenticator"] = "OAUTH_CLIENT_CREDENTIALS"
+ config["oauth_client_id"] = _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID"
+ )
+ config["oauth_client_secret"] = _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_SECRET"
+ )
+ config["oauth_token_request_url"] = _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_TOKEN"
+ )
+ config["user"] = _get_env_variable(
+ "SNOWFLAKE_AUTH_TEST_EXTERNAL_OAUTH_OKTA_CLIENT_ID"
+ )
+
+ return config
+
+ def get_pat_connection_parameters(self) -> dict[str, str]:
+ config = self.basic_config.copy()
+
+ config["authenticator"] = "PROGRAMMATIC_ACCESS_TOKEN"
+ config["user"] = _get_env_variable("SNOWFLAKE_AUTH_TEST_BROWSER_USER")
+
+ return config
diff --git a/test/auth/authorization_test_helper.py b/test/auth/authorization_test_helper.py
new file mode 100644
index 0000000000..0d3148be0d
--- /dev/null
+++ b/test/auth/authorization_test_helper.py
@@ -0,0 +1,144 @@
+import logging.config
+import os
+import subprocess
+import threading
+import webbrowser
+from enum import Enum
+from typing import Union
+
+import requests
+
+import snowflake.connector
+
+try:
+ from src.snowflake.connector.vendored.requests.auth import HTTPBasicAuth
+except ImportError:
+ pass
+
+logger = logging.getLogger(__name__)
+
+logger.setLevel(logging.INFO)
+
+
+class Scenario(Enum):
+ SUCCESS = "success"
+ FAIL = "fail"
+ TIMEOUT = "timeout"
+ EXTERNAL_OAUTH_OKTA_SUCCESS = "externalOauthOktaSuccess"
+ INTERNAL_OAUTH_SNOWFLAKE_SUCCESS = "internalOauthSnowflakeSuccess"
+
+
+def get_access_token_oauth(cfg):
+ auth_url = cfg["auth_url"]
+
+ data = {
+ "username": cfg["okta_user"],
+ "password": cfg["okta_pass"],
+ "grant_type": "password",
+ "scope": f"session:role:{cfg['role']}",
+ }
+
+ headers = {"Content-Type": "application/x-www-form-urlencoded;charset=UTF-8"}
+
+ auth_credentials = HTTPBasicAuth(cfg["oauth_client_id"], cfg["oauth_client_secret"])
+ try:
+ response = requests.post(
+ url=auth_url, data=data, headers=headers, auth=auth_credentials
+ )
+ response.raise_for_status()
+ return response.json()["access_token"]
+
+ except requests.exceptions.HTTPError as http_err:
+ logger.error(f"HTTP error occurred: {http_err}")
+ raise
+
+
+def clean_browser_processes():
+ if os.getenv("AUTHENTICATION_TESTS_ENV") == "docker":
+ try:
+ clean_browser_processes_path = "/externalbrowser/cleanBrowserProcesses.js"
+ process = subprocess.run(["node", clean_browser_processes_path], timeout=15)
+ logger.debug(f"OUTPUT: {process.stdout}, ERRORS: {process.stderr}")
+ except Exception as e:
+ raise RuntimeError(e)
+
+
+class AuthorizationTestHelper:
+ def __init__(self, configuration: dict):
+ self.auth_test_env = os.getenv("AUTHENTICATION_TESTS_ENV")
+ self.configuration = configuration
+ self.error_msg = ""
+
+ def update_config(self, configuration):
+ self.configuration = configuration
+
+ def connect_and_provide_credentials(
+ self, scenario: Scenario, login: str, password: str
+ ):
+ try:
+ connect = threading.Thread(target=self.connect_and_execute_simple_query)
+ connect.start()
+ if self.auth_test_env == "docker":
+ browser = threading.Thread(
+ target=self._provide_credentials, args=(scenario, login, password)
+ )
+ browser.start()
+ browser.join()
+ connect.join()
+
+ except Exception as e:
+ self.error_msg = e
+ logger.error(e)
+
+ def get_error_msg(self) -> str:
+ return str(self.error_msg)
+
+ def connect_and_execute_simple_query(self):
+ try:
+ logger.info("Trying to connect to Snowflake")
+ with snowflake.connector.connect(**self.configuration) as con:
+ result = con.cursor().execute("select 1;")
+ logger.debug(result.fetchall())
+ logger.info("Successfully connected to Snowflake")
+ return True
+ except Exception as e:
+ self.error_msg = e
+ logger.error(e)
+ return False
+
+ def _provide_credentials(self, scenario: Scenario, login: str, password: str):
+ try:
+ webbrowser.register("xdg-open", None, webbrowser.GenericBrowser("xdg-open"))
+ provide_browser_credentials_path = (
+ "/externalbrowser/provideBrowserCredentials.js"
+ )
+ process = subprocess.run(
+ [
+ "node",
+ provide_browser_credentials_path,
+ scenario.value,
+ login,
+ password,
+ ],
+ timeout=15,
+ )
+ logger.debug(f"OUTPUT: {process.stdout}, ERRORS: {process.stderr}")
+ except Exception as e:
+ self.error_msg = e
+ raise RuntimeError(e)
+
+ def connect_using_okta_connection_and_execute_custom_command(
+ self, command: str, return_token: bool = False
+ ) -> Union[bool, str]:
+ try:
+ logger.info("Setup PAT")
+ with snowflake.connector.connect(**self.configuration) as con:
+ result = con.cursor().execute(command)
+ token = result.fetchall()[0][1]
+ except Exception as e:
+ self.error_msg = e
+ logger.error(e)
+ return False
+ if return_token:
+ return token
+ return False
diff --git a/test/auth/test_external_browser.py b/test/auth/test_external_browser.py
new file mode 100644
index 0000000000..0658bb2c7c
--- /dev/null
+++ b/test/auth/test_external_browser.py
@@ -0,0 +1,90 @@
+import logging
+from test.auth.authorization_parameters import (
+ AuthConnectionParameters,
+ get_okta_login_credentials,
+)
+
+import pytest
+from authorization_test_helper import (
+ AuthorizationTestHelper,
+ Scenario,
+ clean_browser_processes,
+)
+
+
+@pytest.fixture(autouse=True)
+def setup_and_teardown():
+ logging.info("Cleanup before test")
+ clean_browser_processes()
+
+ yield
+
+ logging.info("Teardown: Performing specific actions after the test")
+ clean_browser_processes()
+
+
+@pytest.mark.auth
+def test_external_browser_successful():
+ connection_parameters = (
+ AuthConnectionParameters().get_external_browser_connection_parameters()
+ )
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ browser_login, browser_password = get_okta_login_credentials().values()
+ test_helper.connect_and_provide_credentials(
+ Scenario.SUCCESS, browser_login, browser_password
+ )
+ assert test_helper.error_msg == "", "Error message should be empty"
+
+
+@pytest.mark.auth
+def test_external_browser_mismatched_user():
+ connection_parameters = (
+ AuthConnectionParameters().get_external_browser_connection_parameters()
+ )
+ connection_parameters["user"] = "differentUsername"
+ browser_login, browser_password = get_okta_login_credentials().values()
+
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ test_helper.connect_and_provide_credentials(
+ Scenario.SUCCESS, browser_login, browser_password
+ )
+ assert (
+ "The user you were trying to authenticate as differs from the user"
+ in test_helper.get_error_msg()
+ )
+
+
+@pytest.mark.auth
+@pytest.mark.skip(reason="SNOW-2007651 Adding custom browser timeout")
+def test_external_browser_wrong_credentials():
+ connection_parameters = (
+ AuthConnectionParameters().get_external_browser_connection_parameters()
+ )
+ browser_login, browser_password = "invalidUser", "invalidPassword"
+ connection_parameters["external_browser_timeout"] = 10
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ test_helper.connect_and_provide_credentials(
+ Scenario.FAIL, browser_login, browser_password
+ )
+
+ assert (
+ "Unable to receive the OAuth message within a given timeout"
+ in test_helper.get_error_msg()
+ )
+
+
+@pytest.mark.auth
+@pytest.mark.skip(reason="SNOW-2007651 Adding custom browser timeout")
+def test_external_browser_timeout():
+ connection_parameters = (
+ AuthConnectionParameters().get_external_browser_connection_parameters()
+ )
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ connection_parameters["external_browser_timeout"] = 1
+ assert (
+ not test_helper.connect_and_execute_simple_query()
+ ), "Connection should not be established"
+ assert (
+ "Unable to receive the OAuth message within a given timeout"
+ in test_helper.get_error_msg()
+ )
diff --git a/test/auth/test_key_pair.py b/test/auth/test_key_pair.py
new file mode 100644
index 0000000000..21b46c5738
--- /dev/null
+++ b/test/auth/test_key_pair.py
@@ -0,0 +1,39 @@
+from test.auth.authorization_parameters import (
+ AuthConnectionParameters,
+ get_rsa_private_key_for_key_pair,
+)
+from test.auth.authorization_test_helper import AuthorizationTestHelper
+
+import pytest
+
+
+@pytest.mark.auth
+def test_key_pair_successful():
+ connection_parameters = (
+ AuthConnectionParameters().get_key_pair_connection_parameters()
+ )
+ connection_parameters["private_key"] = get_rsa_private_key_for_key_pair(
+ "SNOWFLAKE_AUTH_TEST_PRIVATE_KEY_PATH"
+ )
+
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ assert (
+ test_helper.connect_and_execute_simple_query()
+ ), "Failed to connect with Snowflake"
+ assert test_helper.error_msg == "", "Error message should be empty"
+
+
+@pytest.mark.auth
+def test_key_pair_invalid_key():
+ connection_parameters = (
+ AuthConnectionParameters().get_key_pair_connection_parameters()
+ )
+ connection_parameters["private_key"] = get_rsa_private_key_for_key_pair(
+ "SNOWFLAKE_AUTH_TEST_INVALID_PRIVATE_KEY_PATH"
+ )
+
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ assert (
+ not test_helper.connect_and_execute_simple_query()
+ ), "Connection to Snowflake should not be established"
+ assert "JWT token is invalid" in test_helper.get_error_msg()
diff --git a/test/auth/test_oauth.py b/test/auth/test_oauth.py
new file mode 100644
index 0000000000..de977fc92d
--- /dev/null
+++ b/test/auth/test_oauth.py
@@ -0,0 +1,59 @@
+from test.auth.authorization_parameters import (
+ AuthConnectionParameters,
+ get_oauth_token_parameters,
+)
+from test.auth.authorization_test_helper import (
+ AuthorizationTestHelper,
+ get_access_token_oauth,
+)
+
+import pytest
+
+
+@pytest.mark.auth
+def test_oauth_successful():
+ token = get_oauth_token()
+ connection_parameters = AuthConnectionParameters().get_oauth_connection_parameters(
+ token
+ )
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ assert (
+ test_helper.connect_and_execute_simple_query()
+ ), "Failed to connect with OAuth token"
+ assert test_helper.error_msg == "", "Error message should be empty"
+
+
+@pytest.mark.auth
+def test_oauth_mismatched_user():
+ token = get_oauth_token()
+ connection_parameters = AuthConnectionParameters().get_oauth_connection_parameters(
+ token
+ )
+ connection_parameters["user"] = "differentUsername"
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ assert (
+ test_helper.connect_and_execute_simple_query() is False
+ ), "Connection should not be established"
+ assert (
+ "The user you were trying to authenticate as differs from the user"
+ in test_helper.get_error_msg()
+ )
+
+
+@pytest.mark.auth
+def test_oauth_invalid_token():
+ token = "invalidToken"
+ connection_parameters = AuthConnectionParameters().get_oauth_connection_parameters(
+ token
+ )
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ assert (
+ test_helper.connect_and_execute_simple_query() is False
+ ), "Connection should not be established"
+ assert "Invalid OAuth access token" in test_helper.get_error_msg()
+
+
+def get_oauth_token():
+ oauth_config = get_oauth_token_parameters()
+ token = get_access_token_oauth(oauth_config)
+ return token
diff --git a/test/auth/test_okta.py b/test/auth/test_okta.py
new file mode 100644
index 0000000000..adfffd31df
--- /dev/null
+++ b/test/auth/test_okta.py
@@ -0,0 +1,58 @@
+from test.auth.authorization_parameters import AuthConnectionParameters
+from test.auth.authorization_test_helper import AuthorizationTestHelper
+
+import pytest
+
+
+@pytest.mark.auth
+def test_okta_successful():
+ connection_parameters = AuthConnectionParameters().get_okta_connection_parameters()
+ test_helper = AuthorizationTestHelper(connection_parameters)
+
+ assert (
+ test_helper.connect_and_execute_simple_query()
+ ), "Failed to connect with Snowflake"
+ assert test_helper.error_msg == "", "Error message should be empty"
+
+
+@pytest.mark.auth
+def test_okta_with_wrong_okta_username():
+ connection_parameters = AuthConnectionParameters().get_okta_connection_parameters()
+ connection_parameters["user"] = "differentUsername"
+
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ assert (
+ not test_helper.connect_and_execute_simple_query()
+ ), "Connection to Snowflake should not be established"
+ assert "Failed to get authentication by OKTA" in test_helper.get_error_msg()
+
+
+@pytest.mark.auth
+def test_okta_wrong_url():
+ connection_parameters = AuthConnectionParameters().get_okta_connection_parameters()
+
+ connection_parameters["authenticator"] = "https://invalid.okta.com/"
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ assert (
+ not test_helper.connect_and_execute_simple_query()
+ ), "Connection to Snowflake should not be established"
+ assert (
+ "The specified authenticator is not accepted by your Snowflake account configuration"
+ in test_helper.get_error_msg()
+ )
+
+
+@pytest.mark.auth
+@pytest.mark.skip(reason="SNOW-1852279 implement error handling for invalid URL")
+def test_okta_wrong_url_2():
+ connection_parameters = AuthConnectionParameters().get_okta_connection_parameters()
+
+ connection_parameters["authenticator"] = "https://invalid.abc.com/"
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ assert (
+ not test_helper.connect_and_execute_simple_query()
+ ), "Connection to Snowflake should not be established"
+ assert (
+ "The specified authenticator is not accepted by your Snowflake account configuration"
+ in test_helper.get_error_msg()
+ )
diff --git a/test/auth/test_okta_authorization_code.py b/test/auth/test_okta_authorization_code.py
new file mode 100644
index 0000000000..db4f16dd34
--- /dev/null
+++ b/test/auth/test_okta_authorization_code.py
@@ -0,0 +1,96 @@
+import logging
+from test.auth.authorization_parameters import (
+ AuthConnectionParameters,
+ get_okta_login_credentials,
+)
+
+import pytest
+from authorization_test_helper import (
+ AuthorizationTestHelper,
+ Scenario,
+ clean_browser_processes,
+)
+
+
+@pytest.fixture(autouse=True)
+def setup_and_teardown():
+ logging.info("Cleanup before test")
+ clean_browser_processes()
+
+ yield
+
+ logging.info("Teardown: Performing specific actions after the test")
+ clean_browser_processes()
+
+
+@pytest.mark.auth
+def test_okta_authorization_code_successful():
+ connection_parameters = (
+ AuthConnectionParameters().get_oauth_external_authorization_code_connection_parameters()
+ )
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ browser_login, browser_password = get_okta_login_credentials().values()
+ test_helper.connect_and_provide_credentials(
+ Scenario.SUCCESS, browser_login, browser_password
+ )
+
+ assert test_helper.error_msg == "", "Error message should be empty"
+
+
+@pytest.mark.auth
+def test_okta_authorization_code_mismatched_user():
+ connection_parameters = (
+ AuthConnectionParameters().get_oauth_external_authorization_code_connection_parameters()
+ )
+ connection_parameters["user"] = "differentUsername"
+ browser_login, browser_password = get_okta_login_credentials().values()
+
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ test_helper.connect_and_provide_credentials(
+ Scenario.SUCCESS, browser_login, browser_password
+ )
+
+ assert (
+ "The user you were trying to authenticate as differs from the user"
+ in test_helper.get_error_msg()
+ )
+
+
+@pytest.mark.auth
+def test_okta_authorization_code_timeout():
+ connection_parameters = (
+ AuthConnectionParameters().get_oauth_external_authorization_code_connection_parameters()
+ )
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ connection_parameters["external_browser_timeout"] = 1
+
+ assert (
+ test_helper.connect_and_execute_simple_query() is False
+ ), "Connection should not be established"
+ assert (
+ "Unable to receive the OAuth message within a given timeout"
+ in test_helper.get_error_msg()
+ )
+
+
+@pytest.mark.auth
+def test_okta_authorization_code_with_token_cache():
+ connection_parameters = (
+ AuthConnectionParameters().get_oauth_external_authorization_code_connection_parameters()
+ )
+ connection_parameters["client_store_temporary_credential"] = True
+ connection_parameters["external_browser_timeout"] = 10
+
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ browser_login, browser_password = get_okta_login_credentials().values()
+
+ test_helper.connect_and_provide_credentials(
+ Scenario.SUCCESS, browser_login, browser_password
+ )
+
+ clean_browser_processes()
+
+ assert (
+ test_helper.connect_and_execute_simple_query() is True
+ ), "Connection should be established"
+ assert test_helper.error_msg == "", "Error message should be empty"
diff --git a/test/auth/test_okta_client_credentials.py b/test/auth/test_okta_client_credentials.py
new file mode 100644
index 0000000000..063e22d786
--- /dev/null
+++ b/test/auth/test_okta_client_credentials.py
@@ -0,0 +1,57 @@
+import logging
+from test.auth.authorization_parameters import AuthConnectionParameters
+
+import pytest
+from authorization_test_helper import AuthorizationTestHelper, clean_browser_processes
+
+
+@pytest.fixture(autouse=True)
+def setup_and_teardown():
+ logging.info("Cleanup before test")
+ clean_browser_processes()
+
+ yield
+
+ logging.info("Teardown: Performing specific actions after the test")
+ clean_browser_processes()
+
+
+@pytest.mark.auth
+def test_okta_client_credentials_successful():
+ connection_parameters = (
+ AuthConnectionParameters().get_oauth_external_client_credential_connection_parameters()
+ )
+ test_helper = AuthorizationTestHelper(connection_parameters)
+
+ test_helper.connect_and_execute_simple_query()
+
+ assert test_helper.error_msg == "", "Error message should be empty"
+
+
+@pytest.mark.auth
+def test_okta_client_credentials_mismatched_user():
+ connection_parameters = (
+ AuthConnectionParameters().get_oauth_external_client_credential_connection_parameters()
+ )
+ connection_parameters["user"] = "differentUsername"
+ test_helper = AuthorizationTestHelper(connection_parameters)
+
+ test_helper.connect_and_execute_simple_query()
+
+ assert (
+ "The user you were trying to authenticate as differs from the user"
+ in test_helper.get_error_msg()
+ )
+
+
+@pytest.mark.auth
+def test_okta_client_credentials_unauthorized():
+ connection_parameters = (
+ AuthConnectionParameters().get_oauth_external_client_credential_connection_parameters()
+ )
+ connection_parameters["oauth_client_id"] = "invalidClientID"
+ test_helper = AuthorizationTestHelper(connection_parameters)
+
+ test_helper.connect_and_execute_simple_query()
+
+ assert "Invalid HTTP request from web browser" in test_helper.get_error_msg()
diff --git a/test/auth/test_pat.py b/test/auth/test_pat.py
new file mode 100644
index 0000000000..5db79967f2
--- /dev/null
+++ b/test/auth/test_pat.py
@@ -0,0 +1,82 @@
+from datetime import datetime
+from test.auth.authorization_parameters import (
+ AuthConnectionParameters,
+ get_pat_setup_command_variables,
+)
+from typing import Union
+
+import pytest
+from authorization_test_helper import AuthorizationTestHelper
+
+
+@pytest.mark.auth
+def test_authenticate_with_pat_successful() -> None:
+ pat_command_variables = get_pat_setup_command_variables()
+ connection_parameters = AuthConnectionParameters().get_pat_connection_parameters()
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ try:
+ pat_command_variables = get_pat_token(pat_command_variables)
+ connection_parameters["token"] = pat_command_variables["token"]
+ test_helper.connect_and_execute_simple_query()
+ finally:
+ remove_pat_token(pat_command_variables)
+ assert test_helper.get_error_msg() == "", "Error message should be empty"
+
+
+@pytest.mark.auth
+def test_authenticate_with_pat_mismatched_user() -> None:
+ pat_command_variables = get_pat_setup_command_variables()
+ connection_parameters = AuthConnectionParameters().get_pat_connection_parameters()
+ connection_parameters["user"] = "differentUsername"
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ try:
+ pat_command_variables = get_pat_token(pat_command_variables)
+ connection_parameters["token"] = pat_command_variables["token"]
+ test_helper.connect_and_execute_simple_query()
+ finally:
+ remove_pat_token(pat_command_variables)
+
+ assert "Programmatic access token is invalid" in test_helper.get_error_msg()
+
+
+@pytest.mark.auth
+def test_authenticate_with_pat_invalid_token() -> None:
+ connection_parameters = AuthConnectionParameters().get_pat_connection_parameters()
+ connection_parameters["token"] = "invalidToken"
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ test_helper.connect_and_execute_simple_query()
+ assert "Programmatic access token is invalid" in test_helper.get_error_msg()
+
+
+def get_pat_token(pat_command_variables) -> dict[str, Union[str, bool]]:
+ okta_connection_parameters = (
+ AuthConnectionParameters().get_okta_connection_parameters()
+ )
+
+ pat_name = "PAT_PYTHON_" + generate_random_suffix()
+ pat_command_variables["pat_name"] = pat_name
+ command = (
+ f"alter user {pat_command_variables['snowflake_user']} add programmatic access token {pat_name} "
+ f"ROLE_RESTRICTION = '{pat_command_variables['role']}' DAYS_TO_EXPIRY=1;"
+ )
+ test_helper = AuthorizationTestHelper(okta_connection_parameters)
+ pat_command_variables["token"] = (
+ test_helper.connect_using_okta_connection_and_execute_custom_command(
+ command, True
+ )
+ )
+ return pat_command_variables
+
+
+def remove_pat_token(pat_command_variables: dict[str, Union[str, bool]]) -> None:
+ okta_connection_parameters = (
+ AuthConnectionParameters().get_okta_connection_parameters()
+ )
+
+ command = f"alter user {pat_command_variables['snowflake_user']} remove programmatic access token {pat_command_variables['pat_name']};"
+ test_helper = AuthorizationTestHelper(okta_connection_parameters)
+ test_helper.connect_using_okta_connection_and_execute_custom_command(command)
+
+
+def generate_random_suffix() -> str:
+ return datetime.now().strftime("%Y%m%d%H%M%S%f")
diff --git a/test/auth/test_snowflake_authorization_code.py b/test/auth/test_snowflake_authorization_code.py
new file mode 100644
index 0000000000..9116c9008e
--- /dev/null
+++ b/test/auth/test_snowflake_authorization_code.py
@@ -0,0 +1,122 @@
+import logging
+from test.auth.authorization_parameters import (
+ AuthConnectionParameters,
+ get_soteria_okta_login_credentials,
+)
+
+import pytest
+from authorization_test_helper import (
+ AuthorizationTestHelper,
+ Scenario,
+ clean_browser_processes,
+)
+
+
+@pytest.fixture(autouse=True)
+def setup_and_teardown():
+ logging.info("Cleanup before test")
+ clean_browser_processes()
+
+ yield
+
+ logging.info("Teardown: Performing specific actions after the test")
+ clean_browser_processes()
+
+
+@pytest.mark.auth
+def test_snowflake_authorization_code_successful():
+ connection_parameters = (
+ AuthConnectionParameters().get_snowflake_authorization_code_connection_parameters()
+ )
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ browser_login, browser_password = get_soteria_okta_login_credentials().values()
+
+ test_helper.connect_and_provide_credentials(
+ Scenario.INTERNAL_OAUTH_SNOWFLAKE_SUCCESS, browser_login, browser_password
+ )
+
+ assert test_helper.error_msg == "", "Error message should be empty"
+
+
+@pytest.mark.auth
+def test_snowflake_authorization_code_mismatched_user():
+ connection_parameters = (
+ AuthConnectionParameters().get_snowflake_authorization_code_connection_parameters()
+ )
+ connection_parameters["user"] = "differentUsername"
+ browser_login, browser_password = get_soteria_okta_login_credentials().values()
+ test_helper = AuthorizationTestHelper(connection_parameters)
+
+ test_helper.connect_and_provide_credentials(
+ Scenario.INTERNAL_OAUTH_SNOWFLAKE_SUCCESS, browser_login, browser_password
+ )
+
+ assert (
+ "The user you were trying to authenticate as differs from the user"
+ in test_helper.get_error_msg()
+ )
+
+
+@pytest.mark.auth
+def test_snowflake_authorization_code_timeout():
+ connection_parameters = (
+ AuthConnectionParameters().get_snowflake_authorization_code_connection_parameters()
+ )
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ connection_parameters["external_browser_timeout"] = 1
+
+ assert (
+ test_helper.connect_and_execute_simple_query() is False
+ ), "Connection should not be established"
+ assert (
+ "Unable to receive the OAuth message within a given timeout"
+ in test_helper.get_error_msg()
+ )
+
+
+@pytest.mark.auth
+def test_snowflake_authorization_code_with_token_cache():
+ connection_parameters = (
+ AuthConnectionParameters().get_snowflake_authorization_code_connection_parameters()
+ )
+ connection_parameters["external_browser_timeout"] = 15
+ connection_parameters["client_store_temporary_credential"] = True
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ browser_login, browser_password = get_soteria_okta_login_credentials().values()
+
+ test_helper.connect_and_provide_credentials(
+ Scenario.INTERNAL_OAUTH_SNOWFLAKE_SUCCESS, browser_login, browser_password
+ )
+
+ clean_browser_processes()
+
+ assert (
+ test_helper.connect_and_execute_simple_query() is True
+ ), "Connection should be established"
+ assert test_helper.get_error_msg() == "", "Error message should be empty"
+
+
+@pytest.mark.auth
+def test_snowflake_authorization_code_without_token_cache():
+ connection_parameters = (
+ AuthConnectionParameters().get_snowflake_authorization_code_connection_parameters()
+ )
+ connection_parameters["client_store_temporary_credential"] = False
+ connection_parameters["external_browser_timeout"] = 15
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ browser_login, browser_password = get_soteria_okta_login_credentials().values()
+
+ test_helper.connect_and_provide_credentials(
+ Scenario.INTERNAL_OAUTH_SNOWFLAKE_SUCCESS, browser_login, browser_password
+ )
+
+ clean_browser_processes()
+
+ assert (
+ test_helper.connect_and_execute_simple_query() is False
+ ), "Connection should be established"
+
+ assert (
+ "Unable to receive the OAuth message within a given timeout"
+ in test_helper.get_error_msg()
+ ), "Error message should contain timeout"
diff --git a/test/auth/test_snowflake_authorization_code_wildcards.py b/test/auth/test_snowflake_authorization_code_wildcards.py
new file mode 100644
index 0000000000..f38db07bdf
--- /dev/null
+++ b/test/auth/test_snowflake_authorization_code_wildcards.py
@@ -0,0 +1,121 @@
+import logging
+from test.auth.authorization_parameters import (
+ AuthConnectionParameters,
+ get_soteria_okta_login_credentials,
+)
+
+import pytest
+from authorization_test_helper import (
+ AuthorizationTestHelper,
+ Scenario,
+ clean_browser_processes,
+)
+
+
+@pytest.fixture(autouse=True)
+def setup_and_teardown():
+ logging.info("Cleanup before test")
+ clean_browser_processes()
+
+ yield
+
+ logging.info("Teardown: Performing specific actions after the test")
+ clean_browser_processes()
+
+
+@pytest.mark.auth
+def test_snowflake_authorization_code_wildcards_successful():
+ connection_parameters = (
+ AuthConnectionParameters().get_snowflake_wildcard_external_authorization_code_connection_parameters()
+ )
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ browser_login, browser_password = get_soteria_okta_login_credentials().values()
+
+ test_helper.connect_and_provide_credentials(
+ Scenario.INTERNAL_OAUTH_SNOWFLAKE_SUCCESS, browser_login, browser_password
+ )
+
+ assert test_helper.error_msg == "", "Error message should be empty"
+
+
+@pytest.mark.auth
+def test_snowflake_authorization_code_wildcards_mismatched_user():
+ connection_parameters = (
+ AuthConnectionParameters().get_snowflake_wildcard_external_authorization_code_connection_parameters()
+ )
+ connection_parameters["user"] = "differentUsername"
+ browser_login, browser_password = get_soteria_okta_login_credentials().values()
+ test_helper = AuthorizationTestHelper(connection_parameters)
+
+ test_helper.connect_and_provide_credentials(
+ Scenario.INTERNAL_OAUTH_SNOWFLAKE_SUCCESS, browser_login, browser_password
+ )
+
+ assert (
+ "The user you were trying to authenticate as differs from the user"
+ in test_helper.get_error_msg()
+ )
+
+
+@pytest.mark.auth
+def test_snowflake_authorization_code_wildcards_timeout():
+ connection_parameters = (
+ AuthConnectionParameters().get_snowflake_wildcard_external_authorization_code_connection_parameters()
+ )
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ connection_parameters["external_browser_timeout"] = 1
+
+ assert (
+ test_helper.connect_and_execute_simple_query() is False
+ ), "Connection should not be established"
+ assert (
+ "Unable to receive the OAuth message within a given timeout"
+ in test_helper.get_error_msg()
+ )
+
+
+@pytest.mark.auth
+def test_snowflake_authorization_code_wildcards_with_token_cache():
+ connection_parameters = (
+ AuthConnectionParameters().get_snowflake_wildcard_external_authorization_code_connection_parameters()
+ )
+ connection_parameters["external_browser_timeout"] = 15
+ connection_parameters["client_store_temporary_credential"] = True
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ browser_login, browser_password = get_soteria_okta_login_credentials().values()
+
+ test_helper.connect_and_provide_credentials(
+ Scenario.INTERNAL_OAUTH_SNOWFLAKE_SUCCESS, browser_login, browser_password
+ )
+
+ clean_browser_processes()
+
+ assert (
+ test_helper.connect_and_execute_simple_query() is True
+ ), "Connection should be established"
+ assert test_helper.get_error_msg() == "", "Error message should be empty"
+
+
+@pytest.mark.auth
+def test_snowflake_authorization_code_wildcards_without_token_cache():
+ connection_parameters = (
+ AuthConnectionParameters().get_snowflake_wildcard_external_authorization_code_connection_parameters()
+ )
+ connection_parameters["client_store_temporary_credential"] = False
+ connection_parameters["external_browser_timeout"] = 15
+ test_helper = AuthorizationTestHelper(connection_parameters)
+ browser_login, browser_password = get_soteria_okta_login_credentials().values()
+
+ test_helper.connect_and_provide_credentials(
+ Scenario.INTERNAL_OAUTH_SNOWFLAKE_SUCCESS, browser_login, browser_password
+ )
+
+ clean_browser_processes()
+
+ assert (
+ test_helper.connect_and_execute_simple_query() is False
+ ), "Connection should be established"
+ assert (
+ "Unable to receive the OAuth message within a given timeout"
+ in test_helper.get_error_msg()
+ ), "Error message should contain timeout"
diff --git a/test/conftest.py b/test/conftest.py
index 9f0fcbc7c8..5cdc714216 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -144,3 +144,7 @@ def pytest_runtest_setup(item) -> None:
pytest.skip("cannot run this test on public Snowflake deployment")
elif INTERNAL_SKIP_TAGS.intersection(test_tags) and not running_on_public_ci():
pytest.skip("cannot run this test on private Snowflake deployment")
+
+ if "auth" in test_tags:
+ if os.getenv("RUN_AUTH_TESTS") != "true":
+ pytest.skip("Skipping auth test in current environment")
diff --git a/test/data/wiremock/mappings/auth/oauth/authorization_code/browser_timeout_authorization_error.json b/test/data/wiremock/mappings/auth/oauth/authorization_code/browser_timeout_authorization_error.json
new file mode 100644
index 0000000000..b14718c2ba
--- /dev/null
+++ b/test/data/wiremock/mappings/auth/oauth/authorization_code/browser_timeout_authorization_error.json
@@ -0,0 +1,15 @@
+{
+ "mappings": [
+ {
+ "scenarioName": "Browser Authorization timeout",
+ "request": {
+ "urlPathPattern": "/oauth/authorize.*",
+ "method": "GET"
+ },
+ "response": {
+ "status": 200,
+ "fixedDelayMilliseconds": 5000
+ }
+ }
+ ]
+}
diff --git a/test/data/wiremock/mappings/auth/oauth/authorization_code/external_idp_custom_urls.json b/test/data/wiremock/mappings/auth/oauth/authorization_code/external_idp_custom_urls.json
new file mode 100644
index 0000000000..327c779c70
--- /dev/null
+++ b/test/data/wiremock/mappings/auth/oauth/authorization_code/external_idp_custom_urls.json
@@ -0,0 +1,80 @@
+{
+ "mappings": [
+ {
+ "scenarioName": "Custom urls OAuth authorization code flow",
+ "requiredScenarioState": "Started",
+ "newScenarioState": "Authorized",
+ "request": {
+ "urlPathPattern": "/authorization",
+ "method": "GET",
+ "queryParameters": {
+ "response_type": {
+ "equalTo": "code"
+ },
+ "scope": {
+ "equalTo": "session:role:ANALYST"
+ },
+ "code_challenge_method": {
+ "equalTo": "S256"
+ },
+ "redirect_uri": {
+ "equalTo": "http://localhost:8009/snowflake/oauth-redirect"
+ },
+ "code_challenge": {
+ "matches": ".*"
+ },
+ "state": {
+ "matches": ".*"
+ },
+ "client_id": {
+ "equalTo": "123"
+ }
+ }
+ },
+ "response": {
+ "status": 302,
+ "headers": {
+ "Location": "http://localhost:8009/snowflake/oauth-redirect?code=123&state=abc123"
+ }
+ }
+ },
+ {
+ "scenarioName": "Custom urls OAuth authorization code flow",
+ "requiredScenarioState": "Authorized",
+ "newScenarioState": "Acquired access token",
+ "request": {
+ "urlPathPattern": "/tokenrequest.*",
+ "method": "POST",
+ "headers": {
+ "Authorization": {
+ "contains": "Basic"
+ },
+ "Content-Type": {
+ "contains": "application/x-www-form-urlencoded; charset=UTF-8"
+ }
+ },
+ "bodyPatterns": [
+ {
+ "contains": "grant_type=authorization_code&code=123&redirect_uri=http%3A%2F%2Flocalhost%3A8009%2Fsnowflake%2Foauth-redirect&code_verifier="
+ }
+ ]
+ },
+ "response": {
+ "status": 200,
+ "headers": {
+ "Content-Type": "application/json"
+ },
+ "jsonBody": {
+ "access_token": "access-token-123",
+ "refresh_token": "123",
+ "token_type": "Bearer",
+ "username": "user",
+ "scope": "refresh_token session:role:ANALYST",
+ "expires_in": 600,
+ "refresh_token_expires_in": 86399,
+ "idpInitiated": false
+ }
+ }
+ }
+ ]
+}
diff --git a/test/data/wiremock/mappings/auth/oauth/authorization_code/invalid_scope_error.json b/test/data/wiremock/mappings/auth/oauth/authorization_code/invalid_scope_error.json
new file mode 100644
index 0000000000..fc495213e1
--- /dev/null
+++ b/test/data/wiremock/mappings/auth/oauth/authorization_code/invalid_scope_error.json
@@ -0,0 +1,17 @@
+{
+ "mappings": [
+ {
+ "scenarioName": "Invalid scope authorization error",
+ "request": {
+ "urlPathPattern": "/oauth/authorize.*",
+ "method": "GET"
+ },
+ "response": {
+ "status": 302,
+ "headers": {
+ "Location": "http://localhost:8009/snowflake/oauth-redirect?error=invalid_scope&error_description=One+or+more+scopes+are+not+configured+for+the+authorization+server+resource."
+ }
+ }
+ }
+ ]
+}
diff --git a/test/data/wiremock/mappings/auth/oauth/authorization_code/invalid_state_error.json b/test/data/wiremock/mappings/auth/oauth/authorization_code/invalid_state_error.json
new file mode 100644
index 0000000000..23799a655c
--- /dev/null
+++ b/test/data/wiremock/mappings/auth/oauth/authorization_code/invalid_state_error.json
@@ -0,0 +1,17 @@
+{
+ "mappings": [
+ {
+ "scenarioName": "Invalid scope authorization error",
+ "request": {
+ "urlPathPattern": "/oauth/authorize.*",
+ "method": "GET"
+ },
+ "response": {
+ "status": 302,
+ "headers": {
+ "Location": "http://localhost:8009/snowflake/oauth-redirect?code=123&state=invalidstate"
+ }
+ }
+ }
+ ]
+}
diff --git a/test/data/wiremock/mappings/auth/oauth/authorization_code/new_tokens_after_failed_refresh.json b/test/data/wiremock/mappings/auth/oauth/authorization_code/new_tokens_after_failed_refresh.json
new file mode 100644
index 0000000000..55d60fe066
--- /dev/null
+++ b/test/data/wiremock/mappings/auth/oauth/authorization_code/new_tokens_after_failed_refresh.json
@@ -0,0 +1,37 @@
+{
+ "requiredScenarioState": "Authorized",
+ "newScenarioState": "Acquired access token",
+ "request": {
+ "urlPathPattern": "/oauth/token-request.*",
+ "method": "POST",
+ "headers": {
+ "Authorization": {
+ "contains": "Basic"
+ },
+ "Content-Type": {
+ "contains": "application/x-www-form-urlencoded; charset=UTF-8"
+ }
+ },
+ "bodyPatterns": [
+ {
+ "matches": "^grant_type=authorization_code&code=123&redirect_uri=http%3A%2F%2Flocalhost%3A([0-9]+)%2Fsnowflake%2Foauth-redirect&code_verifier=abc123$"
+ }
+ ]
+ },
+ "response": {
+ "status": 200,
+ "headers": {
+ "Content-Type": "application/json"
+ },
+ "jsonBody": {
+ "access_token": "access-token-123",
+ "refresh_token": "refresh-token-123",
+ "token_type": "Bearer",
+ "username": "user",
+ "scope": "refresh_token session:role:ANALYST",
+ "expires_in": 600,
+ "refresh_token_expires_in": 86399,
+ "idpInitiated": false
+ }
+ }
+}
diff --git a/test/data/wiremock/mappings/auth/oauth/authorization_code/successful_auth_after_failed_refresh.json b/test/data/wiremock/mappings/auth/oauth/authorization_code/successful_auth_after_failed_refresh.json
new file mode 100644
index 0000000000..f61d618011
--- /dev/null
+++ b/test/data/wiremock/mappings/auth/oauth/authorization_code/successful_auth_after_failed_refresh.json
@@ -0,0 +1,37 @@
+{
+ "requiredScenarioState": "Failed refresh token attempt",
+ "newScenarioState": "Authorized",
+ "request": {
+ "urlPathPattern": "/oauth/authorize",
+ "queryParameters": {
+ "response_type": {
+ "equalTo": "code"
+ },
+ "scope": {
+ "equalTo": "session:role:ANALYST offline_access"
+ },
+ "code_challenge_method": {
+ "equalTo": "S256"
+ },
+ "redirect_uri": {
+ "equalTo": "http://localhost:8009/snowflake/oauth-redirect"
+ },
+ "code_challenge": {
+ "matches": ".*"
+ },
+ "state": {
+ "matches": ".*"
+ },
+ "client_id": {
+ "equalTo": "123"
+ }
+ },
+ "method": "GET"
+ },
+ "response": {
+ "status": 302,
+ "headers": {
+ "Location": "http://localhost:8009/snowflake/oauth-redirect?code=123&state=abc123"
+ }
+ }
+}
diff --git a/test/data/wiremock/mappings/auth/oauth/authorization_code/successful_flow.json b/test/data/wiremock/mappings/auth/oauth/authorization_code/successful_flow.json
new file mode 100644
index 0000000000..5ca87b98c8
--- /dev/null
+++ b/test/data/wiremock/mappings/auth/oauth/authorization_code/successful_flow.json
@@ -0,0 +1,80 @@
+{
+ "mappings": [
+ {
+ "scenarioName": "Successful OAuth authorization code flow",
+ "requiredScenarioState": "Started",
+ "newScenarioState": "Authorized",
+ "request": {
+ "urlPathPattern": "/oauth/authorize",
+ "queryParameters": {
+ "response_type": {
+ "equalTo": "code"
+ },
+ "scope": {
+ "equalTo": "session:role:ANALYST"
+ },
+ "code_challenge_method": {
+ "equalTo": "S256"
+ },
+ "redirect_uri": {
+ "equalTo": "http://localhost:8009/snowflake/oauth-redirect"
+ },
+ "code_challenge": {
+ "matches": ".*"
+ },
+ "state": {
+ "matches": ".*"
+ },
+ "client_id": {
+ "equalTo": "123"
+ }
+ },
+ "method": "GET"
+ },
+ "response": {
+ "status": 302,
+ "headers": {
+ "Location": "http://localhost:8009/snowflake/oauth-redirect?code=123&state=abc123"
+ }
+ }
+ },
+ {
+ "scenarioName": "Successful OAuth authorization code flow",
+ "requiredScenarioState": "Authorized",
+ "newScenarioState": "Acquired access token",
+ "request": {
+ "urlPathPattern": "/oauth/token-request.*",
+ "method": "POST",
+ "headers": {
+ "Authorization": {
+ "contains": "Basic"
+ },
+ "Content-Type": {
+ "contains": "application/x-www-form-urlencoded; charset=UTF-8"
+ }
+ },
+ "bodyPatterns": [
+ {
+ "matches": "^grant_type=authorization_code&code=123&redirect_uri=http%3A%2F%2Flocalhost%3A([0-9]+)%2Fsnowflake%2Foauth-redirect&code_verifier=abc123$"
+ }
+ ]
+ },
+ "response": {
+ "status": 200,
+ "headers": {
+ "Content-Type": "application/json"
+ },
+ "jsonBody": {
+ "access_token": "access-token-123",
+ "refresh_token": "123",
+ "token_type": "Bearer",
+ "username": "user",
+ "scope": "refresh_token session:role:ANALYST",
+ "expires_in": 600,
+ "refresh_token_expires_in": 86399,
+ "idpInitiated": false
+ }
+ }
+ }
+ ]
+}
diff --git a/test/data/wiremock/mappings/auth/oauth/authorization_code/token_request_error.json b/test/data/wiremock/mappings/auth/oauth/authorization_code/token_request_error.json
new file mode 100644
index 0000000000..ca925266be
--- /dev/null
+++ b/test/data/wiremock/mappings/auth/oauth/authorization_code/token_request_error.json
@@ -0,0 +1,67 @@
+{
+ "mappings": [
+ {
+ "scenarioName": "OAuth token request error",
+ "requiredScenarioState": "Started",
+ "newScenarioState": "Authorized",
+ "request": {
+ "urlPathPattern": "/oauth/authorize",
+ "queryParameters": {
+ "response_type": {
+ "equalTo": "code"
+ },
+ "scope": {
+ "equalTo": "session:role:ANALYST"
+ },
+ "code_challenge_method": {
+ "equalTo": "S256"
+ },
+ "redirect_uri": {
+ "equalTo": "http://localhost:8009/snowflake/oauth-redirect"
+ },
+ "code_challenge": {
+ "matches": ".*"
+ },
+ "state": {
+ "matches": ".*"
+ },
+ "client_id": {
+ "equalTo": "123"
+ }
+ },
+ "method": "GET"
+ },
+ "response": {
+ "status": 302,
+ "headers": {
+ "Location": "http://localhost:8009/snowflake/oauth-redirect?code=123&state=abc123"
+ }
+ }
+ },
+ {
+ "scenarioName": "OAuth token request error",
+ "requiredScenarioState": "Authorized",
+ "newScenarioState": "Token request error",
+ "request": {
+ "urlPathPattern": "/oauth/token-request.*",
+ "method": "POST",
+ "headers": {
+ "Authorization": {
+ "contains": "Basic"
+ },
+ "Content-Type": {
+ "contains": "application/x-www-form-urlencoded; charset=UTF-8"
+ }
+ },
+ "bodyPatterns": [
+ {
+ "contains": "grant_type=authorization_code&code=123&redirect_uri=http%3A%2F%2Flocalhost%3A8009%2Fsnowflake%2Foauth-redirect&code_verifier="
+ }
+ ]
+ },
+ "response": {
+ "status": 400
+ }
+ }
+ ]
+}
diff --git a/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_auth_after_failed_refresh.json b/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_auth_after_failed_refresh.json
new file mode 100644
index 0000000000..6b8e9699f5
--- /dev/null
+++ b/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_auth_after_failed_refresh.json
@@ -0,0 +1,38 @@
+{
+ "scenarioName": "Successful OAuth client credentials flow",
+ "requiredScenarioState": "Started",
+ "newScenarioState": "Acquired access token",
+ "request": {
+ "urlPathPattern": "/oauth/token-request.*",
+ "method": "POST",
+ "headers": {
+ "Authorization": {
+ "contains": "Basic"
+ },
+ "Content-Type": {
+ "contains": "application/x-www-form-urlencoded; charset=UTF-8"
+ }
+ },
+ "bodyPatterns": [
+ {
+ "contains": "grant_type=client_credentials&scope=session%3Arole%3AANALYST"
+ }
+ ]
+ },
+ "response": {
+ "status": 200,
+ "headers": {
+ "Content-Type": "application/json"
+ },
+ "jsonBody": {
+ "access_token": "access-token-123",
+ "refresh_token": "refresh-token-123",
+ "token_type": "Bearer",
+ "username": "user",
+ "scope": "refresh_token session:role:ANALYST",
+ "expires_in": 600,
+ "refresh_token_expires_in": 86399,
+ "idpInitiated": false
+ }
+ }
+}
diff --git a/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_flow.json b/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_flow.json
new file mode 100644
index 0000000000..5e6137bd0e
--- /dev/null
+++ b/test/data/wiremock/mappings/auth/oauth/client_credentials/successful_flow.json
@@ -0,0 +1,42 @@
+{
+ "mappings": [
+ {
+ "scenarioName": "Successful OAuth client credentials flow",
+ "requiredScenarioState": "Started",
+ "newScenarioState": "Acquired access token",
+ "request": {
+ "urlPathPattern": "/oauth/token-request.*",
+ "method": "POST",
+ "headers": {
+ "Authorization": {
+ "contains": "Basic"
+ },
+ "Content-Type": {
+ "contains": "application/x-www-form-urlencoded; charset=UTF-8"
+ }
+ },
+ "bodyPatterns": [
+ {
+ "contains": "grant_type=client_credentials&scope=session%3Arole%3AANALYST"
+ }
+ ]
+ },
+ "response": {
+ "status": 200,
+ "headers": {
+ "Content-Type": "application/json"
+ },
+ "jsonBody": {
+ "access_token": "access-token-123",
+ "refresh_token": "123",
+ "token_type": "Bearer",
+ "username": "user",
+ "scope": "refresh_token session:role:ANALYST",
+ "expires_in": 600,
+ "refresh_token_expires_in": 86399,
+ "idpInitiated": false
+ }
+ }
+ }
+ ]
+}
diff --git a/test/data/wiremock/mappings/auth/oauth/client_credentials/token_request_error.json b/test/data/wiremock/mappings/auth/oauth/client_credentials/token_request_error.json
new file mode 100644
index 0000000000..b30b6056bf
--- /dev/null
+++ b/test/data/wiremock/mappings/auth/oauth/client_credentials/token_request_error.json
@@ -0,0 +1,29 @@
+{
+ "mappings": [
+ {
+ "scenarioName": "OAuth client credentials flow with token request error",
+ "requiredScenarioState": "Started",
+ "newScenarioState": "Acquired access token",
+ "request": {
+ "urlPathPattern": "/oauth/token-request.*",
+ "method": "POST",
+ "headers": {
+ "Authorization": {
+ "contains": "Basic"
+ },
+ "Content-Type": {
+ "contains": "application/x-www-form-urlencoded; charset=UTF-8"
+ }
+ },
+ "bodyPatterns": [
+ {
+ "contains": "grant_type=client_credentials&scope=session%3Arole%3AANALYST"
+ }
+ ]
+ },
+ "response": {
+ "status": 400
+ }
+ }
+ ]
+}
diff --git a/test/data/wiremock/mappings/auth/oauth/refresh_token/refresh_failed.json b/test/data/wiremock/mappings/auth/oauth/refresh_token/refresh_failed.json
new file mode 100644
index 0000000000..5529590b4b
--- /dev/null
+++ b/test/data/wiremock/mappings/auth/oauth/refresh_token/refresh_failed.json
@@ -0,0 +1,28 @@
+{
+ "requiredScenarioState": "Expired access token",
+ "newScenarioState": "Failed refresh token attempt",
+ "request": {
+ "urlPathPattern": "/oauth/token-request.*",
+ "method": "POST",
+ "headers": {
+ "Authorization": {
+ "contains": "Basic"
+ },
+ "Content-Type": {
+ "contains": "application/x-www-form-urlencoded; charset=UTF-8"
+ }
+ },
+ "bodyPatterns": [
+ {
+ "contains": "grant_type=refresh_token&refresh_token=expired-refresh-token-123&scope=session%3Arole%3AANALYST+offline_access"
+ }
+ ]
+ },
+ "response": {
+ "status": 400,
+ "jsonBody": {
+ "error": "invalid_grant",
+ "error_description": "Unknown or invalid refresh token."
+ }
+ }
+}
diff --git a/test/data/wiremock/mappings/auth/oauth/refresh_token/refresh_successful.json b/test/data/wiremock/mappings/auth/oauth/refresh_token/refresh_successful.json
new file mode 100644
index 0000000000..6a1ec8cf56
--- /dev/null
+++ b/test/data/wiremock/mappings/auth/oauth/refresh_token/refresh_successful.json
@@ -0,0 +1,33 @@
+{
+ "requiredScenarioState": "Expired access token",
+ "newScenarioState": "Acquired access token",
+ "request": {
+ "urlPathPattern": "/oauth/token-request.*",
+ "method": "POST",
+ "headers": {
+ "Authorization": {
+ "contains": "Basic"
+ },
+ "Content-Type": {
+ "contains": "application/x-www-form-urlencoded; charset=UTF-8"
+ }
+ },
+ "bodyPatterns": [
+ {
+ "contains": "grant_type=refresh_token&refresh_token=refresh-token-123&scope=session%3Arole%3AANALYST+offline_access"
+ }
+ ]
+ },
+ "response": {
+ "status": 200,
+ "headers": {
+ "Content-Type": "application/json"
+ },
+ "jsonBody": {
+ "access_token": "access-token-123",
+ "token_type": "Bearer",
+ "expires_in": 599,
+ "idpInitiated": false
+ }
+ }
+}
diff --git a/test/data/wiremock/mappings/generic/snowflake_login_failed.json b/test/data/wiremock/mappings/generic/snowflake_login_failed.json
new file mode 100644
index 0000000000..bf848d16b3
--- /dev/null
+++ b/test/data/wiremock/mappings/generic/snowflake_login_failed.json
@@ -0,0 +1,51 @@
+{
+ "mappings": [
+ {
+ "scenarioName": "Refresh expired access token",
+ "requiredScenarioState": "Started",
+ "newScenarioState": "Expired access token",
+ "request": {
+ "urlPathPattern": "/session/v1/login-request",
+ "method": "POST",
+ "queryParameters": {
+ "request_id": {
+ "matches": ".*"
+ },
+ "roleName": {
+ "equalTo": "ANALYST"
+ }
+ },
+ "headers": {
+ "Content-Type": {
+ "contains": "application/json"
+ }
+ },
+ "bodyPatterns": [
+ {
+ "matchesJsonPath": "$.data"
+ },
+ {
+ "matchesJsonPath": "$[?(@.data.TOKEN==\"expired-access-token-123\")]"
+ }
+ ]
+ },
+ "response": {
+ "status": 200,
+ "headers": {
+ "Content-Type": "application/json"
+ },
+ "jsonBody": {
+ "data": {
+ "nextAction": "RETRY_LOGIN",
+ "authnMethod": "OAUTH",
+ "signInOptions": {}
+ },
+ "code": "390318",
+ "message": "OAuth access token expired. [1172527951366]",
+ "success": false,
+ "headers": null
+ }
+ }
+ }
+ ]
+}
diff --git a/test/data/wiremock/mappings/generic/snowflake_login_successful.json b/test/data/wiremock/mappings/generic/snowflake_login_successful.json
new file mode 100644
index 0000000000..940ffad2e6
--- /dev/null
+++ b/test/data/wiremock/mappings/generic/snowflake_login_successful.json
@@ -0,0 +1,67 @@
+{
+ "requiredScenarioState": "Acquired access token",
+ "newScenarioState": "Connected",
+ "request": {
+ "urlPathPattern": "/session/v1/login-request",
+ "method": "POST",
+ "queryParameters": {
+ "request_id": {
+ "matches": ".*"
+ },
+ "roleName": {
+ "equalTo": "ANALYST"
+ }
+ },
+ "headers": {
+ "Content-Type": {
+ "contains": "application/json"
+ }
+ },
+ "bodyPatterns": [
+ {
+ "matchesJsonPath": "$.data"
+ },
+ {
+ "matchesJsonPath": "$[?(@.data.TOKEN==\"access-token-123\")]"
+ }
+ ]
+ },
+ "response": {
+ "status": 200,
+ "fixedDelayMilliseconds": "1000",
+ "headers": {
+ "Content-Type": "application/json"
+ },
+ "jsonBody": {
+ "data": {
+ "masterToken": "token-m1",
+ "token": "token-t1",
+ "validityInSeconds": 3599,
+ "masterValidityInSeconds": 14400,
+ "displayUserName": "***",
+ "serverVersion": "***",
+ "firstLogin": false,
+ "remMeToken": null,
+ "remMeValidityInSeconds": 0,
+ "healthCheckInterval": 45,
+ "newClientForUpgrade": null,
+ "sessionId": 1313,
+ "parameters": [],
+ "sessionInfo": {
+ "databaseName": null,
+ "schemaName": null,
+ "warehouseName": "TEST",
+ "roleName": "ACCOUNTADMIN"
+ },
+ "idToken": null,
+ "idTokenValidityInSeconds": 0,
+ "responseData": null,
+ "mfaToken": null,
+ "mfaTokenValidityInSeconds": 0
+ },
+ "code": null,
+ "message": null,
+ "success": true
+ }
+ }
+}
diff --git a/test/integ/aio/__init__.py b/test/integ/aio_it/__init__.py
similarity index 100%
rename from test/integ/aio/__init__.py
rename to test/integ/aio_it/__init__.py
diff --git a/test/integ/aio/conftest.py b/test/integ/aio_it/conftest.py
similarity index 100%
rename from test/integ/aio/conftest.py
rename to test/integ/aio_it/conftest.py
diff --git a/test/integ/aio/lambda/__init__.py b/test/integ/aio_it/lambda/__init__.py
similarity index 100%
rename from test/integ/aio/lambda/__init__.py
rename to test/integ/aio_it/lambda/__init__.py
diff --git a/test/integ/aio/lambda/test_basic_query_async.py b/test/integ/aio_it/lambda/test_basic_query_async.py
similarity index 100%
rename from test/integ/aio/lambda/test_basic_query_async.py
rename to test/integ/aio_it/lambda/test_basic_query_async.py
diff --git a/test/integ/aio/pandas/__init__.py b/test/integ/aio_it/pandas/__init__.py
similarity index 100%
rename from test/integ/aio/pandas/__init__.py
rename to test/integ/aio_it/pandas/__init__.py
diff --git a/test/integ/aio/pandas/test_arrow_chunk_iterator_async.py b/test/integ/aio_it/pandas/test_arrow_chunk_iterator_async.py
similarity index 100%
rename from test/integ/aio/pandas/test_arrow_chunk_iterator_async.py
rename to test/integ/aio_it/pandas/test_arrow_chunk_iterator_async.py
diff --git a/test/integ/aio/pandas/test_arrow_pandas_async.py b/test/integ/aio_it/pandas/test_arrow_pandas_async.py
similarity index 100%
rename from test/integ/aio/pandas/test_arrow_pandas_async.py
rename to test/integ/aio_it/pandas/test_arrow_pandas_async.py
diff --git a/test/integ/aio/pandas/test_logging_async.py b/test/integ/aio_it/pandas/test_logging_async.py
similarity index 100%
rename from test/integ/aio/pandas/test_logging_async.py
rename to test/integ/aio_it/pandas/test_logging_async.py
diff --git a/test/integ/aio/sso/__init__.py b/test/integ/aio_it/sso/__init__.py
similarity index 100%
rename from test/integ/aio/sso/__init__.py
rename to test/integ/aio_it/sso/__init__.py
diff --git a/test/integ/aio/sso/test_connection_manual_async.py b/test/integ/aio_it/sso/test_connection_manual_async.py
similarity index 100%
rename from test/integ/aio/sso/test_connection_manual_async.py
rename to test/integ/aio_it/sso/test_connection_manual_async.py
diff --git a/test/integ/aio/sso/test_unit_mfa_cache_async.py b/test/integ/aio_it/sso/test_unit_mfa_cache_async.py
similarity index 100%
rename from test/integ/aio/sso/test_unit_mfa_cache_async.py
rename to test/integ/aio_it/sso/test_unit_mfa_cache_async.py
diff --git a/test/integ/aio/test_arrow_result_async.py b/test/integ/aio_it/test_arrow_result_async.py
similarity index 100%
rename from test/integ/aio/test_arrow_result_async.py
rename to test/integ/aio_it/test_arrow_result_async.py
diff --git a/test/integ/aio/test_async_async.py b/test/integ/aio_it/test_async_async.py
similarity index 100%
rename from test/integ/aio/test_async_async.py
rename to test/integ/aio_it/test_async_async.py
diff --git a/test/integ/aio/test_autocommit_async.py b/test/integ/aio_it/test_autocommit_async.py
similarity index 100%
rename from test/integ/aio/test_autocommit_async.py
rename to test/integ/aio_it/test_autocommit_async.py
diff --git a/test/integ/aio/test_bindings_async.py b/test/integ/aio_it/test_bindings_async.py
similarity index 100%
rename from test/integ/aio/test_bindings_async.py
rename to test/integ/aio_it/test_bindings_async.py
diff --git a/test/integ/aio/test_boolean_async.py b/test/integ/aio_it/test_boolean_async.py
similarity index 100%
rename from test/integ/aio/test_boolean_async.py
rename to test/integ/aio_it/test_boolean_async.py
diff --git a/test/integ/aio/test_client_session_keep_alive_async.py b/test/integ/aio_it/test_client_session_keep_alive_async.py
similarity index 100%
rename from test/integ/aio/test_client_session_keep_alive_async.py
rename to test/integ/aio_it/test_client_session_keep_alive_async.py
diff --git a/test/integ/aio/test_concurrent_create_objects_async.py b/test/integ/aio_it/test_concurrent_create_objects_async.py
similarity index 100%
rename from test/integ/aio/test_concurrent_create_objects_async.py
rename to test/integ/aio_it/test_concurrent_create_objects_async.py
diff --git a/test/integ/aio/test_concurrent_insert_async.py b/test/integ/aio_it/test_concurrent_insert_async.py
similarity index 100%
rename from test/integ/aio/test_concurrent_insert_async.py
rename to test/integ/aio_it/test_concurrent_insert_async.py
diff --git a/test/integ/aio/test_connection_async.py b/test/integ/aio_it/test_connection_async.py
similarity index 99%
rename from test/integ/aio/test_connection_async.py
rename to test/integ/aio_it/test_connection_async.py
index e0c771664a..df76fa1df4 100644
--- a/test/integ/aio/test_connection_async.py
+++ b/test/integ/aio_it/test_connection_async.py
@@ -1674,7 +1674,7 @@ async def test_is_valid(conn_cnx):
async def test_no_auth_connection_negative_case():
# AuthNoAuth does not exist in old drivers, so we import at test level to
# skip importing it for old driver tests.
- from test.integ.aio.conftest import create_connection
+ from test.integ.aio_it.conftest import create_connection
from snowflake.connector.aio.auth._no_auth import AuthNoAuth
diff --git a/test/integ/aio/test_converter_async.py b/test/integ/aio_it/test_converter_async.py
similarity index 100%
rename from test/integ/aio/test_converter_async.py
rename to test/integ/aio_it/test_converter_async.py
diff --git a/test/integ/aio/test_converter_more_timestamp_async.py b/test/integ/aio_it/test_converter_more_timestamp_async.py
similarity index 100%
rename from test/integ/aio/test_converter_more_timestamp_async.py
rename to test/integ/aio_it/test_converter_more_timestamp_async.py
diff --git a/test/integ/aio/test_converter_null_async.py b/test/integ/aio_it/test_converter_null_async.py
similarity index 100%
rename from test/integ/aio/test_converter_null_async.py
rename to test/integ/aio_it/test_converter_null_async.py
diff --git a/test/integ/aio/test_cursor_async.py b/test/integ/aio_it/test_cursor_async.py
similarity index 100%
rename from test/integ/aio/test_cursor_async.py
rename to test/integ/aio_it/test_cursor_async.py
diff --git a/test/integ/aio/test_cursor_binding_async.py b/test/integ/aio_it/test_cursor_binding_async.py
similarity index 100%
rename from test/integ/aio/test_cursor_binding_async.py
rename to test/integ/aio_it/test_cursor_binding_async.py
diff --git a/test/integ/aio/test_cursor_context_manager_async.py b/test/integ/aio_it/test_cursor_context_manager_async.py
similarity index 100%
rename from test/integ/aio/test_cursor_context_manager_async.py
rename to test/integ/aio_it/test_cursor_context_manager_async.py
diff --git a/test/integ/aio/test_dataintegrity_async.py b/test/integ/aio_it/test_dataintegrity_async.py
similarity index 100%
rename from test/integ/aio/test_dataintegrity_async.py
rename to test/integ/aio_it/test_dataintegrity_async.py
diff --git a/test/integ/aio/test_daylight_savings_async.py b/test/integ/aio_it/test_daylight_savings_async.py
similarity index 100%
rename from test/integ/aio/test_daylight_savings_async.py
rename to test/integ/aio_it/test_daylight_savings_async.py
diff --git a/test/integ/aio/test_dbapi_async.py b/test/integ/aio_it/test_dbapi_async.py
similarity index 100%
rename from test/integ/aio/test_dbapi_async.py
rename to test/integ/aio_it/test_dbapi_async.py
diff --git a/test/integ/aio/test_decfloat_async.py b/test/integ/aio_it/test_decfloat_async.py
similarity index 100%
rename from test/integ/aio/test_decfloat_async.py
rename to test/integ/aio_it/test_decfloat_async.py
diff --git a/test/integ/aio/test_direct_file_operation_utils_async.py b/test/integ/aio_it/test_direct_file_operation_utils_async.py
similarity index 100%
rename from test/integ/aio/test_direct_file_operation_utils_async.py
rename to test/integ/aio_it/test_direct_file_operation_utils_async.py
diff --git a/test/integ/aio/test_errors_async.py b/test/integ/aio_it/test_errors_async.py
similarity index 100%
rename from test/integ/aio/test_errors_async.py
rename to test/integ/aio_it/test_errors_async.py
diff --git a/test/integ/aio/test_execute_multi_statements_async.py b/test/integ/aio_it/test_execute_multi_statements_async.py
similarity index 100%
rename from test/integ/aio/test_execute_multi_statements_async.py
rename to test/integ/aio_it/test_execute_multi_statements_async.py
diff --git a/test/integ/aio/test_key_pair_authentication_async.py b/test/integ/aio_it/test_key_pair_authentication_async.py
similarity index 100%
rename from test/integ/aio/test_key_pair_authentication_async.py
rename to test/integ/aio_it/test_key_pair_authentication_async.py
diff --git a/test/integ/aio/test_large_put_async.py b/test/integ/aio_it/test_large_put_async.py
similarity index 100%
rename from test/integ/aio/test_large_put_async.py
rename to test/integ/aio_it/test_large_put_async.py
diff --git a/test/integ/aio/test_large_result_set_async.py b/test/integ/aio_it/test_large_result_set_async.py
similarity index 100%
rename from test/integ/aio/test_large_result_set_async.py
rename to test/integ/aio_it/test_large_result_set_async.py
diff --git a/test/integ/aio/test_load_unload_async.py b/test/integ/aio_it/test_load_unload_async.py
similarity index 100%
rename from test/integ/aio/test_load_unload_async.py
rename to test/integ/aio_it/test_load_unload_async.py
diff --git a/test/integ/aio/test_multi_statement_async.py b/test/integ/aio_it/test_multi_statement_async.py
similarity index 100%
rename from test/integ/aio/test_multi_statement_async.py
rename to test/integ/aio_it/test_multi_statement_async.py
diff --git a/test/integ/aio/test_network_async.py b/test/integ/aio_it/test_network_async.py
similarity index 100%
rename from test/integ/aio/test_network_async.py
rename to test/integ/aio_it/test_network_async.py
diff --git a/test/integ/aio/test_numpy_binding_async.py b/test/integ/aio_it/test_numpy_binding_async.py
similarity index 100%
rename from test/integ/aio/test_numpy_binding_async.py
rename to test/integ/aio_it/test_numpy_binding_async.py
diff --git a/test/integ/aio/test_pickle_timestamp_tz_async.py b/test/integ/aio_it/test_pickle_timestamp_tz_async.py
similarity index 100%
rename from test/integ/aio/test_pickle_timestamp_tz_async.py
rename to test/integ/aio_it/test_pickle_timestamp_tz_async.py
diff --git a/test/integ/aio/test_put_get_async.py b/test/integ/aio_it/test_put_get_async.py
similarity index 100%
rename from test/integ/aio/test_put_get_async.py
rename to test/integ/aio_it/test_put_get_async.py
diff --git a/test/integ/aio/test_put_get_compress_enc_async.py b/test/integ/aio_it/test_put_get_compress_enc_async.py
similarity index 100%
rename from test/integ/aio/test_put_get_compress_enc_async.py
rename to test/integ/aio_it/test_put_get_compress_enc_async.py
diff --git a/test/integ/aio/test_put_get_medium_async.py b/test/integ/aio_it/test_put_get_medium_async.py
similarity index 100%
rename from test/integ/aio/test_put_get_medium_async.py
rename to test/integ/aio_it/test_put_get_medium_async.py
diff --git a/test/integ/aio/test_put_get_snow_4525_async.py b/test/integ/aio_it/test_put_get_snow_4525_async.py
similarity index 100%
rename from test/integ/aio/test_put_get_snow_4525_async.py
rename to test/integ/aio_it/test_put_get_snow_4525_async.py
diff --git a/test/integ/aio/test_put_get_user_stage_async.py b/test/integ/aio_it/test_put_get_user_stage_async.py
similarity index 100%
rename from test/integ/aio/test_put_get_user_stage_async.py
rename to test/integ/aio_it/test_put_get_user_stage_async.py
diff --git a/test/integ/aio/test_put_get_with_aws_token_async.py b/test/integ/aio_it/test_put_get_with_aws_token_async.py
similarity index 100%
rename from test/integ/aio/test_put_get_with_aws_token_async.py
rename to test/integ/aio_it/test_put_get_with_aws_token_async.py
diff --git a/test/integ/aio/test_put_get_with_azure_token_async.py b/test/integ/aio_it/test_put_get_with_azure_token_async.py
similarity index 100%
rename from test/integ/aio/test_put_get_with_azure_token_async.py
rename to test/integ/aio_it/test_put_get_with_azure_token_async.py
diff --git a/test/integ/aio/test_put_get_with_gcp_account_async.py b/test/integ/aio_it/test_put_get_with_gcp_account_async.py
similarity index 100%
rename from test/integ/aio/test_put_get_with_gcp_account_async.py
rename to test/integ/aio_it/test_put_get_with_gcp_account_async.py
diff --git a/test/integ/aio/test_put_windows_path_async.py b/test/integ/aio_it/test_put_windows_path_async.py
similarity index 100%
rename from test/integ/aio/test_put_windows_path_async.py
rename to test/integ/aio_it/test_put_windows_path_async.py
diff --git a/test/integ/aio/test_qmark_async.py b/test/integ/aio_it/test_qmark_async.py
similarity index 100%
rename from test/integ/aio/test_qmark_async.py
rename to test/integ/aio_it/test_qmark_async.py
diff --git a/test/integ/aio/test_query_cancelling_async.py b/test/integ/aio_it/test_query_cancelling_async.py
similarity index 100%
rename from test/integ/aio/test_query_cancelling_async.py
rename to test/integ/aio_it/test_query_cancelling_async.py
diff --git a/test/integ/aio/test_results_async.py b/test/integ/aio_it/test_results_async.py
similarity index 100%
rename from test/integ/aio/test_results_async.py
rename to test/integ/aio_it/test_results_async.py
diff --git a/test/integ/aio/test_reuse_cursor_async.py b/test/integ/aio_it/test_reuse_cursor_async.py
similarity index 100%
rename from test/integ/aio/test_reuse_cursor_async.py
rename to test/integ/aio_it/test_reuse_cursor_async.py
diff --git a/test/integ/aio/test_session_parameters_async.py b/test/integ/aio_it/test_session_parameters_async.py
similarity index 100%
rename from test/integ/aio/test_session_parameters_async.py
rename to test/integ/aio_it/test_session_parameters_async.py
diff --git a/test/integ/aio/test_statement_parameter_binding_async.py b/test/integ/aio_it/test_statement_parameter_binding_async.py
similarity index 100%
rename from test/integ/aio/test_statement_parameter_binding_async.py
rename to test/integ/aio_it/test_statement_parameter_binding_async.py
diff --git a/test/integ/aio/test_structured_types_async.py b/test/integ/aio_it/test_structured_types_async.py
similarity index 100%
rename from test/integ/aio/test_structured_types_async.py
rename to test/integ/aio_it/test_structured_types_async.py
diff --git a/test/integ/aio/test_transaction_async.py b/test/integ/aio_it/test_transaction_async.py
similarity index 100%
rename from test/integ/aio/test_transaction_async.py
rename to test/integ/aio_it/test_transaction_async.py
diff --git a/test/unit/aio/test_auth_oauth_code_async.py b/test/unit/aio/test_auth_oauth_code_async.py
new file mode 100644
index 0000000000..646c2df7d3
--- /dev/null
+++ b/test/unit/aio/test_auth_oauth_code_async.py
@@ -0,0 +1,49 @@
+#!/usr/bin/env python
+#
+# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
+#
+
+from __future__ import annotations
+
+import os
+
+from snowflake.connector.aio.auth import AuthByOauthCode
+
+
+async def test_auth_oauth_code():
+ """Simple OAuth Code test."""
+ # Set experimental auth flag for the test
+ os.environ["SF_ENABLE_EXPERIMENTAL_AUTHENTICATION"] = "true"
+
+ auth = AuthByOauthCode(
+ application="test_app",
+ client_id="test_client_id",
+ client_secret="test_client_secret",
+ authentication_url="https://example.com/auth",
+ token_request_url="https://example.com/token",
+ redirect_uri="http://localhost:8080/callback",
+ scope="session:role:test_role",
+ pkce_enabled=True,
+ refresh_token_enabled=False,
+ )
+
+ body = {"data": {}}
+ await auth.update_body(body)
+
+ # Check that OAuth authenticator is set
+ assert body["data"]["AUTHENTICATOR"] == "OAUTH", body
+ # OAuth type should be set to authorization_code
+ assert body["data"]["OAUTH_TYPE"] == "authorization_code", body
+
+ # Clean up environment variable
+ del os.environ["SF_ENABLE_EXPERIMENTAL_AUTHENTICATION"]
+
+
+def test_mro():
+ """Ensure that methods from AuthByPluginAsync override those from AuthByPlugin."""
+ from snowflake.connector.aio.auth import AuthByPlugin as AuthByPluginAsync
+ from snowflake.connector.auth import AuthByPlugin as AuthByPluginSync
+
+ assert AuthByOauthCode.mro().index(AuthByPluginAsync) < AuthByOauthCode.mro().index(
+ AuthByPluginSync
+ )
diff --git a/test/unit/aio/test_auth_oauth_credentials_async.py b/test/unit/aio/test_auth_oauth_credentials_async.py
new file mode 100644
index 0000000000..297614bd48
--- /dev/null
+++ b/test/unit/aio/test_auth_oauth_credentials_async.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python
+#
+# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
+#
+
+from __future__ import annotations
+
+import os
+
+from snowflake.connector.aio.auth import AuthByOauthCredentials
+
+
+async def test_auth_oauth_credentials():
+ """Simple OAuth Credentials test."""
+ # Set experimental auth flag for the test
+ os.environ["SF_ENABLE_EXPERIMENTAL_AUTHENTICATION"] = "true"
+
+ auth = AuthByOauthCredentials(
+ application="test_app",
+ client_id="test_client_id",
+ client_secret="test_client_secret",
+ token_request_url="https://example.com/token",
+ scope="session:role:test_role",
+ refresh_token_enabled=False,
+ )
+
+ body = {"data": {}}
+ await auth.update_body(body)
+
+ # Check that OAuth authenticator is set
+ assert body["data"]["AUTHENTICATOR"] == "OAUTH", body
+ # OAuth type should be set to client_credentials
+ assert body["data"]["OAUTH_TYPE"] == "client_credentials", body
+
+ # Clean up environment variable
+ del os.environ["SF_ENABLE_EXPERIMENTAL_AUTHENTICATION"]
+
+
+def test_mro():
+ """Ensure that methods from AuthByPluginAsync override those from AuthByPlugin."""
+ from snowflake.connector.aio.auth import AuthByPlugin as AuthByPluginAsync
+ from snowflake.connector.auth import AuthByPlugin as AuthByPluginSync
+
+ assert AuthByOauthCredentials.mro().index(
+ AuthByPluginAsync
+ ) < AuthByOauthCredentials.mro().index(AuthByPluginSync)
diff --git a/test/unit/aio/test_oauth_token_async.py b/test/unit/aio/test_oauth_token_async.py
new file mode 100644
index 0000000000..3d89af5186
--- /dev/null
+++ b/test/unit/aio/test_oauth_token_async.py
@@ -0,0 +1,760 @@
+#
+# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
+#
+
+import logging
+import pathlib
+from typing import Any, Generator, Union
+from unittest import mock
+from unittest.mock import Mock, patch
+
+import pytest
+
+try:
+ from snowflake.connector.aio import SnowflakeConnection
+ from snowflake.connector.aio.auth import AuthByOauthCredentials
+except ImportError:
+ pass
+
+import snowflake.connector.errors
+from snowflake.connector.token_cache import TokenCache, TokenKey, TokenType
+
+from ...wiremock.wiremock_utils import WiremockClient
+
+logger = logging.getLogger(__name__)
+
+
+@pytest.fixture(scope="session")
+def wiremock_client() -> Generator[Union[WiremockClient, Any], Any, None]:
+ with WiremockClient() as client:
+ yield client
+
+
+@pytest.fixture(scope="session")
+def wiremock_oauth_authorization_code_dir() -> pathlib.Path:
+ return (
+ pathlib.Path(__file__).parent.parent.parent
+ / "data"
+ / "wiremock"
+ / "mappings"
+ / "auth"
+ / "oauth"
+ / "authorization_code"
+ )
+
+
+@pytest.fixture(scope="session")
+def wiremock_oauth_client_creds_dir() -> pathlib.Path:
+ return (
+ pathlib.Path(__file__).parent.parent.parent
+ / "data"
+ / "wiremock"
+ / "mappings"
+ / "auth"
+ / "oauth"
+ / "client_credentials"
+ )
+
+
+@pytest.fixture(scope="session")
+def wiremock_generic_mappings_dir() -> pathlib.Path:
+ return (
+ pathlib.Path(__file__).parent.parent.parent
+ / "data"
+ / "wiremock"
+ / "mappings"
+ / "generic"
+ )
+
+
+@pytest.fixture(scope="session")
+def wiremock_oauth_refresh_token_dir() -> pathlib.Path:
+ return (
+ pathlib.Path(__file__).parent.parent.parent
+ / "data"
+ / "wiremock"
+ / "mappings"
+ / "auth"
+ / "oauth"
+ / "refresh_token"
+ )
+
+
+def _call_auth_server_sync(url: str):
+ """Sync version of auth server call for OAuth redirect simulation.
+
+ Since async classes call sync methods, we need to use sync requests.
+ """
+ import requests
+
+ # Use sync requests since the OAuth implementation uses sync urllib3
+ requests.get(url, allow_redirects=True, timeout=6)
+
+
+def _webbrowser_redirect_sync(*args):
+ """Sync version of webbrowser redirect simulation.
+
+ Since async OAuth classes use sync webbrowser.open(), we need sync simulation.
+ """
+ assert len(args) == 1, "Invalid number of arguments passed to webbrowser open"
+
+ from threading import Thread
+
+ # Use threading to avoid blocking since sync OAuth expects this pattern
+ thread = Thread(target=_call_auth_server_sync, args=(args[0],))
+ thread.start()
+
+ return thread.is_alive()
+
+
+@pytest.fixture(scope="session")
+def webbrowser_mock_sync() -> Mock:
+ """Mock for sync webbrowser since async OAuth classes use sync webbrowser.open()."""
+ webbrowser_mock = Mock()
+ webbrowser_mock.open = _webbrowser_redirect_sync
+ return webbrowser_mock
+
+
+@pytest.fixture()
+def temp_cache_async():
+ """Async-compatible temporary cache."""
+
+ class TemporaryCache(TokenCache):
+ def __init__(self):
+ self._cache = {}
+
+ def store(self, key: TokenKey, token: str) -> None:
+ self._cache[(key.user, key.host, key.tokenType)] = token
+
+ def retrieve(self, key: TokenKey) -> str:
+ return self._cache.get((key.user, key.host, key.tokenType))
+
+ def remove(self, key: TokenKey) -> None:
+ self._cache.pop((key.user, key.host, key.tokenType))
+
+ tmp_cache = TemporaryCache()
+ # Patch both sync and async versions to be safe since async Auth inherits from sync Auth
+ # but the actual Auth instance used is async
+ with mock.patch(
+ "snowflake.connector.aio.auth._auth.Auth.get_token_cache",
+ return_value=tmp_cache,
+ ), mock.patch(
+ "snowflake.connector.auth._auth.Auth.get_token_cache",
+ return_value=tmp_cache,
+ ):
+ yield tmp_cache
+
+
+@pytest.mark.skipolddriver
+@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30)
+async def test_oauth_code_successful_flow_async(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_authorization_code_dir,
+ wiremock_generic_mappings_dir,
+ webbrowser_mock_sync,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
+
+ wiremock_client.import_mapping(
+ wiremock_oauth_authorization_code_dir / "successful_flow.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
+ )
+
+ with mock.patch("webbrowser.open", new=webbrowser_mock_sync.open):
+ with mock.patch("secrets.token_urlsafe", return_value="abc123"):
+ cnx = SnowflakeConnection(
+ user="testUser",
+ authenticator="OAUTH_AUTHORIZATION_CODE",
+ oauth_client_id="123",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_client_secret="testClientSecret",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize",
+ oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ )
+
+ await cnx.connect()
+ await cnx.close()
+
+
+@pytest.mark.skipolddriver
+@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30)
+async def test_oauth_code_invalid_state_async(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_authorization_code_dir,
+ webbrowser_mock_sync,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
+
+ wiremock_client.import_mapping(
+ wiremock_oauth_authorization_code_dir / "invalid_state_error.json"
+ )
+
+ with pytest.raises(snowflake.connector.errors.DatabaseError) as execinfo:
+ with mock.patch("webbrowser.open", new=webbrowser_mock_sync.open):
+ with mock.patch("secrets.token_urlsafe", return_value="abc123"):
+ cnx = SnowflakeConnection(
+ user="testUser",
+ authenticator="OAUTH_AUTHORIZATION_CODE",
+ oauth_client_id="123",
+ oauth_client_secret="testClientSecret",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize",
+ oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ )
+ await cnx.connect()
+
+ assert str(execinfo.value).endswith("State changed during OAuth process.")
+
+
+@pytest.mark.skipolddriver
+@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30)
+async def test_oauth_code_scope_error_async(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_authorization_code_dir,
+ webbrowser_mock_sync,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
+
+ wiremock_client.import_mapping(
+ wiremock_oauth_authorization_code_dir / "invalid_scope_error.json"
+ )
+
+ with pytest.raises(snowflake.connector.errors.DatabaseError) as execinfo:
+ with mock.patch("webbrowser.open", new=webbrowser_mock_sync.open):
+ with mock.patch("secrets.token_urlsafe", return_value="abc123"):
+ cnx = SnowflakeConnection(
+ user="testUser",
+ authenticator="OAUTH_AUTHORIZATION_CODE",
+ oauth_client_id="123",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize",
+ oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ )
+ await cnx.connect()
+
+ assert str(execinfo.value).endswith(
+ "Oauth callback returned an invalid_scope error: One or more scopes are not configured for the authorization server resource."
+ )
+
+
+@pytest.mark.skipolddriver
+@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30)
+async def test_oauth_code_token_request_error_async(
+ wiremock_oauth_authorization_code_dir,
+ webbrowser_mock_sync,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
+
+ with WiremockClient() as wiremock_client:
+ wiremock_client.import_mapping(
+ wiremock_oauth_authorization_code_dir / "token_request_error.json"
+ )
+
+ with pytest.raises(snowflake.connector.errors.DatabaseError) as execinfo:
+ with mock.patch("webbrowser.open", new=webbrowser_mock_sync.open):
+ with mock.patch("secrets.token_urlsafe", return_value="abc123"):
+ cnx = SnowflakeConnection(
+ user="testUser",
+ authenticator="OAUTH_AUTHORIZATION_CODE",
+ oauth_client_id="123",
+ oauth_client_secret="testClientSecret",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize",
+ oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ )
+ await cnx.connect()
+
+ assert str(execinfo.value).endswith(
+ "Invalid HTTP request from web browser. Idp authentication could have failed."
+ )
+
+
+@pytest.mark.skipolddriver
+async def test_oauth_code_browser_timeout_async(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_authorization_code_dir,
+ webbrowser_mock_sync,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
+
+ wiremock_client.import_mapping(
+ wiremock_oauth_authorization_code_dir
+ / "browser_timeout_authorization_error.json"
+ )
+
+ with pytest.raises(snowflake.connector.errors.DatabaseError) as execinfo:
+ with mock.patch("webbrowser.open", new=webbrowser_mock_sync.open):
+ with mock.patch("secrets.token_urlsafe", return_value="abc123"):
+ cnx = SnowflakeConnection(
+ user="testUser",
+ authenticator="OAUTH_AUTHORIZATION_CODE",
+ oauth_client_id="123",
+ oauth_client_secret="testClientSecret",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize",
+ oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ external_browser_timeout=2,
+ )
+ await cnx.connect()
+
+ assert str(execinfo.value).endswith(
+ "Unable to receive the OAuth message within a given timeout. Please check the redirect URI and try again."
+ )
+
+
+@pytest.mark.skipolddriver
+@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30)
+async def test_oauth_code_custom_urls_async(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_authorization_code_dir,
+ wiremock_generic_mappings_dir,
+ webbrowser_mock_sync,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
+
+ wiremock_client.import_mapping(
+ wiremock_oauth_authorization_code_dir / "external_idp_custom_urls.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
+ )
+
+ with mock.patch("webbrowser.open", new=webbrowser_mock_sync.open):
+ with mock.patch("secrets.token_urlsafe", return_value="abc123"):
+ cnx = SnowflakeConnection(
+ user="testUser",
+ authenticator="OAUTH_AUTHORIZATION_CODE",
+ oauth_client_id="123",
+ oauth_client_secret="testClientSecret",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/tokenrequest",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/authorization",
+ oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ )
+
+ await cnx.connect()
+ await cnx.close()
+
+
+@pytest.mark.skipolddriver
+@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30)
+async def test_oauth_code_successful_refresh_token_flow_async(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_refresh_token_dir,
+ wiremock_generic_mappings_dir,
+ monkeypatch,
+ temp_cache_async,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
+
+ wiremock_client.import_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_failed.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_oauth_refresh_token_dir / "refresh_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
+ )
+ user = "testUser"
+ access_token_key = TokenKey(
+ user, wiremock_client.wiremock_host, TokenType.OAUTH_ACCESS_TOKEN
+ )
+ refresh_token_key = TokenKey(
+ user, wiremock_client.wiremock_host, TokenType.OAUTH_REFRESH_TOKEN
+ )
+ temp_cache_async.store(access_token_key, "expired-access-token-123")
+ temp_cache_async.store(refresh_token_key, "refresh-token-123")
+ cnx = SnowflakeConnection(
+ user=user,
+ authenticator="OAUTH_AUTHORIZATION_CODE",
+ oauth_client_id="123",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_client_secret="testClientSecret",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize",
+ oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ oauth_security_features=("pkce", "refresh_token"),
+ client_store_temporary_credential=True,
+ )
+ await cnx.connect()
+ await cnx.close()
+ new_access_token = temp_cache_async.retrieve(access_token_key)
+ new_refresh_token = temp_cache_async.retrieve(refresh_token_key)
+
+ assert new_access_token == "access-token-123"
+ assert new_refresh_token == "refresh-token-123"
+
+
+@pytest.mark.skipolddriver
+@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30)
+async def test_oauth_code_expired_refresh_token_flow_async(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_refresh_token_dir,
+ wiremock_oauth_authorization_code_dir,
+ wiremock_generic_mappings_dir,
+ webbrowser_mock_sync,
+ monkeypatch,
+ temp_cache_async,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
+
+ wiremock_client.import_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_failed.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_oauth_refresh_token_dir / "refresh_failed.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_oauth_authorization_code_dir
+ / "successful_auth_after_failed_refresh.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_oauth_authorization_code_dir / "new_tokens_after_failed_refresh.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
+ )
+
+ user = "testUser"
+ access_token_key = TokenKey(
+ user, wiremock_client.wiremock_host, TokenType.OAUTH_ACCESS_TOKEN
+ )
+ refresh_token_key = TokenKey(
+ user, wiremock_client.wiremock_host, TokenType.OAUTH_REFRESH_TOKEN
+ )
+ temp_cache_async.store(access_token_key, "expired-access-token-123")
+ temp_cache_async.store(refresh_token_key, "expired-refresh-token-123")
+ with mock.patch("webbrowser.open", new=webbrowser_mock_sync.open):
+ with mock.patch("secrets.token_urlsafe", return_value="abc123"):
+ cnx = SnowflakeConnection(
+ user=user,
+ authenticator="OAUTH_AUTHORIZATION_CODE",
+ oauth_client_id="123",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_client_secret="testClientSecret",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize",
+ oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ oauth_security_features=("pkce", "refresh_token"),
+ client_store_temporary_credential=True,
+ )
+ await cnx.connect()
+ await cnx.close()
+
+ new_access_token = temp_cache_async.retrieve(access_token_key)
+ new_refresh_token = temp_cache_async.retrieve(refresh_token_key)
+ assert new_access_token == "access-token-123"
+ assert new_refresh_token == "refresh-token-123"
+
+
+@pytest.mark.skipolddriver
+async def test_client_creds_oauth_type_async():
+ """Simple OAuth Client credentials type test for async."""
+ auth = AuthByOauthCredentials(
+ "app",
+ "clientId",
+ "clientSecret",
+ "tokenRequestUrl",
+ "scope",
+ )
+ body = {"data": {}}
+ await auth.update_body(body)
+ assert body["data"]["OAUTH_TYPE"] == "client_credentials"
+
+
+@pytest.mark.skipolddriver
+async def test_client_creds_successful_flow_async(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_client_creds_dir,
+ wiremock_generic_mappings_dir,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ wiremock_client.import_mapping(
+ wiremock_oauth_client_creds_dir / "successful_flow.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
+ )
+ with mock.patch("secrets.token_urlsafe", return_value="abc123"):
+ cnx = SnowflakeConnection(
+ user="testUser",
+ authenticator="OAUTH_CLIENT_CREDENTIALS",
+ oauth_client_id="123",
+ oauth_client_secret="123",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ )
+
+ await cnx.connect()
+ await cnx.close()
+
+
+@pytest.mark.skipolddriver
+async def test_client_creds_token_request_error_async(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_client_creds_dir,
+ wiremock_generic_mappings_dir,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ wiremock_client.import_mapping(
+ wiremock_oauth_client_creds_dir / "token_request_error.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
+ )
+
+ with pytest.raises(snowflake.connector.errors.DatabaseError) as execinfo:
+ with mock.patch("secrets.token_urlsafe", return_value="abc123"):
+ cnx = SnowflakeConnection(
+ user="testUser",
+ authenticator="OAUTH_CLIENT_CREDENTIALS",
+ oauth_client_id="123",
+ oauth_client_secret="123",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ )
+ await cnx.connect()
+
+ assert str(execinfo.value).endswith(
+ "Invalid HTTP request from web browser. Idp authentication could have failed."
+ )
+
+
+@pytest.mark.skipolddriver
+async def test_client_creds_successful_refresh_token_flow_async(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_refresh_token_dir,
+ wiremock_generic_mappings_dir,
+ monkeypatch,
+ temp_cache_async,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+
+ wiremock_client.import_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_failed.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_oauth_refresh_token_dir / "refresh_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
+ )
+ user = "testUser"
+ access_token_key = TokenKey(
+ user, wiremock_client.wiremock_host, TokenType.OAUTH_ACCESS_TOKEN
+ )
+ refresh_token_key = TokenKey(
+ user, wiremock_client.wiremock_host, TokenType.OAUTH_REFRESH_TOKEN
+ )
+ temp_cache_async.store(access_token_key, "expired-access-token-123")
+ temp_cache_async.store(refresh_token_key, "refresh-token-123")
+ cnx = SnowflakeConnection(
+ user=user,
+ authenticator="OAUTH_CLIENT_CREDENTIALS",
+ oauth_client_id="123",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_client_secret="testClientSecret",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ oauth_security_features=("refresh_token",),
+ client_store_temporary_credential=True,
+ )
+ await cnx.connect()
+ await cnx.close()
+
+ new_access_token = temp_cache_async.retrieve(access_token_key)
+ new_refresh_token = temp_cache_async.retrieve(refresh_token_key)
+ assert new_access_token == "access-token-123"
+ assert new_refresh_token == "refresh-token-123"
+
+
+@pytest.mark.skipolddriver
+async def test_client_creds_expired_refresh_token_flow_async(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_refresh_token_dir,
+ wiremock_oauth_client_creds_dir,
+ wiremock_generic_mappings_dir,
+ webbrowser_mock_sync,
+ monkeypatch,
+ temp_cache_async,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+
+ wiremock_client.import_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_failed.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_oauth_refresh_token_dir / "refresh_failed.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_oauth_client_creds_dir / "successful_auth_after_failed_refresh.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
+ )
+
+ user = "testUser"
+ access_token_key = TokenKey(
+ user, wiremock_client.wiremock_host, TokenType.OAUTH_ACCESS_TOKEN
+ )
+ refresh_token_key = TokenKey(
+ user, wiremock_client.wiremock_host, TokenType.OAUTH_REFRESH_TOKEN
+ )
+ temp_cache_async.store(access_token_key, "expired-access-token-123")
+ temp_cache_async.store(refresh_token_key, "expired-refresh-token-123")
+ cnx = SnowflakeConnection(
+ user=user,
+ authenticator="OAUTH_CLIENT_CREDENTIALS",
+ oauth_client_id="123",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_client_secret="testClientSecret",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ oauth_security_features=("refresh_token",),
+ client_store_temporary_credential=True,
+ )
+ await cnx.connect()
+ await cnx.close()
+
+ new_access_token = temp_cache_async.retrieve(access_token_key)
+ new_refresh_token = temp_cache_async.retrieve(refresh_token_key)
+ assert new_access_token == "access-token-123"
+ assert new_refresh_token == "refresh-token-123"
+
+
+@pytest.mark.skipolddriver
+@pytest.mark.parametrize(
+ "authenticator", ["OAUTH_AUTHORIZATION_CODE", "OAUTH_CLIENT_CREDENTIALS"]
+)
+async def test_auth_is_experimental_async(
+ authenticator,
+ monkeypatch,
+) -> None:
+ monkeypatch.delenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", False)
+ with pytest.raises(
+ snowflake.connector.errors.ProgrammingError,
+ match=r"SF_ENABLE_EXPERIMENTAL_AUTHENTICATION",
+ ):
+ cnx = SnowflakeConnection(
+ user="testUser",
+ account="testAccount",
+ authenticator=authenticator,
+ )
+ await cnx.connect()
+
+
+@pytest.mark.skipolddriver
+@pytest.mark.parametrize(
+ "authenticator", ["OAUTH_AUTHORIZATION_CODE", "OAUTH_CLIENT_CREDENTIALS"]
+)
+async def test_auth_experimental_when_variable_set_to_false_async(
+ authenticator,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "false")
+ with pytest.raises(
+ snowflake.connector.errors.ProgrammingError,
+ match=r"SF_ENABLE_EXPERIMENTAL_AUTHENTICATION",
+ ):
+ cnx = SnowflakeConnection(
+ user="testUser",
+ account="testAccount",
+ authenticator="OAUTH_CLIENT_CREDENTIALS",
+ )
+ await cnx.connect()
diff --git a/test/unit/test_auth_callback_server.py b/test/unit/test_auth_callback_server.py
new file mode 100644
index 0000000000..bf03a8d5f6
--- /dev/null
+++ b/test/unit/test_auth_callback_server.py
@@ -0,0 +1,63 @@
+#
+# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
+#
+
+from __future__ import annotations
+
+import socket
+import time
+from threading import Thread
+
+import pytest
+
+from snowflake.connector.auth._http_server import AuthHttpServer
+from snowflake.connector.vendored import requests
+
+
+@pytest.mark.parametrize(
+ "dontwait",
+ ["false", "true"],
+)
+@pytest.mark.parametrize("timeout", [None, 0.05])
+@pytest.mark.parametrize("reuse_port", ["true"])
+def test_auth_callback_success(monkeypatch, dontwait, timeout, reuse_port) -> None:
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
+ test_response: requests.Response | None = None
+ with AuthHttpServer("http://127.0.0.1/test_request") as callback_server:
+
+ def request_callback():
+ nonlocal test_response
+ if timeout:
+ time.sleep(timeout / 5)
+ test_response = requests.get(
+ f"http://{callback_server.hostname}:{callback_server.port}/test_request"
+ )
+
+ request_callback_thread = Thread(target=request_callback)
+ request_callback_thread.start()
+ block, client_socket = callback_server.receive_block(timeout=timeout)
+ test_callback_request = block[0]
+ response = ["HTTP/1.1 200 OK", "Content-Type: text/html", "", "test_response"]
+ client_socket.sendall("\r\n".join(response).encode("utf-8"))
+ client_socket.shutdown(socket.SHUT_RDWR)
+ client_socket.close()
+ request_callback_thread.join()
+ assert test_response.ok
+ assert test_response.text == "test_response"
+ assert test_callback_request == "GET /test_request HTTP/1.1"
+
+
+@pytest.mark.parametrize(
+ "dontwait",
+ ["false", "true"],
+)
+@pytest.mark.parametrize("timeout", [0.05])
+@pytest.mark.parametrize("reuse_port", ["true"])
+def test_auth_callback_timeout(monkeypatch, dontwait, timeout, reuse_port) -> None:
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
+ with AuthHttpServer("http://127.0.0.1/test_request") as callback_server:
+ block, client_socket = callback_server.receive_block(timeout=timeout)
+ assert block is None
+ assert client_socket is None
diff --git a/test/unit/test_auth_oauth_auth_code.py b/test/unit/test_auth_oauth_auth_code.py
new file mode 100644
index 0000000000..6a01bb014f
--- /dev/null
+++ b/test/unit/test_auth_oauth_auth_code.py
@@ -0,0 +1,22 @@
+#!/usr/bin/env python
+#
+# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
+#
+
+from snowflake.connector.auth import AuthByOauthCode
+
+
+def test_auth_oauth_auth_code_oauth_type():
+ """Simple OAuth Auth Code oauth type test."""
+ auth = AuthByOauthCode(
+ "app",
+ "clientId",
+ "clientSecret",
+ "auth_url",
+ "tokenRequestUrl",
+ "redirectUri:{port}",
+ "scope",
+ )
+ body = {"data": {}}
+ auth.update_body(body)
+ assert body["data"]["OAUTH_TYPE"] == "authorization_code"
diff --git a/test/unit/test_connection.py b/test/unit/test_connection.py
index 8e229b751f..a29babc2c4 100644
--- a/test/unit/test_connection.py
+++ b/test/unit/test_connection.py
@@ -637,7 +637,7 @@ def test_cannot_set_wlid_authenticator_without_env_variable(mock_post_requests):
account="account", authenticator="WORKLOAD_IDENTITY"
)
assert (
- "Please set the 'SF_ENABLE_EXPERIMENTAL_AUTHENTICATION' environment variable to use the 'WORKLOAD_IDENTITY' authenticator"
+ "Please set the 'SF_ENABLE_EXPERIMENTAL_AUTHENTICATION' environment variable true to use the 'WORKLOAD_IDENTITY' authenticator"
in str(excinfo.value)
)
@@ -647,7 +647,7 @@ def test_connection_params_are_plumbed_into_authbyworkloadidentity(monkeypatch):
m.setattr(
"snowflake.connector.SnowflakeConnection._authenticate", lambda *_: None
)
- m.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "") # Can be set to anything.
+ m.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
conn = snowflake.connector.connect(
account="my_account_1",
@@ -689,7 +689,7 @@ def test_toml_connection_params_are_plumbed_into_authbyworkloadidentity(
m.setattr(
"snowflake.connector.SnowflakeConnection._authenticate", lambda *_: None
)
- m.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "")
+ m.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
conn = snowflake.connector.connect(connections_file_path=connections_file)
assert conn.auth_class.provider == AttestationProvider.OIDC
diff --git a/test/unit/test_linux_local_file_cache.py b/test/unit/test_linux_local_file_cache.py
index 51617f6094..2cf7c6348f 100644
--- a/test/unit/test_linux_local_file_cache.py
+++ b/test/unit/test_linux_local_file_cache.py
@@ -1,12 +1,15 @@
#!/usr/bin/env python
from __future__ import annotations
-import os
+import time
import pytest
+from _pytest import pathlib
from snowflake.connector.compat import IS_LINUX
+pytestmark = pytest.mark.skipif(not IS_LINUX, reason="Testing on linux only")
+
try:
from snowflake.connector.token_cache import FileTokenCache, TokenKey, TokenType
@@ -23,13 +26,13 @@
CRED_1 = "cred_1"
-@pytest.mark.skipif(not IS_LINUX, reason="The test is only for Linux platform")
@pytest.mark.skipolddriver
-def test_basic_store(tmpdir):
- os.environ["SF_TEMPORARY_CREDENTIAL_CACHE_DIR"] = str(tmpdir)
-
- cache = FileTokenCache()
- cache.delete_temporary_credential_file()
+def test_basic_store(tmpdir, monkeypatch):
+ monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(tmpdir))
+ cache = FileTokenCache.make()
+ assert cache
+ assert cache.cache_dir == pathlib.Path(tmpdir)
+ cache.cache_file().unlink(missing_ok=True)
cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0)
cache.store(TokenKey(HOST_1, USER_1, CRED_TYPE_1), CRED_1)
@@ -39,13 +42,15 @@ def test_basic_store(tmpdir):
assert cache.retrieve(TokenKey(HOST_1, USER_1, CRED_TYPE_1)) == CRED_1
assert cache.retrieve(TokenKey(HOST_0, USER_1, CRED_TYPE_1)) == CRED_1
- cache.delete_temporary_credential_file()
+ cache.cache_file().unlink(missing_ok=True)
-def test_delete_specific_item():
- """The old behavior of delete cache is deleting the whole cache file. Now we change it to partially deletion."""
- cache = FileTokenCache()
- cache.delete_temporary_credential_file()
+@pytest.mark.skipif(not IS_LINUX, reason="The test is only for Linux platform")
+def test_delete_specific_item(tmpdir, monkeypatch):
+ monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(tmpdir))
+ cache = FileTokenCache.make()
+ assert cache
+ cache.cache_file().unlink(missing_ok=True)
cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0)
cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_1), CRED_1)
@@ -55,4 +60,170 @@ def test_delete_specific_item():
cache.remove(TokenKey(HOST_0, USER_0, CRED_TYPE_0))
assert not cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0))
assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_1)) == CRED_1
- cache.delete_temporary_credential_file()
+ cache.cache_file().unlink(missing_ok=True)
+
+
+def test_malformed_json_cache(tmpdir, monkeypatch):
+ monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(tmpdir))
+ cache = FileTokenCache.make()
+ assert cache
+ cache.cache_file().unlink(missing_ok=True)
+ cache.cache_file().touch(0o600)
+ invalid_json = "[}"
+ cache.cache_file().write_text(invalid_json)
+ assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) is None
+ cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0)
+ assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) == CRED_0
+
+
+def test_malformed_utf_cache(tmpdir, monkeypatch):
+ monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(tmpdir))
+ cache = FileTokenCache.make()
+ assert cache
+ cache.cache_file().unlink(missing_ok=True)
+ cache.cache_file().touch(0o600)
+ invalid_utf_sequence = bytes.fromhex("c0af")
+ cache.cache_file().write_bytes(invalid_utf_sequence)
+ assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) is None
+ cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0)
+ assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) == CRED_0
+
+
+def test_cache_dir_is_not_a_directory(tmpdir, monkeypatch):
+ file = pathlib.Path(str(tmpdir)) / "file"
+ file.touch()
+ monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(file))
+ monkeypatch.delenv("XDG_CACHE_HOME", raising=False)
+ monkeypatch.delenv("HOME", raising=False)
+ cache_dir = FileTokenCache.find_cache_dir()
+ assert cache_dir is None
+ file.unlink()
+
+
+def test_cache_dir_does_not_exist(tmpdir, monkeypatch):
+ directory = pathlib.Path(str(tmpdir)) / "dir"
+ directory.unlink(missing_ok=True)
+ monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(directory))
+ monkeypatch.delenv("XDG_CACHE_HOME", raising=False)
+ monkeypatch.delenv("HOME", raising=False)
+ cache_dir = FileTokenCache.find_cache_dir()
+ assert cache_dir is None
+
+
+def test_cache_dir_incorrect_permissions(tmpdir, monkeypatch):
+ directory = pathlib.Path(str(tmpdir)) / "dir"
+ directory.unlink(missing_ok=True)
+ directory.touch(0o777)
+ monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(directory))
+ monkeypatch.delenv("XDG_CACHE_HOME", raising=False)
+ monkeypatch.delenv("HOME", raising=False)
+ cache_dir = FileTokenCache.find_cache_dir()
+ assert cache_dir is None
+ directory.unlink()
+
+
+def test_cache_file_incorrect_permissions(tmpdir, monkeypatch):
+ monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(tmpdir))
+ cache = FileTokenCache.make()
+ assert cache
+ cache.cache_file().unlink(missing_ok=True)
+ cache.cache_file().touch(0o777)
+ assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) is None
+ cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0)
+ assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) is None
+ assert len(cache.cache_file().read_text("utf-8")) == 0
+ cache.cache_file().unlink()
+
+
+def test_cache_dir_xdg_cache_home(tmpdir, monkeypatch):
+ monkeypatch.delenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", raising=False)
+ monkeypatch.setenv("XDG_CACHE_HOME", str(tmpdir))
+ cache = FileTokenCache.make()
+ assert cache
+ cache.cache_file().unlink(missing_ok=True)
+ assert cache.cache_dir == pathlib.Path(str(tmpdir)) / "snowflake"
+ assert (
+ cache.cache_file()
+ == pathlib.Path(str(tmpdir)) / "snowflake" / "credential_cache_v1.json"
+ )
+ assert (
+ cache.lock_file()
+ == pathlib.Path(str(tmpdir)) / "snowflake" / "credential_cache_v1.json.lck"
+ )
+ cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0)
+ assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) == CRED_0
+ cache.cache_file().unlink()
+
+
+def test_cache_dir_home(tmpdir, monkeypatch):
+ monkeypatch.delenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", raising=False)
+ monkeypatch.delenv("XDG_CACHE_HOME", raising=False)
+ monkeypatch.setenv("HOME", str(tmpdir))
+ cache = FileTokenCache.make()
+ assert cache
+ cache.cache_file().unlink(missing_ok=True)
+ assert cache.cache_dir == pathlib.Path(str(tmpdir)) / ".cache" / "snowflake"
+ assert (
+ cache.cache_file()
+ == pathlib.Path(str(tmpdir))
+ / ".cache"
+ / "snowflake"
+ / "credential_cache_v1.json"
+ )
+ assert (
+ cache.lock_file()
+ == pathlib.Path(str(tmpdir))
+ / ".cache"
+ / "snowflake"
+ / "credential_cache_v1.json.lck"
+ )
+ cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0)
+ assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) == CRED_0
+
+
+def test_file_lock(tmpdir, monkeypatch):
+ monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(tmpdir))
+ cache = FileTokenCache.make()
+ assert cache
+ cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0)
+ assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) == CRED_0
+ cache.lock_file().mkdir(0o700)
+ assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) is None
+ assert cache.lock_file().exists()
+ cache.lock_file().rmdir()
+
+
+def test_file_lock_stale(tmpdir, monkeypatch):
+ monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(tmpdir))
+ cache = FileTokenCache.make()
+ assert cache
+ cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0)
+ assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) == CRED_0
+ cache.lock_file().mkdir(0o700)
+ time.sleep(1)
+ assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) == CRED_0
+ assert not cache.lock_file().exists()
+
+
+def test_file_missing_tokens_field(tmpdir, monkeypatch):
+ monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(tmpdir))
+ cache = FileTokenCache.make()
+ assert cache
+ cache.cache_file().touch(0o600)
+ cache.cache_file().write_text("{}")
+ assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) is None
+ cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0)
+ assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) == CRED_0
+ cache.cache_file().unlink()
+
+
+def test_file_tokens_is_not_dict(tmpdir, monkeypatch):
+ monkeypatch.setenv("SF_TEMPORARY_CREDENTIAL_CACHE_DIR", str(tmpdir))
+ cache = FileTokenCache.make()
+ assert cache
+ cache.cache_file().touch(0o600)
+ cache.cache_file().write_text('{ "tokens": [] }')
+ assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) is None
+ cache.store(TokenKey(HOST_0, USER_0, CRED_TYPE_0), CRED_0)
+ assert cache.retrieve(TokenKey(HOST_0, USER_0, CRED_TYPE_0)) == CRED_0
+ cache.cache_file().unlink()
diff --git a/test/unit/test_oauth_token.py b/test/unit/test_oauth_token.py
new file mode 100644
index 0000000000..9152f39c8c
--- /dev/null
+++ b/test/unit/test_oauth_token.py
@@ -0,0 +1,729 @@
+#
+# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
+#
+
+import logging
+import pathlib
+from threading import Thread
+from typing import Any, Generator, Union
+from unittest import mock
+from unittest.mock import Mock, patch
+
+import pytest
+import requests
+
+import snowflake.connector
+from snowflake.connector.auth import AuthByOauthCredentials
+from snowflake.connector.token_cache import TokenCache, TokenKey, TokenType
+
+from ..wiremock.wiremock_utils import WiremockClient
+
+logger = logging.getLogger(__name__)
+
+
+@pytest.fixture(scope="session")
+def wiremock_client() -> Generator[Union[WiremockClient, Any], Any, None]:
+ with WiremockClient() as client:
+ yield client
+
+
+@pytest.fixture(scope="session")
+def wiremock_oauth_authorization_code_dir() -> pathlib.Path:
+ return (
+ pathlib.Path(__file__).parent.parent
+ / "data"
+ / "wiremock"
+ / "mappings"
+ / "auth"
+ / "oauth"
+ / "authorization_code"
+ )
+
+
+@pytest.fixture(scope="session")
+def wiremock_oauth_client_creds_dir() -> pathlib.Path:
+ return (
+ pathlib.Path(__file__).parent.parent
+ / "data"
+ / "wiremock"
+ / "mappings"
+ / "auth"
+ / "oauth"
+ / "client_credentials"
+ )
+
+
+@pytest.fixture(scope="session")
+def wiremock_generic_mappings_dir() -> pathlib.Path:
+ return (
+ pathlib.Path(__file__).parent.parent
+ / "data"
+ / "wiremock"
+ / "mappings"
+ / "generic"
+ )
+
+
+@pytest.fixture(scope="session")
+def wiremock_oauth_refresh_token_dir() -> pathlib.Path:
+ return (
+ pathlib.Path(__file__).parent.parent
+ / "data"
+ / "wiremock"
+ / "mappings"
+ / "auth"
+ / "oauth"
+ / "refresh_token"
+ )
+
+
+def _call_auth_server(url: str):
+ requests.get(url, allow_redirects=True, timeout=6)
+
+
+def _webbrowser_redirect(*args):
+ assert len(args) == 1, "Invalid number of arguments passed to webbrowser open"
+
+ thread = Thread(target=_call_auth_server, args=(args[0],))
+ thread.start()
+
+ return thread.is_alive()
+
+
+@pytest.fixture(scope="session")
+def webbrowser_mock() -> Mock:
+ webbrowser_mock = Mock()
+ webbrowser_mock.open = _webbrowser_redirect
+ return webbrowser_mock
+
+
+@pytest.fixture()
+def temp_cache():
+ class TemporaryCache(TokenCache):
+ def __init__(self):
+ self._cache = {}
+
+ def store(self, key: TokenKey, token: str) -> None:
+ self._cache[(key.user, key.host, key.tokenType)] = token
+
+ def retrieve(self, key: TokenKey) -> str:
+ return self._cache.get((key.user, key.host, key.tokenType))
+
+ def remove(self, key: TokenKey) -> None:
+ self._cache.pop((key.user, key.host, key.tokenType))
+
+ tmp_cache = TemporaryCache()
+ with mock.patch(
+ "snowflake.connector.auth._auth.Auth.get_token_cache", return_value=tmp_cache
+ ):
+ yield tmp_cache
+
+
+@pytest.mark.skipolddriver
+@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30)
+def test_oauth_code_successful_flow(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_authorization_code_dir,
+ wiremock_generic_mappings_dir,
+ webbrowser_mock,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
+
+ wiremock_client.import_mapping(
+ wiremock_oauth_authorization_code_dir / "successful_flow.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
+ )
+
+ with mock.patch("webbrowser.open", new=webbrowser_mock.open):
+ with mock.patch("secrets.token_urlsafe", return_value="abc123"):
+ cnx = snowflake.connector.connect(
+ user="testUser",
+ authenticator="OAUTH_AUTHORIZATION_CODE",
+ oauth_client_id="123",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_client_secret="testClientSecret",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize",
+ oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ )
+
+ assert cnx, "invalid cnx"
+ cnx.close()
+
+
+@pytest.mark.skipolddriver
+@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30)
+def test_oauth_code_invalid_state(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_authorization_code_dir,
+ webbrowser_mock,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
+
+ wiremock_client.import_mapping(
+ wiremock_oauth_authorization_code_dir / "invalid_state_error.json"
+ )
+
+ with pytest.raises(snowflake.connector.DatabaseError) as execinfo:
+ with mock.patch("webbrowser.open", new=webbrowser_mock.open):
+ with mock.patch("secrets.token_urlsafe", return_value="abc123"):
+ snowflake.connector.connect(
+ user="testUser",
+ authenticator="OAUTH_AUTHORIZATION_CODE",
+ oauth_client_id="123",
+ oauth_client_secret="testClientSecret",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize",
+ oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ )
+
+ assert str(execinfo.value).endswith("State changed during OAuth process.")
+
+
+@pytest.mark.skipolddriver
+@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30)
+def test_oauth_code_scope_error(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_authorization_code_dir,
+ webbrowser_mock,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
+
+ wiremock_client.import_mapping(
+ wiremock_oauth_authorization_code_dir / "invalid_scope_error.json"
+ )
+
+ with pytest.raises(snowflake.connector.DatabaseError) as execinfo:
+ with mock.patch("webbrowser.open", new=webbrowser_mock.open):
+ with mock.patch("secrets.token_urlsafe", return_value="abc123"):
+ snowflake.connector.connect(
+ user="testUser",
+ authenticator="OAUTH_AUTHORIZATION_CODE",
+ oauth_client_id="123",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize",
+ oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ )
+
+ assert str(execinfo.value).endswith(
+ "Oauth callback returned an invalid_scope error: One or more scopes are not configured for the authorization server resource."
+ )
+
+
+@pytest.mark.skipolddriver
+@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30)
+def test_oauth_code_token_request_error(
+ wiremock_oauth_authorization_code_dir,
+ webbrowser_mock,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
+
+ with WiremockClient() as wiremock_client:
+ wiremock_client.import_mapping(
+ wiremock_oauth_authorization_code_dir / "token_request_error.json"
+ )
+
+ with pytest.raises(snowflake.connector.DatabaseError) as execinfo:
+ with mock.patch("webbrowser.open", new=webbrowser_mock.open):
+ with mock.patch("secrets.token_urlsafe", return_value="abc123"):
+ snowflake.connector.connect(
+ user="testUser",
+ authenticator="OAUTH_AUTHORIZATION_CODE",
+ oauth_client_id="123",
+ oauth_client_secret="testClientSecret",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize",
+ oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ )
+
+ assert str(execinfo.value).endswith(
+ "Invalid HTTP request from web browser. Idp authentication could have failed."
+ )
+
+
+@pytest.mark.skipolddriver
+def test_oauth_code_browser_timeout(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_authorization_code_dir,
+ webbrowser_mock,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
+
+ wiremock_client.import_mapping(
+ wiremock_oauth_authorization_code_dir
+ / "browser_timeout_authorization_error.json"
+ )
+
+ with pytest.raises(snowflake.connector.DatabaseError) as execinfo:
+ with mock.patch("webbrowser.open", new=webbrowser_mock.open):
+ with mock.patch("secrets.token_urlsafe", return_value="abc123"):
+ snowflake.connector.connect(
+ user="testUser",
+ authenticator="OAUTH_AUTHORIZATION_CODE",
+ oauth_client_id="123",
+ oauth_client_secret="testClientSecret",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize",
+ oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ external_browser_timeout=2,
+ )
+
+ assert str(execinfo.value).endswith(
+ "Unable to receive the OAuth message within a given timeout. Please check the redirect URI and try again."
+ )
+
+
+@pytest.mark.skipolddriver
+@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30)
+def test_oauth_code_custom_urls(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_authorization_code_dir,
+ wiremock_generic_mappings_dir,
+ webbrowser_mock,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
+
+ wiremock_client.import_mapping(
+ wiremock_oauth_authorization_code_dir / "external_idp_custom_urls.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
+ )
+
+ with mock.patch("webbrowser.open", new=webbrowser_mock.open):
+ with mock.patch("secrets.token_urlsafe", return_value="abc123"):
+ cnx = snowflake.connector.connect(
+ user="testUser",
+ authenticator="OAUTH_AUTHORIZATION_CODE",
+ oauth_client_id="123",
+ oauth_client_secret="testClientSecret",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/tokenrequest",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/authorization",
+ oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ )
+
+ assert cnx, "invalid cnx"
+ cnx.close()
+
+
+@pytest.mark.skipolddriver
+@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30)
+def test_oauth_code_successful_refresh_token_flow(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_refresh_token_dir,
+ wiremock_generic_mappings_dir,
+ monkeypatch,
+ temp_cache,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
+
+ wiremock_client.import_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_failed.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_oauth_refresh_token_dir / "refresh_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
+ )
+ user = "testUser"
+ access_token_key = TokenKey(
+ user, wiremock_client.wiremock_host, TokenType.OAUTH_ACCESS_TOKEN
+ )
+ refresh_token_key = TokenKey(
+ user, wiremock_client.wiremock_host, TokenType.OAUTH_REFRESH_TOKEN
+ )
+ temp_cache.store(access_token_key, "expired-access-token-123")
+ temp_cache.store(refresh_token_key, "refresh-token-123")
+ cnx = snowflake.connector.connect(
+ user=user,
+ authenticator="OAUTH_AUTHORIZATION_CODE",
+ oauth_client_id="123",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_client_secret="testClientSecret",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize",
+ oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ oauth_security_features=("pkce", "refresh_token"),
+ client_store_temporary_credential=True,
+ )
+ assert cnx, "invalid cnx"
+ cnx.close()
+ new_access_token = temp_cache.retrieve(access_token_key)
+ new_refresh_token = temp_cache.retrieve(refresh_token_key)
+
+ assert new_access_token == "access-token-123"
+ assert new_refresh_token == "refresh-token-123"
+
+
+@pytest.mark.skipolddriver
+@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30)
+def test_oauth_code_expired_refresh_token_flow(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_refresh_token_dir,
+ wiremock_oauth_authorization_code_dir,
+ wiremock_generic_mappings_dir,
+ webbrowser_mock,
+ monkeypatch,
+ temp_cache,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
+
+ wiremock_client.import_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_failed.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_oauth_refresh_token_dir / "refresh_failed.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_oauth_authorization_code_dir
+ / "successful_auth_after_failed_refresh.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_oauth_authorization_code_dir / "new_tokens_after_failed_refresh.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
+ )
+
+ user = "testUser"
+ access_token_key = TokenKey(
+ user, wiremock_client.wiremock_host, TokenType.OAUTH_ACCESS_TOKEN
+ )
+ refresh_token_key = TokenKey(
+ user, wiremock_client.wiremock_host, TokenType.OAUTH_REFRESH_TOKEN
+ )
+ temp_cache.store(access_token_key, "expired-access-token-123")
+ temp_cache.store(refresh_token_key, "expired-refresh-token-123")
+ with mock.patch("webbrowser.open", new=webbrowser_mock.open):
+ with mock.patch("secrets.token_urlsafe", return_value="abc123"):
+ cnx = snowflake.connector.connect(
+ user=user,
+ authenticator="OAUTH_AUTHORIZATION_CODE",
+ oauth_client_id="123",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_client_secret="testClientSecret",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize",
+ oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ oauth_security_features=("pkce", "refresh_token"),
+ client_store_temporary_credential=True,
+ )
+ assert cnx, "invalid cnx"
+ cnx.close()
+
+ new_access_token = temp_cache.retrieve(access_token_key)
+ new_refresh_token = temp_cache.retrieve(refresh_token_key)
+ assert new_access_token == "access-token-123"
+ assert new_refresh_token == "refresh-token-123"
+
+
+@pytest.mark.skipolddriver
+def test_client_creds_oauth_type():
+ """Simple OAuth Client credentials type test."""
+ auth = AuthByOauthCredentials(
+ "app",
+ "clientId",
+ "clientSecret",
+ "auth_url",
+ "tokenRequestUrl",
+ "scope",
+ )
+ body = {"data": {}}
+ auth.update_body(body)
+ assert body["data"]["OAUTH_TYPE"] == "client_credentials"
+
+
+@pytest.mark.skipolddriver
+def test_client_creds_successful_flow(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_client_creds_dir,
+ wiremock_generic_mappings_dir,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ wiremock_client.import_mapping(
+ wiremock_oauth_client_creds_dir / "successful_flow.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
+ )
+ with mock.patch("secrets.token_urlsafe", return_value="abc123"):
+ cnx = snowflake.connector.connect(
+ user="testUser",
+ authenticator="OAUTH_CLIENT_CREDENTIALS",
+ oauth_client_id="123",
+ oauth_client_secret="123",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ )
+
+ assert cnx, "invalid cnx"
+ cnx.close()
+
+
+@pytest.mark.skipolddriver
+def test_client_creds_token_request_error(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_client_creds_dir,
+ wiremock_generic_mappings_dir,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+ wiremock_client.import_mapping(
+ wiremock_oauth_client_creds_dir / "token_request_error.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
+ )
+
+ with pytest.raises(snowflake.connector.DatabaseError) as execinfo:
+ with mock.patch("secrets.token_urlsafe", return_value="abc123"):
+ snowflake.connector.connect(
+ user="testUser",
+ authenticator="OAUTH_CLIENT_CREDENTIALS",
+ oauth_client_id="123",
+ oauth_client_secret="123",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ oauth_authorization_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/authorize",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ )
+
+ assert str(execinfo.value).endswith(
+ "Invalid HTTP request from web browser. Idp authentication could have failed."
+ )
+
+
+@pytest.mark.skipolddriver
+def test_client_creds_successful_refresh_token_flow(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_refresh_token_dir,
+ wiremock_generic_mappings_dir,
+ monkeypatch,
+ temp_cache,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+
+ wiremock_client.import_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_failed.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_oauth_refresh_token_dir / "refresh_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
+ )
+ user = "testUser"
+ access_token_key = TokenKey(
+ user, wiremock_client.wiremock_host, TokenType.OAUTH_ACCESS_TOKEN
+ )
+ refresh_token_key = TokenKey(
+ user, wiremock_client.wiremock_host, TokenType.OAUTH_REFRESH_TOKEN
+ )
+ temp_cache.store(access_token_key, "expired-access-token-123")
+ temp_cache.store(refresh_token_key, "refresh-token-123")
+ cnx = snowflake.connector.connect(
+ user=user,
+ authenticator="OAUTH_CLIENT_CREDENTIALS",
+ oauth_client_id="123",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_client_secret="testClientSecret",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ oauth_security_features=("refresh_token",),
+ client_store_temporary_credential=True,
+ )
+ assert cnx, "invalid cnx"
+ cnx.close()
+
+ new_access_token = temp_cache.retrieve(access_token_key)
+ new_refresh_token = temp_cache.retrieve(refresh_token_key)
+ assert new_access_token == "access-token-123"
+ assert new_refresh_token == "refresh-token-123"
+
+
+@pytest.mark.skipolddriver
+def test_client_creds_expired_refresh_token_flow(
+ wiremock_client: WiremockClient,
+ wiremock_oauth_refresh_token_dir,
+ wiremock_oauth_client_creds_dir,
+ wiremock_generic_mappings_dir,
+ webbrowser_mock,
+ monkeypatch,
+ temp_cache,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "true")
+
+ wiremock_client.import_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_failed.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_oauth_refresh_token_dir / "refresh_failed.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_oauth_client_creds_dir / "successful_auth_after_failed_refresh.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_login_successful.json"
+ )
+ wiremock_client.add_mapping(
+ wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
+ )
+
+ user = "testUser"
+ access_token_key = TokenKey(
+ user, wiremock_client.wiremock_host, TokenType.OAUTH_ACCESS_TOKEN
+ )
+ refresh_token_key = TokenKey(
+ user, wiremock_client.wiremock_host, TokenType.OAUTH_REFRESH_TOKEN
+ )
+ temp_cache.store(access_token_key, "expired-access-token-123")
+ temp_cache.store(refresh_token_key, "expired-refresh-token-123")
+ cnx = snowflake.connector.connect(
+ user=user,
+ authenticator="OAUTH_CLIENT_CREDENTIALS",
+ oauth_client_id="123",
+ account="testAccount",
+ protocol="http",
+ role="ANALYST",
+ oauth_client_secret="testClientSecret",
+ oauth_token_request_url=f"http://{wiremock_client.wiremock_host}:{wiremock_client.wiremock_http_port}/oauth/token-request",
+ host=wiremock_client.wiremock_host,
+ port=wiremock_client.wiremock_http_port,
+ oauth_security_features=("refresh_token",),
+ client_store_temporary_credential=True,
+ )
+ assert cnx, "invalid cnx"
+ cnx.close()
+
+ new_access_token = temp_cache.retrieve(access_token_key)
+ new_refresh_token = temp_cache.retrieve(refresh_token_key)
+ assert new_access_token == "access-token-123"
+ assert new_refresh_token == "refresh-token-123"
+
+
+@pytest.mark.skipolddriver
+@pytest.mark.parametrize(
+ "authenticator", ["OAUTH_AUTHORIZATION_CODE", "OAUTH_CLIENT_CREDENTIALS"]
+)
+def test_auth_is_experimental(
+ authenticator,
+ monkeypatch,
+) -> None:
+ monkeypatch.delenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", False)
+ with pytest.raises(
+ snowflake.connector.ProgrammingError,
+ match=r"SF_ENABLE_EXPERIMENTAL_AUTHENTICATION",
+ ):
+ snowflake.connector.connect(
+ user="testUser",
+ account="testAccount",
+ authenticator=authenticator,
+ )
+
+
+@pytest.mark.skipolddriver
+@pytest.mark.skipolddriver
+@pytest.mark.parametrize(
+ "authenticator", ["OAUTH_AUTHORIZATION_CODE", "OAUTH_CLIENT_CREDENTIALS"]
+)
+def test_auth_experimental_when_variable_set_to_false(
+ authenticator,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION", "false")
+ with pytest.raises(
+ snowflake.connector.ProgrammingError,
+ match=r"SF_ENABLE_EXPERIMENTAL_AUTHENTICATION",
+ ):
+ snowflake.connector.connect(
+ user="testUser",
+ account="testAccount",
+ authenticator="OAUTH_CLIENT_CREDENTIALS",
+ )
diff --git a/test/unit/test_wiremock_client.py b/test/unit/test_wiremock_client.py
index df4cacd2da..b471f39df7 100644
--- a/test/unit/test_wiremock_client.py
+++ b/test/unit/test_wiremock_client.py
@@ -12,6 +12,7 @@
from ..wiremock.wiremock_utils import WiremockClient
+@pytest.mark.skipolddriver
@pytest.fixture(scope="session")
def wiremock_client() -> Generator[WiremockClient, Any, None]:
with WiremockClient() as client:
diff --git a/test/wiremock/wiremock_utils.py b/test/wiremock/wiremock_utils.py
index 95b7374c1e..1d036a8023 100644
--- a/test/wiremock/wiremock_utils.py
+++ b/test/wiremock/wiremock_utils.py
@@ -31,11 +31,12 @@ def _get_mapping_str(mapping: Union[str, dict, pathlib.Path]) -> str:
class WiremockClient:
- def __init__(self):
+ def __init__(self, forbidden_ports: Optional[List[int]] = None) -> None:
self.wiremock_filename = "wiremock-standalone.jar"
self.wiremock_host = "localhost"
self.wiremock_http_port = None
self.wiremock_https_port = None
+ self.forbidden_ports = forbidden_ports if forbidden_ports is not None else []
self.wiremock_dir = pathlib.Path(__file__).parent.parent.parent / ".wiremock"
assert self.wiremock_dir.exists(), f"{self.wiremock_dir} does not exist"
@@ -46,9 +47,11 @@ def __init__(self):
), f"{self.wiremock_jar_path} does not exist"
def _start_wiremock(self):
- self.wiremock_http_port = self._find_free_port()
+ self.wiremock_http_port = self._find_free_port(
+ forbidden_ports=self.forbidden_ports,
+ )
self.wiremock_https_port = self._find_free_port(
- forbidden_ports=[self.wiremock_http_port]
+ forbidden_ports=self.forbidden_ports + [self.wiremock_http_port]
)
self.wiremock_process = subprocess.Popen(
[
@@ -119,6 +122,10 @@ def _health_check(self):
return True
def _reset_wiremock(self):
+ clean_journal_endpoint = (
+ f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/requests"
+ )
+ requests.delete(clean_journal_endpoint)
reset_endpoint = (
f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/reset"
)