Skip to content

Commit ca903c1

Browse files
[Async] Apply #2203 to async code
1 parent b8f333c commit ca903c1

File tree

5 files changed

+621
-0
lines changed

5 files changed

+621
-0
lines changed

src/snowflake/connector/aio/_connection.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from ..connection import _get_private_bytes_from_file
3636
from ..constants import (
3737
_CONNECTIVITY_ERR_MSG,
38+
ENV_VAR_EXPERIMENTAL_AUTHENTICATION,
3839
ENV_VAR_PARTNER,
3940
PARAMETER_AUTOCOMMIT,
4041
PARAMETER_CLIENT_PREFETCH_THREADS,
@@ -55,6 +56,7 @@
5556
ER_CONNECTION_IS_CLOSED,
5657
ER_FAILED_TO_CONNECT_TO_DB,
5758
ER_INVALID_VALUE,
59+
ER_INVALID_WIF_SETTINGS,
5860
)
5961
from ..network import (
6062
DEFAULT_AUTHENTICATOR,
@@ -64,12 +66,14 @@
6466
PROGRAMMATIC_ACCESS_TOKEN,
6567
REQUEST_ID,
6668
USR_PWD_MFA_AUTHENTICATOR,
69+
WORKLOAD_IDENTITY_AUTHENTICATOR,
6770
ReauthenticationRequest,
6871
)
6972
from ..sqlstate import SQLSTATE_CONNECTION_NOT_EXISTS, SQLSTATE_FEATURE_NOT_SUPPORTED
7073
from ..telemetry import TelemetryData, TelemetryField
7174
from ..time_util import get_time_millis
7275
from ..util_text import split_statements
76+
from ..wif_util import AttestationProvider
7377
from ._cursor import SnowflakeCursor
7478
from ._description import CLIENT_NAME
7579
from ._network import SnowflakeRestful
@@ -87,6 +91,7 @@
8791
AuthByPlugin,
8892
AuthByUsrPwdMfa,
8993
AuthByWebBrowser,
94+
AuthByWorkloadIdentity,
9095
)
9196

9297
logger = getLogger(__name__)
@@ -320,6 +325,29 @@ async def __open_connection(self):
320325
timeout=self.login_timeout,
321326
backoff_generator=self._backoff_generator,
322327
)
328+
elif self._authenticator == WORKLOAD_IDENTITY_AUTHENTICATOR:
329+
if ENV_VAR_EXPERIMENTAL_AUTHENTICATION not in os.environ:
330+
Error.errorhandler_wrapper(
331+
self,
332+
None,
333+
ProgrammingError,
334+
{
335+
"msg": f"Please set the '{ENV_VAR_EXPERIMENTAL_AUTHENTICATION}' environment variable to use the '{WORKLOAD_IDENTITY_AUTHENTICATOR}' authenticator.",
336+
"errno": ER_INVALID_WIF_SETTINGS,
337+
},
338+
)
339+
# Standardize the provider enum.
340+
if self._workload_identity_provider and isinstance(
341+
self._workload_identity_provider, str
342+
):
343+
self._workload_identity_provider = AttestationProvider.from_string(
344+
self._workload_identity_provider
345+
)
346+
self.auth_class = AuthByWorkloadIdentity(
347+
provider=self._workload_identity_provider,
348+
token=self._token,
349+
entra_resource=self._workload_identity_entra_resource,
350+
)
323351
else:
324352
# okta URL, e.g., https://<account>.okta.com/
325353
self.auth_class = AuthByOkta(

src/snowflake/connector/aio/auth/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ._pat import AuthByPAT
1717
from ._usrpwdmfa import AuthByUsrPwdMfa
1818
from ._webbrowser import AuthByWebBrowser
19+
from ._workload_identity import AuthByWorkloadIdentity
1920

2021
FIRST_PARTY_AUTHENTICATORS = frozenset(
2122
(
@@ -27,6 +28,7 @@
2728
AuthByWebBrowser,
2829
AuthByIdToken,
2930
AuthByPAT,
31+
AuthByWorkloadIdentity,
3032
AuthNoAuth,
3133
)
3234
)
@@ -40,6 +42,7 @@
4042
"AuthByOkta",
4143
"AuthByUsrPwdMfa",
4244
"AuthByWebBrowser",
45+
"AuthByWorkloadIdentity",
4346
"AuthNoAuth",
4447
"Auth",
4548
"AuthType",
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#
2+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
from __future__ import annotations
6+
7+
from typing import Any
8+
9+
from ...auth.workload_identity import (
10+
AuthByWorkloadIdentity as AuthByWorkloadIdentitySync,
11+
)
12+
from ._by_plugin import AuthByPlugin as AuthByPluginAsync
13+
14+
15+
class AuthByWorkloadIdentity(AuthByPluginAsync, AuthByWorkloadIdentitySync):
16+
def __init__(
17+
self,
18+
*,
19+
provider=None,
20+
token: str | None = None,
21+
entra_resource: str | None = None,
22+
**kwargs,
23+
) -> None:
24+
"""Initializes an instance with workload identity authentication."""
25+
AuthByWorkloadIdentitySync.__init__(
26+
self,
27+
provider=provider,
28+
token=token,
29+
entra_resource=entra_resource,
30+
**kwargs,
31+
)
32+
33+
async def reset_secrets(self) -> None:
34+
AuthByWorkloadIdentitySync.reset_secrets(self)
35+
36+
async def prepare(self, **kwargs: Any) -> None:
37+
AuthByWorkloadIdentitySync.prepare(self, **kwargs)
38+
39+
async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]:
40+
return AuthByWorkloadIdentitySync.reauthenticate(self, **kwargs)
41+
42+
async def update_body(self, body: dict[Any, Any]) -> None:
43+
AuthByWorkloadIdentitySync.update_body(self, body)

0 commit comments

Comments
 (0)