Skip to content

Commit c6be86c

Browse files
Link sync implementation of Oauth to async code
1 parent 607db2e commit c6be86c

File tree

4 files changed

+217
-0
lines changed

4 files changed

+217
-0
lines changed

src/snowflake/connector/aio/_connection.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from ..connection import _get_private_bytes_from_file
3333
from ..constants import (
3434
_CONNECTIVITY_ERR_MSG,
35+
_OAUTH_DEFAULT_SCOPE,
3536
ENV_VAR_EXPERIMENTAL_AUTHENTICATION,
3637
ENV_VAR_PARTNER,
3738
PARAMETER_AUTOCOMMIT,
@@ -51,15 +52,19 @@
5152
from ..description import PLATFORM, PYTHON_VERSION, SNOWFLAKE_CONNECTOR_VERSION
5253
from ..errorcode import (
5354
ER_CONNECTION_IS_CLOSED,
55+
ER_EXPERIMENTAL_AUTHENTICATION_NOT_SUPPORTED,
5456
ER_FAILED_TO_CONNECT_TO_DB,
5557
ER_INVALID_VALUE,
5658
ER_INVALID_WIF_SETTINGS,
59+
ER_NO_CLIENT_ID,
5760
)
5861
from ..network import (
5962
DEFAULT_AUTHENTICATOR,
6063
EXTERNAL_BROWSER_AUTHENTICATOR,
6164
KEY_PAIR_AUTHENTICATOR,
6265
OAUTH_AUTHENTICATOR,
66+
OAUTH_AUTHORIZATION_CODE,
67+
OAUTH_CLIENT_CREDENTIALS,
6368
PROGRAMMATIC_ACCESS_TOKEN,
6469
REQUEST_ID,
6570
USR_PWD_MFA_AUTHENTICATOR,
@@ -84,6 +89,8 @@
8489
AuthByIdToken,
8590
AuthByKeyPair,
8691
AuthByOAuth,
92+
AuthByOauthCode,
93+
AuthByOauthCredentials,
8794
AuthByOkta,
8895
AuthByPAT,
8996
AuthByPlugin,
@@ -307,6 +314,56 @@ async def __open_connection(self):
307314
timeout=self.login_timeout,
308315
backoff_generator=self._backoff_generator,
309316
)
317+
elif self._authenticator == OAUTH_AUTHORIZATION_CODE:
318+
self._check_experimental_authentication_flag()
319+
self._check_oauth_required_parameters()
320+
features = self.oauth_security_features
321+
if self._role and (self._oauth_scope == ""):
322+
# if role is known then let's inject it into scope
323+
self._oauth_scope = _OAUTH_DEFAULT_SCOPE.format(role=self._role)
324+
self.auth_class = AuthByOauthCode(
325+
application=self.application,
326+
client_id=self._oauth_client_id,
327+
client_secret=self._oauth_client_secret,
328+
authentication_url=self._oauth_authorization_url.format(
329+
host=self.host, port=self.port
330+
),
331+
token_request_url=self._oauth_token_request_url.format(
332+
host=self.host, port=self.port
333+
),
334+
redirect_uri=self._oauth_redirect_uri,
335+
scope=self._oauth_scope,
336+
pkce_enabled=features.pkce_enabled,
337+
token_cache=(
338+
auth.get_token_cache()
339+
if self._client_store_temporary_credential
340+
else None
341+
),
342+
refresh_token_enabled=features.refresh_token_enabled,
343+
external_browser_timeout=self._external_browser_timeout,
344+
)
345+
elif self._authenticator == OAUTH_CLIENT_CREDENTIALS:
346+
self._check_experimental_authentication_flag()
347+
self._check_oauth_required_parameters()
348+
features = self.oauth_security_features
349+
if self._role and (self._oauth_scope == ""):
350+
# if role is known then let's inject it into scope
351+
self._oauth_scope = _OAUTH_DEFAULT_SCOPE.format(role=self._role)
352+
self.auth_class = AuthByOauthCredentials(
353+
application=self.application,
354+
client_id=self._oauth_client_id,
355+
client_secret=self._oauth_client_secret,
356+
token_request_url=self._oauth_token_request_url.format(
357+
host=self.host, port=self.port
358+
),
359+
scope=self._oauth_scope,
360+
token_cache=(
361+
auth.get_token_cache()
362+
if self._client_store_temporary_credential
363+
else None
364+
),
365+
refresh_token_enabled=features.refresh_token_enabled,
366+
)
310367
elif self._authenticator == PROGRAMMATIC_ACCESS_TOKEN:
311368
self.auth_class = AuthByPAT(self._token)
312369
elif self._authenticator == USR_PWD_MFA_AUTHENTICATOR:
@@ -1052,3 +1109,37 @@ async def is_valid(self) -> bool:
10521109
except Exception as e:
10531110
logger.debug("session could not be validated due to exception: %s", e)
10541111
return False
1112+
1113+
def _check_experimental_authentication_flag(self) -> None:
1114+
if os.getenv(ENV_VAR_EXPERIMENTAL_AUTHENTICATION, "false").lower() != "true":
1115+
Error.errorhandler_wrapper(
1116+
self,
1117+
None,
1118+
ProgrammingError,
1119+
{
1120+
"msg": f"Please set the '{ENV_VAR_EXPERIMENTAL_AUTHENTICATION}' environment variable true to use the '{self._authenticator}' authenticator.",
1121+
"errno": ER_EXPERIMENTAL_AUTHENTICATION_NOT_SUPPORTED,
1122+
},
1123+
)
1124+
1125+
def _check_oauth_required_parameters(self) -> None:
1126+
if self._oauth_client_id is None:
1127+
Error.errorhandler_wrapper(
1128+
self,
1129+
None,
1130+
ProgrammingError,
1131+
{
1132+
"msg": "Oauth code flow requirement 'client_id' is empty",
1133+
"errno": ER_NO_CLIENT_ID,
1134+
},
1135+
)
1136+
if self._oauth_client_secret is None:
1137+
Error.errorhandler_wrapper(
1138+
self,
1139+
None,
1140+
ProgrammingError,
1141+
{
1142+
"msg": "Oauth code flow requirement 'client_secret' is empty",
1143+
"errno": ER_NO_CLIENT_ID,
1144+
},
1145+
)

src/snowflake/connector/aio/auth/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from ._keypair import AuthByKeyPair
99
from ._no_auth import AuthNoAuth
1010
from ._oauth import AuthByOAuth
11+
from ._oauth_code import AuthByOauthCode
12+
from ._oauth_credentials import AuthByOauthCredentials
1113
from ._okta import AuthByOkta
1214
from ._pat import AuthByPAT
1315
from ._usrpwdmfa import AuthByUsrPwdMfa
@@ -19,6 +21,8 @@
1921
AuthByDefault,
2022
AuthByKeyPair,
2123
AuthByOAuth,
24+
AuthByOauthCode,
25+
AuthByOauthCredentials,
2226
AuthByOkta,
2327
AuthByUsrPwdMfa,
2428
AuthByWebBrowser,
@@ -35,6 +39,8 @@
3539
"AuthByKeyPair",
3640
"AuthByPAT",
3741
"AuthByOAuth",
42+
"AuthByOauthCode",
43+
"AuthByOauthCredentials",
3844
"AuthByOkta",
3945
"AuthByUsrPwdMfa",
4046
"AuthByWebBrowser",
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#!/usr/bin/env python
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
from typing import Any
7+
8+
from ...auth.oauth_code import AuthByOauthCode as AuthByOauthCodeSync
9+
from ...token_cache import TokenCache
10+
from ._by_plugin import AuthByPlugin as AuthByPluginAsync
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class AuthByOauthCode(AuthByPluginAsync, AuthByOauthCodeSync):
16+
"""Async version of OAuth authorization code authenticator."""
17+
18+
def __init__(
19+
self,
20+
application: str,
21+
client_id: str,
22+
client_secret: str,
23+
authentication_url: str,
24+
token_request_url: str,
25+
redirect_uri: str,
26+
scope: str,
27+
pkce_enabled: bool = True,
28+
token_cache: TokenCache | None = None,
29+
refresh_token_enabled: bool = False,
30+
external_browser_timeout: int | None = None,
31+
**kwargs,
32+
) -> None:
33+
"""Initializes an instance with OAuth authorization code parameters."""
34+
logger.debug(
35+
"OAuth authentication is not supported in async version - falling back to sync implementation"
36+
)
37+
AuthByOauthCodeSync.__init__(
38+
self,
39+
application=application,
40+
client_id=client_id,
41+
client_secret=client_secret,
42+
authentication_url=authentication_url,
43+
token_request_url=token_request_url,
44+
redirect_uri=redirect_uri,
45+
scope=scope,
46+
pkce_enabled=pkce_enabled,
47+
token_cache=token_cache,
48+
refresh_token_enabled=refresh_token_enabled,
49+
external_browser_timeout=external_browser_timeout,
50+
**kwargs,
51+
)
52+
53+
async def reset_secrets(self) -> None:
54+
AuthByOauthCodeSync.reset_secrets(self)
55+
56+
async def prepare(self, **kwargs: Any) -> None:
57+
AuthByOauthCodeSync.prepare(self, **kwargs)
58+
59+
async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]:
60+
return AuthByOauthCodeSync.reauthenticate(self, **kwargs)
61+
62+
async def update_body(self, body: dict[Any, Any]) -> None:
63+
AuthByOauthCodeSync.update_body(self, body)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#!/usr/bin/env python
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
from typing import Any
7+
8+
from ...auth.oauth_credentials import (
9+
AuthByOauthCredentials as AuthByOauthCredentialsSync,
10+
)
11+
from ...token_cache import TokenCache
12+
from ._by_plugin import AuthByPlugin as AuthByPluginAsync
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
class AuthByOauthCredentials(AuthByPluginAsync, AuthByOauthCredentialsSync):
18+
"""Async version of OAuth client credentials authenticator."""
19+
20+
def __init__(
21+
self,
22+
application: str,
23+
client_id: str,
24+
client_secret: str,
25+
token_request_url: str,
26+
scope: str,
27+
token_cache: TokenCache | None = None,
28+
refresh_token_enabled: bool = False,
29+
**kwargs,
30+
) -> None:
31+
"""Initializes an instance with OAuth client credentials parameters."""
32+
logger.debug(
33+
"OAuth authentication is not supported in async version - falling back to sync implementation"
34+
)
35+
AuthByOauthCredentialsSync.__init__(
36+
self,
37+
application=application,
38+
client_id=client_id,
39+
client_secret=client_secret,
40+
token_request_url=token_request_url,
41+
scope=scope,
42+
token_cache=token_cache,
43+
refresh_token_enabled=refresh_token_enabled,
44+
**kwargs,
45+
)
46+
47+
async def reset_secrets(self) -> None:
48+
AuthByOauthCredentialsSync.reset_secrets(self)
49+
50+
async def prepare(self, **kwargs: Any) -> None:
51+
AuthByOauthCredentialsSync.prepare(self, **kwargs)
52+
53+
async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]:
54+
return AuthByOauthCredentialsSync.reauthenticate(self, **kwargs)
55+
56+
async def update_body(self, body: dict[Any, Any]) -> None:
57+
AuthByOauthCredentialsSync.update_body(self, body)

0 commit comments

Comments
 (0)