Skip to content

Cherrypicks to aio connector part13 - Oauth #2466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: cherrypicks-to-aio-connector-part12
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file not shown.
Binary file not shown.
1 change: 1 addition & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
- Dropped support for Python 3.8.
- Basic decimal floating-point type support.
- Added handling of PAT provided in `password` field.
- Added experimental support for OAuth authorization code and client credentials flows.
- Improved error message for client-side query cancellations due to timeouts.
- Added support of GCS regional endpoints.
- Added `gcs_use_virtual_endpoints` connection property that forces the usage of the virtual GCS usage. Thanks to this it should be possible to set up private DNS entry for the GCS endpoint. See more: https://cloud.google.com/storage/docs/request-endpoints#xml-api
Expand Down
57 changes: 37 additions & 20 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,46 @@ timestamps {
string(name: 'parent_job', value: env.JOB_NAME),
string(name: 'parent_build_number', value: env.BUILD_NUMBER)
]
stage('Test') {
try {
def commit_hash = "main" // default which we want to override
def bptp_tag = "bptp-stable"
def response = authenticatedGithubCall("https://api.github.com/repos/snowflakedb/snowflake/git/ref/tags/${bptp_tag}")
commit_hash = response.object.sha
// Append the bptp-stable commit sha to params
params += [string(name: 'svn_revision', value: commit_hash)]
} catch(Exception e) {
println("Exception computing commit hash from: ${response}")
parallel(
'Test': {
stage('Test') {
try {
def commit_hash = "main" // default which we want to override
def bptp_tag = "bptp-stable"
def response = authenticatedGithubCall("https://api.github.com/repos/snowflakedb/snowflake/git/ref/tags/${bptp_tag}")
commit_hash = response.object.sha
// Append the bptp-stable commit sha to params
params += [string(name: 'svn_revision', value: commit_hash)]
} catch(Exception e) {
println("Exception computing commit hash from: ${response}")
}
parallel (
'Test Python 39': { build job: 'RT-PyConnector39-PC',parameters: params},
'Test Python 310': { build job: 'RT-PyConnector310-PC',parameters: params},
'Test Python 311': { build job: 'RT-PyConnector311-PC',parameters: params},
'Test Python 312': { build job: 'RT-PyConnector312-PC',parameters: params},
'Test Python 313': { build job: 'RT-PyConnector313-PC',parameters: params},
'Test Python 39 OldDriver': { build job: 'RT-PyConnector39-OldDriver-PC',parameters: params},
'Test Python 39 FIPS': { build job: 'RT-FIPS-PyConnector39',parameters: params},
)
}
},
'Test Authentication': {
stage('Test Authentication') {
withCredentials([
string(credentialsId: 'a791118f-a1ea-46cd-b876-56da1b9bc71c', variable: 'NEXUS_PASSWORD'),
string(credentialsId: 'sfctest0-parameters-secret', variable: 'PARAMETERS_SECRET')
]) {
sh '''\
|#!/bin/bash -e
|$WORKSPACE/ci/test_authentication.sh
'''.stripMargin()
}
parallel (
'Test Python 39': { build job: 'RT-PyConnector39-PC',parameters: params},
'Test Python 310': { build job: 'RT-PyConnector310-PC',parameters: params},
'Test Python 311': { build job: 'RT-PyConnector311-PC',parameters: params},
'Test Python 312': { build job: 'RT-PyConnector312-PC',parameters: params},
'Test Python 313': { build job: 'RT-PyConnector313-PC',parameters: params},
'Test Python 39 OldDriver': { build job: 'RT-PyConnector39-OldDriver-PC',parameters: params},
'Test Python 39 FIPS': { build job: 'RT-FIPS-PyConnector39',parameters: params},
)
}
}
}
)
}
}


