-
Notifications
You must be signed in to change notification settings - Fork 537
Expand file tree
/
Copy pathworkload_identity.py
More file actions
110 lines (92 loc) · 3.98 KB
/
workload_identity.py
File metadata and controls
110 lines (92 loc) · 3.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from __future__ import annotations
import json
import typing
from enum import Enum, unique
if typing.TYPE_CHECKING:
from snowflake.connector.connection import SnowflakeConnection
from ..network import WORKLOAD_IDENTITY_AUTHENTICATOR
from ..wif_util import (
AttestationProvider,
WorkloadIdentityAttestation,
create_attestation,
)
from .by_plugin import AuthByPlugin, AuthType
@unique
class ApiFederatedAuthenticationType(Enum):
"""An API-specific enum of the WIF authentication type."""
AWS = "AWS"
AZURE = "AZURE"
GCP = "GCP"
OIDC = "OIDC"
@staticmethod
def from_attestation(
attestation: WorkloadIdentityAttestation,
) -> ApiFederatedAuthenticationType:
"""Maps the internal / driver-specific attestation providers to API authenticator types.
The AttestationProvider is related to how the driver fetches the credential, while the API authenticator
type is related to how the credential is verified. In most current cases these may be the same, though
in the future we could have, for example, multiple AttestationProviders that all fetch an OIDC ID token.
"""
if attestation.provider == AttestationProvider.AWS:
return ApiFederatedAuthenticationType.AWS
if attestation.provider == AttestationProvider.AZURE:
return ApiFederatedAuthenticationType.AZURE
if attestation.provider == AttestationProvider.GCP:
return ApiFederatedAuthenticationType.GCP
if attestation.provider == AttestationProvider.OIDC:
return ApiFederatedAuthenticationType.OIDC
raise ValueError(f"Unknown attestation provider '{attestation.provider}'")
class AuthByWorkloadIdentity(AuthByPlugin):
"""Plugin to authenticate via workload identity."""
def __init__(
self,
*,
provider: AttestationProvider | None = None,
token: str | None = None,
entra_resource: str | None = None,
impersonation_path: list[str] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.provider = provider
self.token = token
self.entra_resource = entra_resource
self.impersonation_path = impersonation_path
self.attestation: WorkloadIdentityAttestation | None = None
def type_(self) -> AuthType:
return AuthType.WORKLOAD_IDENTITY
def reset_secrets(self) -> None:
self.attestation = None
def update_body(self, body: dict[typing.Any, typing.Any]) -> None:
body["data"]["AUTHENTICATOR"] = WORKLOAD_IDENTITY_AUTHENTICATOR
body["data"]["PROVIDER"] = ApiFederatedAuthenticationType.from_attestation(
self.attestation
).value
body["data"]["TOKEN"] = self.attestation.credential
body["data"].setdefault("CLIENT_ENVIRONMENT", {})[
"WORKLOAD_IDENTITY_IMPERSONATION_PATH_LENGTH"
] = len(self.impersonation_path or [])
def prepare(
self, *, conn: SnowflakeConnection | None, **kwargs: typing.Any
) -> None:
"""Fetch the token."""
self.attestation = create_attestation(
self.provider,
self.entra_resource,
self.token,
self.impersonation_path,
session_manager=(
conn._session_manager.clone(max_retries=0) if conn else None
),
)
def reauthenticate(self, **kwargs: typing.Any) -> dict[str, bool]:
"""This is only relevant for AuthByIdToken, which uses a web-browser based flow. All other auth plugins just call authenticate() again."""
return {"success": False}
@property
def assertion_content(self) -> str:
"""Returns the CSP provider name and an identifier. Used for logging purposes."""
if not self.attestation:
return ""
properties = self.attestation.user_identifier_components
properties["_provider"] = self.attestation.provider.value
return json.dumps(properties, sort_keys=True, separators=(",", ":"))