Skip to content

Commit 2722b95

Browse files
sfc-gh-pmansoursfc-gh-pczajka
authored andcommitted
Add support for workload identity federation (#2203)
1 parent a4c6b9d commit 2722b95

File tree

14 files changed

+1307
-4
lines changed

14 files changed

+1307
-4
lines changed

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ python_requires = >=3.9
4444
packages = find_namespace:
4545
install_requires =
4646
asn1crypto>0.24.0,<2.0.0
47+
boto3>=1.0
48+
botocore>=1.0
4749
cffi>=1.9,<2.0.0
4850
cryptography>=3.1.0
4951
pyOpenSSL>=22.0.0,<25.0.0

src/snowflake/connector/auth/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .pat import AuthByPAT
1616
from .usrpwdmfa import AuthByUsrPwdMfa
1717
from .webbrowser import AuthByWebBrowser
18+
from .workload_identity import AuthByWorkloadIdentity
1819

1920
FIRST_PARTY_AUTHENTICATORS = frozenset(
2021
(
@@ -26,6 +27,7 @@
2627
AuthByWebBrowser,
2728
AuthByIdToken,
2829
AuthByPAT,
30+
AuthByWorkloadIdentity,
2931
AuthNoAuth,
3032
)
3133
)
@@ -39,6 +41,7 @@
3941
"AuthByOkta",
4042
"AuthByUsrPwdMfa",
4143
"AuthByWebBrowser",
44+
"AuthByWorkloadIdentity",
4245
"AuthNoAuth",
4346
"Auth",
4447
"AuthType",

src/snowflake/connector/auth/by_plugin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class AuthType(Enum):
5656
OKTA = "OKTA"
5757
PAT = "PROGRAMMATIC_ACCESS_TOKEN"
5858
NO_AUTH = "NO_AUTH"
59+
WORKLOAD_IDENTITY = "WORKLOAD_IDENTITY"
5960

6061

6162
class AuthByPlugin(ABC):
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#
2+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
from __future__ import annotations
6+
7+
import json
8+
import typing
9+
from enum import Enum, unique
10+
11+
from ..network import WORKLOAD_IDENTITY_AUTHENTICATOR
12+
from ..wif_util import (
13+
AttestationProvider,
14+
WorkloadIdentityAttestation,
15+
create_attestation,
16+
)
17+
from .by_plugin import AuthByPlugin, AuthType
18+
19+
20+
@unique
21+
class ApiFederatedAuthenticationType(Enum):
22+
"""An API-specific enum of the WIF authentication type."""
23+
24+
AWS = "AWS"
25+
AZURE = "AZURE"
26+
GCP = "GCP"
27+
OIDC = "OIDC"
28+
29+
@staticmethod
30+
def from_attestation(
31+
attestation: WorkloadIdentityAttestation,
32+
) -> ApiFederatedAuthenticationType:
33+
"""Maps the internal / driver-specific attestation providers to API authenticator types.
34+
35+
The AttestationProvider is related to how the driver fetches the credential, while the API authenticator
36+
type is related to how the credential is verified. In most current cases these may be the same, though
37+
in the future we could have, for example, multiple AttestationProviders that all fetch an OIDC ID token.
38+
"""
39+
if attestation.provider == AttestationProvider.AWS:
40+
return ApiFederatedAuthenticationType.AWS
41+
if attestation.provider == AttestationProvider.AZURE:
42+
return ApiFederatedAuthenticationType.AZURE
43+
if attestation.provider == AttestationProvider.GCP:
44+
return ApiFederatedAuthenticationType.GCP
45+
if attestation.provider == AttestationProvider.OIDC:
46+
return ApiFederatedAuthenticationType.OIDC
47+
return ValueError(f"Unknown attestation provider '{attestation.provider}'")
48+
49+
50+
class AuthByWorkloadIdentity(AuthByPlugin):
51+
"""Plugin to authenticate via workload identity."""
52+
53+
def __init__(
54+
self,
55+
*,
56+
provider: AttestationProvider | None = None,
57+
token: str | None = None,
58+
entra_resource: str | None = None,
59+
**kwargs,
60+
) -> None:
61+
super().__init__(**kwargs)
62+
self.provider = provider
63+
self.token = token
64+
self.entra_resource = entra_resource
65+
66+
self.attestation: WorkloadIdentityAttestation | None = None
67+
68+
def type_(self) -> AuthType:
69+
return AuthType.WORKLOAD_IDENTITY
70+
71+
def reset_secrets(self) -> None:
72+
self.attestation = None
73+
74+
def update_body(self, body: dict[typing.Any, typing.Any]) -> None:
75+
body["data"]["AUTHENTICATOR"] = WORKLOAD_IDENTITY_AUTHENTICATOR
76+
body["data"]["PROVIDER"] = ApiFederatedAuthenticationType.from_attestation(
77+
self.attestation
78+
).value
79+
body["data"]["TOKEN"] = self.attestation.credential
80+
81+
def prepare(self, **kwargs: typing.Any) -> None:
82+
"""Fetch the token."""
83+
self.attestation = create_attestation(
84+
self.provider, self.entra_resource, self.token
85+
)
86+
87+
def reauthenticate(self, **kwargs: typing.Any) -> dict[str, bool]:
88+
"""This is only relevant for AuthByIdToken, which uses a web-browser based flow. All other auth plugins just call authenticate() again."""
89+
return {"success": False}
90+
91+
@property
92+
def assertion_content(self) -> str:
93+
"""Returns the CSP provider name and an identifier. Used for logging purposes."""
94+
if not self.attestation:
95+
return ""
96+
properties = self.attestation.user_identifier_components
97+
properties["_provider"] = self.attestation.provider.value
98+
return json.dumps(properties, sort_keys=True, separators=(",", ":"))

src/snowflake/connector/connection.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
AuthByPlugin,
4545
AuthByUsrPwdMfa,
4646
AuthByWebBrowser,
47+
AuthByWorkloadIdentity,
4748
AuthNoAuth,
4849
)
4950
from .auth.idtoken import AuthByIdToken
@@ -55,6 +56,7 @@
5556
from .constants import (
5657
_CONNECTIVITY_ERR_MSG,
5758
_DOMAIN_NAME_MAP,
59+
ENV_VAR_EXPERIMENTAL_AUTHENTICATION,
5860
ENV_VAR_PARTNER,
5961
PARAMETER_AUTOCOMMIT,
6062
PARAMETER_CLIENT_PREFETCH_THREADS,
@@ -87,6 +89,7 @@
8789
ER_FAILED_TO_CONNECT_TO_DB,
8890
ER_INVALID_BACKOFF_POLICY,
8991
ER_INVALID_VALUE,
92+
ER_INVALID_WIF_SETTINGS,
9093
ER_NO_ACCOUNT_NAME,
9194
ER_NO_NUMPY,
9295
ER_NO_PASSWORD,
@@ -104,6 +107,7 @@
104107
PROGRAMMATIC_ACCESS_TOKEN,
105108
REQUEST_ID,
106109
USR_PWD_MFA_AUTHENTICATOR,
110+
WORKLOAD_IDENTITY_AUTHENTICATOR,
107111
ReauthenticationRequest,
108112
SnowflakeRestful,
109113
)
@@ -112,6 +116,7 @@
112116
from .time_util import HeartBeatTimer, get_time_millis
113117
from .url_util import extract_top_level_domain_from_hostname
114118
from .util_text import construct_hostname, parse_account, split_statements
119+
from .wif_util import AttestationProvider
115120

