Skip to content

Commit 5e5fb5e

Browse files
remove duplication in workflow identity
1 parent 59ab6d4 commit 5e5fb5e

File tree

2 files changed

+14
-65
lines changed

2 files changed

+14
-65
lines changed

src/snowflake/connector/aio/auth/_workload_identity.py

Lines changed: 13 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,16 @@
44

55
from __future__ import annotations
66

7-
from enum import Enum, unique
87
from typing import Any
98

10-
from ...auth.by_plugin import AuthType
11-
from ...network import WORKLOAD_IDENTITY_AUTHENTICATOR
9+
from ...auth.workload_identity import (
10+
AuthByWorkloadIdentity as AuthByWorkloadIdentitySync,
11+
)
1212
from .._wif_util import AttestationProvider, create_attestation
1313
from ._by_plugin import AuthByPlugin as AuthByPluginAsync
1414

1515

16-
@unique
17-
class ApiFederatedAuthenticationType(Enum):
18-
"""An API-specific enum of the WIF authentication type."""
19-
20-
AWS = "AWS"
21-
AZURE = "AZURE"
22-
GCP = "GCP"
23-
OIDC = "OIDC"
24-
25-
@staticmethod
26-
def from_attestation(attestation) -> ApiFederatedAuthenticationType:
27-
"""Maps the internal / driver-specific attestation providers to API authenticator types."""
28-
if attestation.provider == AttestationProvider.AWS:
29-
return ApiFederatedAuthenticationType.AWS
30-
if attestation.provider == AttestationProvider.AZURE:
31-
return ApiFederatedAuthenticationType.AZURE
32-
if attestation.provider == AttestationProvider.GCP:
33-
return ApiFederatedAuthenticationType.GCP
34-
if attestation.provider == AttestationProvider.OIDC:
35-
return ApiFederatedAuthenticationType.OIDC
36-
raise ValueError(f"Unknown attestation provider '{attestation.provider}'")
37-
38-
39-
class AuthByWorkloadIdentity(AuthByPluginAsync):
16+
class AuthByWorkloadIdentity(AuthByWorkloadIdentitySync, AuthByPluginAsync):
4017
"""Plugin to authenticate via workload identity."""
4118

4219
def __init__(
@@ -48,17 +25,16 @@ def __init__(
4825
**kwargs,
4926
) -> None:
5027
"""Initializes an instance with workload identity authentication."""
51-
super().__init__(**kwargs)
52-
self.provider = provider
53-
self.token = token
54-
self.entra_resource = entra_resource
55-
self.attestation = None
56-
57-
def type_(self) -> AuthType:
58-
return AuthType.WORKLOAD_IDENTITY
28+
AuthByWorkloadIdentitySync.__init__(
29+
self,
30+
provider=provider,
31+
token=token,
32+
entra_resource=entra_resource,
33+
**kwargs,
34+
)
5935

6036
async def reset_secrets(self) -> None:
61-
self.attestation = None
37+
AuthByWorkloadIdentitySync.reset_secrets(self)
6238

6339
async def prepare(self, **kwargs: Any) -> None:
6440
"""Fetch the token using async wif_util."""
@@ -71,19 +47,4 @@ async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]:
7147
return {"success": False}
7248

7349
async def update_body(self, body: dict[Any, Any]) -> None:
74-
body["data"]["AUTHENTICATOR"] = WORKLOAD_IDENTITY_AUTHENTICATOR
75-
body["data"]["PROVIDER"] = ApiFederatedAuthenticationType.from_attestation(
76-
self.attestation
77-
).value
78-
body["data"]["TOKEN"] = self.attestation.credential
79-
80-
@property
81-
def assertion_content(self) -> str:
82-
"""Returns the CSP provider name and an identifier. Used for logging purposes."""
83-
if not self.attestation:
84-
return ""
85-
properties = self.attestation.user_identifier_components
86-
properties["_provider"] = self.attestation.provider.value
87-
import json
88-
89-
return json.dumps(properties, sort_keys=True, separators=(",", ":"))
50+
AuthByWorkloadIdentitySync.update_body(self, body)

test/csp_helpers.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -354,19 +354,7 @@ def __enter__(self):
354354
"snowflake.connector.wif_util.get_aws_arn", side_effect=self.get_arn
355355
)
356356
)
357-
# Also patch the async versions
358-
self.patchers.append(
359-
mock.patch(
360-
"snowflake.connector.aio._wif_util.get_aws_region",
361-
side_effect=self.get_region,
362-
)
363-
)
364-
self.patchers.append(
365-
mock.patch(
366-
"snowflake.connector.aio._wif_util.get_aws_arn",
367-
side_effect=self.get_arn,
368-
)
369-
)
357+
# Note: No need to patch async versions anymore since async now imports from sync
370358
for patcher in self.patchers:
371359
patcher.__enter__()
372360
return self

0 commit comments

Comments
 (0)