pipeline {
Expand Down
24 changes: 24 additions & 0 deletions ci/container/test_authentication.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash -e

set -o pipefail


export WORKSPACE=${WORKSPACE:-/mnt/workspace}
export SOURCE_ROOT=${SOURCE_ROOT:-/mnt/host}

MVNW_EXE=$SOURCE_ROOT/mvnw
AUTH_PARAMETER_FILE=./.github/workflows/parameters/private/parameters_aws_auth_tests.json
eval $(jq -r '.authtestparams | to_entries | map("export \(.key)=\(.value|tostring)")|.[]' $AUTH_PARAMETER_FILE)

export SNOWFLAKE_AUTH_TEST_PRIVATE_KEY_PATH=./.github/workflows/parameters/private/rsa_keys/rsa_key.p8
export SNOWFLAKE_AUTH_TEST_INVALID_PRIVATE_KEY_PATH=./.github/workflows/parameters/private/rsa_keys/rsa_key_invalid.p8

export SF_OCSP_TEST_MODE=true
export SF_ENABLE_EXPERIMENTAL_AUTHENTICATION=true
export RUN_AUTH_TESTS=true
export AUTHENTICATION_TESTS_ENV="docker"
export PYTHONPATH=$SOURCE_ROOT

python3 -m pip install --break-system-packages -e .

python3 -m pytest test/auth/*
27 changes: 27 additions & 0 deletions ci/test_authentication.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash -e

set -o pipefail


export THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
export WORKSPACE=${WORKSPACE:-/tmp}

CI_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
if [[ -n "$JENKINS_HOME" ]]; then
ROOT_DIR="$(cd "${CI_DIR}/.." && pwd)"
export WORKSPACE=${WORKSPACE:-/tmp}
echo "Use /sbin/ip"
IP_ADDR=$(/sbin/ip -4 addr show scope global dev eth0 | grep inet | awk '{print $2}' | cut -d / -f 1)

fi

gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" --output $THIS_DIR/../.github/workflows/parameters/private/parameters_aws_auth_tests.json "$THIS_DIR/../.github/workflows/parameters/private/parameters_aws_auth_tests.json.gpg"
gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" --output $THIS_DIR/../.github/workflows/parameters/private/rsa_keys/rsa_key.p8 "$THIS_DIR/../.github/workflows/parameters/private/rsa_keys/rsa_key.p8.gpg"
gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" --output $THIS_DIR/../.github/workflows/parameters/private/rsa_keys/rsa_key_invalid.p8 "$THIS_DIR/../.github/workflows/parameters/private/rsa_keys/rsa_key_invalid.p8.gpg"

docker run \
-v $(cd $THIS_DIR/.. && pwd):/mnt/host \
-v $WORKSPACE:/mnt/workspace \
--rm \
nexus.int.snowflakecomputing.com:8086/docker/snowdrivers-test-external-browser-python:1 \
"/mnt/host/ci/container/test_authentication.sh"
97 changes: 96 additions & 1 deletion src/snowflake/connector/aio/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ..connection import _get_private_bytes_from_file
from ..constants import (
_CONNECTIVITY_ERR_MSG,
_OAUTH_DEFAULT_SCOPE,
ENV_VAR_EXPERIMENTAL_AUTHENTICATION,
ENV_VAR_PARTNER,
PARAMETER_AUTOCOMMIT,
Expand All @@ -51,15 +52,19 @@
from ..description import PLATFORM, PYTHON_VERSION, SNOWFLAKE_CONNECTOR_VERSION
from ..errorcode import (
ER_CONNECTION_IS_CLOSED,
ER_EXPERIMENTAL_AUTHENTICATION_NOT_SUPPORTED,
ER_FAILED_TO_CONNECT_TO_DB,
ER_INVALID_VALUE,
ER_INVALID_WIF_SETTINGS,
ER_NO_CLIENT_ID,
)
from ..network import (
DEFAULT_AUTHENTICATOR,
EXTERNAL_BROWSER_AUTHENTICATOR,
KEY_PAIR_AUTHENTICATOR,
OAUTH_AUTHENTICATOR,
OAUTH_AUTHORIZATION_CODE,
OAUTH_CLIENT_CREDENTIALS,
PROGRAMMATIC_ACCESS_TOKEN,
REQUEST_ID,
USR_PWD_MFA_AUTHENTICATOR,
Expand All @@ -84,6 +89,8 @@
AuthByIdToken,
AuthByKeyPair,
AuthByOAuth,
AuthByOauthCode,
AuthByOauthCredentials,
AuthByOkta,
AuthByPAT,
AuthByPlugin,
Expand Down Expand Up @@ -307,6 +314,56 @@ async def __open_connection(self):
timeout=self.login_timeout,
backoff_generator=self._backoff_generator,
)
elif self._authenticator == OAUTH_AUTHORIZATION_CODE:
self._check_experimental_authentication_flag()
self._check_oauth_required_parameters()
features = self.oauth_security_features
if self._role and (self._oauth_scope == ""):
# if role is known then let's inject it into scope
self._oauth_scope = _OAUTH_DEFAULT_SCOPE.format(role=self._role)
self.auth_class = AuthByOauthCode(
application=self.application,
client_id=self._oauth_client_id,
client_secret=self._oauth_client_secret,
authentication_url=self._oauth_authorization_url.format(
host=self.host, port=self.port
),
token_request_url=self._oauth_token_request_url.format(
host=self.host, port=self.port
),
redirect_uri=self._oauth_redirect_uri,
scope=self._oauth_scope,
pkce_enabled=features.pkce_enabled,
token_cache=(
auth.get_token_cache()
if self._client_store_temporary_credential
else None
),
refresh_token_enabled=features.refresh_token_enabled,
external_browser_timeout=self._external_browser_timeout,
)
elif self._authenticator == OAUTH_CLIENT_CREDENTIALS:
self._check_experimental_authentication_flag()
self._check_oauth_required_parameters()
features = self.oauth_security_features
if self._role and (self._oauth_scope == ""):
# if role is known then let's inject it into scope
self._oauth_scope = _OAUTH_DEFAULT_SCOPE.format(role=self._role)
self.auth_class = AuthByOauthCredentials(
application=self.application,
client_id=self._oauth_client_id,
client_secret=self._oauth_client_secret,
token_request_url=self._oauth_token_request_url.format(
host=self.host, port=self.port
),
scope=self._oauth_scope,
token_cache=(
auth.get_token_cache()
if self._client_store_temporary_credential
else None
),
refresh_token_enabled=features.refresh_token_enabled,
)
elif self._authenticator == PROGRAMMATIC_ACCESS_TOKEN:
self.auth_class = AuthByPAT(self._token)
elif self._authenticator == USR_PWD_MFA_AUTHENTICATOR:
Expand Down Expand Up @@ -747,7 +804,11 @@ async def authenticate_with_retry(self, auth_instance) -> None:
# SSO if it has expired
await self._reauthenticate()
else:
await self._authenticate(auth_instance)
# TODO pczajka: check if this is correct
# For OAuth and other auth types, call their reauthenticate method
await auth_instance.reauthenticate(conn=self)
# The reauthenticate method will call authenticate_with_retry internally,
# so we don't need to call _authenticate again here

async def autocommit(self, mode) -> None:
"""Sets autocommit mode to True, or False. Defaults to True."""
Expand Down Expand Up @@ -1052,3 +1113,37 @@ async def is_valid(self) -> bool:
except Exception as e:
logger.debug("session could not be validated due to exception: %s", e)
return False

def _check_experimental_authentication_flag(self) -> None:
if os.getenv(ENV_VAR_EXPERIMENTAL_AUTHENTICATION, "false").lower() != "true":
Error.errorhandler_wrapper(
self,
None,
ProgrammingError,
{
"msg": f"Please set the '{ENV_VAR_EXPERIMENTAL_AUTHENTICATION}' environment variable true to use the '{self._authenticator}' authenticator.",
"errno": ER_EXPERIMENTAL_AUTHENTICATION_NOT_SUPPORTED,
},
)

def _check_oauth_required_parameters(self) -> None:
if self._oauth_client_id is None:
Error.errorhandler_wrapper(
self,
None,
ProgrammingError,
{
"msg": "Oauth code flow requirement 'client_id' is empty",
"errno": ER_NO_CLIENT_ID,
},
)
if self._oauth_client_secret is None:
Error.errorhandler_wrapper(
self,
None,
ProgrammingError,
{
"msg": "Oauth code flow requirement 'client_secret' is empty",
"errno": ER_NO_CLIENT_ID,
},
)
6 changes: 6 additions & 0 deletions src/snowflake/connector/aio/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from ._keypair import AuthByKeyPair
from ._no_auth import AuthNoAuth
from ._oauth import AuthByOAuth
from ._oauth_code import AuthByOauthCode
from ._oauth_credentials import AuthByOauthCredentials
from ._okta import AuthByOkta
from ._pat import AuthByPAT
from ._usrpwdmfa import AuthByUsrPwdMfa
Expand All @@ -19,6 +21,8 @@
AuthByDefault,
AuthByKeyPair,
AuthByOAuth,
AuthByOauthCode,
AuthByOauthCredentials,
AuthByOkta,
AuthByUsrPwdMfa,
AuthByWebBrowser,
Expand All @@ -35,6 +39,8 @@
"AuthByKeyPair",
"AuthByPAT",
"AuthByOAuth",
"AuthByOauthCode",
"AuthByOauthCredentials",
"AuthByOkta",
"AuthByUsrPwdMfa",
"AuthByWebBrowser",
Expand Down
10 changes: 10 additions & 0 deletions src/snowflake/connector/aio/auth/_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ACCEPT_TYPE_APPLICATION_SNOWFLAKE,
CONTENT_TYPE_APPLICATION_JSON,
ID_TOKEN_INVALID_LOGIN_REQUEST_GS_CODE,
OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE,
PYTHON_CONNECTOR_USER_AGENT,
ReauthenticationRequest,
)
Expand Down Expand Up @@ -282,6 +283,15 @@ async def post_request_wrapper(self, url, headers, body) -> None:
sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
)
)
elif errno == OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE:
raise ReauthenticationRequest(
ProgrammingError(
msg=ret["message"],
errno=int(errno),
sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
)
)

from . import AuthByKeyPair

if isinstance(auth_instance, AuthByKeyPair):
Expand Down
Loading
Loading