116121
DEFAULT_CLIENT_PREFETCH_THREADS = 4
117122
MAX_CLIENT_PREFETCH_THREADS = 10
@@ -188,12 +193,14 @@ def _get_private_bytes_from_file(
188193
"private_key": (None, (type(None), bytes, str, RSAPrivateKey)),
189194
"private_key_file": (None, (type(None), str)),
190195
"private_key_file_pwd": (None, (type(None), str, bytes)),
191-
"token": (None, (type(None), str)), # OAuth/JWT/PAT Token
196+
"token": (None, (type(None), str)), # OAuth/JWT/PAT/OIDC Token
192197
"token_file_path": (
193198
None,
194199
(type(None), str, bytes),
195-
), # OAuth/JWT/PAT Token file path
200+
), # OAuth/JWT/PAT/OIDC Token file path
196201
"authenticator": (DEFAULT_AUTHENTICATOR, (type(None), str)),
202+
"workload_identity_provider": (None, (type(None), AttestationProvider)),
203+
"workload_identity_entra_resource": (None, (type(None), str)),
197204
"mfa_callback": (None, (type(None), Callable)),
198205
"password_callback": (None, (type(None), Callable)),
199206
"auth_class": (None, (type(None), AuthByPlugin)),
@@ -1144,6 +1151,29 @@ def __open_connection(self):
11441151
if not self._token and self._password:
11451152
self._token = self._password
11461153
self.auth_class = AuthByPAT(self._token)
1154+
elif self._authenticator == WORKLOAD_IDENTITY_AUTHENTICATOR:
1155+
if ENV_VAR_EXPERIMENTAL_AUTHENTICATION not in os.environ:
1156+
Error.errorhandler_wrapper(
1157+
self,
1158+
None,
1159+
ProgrammingError,
1160+
{
1161+
"msg": f"Please set the '{ENV_VAR_EXPERIMENTAL_AUTHENTICATION}' environment variable to use the '{WORKLOAD_IDENTITY_AUTHENTICATOR}' authenticator.",
1162+
"errno": ER_INVALID_WIF_SETTINGS,
1163+
},
1164+
)
1165+
# Standardize the provider enum.
1166+
if self._workload_identity_provider and isinstance(
1167+
self._workload_identity_provider, str
1168+
):
1169+
self._workload_identity_provider = AttestationProvider.from_string(
1170+
self._workload_identity_provider
1171+
)
1172+
self.auth_class = AuthByWorkloadIdentity(
1173+
provider=self._workload_identity_provider,
1174+
token=self._token,
1175+
entra_resource=self._workload_identity_entra_resource,
1176+
)
11471177
else:
11481178
# okta URL, e.g., https://<account>.okta.com/
11491179
self.auth_class = AuthByOkta(
@@ -1267,6 +1297,7 @@ def __config(self, **kwargs):
12671297
KEY_PAIR_AUTHENTICATOR,
12681298
OAUTH_AUTHENTICATOR,
12691299
USR_PWD_MFA_AUTHENTICATOR,
1300+
WORKLOAD_IDENTITY_AUTHENTICATOR,
12701301
]:
12711302
self._authenticator = auth_tmp
12721303

@@ -1277,14 +1308,18 @@ def __config(self, **kwargs):
12771308
self._token = f.read()
12781309

12791310
# Set of authenticators allowing empty user.
1280-
empty_user_allowed_authenticators = {OAUTH_AUTHENTICATOR, NO_AUTH_AUTHENTICATOR}
1311+
empty_user_allowed_authenticators = {
1312+
OAUTH_AUTHENTICATOR,
1313+
NO_AUTH_AUTHENTICATOR,
1314+
WORKLOAD_IDENTITY_AUTHENTICATOR,
1315+
}
12811316

12821317
if not (self._master_token and self._session_token):
12831318
if (
12841319
not self.user
12851320
and self._authenticator not in empty_user_allowed_authenticators
12861321
):
1287-
# OAuth and NoAuth Authentications does not require a username
1322+
# Some authenticators do not require a username
12881323
Error.errorhandler_wrapper(
12891324
self,
12901325
None,
@@ -1295,6 +1330,25 @@ def __config(self, **kwargs):
12951330
if self._private_key or self._private_key_file:
12961331
self._authenticator = KEY_PAIR_AUTHENTICATOR
12971332

1333+
workload_identity_dependent_options = [
1334+
"workload_identity_provider",
1335+
"workload_identity_entra_resource",
1336+
]
1337+
for dependent_option in workload_identity_dependent_options:
1338+
if (
1339+
self.__getattribute__(f"_{dependent_option}") is not None
1340+
and self._authenticator != WORKLOAD_IDENTITY_AUTHENTICATOR
1341+
):
1342+
Error.errorhandler_wrapper(
1343+
self,
1344+
None,
1345+
ProgrammingError,
1346+
{
1347+
"msg": f"{dependent_option} was set but authenticator was not set to {WORKLOAD_IDENTITY_AUTHENTICATOR}",
1348+
"errno": ER_INVALID_WIF_SETTINGS,
1349+
},
1350+
)
1351+
12981352
if (
12991353
self.auth_class is None
13001354
and self._authenticator
@@ -1303,6 +1357,7 @@ def __config(self, **kwargs):
13031357
OAUTH_AUTHENTICATOR,
13041358
KEY_PAIR_AUTHENTICATOR,
13051359
PROGRAMMATIC_ACCESS_TOKEN,
1360+
WORKLOAD_IDENTITY_AUTHENTICATOR,
13061361
)
13071362
and not self._password
13081363
):

src/snowflake/connector/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@ class IterUnit(Enum):
430430
# TODO: all env variables definitions should be here
431431
ENV_VAR_PARTNER = "SF_PARTNER"
432432
ENV_VAR_TEST_MODE = "SNOWFLAKE_TEST_MODE"
433+
ENV_VAR_EXPERIMENTAL_AUTHENTICATION = "SF_ENABLE_EXPERIMENTAL_AUTHENTICATION" # Needed to enable new strong auth features during the private preview.
433434

434435

435436
_DOMAIN_NAME_MAP = {_DEFAULT_HOSTNAME_TLD: "GLOBAL", _CHINA_HOSTNAME_TLD: "CHINA"}

src/snowflake/connector/errorcode.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
ER_JWT_RETRY_EXPIRED = 251010
3232
ER_CONNECTION_TIMEOUT = 251011
3333
ER_RETRYABLE_CODE = 251012
34+
ER_INVALID_WIF_SETTINGS = 251013
35+
ER_WIF_CREDENTIALS_NOT_FOUND = 251014
3436

3537
# cursor
3638
ER_FAILED_TO_REWRITE_MULTI_ROW_INSERT = 252001

src/snowflake/connector/network.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@
189189
USR_PWD_MFA_AUTHENTICATOR = "USERNAME_PASSWORD_MFA"
190190
PROGRAMMATIC_ACCESS_TOKEN = "PROGRAMMATIC_ACCESS_TOKEN"
191191
NO_AUTH_AUTHENTICATOR = "NO_AUTH"
192+
WORKLOAD_IDENTITY_AUTHENTICATOR = "WORKLOAD_IDENTITY"
192193

193194

194195
def is_retryable_http_code(code: int) -> bool:

0 commit comments

Comments
 (0)