Skip to content

Commit 08dbaa8

Browse files
SNOW-2111644 Support sovereign clouds for WIF (#2367)
1 parent cef7b51 commit 08dbaa8

File tree

2 files changed

+173
-5
lines changed

2 files changed

+173
-5
lines changed

src/snowflake/connector/wif_util.py

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,19 @@
2222
SNOWFLAKE_AUDIENCE = "snowflakecomputing.com"
2323
DEFAULT_ENTRA_SNOWFLAKE_RESOURCE = "api://fd3f753b-eed3-462c-b6a7-a4b5bb650aad"
2424

25+
"""
26+
References:
27+
- https://learn.microsoft.com/en-us/entra/identity-platform/authentication-national-cloud#microsoft-entra-authentication-endpoints
28+
- https://learn.microsoft.com/en-us/answers/questions/1190472/what-are-the-token-issuers-for-the-sovereign-cloud
29+
"""
30+
AZURE_ISSUER_PREFIXES = [
31+
"https://sts.windows.net/", # Public and USGov (v1 issuer)
32+
"https://sts.chinacloudapi.cn/", # Mooncake (v1 issuer)
33+
"https://login.microsoftonline.com/", # Public (v2 issuer)
34+
"https://login.microsoftonline.us/", # USGov (v2 issuer)
35+
"https://login.partner.microsoftonline.cn/", # Mooncake (v2 issuer)
36+
]
37+
2538

2639
@unique
2740
class AttestationProvider(Enum):
@@ -108,6 +121,70 @@ def get_aws_arn() -> str | None:
108121
return caller_identity["Arn"]
109122

110123

124+
def get_aws_partition(arn: str) -> str | None:
125+
"""Get the current AWS partition from ARN, if any.
126+
127+
Args:
128+
arn (str): The Amazon Resource Name (ARN) string.
129+
130+
Returns:
131+
str | None: The AWS partition (e.g., 'aws', 'aws-cn', 'aws-us-gov')
132+
if found, otherwise None.
133+
134+
Reference: https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html.
135+
"""
136+
if not arn or not isinstance(arn, str):
137+
return None
138+
parts = arn.split(":")
139+
if len(parts) > 1 and parts[0] == "arn" and parts[1]:
140+
return parts[1]
141+
logger.warning("Invalid AWS ARN: %s", arn)
142+
return None
143+
144+
145+
def get_aws_sts_hostname(region: str, partition: str) -> str | None:
146+
"""Constructs the AWS STS hostname for a given region and partition.
147+
148+
Args:
149+
region (str): The AWS region (e.g., 'us-east-1', 'cn-north-1').
150+
partition (str): The AWS partition (e.g., 'aws', 'aws-cn', 'aws-us-gov').
151+
152+
Returns:
153+
str | None: The AWS STS hostname (e.g., 'sts.us-east-1.amazonaws.com')
154+
if a valid hostname can be constructed, otherwise None.
155+
156+
References:
157+
- https://docs.aws.amazon.com/sdkref/latest/guide/feature-sts-regionalized-endpoints.html
158+
- https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_region-endpoints.html
159+
- https://docs.aws.amazon.com/general/latest/gr/sts.html
160+
"""
161+
if (
162+
not region
163+
or not partition
164+
or not isinstance(region, str)
165+
or not isinstance(partition, str)
166+
):
167+
return None
168+
169+
if partition == "aws":
170+
# For the 'aws' partition, STS endpoints are generally regional
171+
# except for the global endpoint (sts.amazonaws.com) which is
172+
# generally resolved to us-east-1 under the hood by the SDKs
173+
# when a region is not explicitly specified.
174+
# However, for explicit regional endpoints, the format is sts.<region>.amazonaws.com
175+
return f"sts.{region}.amazonaws.com"
176+
elif partition == "aws-cn":
177+
# China regions have a different domain suffix
178+
return f"sts.{region}.amazonaws.com.cn"
179+
elif partition == "aws-us-gov":
180+
return (
181+
f"sts.{region}.amazonaws.com" # GovCloud uses .com, but dedicated regions
182+
)
183+
else:
184+
logger.warning("Invalid AWS partition: %s", partition)
185+
return None
186+
187+
111188
def create_aws_attestation() -> WorkloadIdentityAttestation | None:
112189
"""Tries to create a workload identity attestation for AWS.
113190
@@ -125,8 +202,12 @@ def create_aws_attestation() -> WorkloadIdentityAttestation | None:
125202
if not arn:
126203
logger.debug("No AWS caller identity was found.")
127204
return None
205+
partition = get_aws_partition(arn)
206+
if not partition:
207+
logger.debug("No AWS partition was found.")
208+
return None
128209

129-
sts_hostname = f"sts.{region}.amazonaws.com"
210+
sts_hostname = get_aws_sts_hostname(region, partition)
130211
request = AWSRequest(
131212
method="POST",
132213
url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15",
@@ -234,9 +315,8 @@ def create_azure_attestation(
234315
issuer, subject = extract_iss_and_sub_without_signature_verification(jwt_str)
235316
if not issuer or not subject:
236317
return None
237-
if not (
238-
issuer.startswith("https://sts.windows.net/")
239-
or issuer.startswith("https://login.microsoftonline.com/")
318+
if not any(
319+
issuer.startswith(issuer_prefix) for issuer_prefix in AZURE_ISSUER_PREFIXES
240320
):
241321
# This might happen if we're running on a different platform that responds to the same metadata request signature as Azure.
242322
logger.debug("Unexpected Azure token issuer '%s'", issuer)

test/unit/test_auth_workload_identity.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
HTTPError,
1515
Timeout,
1616
)
17-
from snowflake.connector.wif_util import AttestationProvider
17+
from snowflake.connector.wif_util import (
18+
AZURE_ISSUER_PREFIXES,
19+
AttestationProvider,
20+
get_aws_partition,
21+
get_aws_sts_hostname,
22+
)
1823

1924
from ..csp_helpers import FakeAwsEnvironment, FakeGceMetadataService, gen_dummy_id_token
2025

@@ -154,6 +159,73 @@ def test_explicit_aws_generates_unique_assertion_content(
154159
)
155160

156161

162+
@pytest.mark.parametrize(
163+
"arn, expected_partition",
164+
[
165+
("arn:aws:iam::123456789012:role/MyTestRole", "aws"),
166+
(
167+
"arn:aws-cn:ec2:cn-north-1:987654321098:instance/i-1234567890abcdef0",
168+
"aws-cn",
169+
),
170+
("arn:aws-us-gov:s3:::my-gov-bucket", "aws-us-gov"),
171+
("arn:aws:s3:::my-bucket/my/key", "aws"),
172+
("arn:aws:lambda:us-east-1:123456789012:function:my-function", "aws"),
173+
("arn:aws:sns:eu-west-1:111122223333:my-topic", "aws"),
174+
# Edge cases / Invalid inputs
175+
("invalid-arn", None),
176+
("arn::service:region:account:resource", None), # Missing partition
177+
("arn:aws:iam:", "aws"), # Incomplete ARN, but partition is present
178+
("", None), # Empty string
179+
(None, None), # None input
180+
(123, None), # Non-string input
181+
],
182+
)
183+
def test_get_aws_partition_valid_and_invalid_arns(arn, expected_partition):
184+
assert get_aws_partition(arn) == expected_partition
185+
186+
187+
@pytest.mark.parametrize(
188+
"region, partition, expected_hostname",
189+
[
190+
# AWS partition
191+
("us-east-1", "aws", "sts.us-east-1.amazonaws.com"),
192+
("eu-west-2", "aws", "sts.eu-west-2.amazonaws.com"),
193+
("ap-southeast-1", "aws", "sts.ap-southeast-1.amazonaws.com"),
194+
(
195+
"us-east-1",
196+
"aws",
197+
"sts.us-east-1.amazonaws.com",
198+
), # Redundant but good for coverage
199+
# AWS China partition
200+
("cn-north-1", "aws-cn", "sts.cn-north-1.amazonaws.com.cn"),
201+
("cn-northwest-1", "aws-cn", "sts.cn-northwest-1.amazonaws.com.cn"),
202+
("", "aws-cn", None), # No global endpoint for 'aws-cn' without region
203+
# AWS GovCloud partition
204+
("us-gov-west-1", "aws-us-gov", "sts.us-gov-west-1.amazonaws.com"),
205+
("us-gov-east-1", "aws-us-gov", "sts.us-gov-east-1.amazonaws.com"),
206+
("", "aws-us-gov", None), # No global endpoint for 'aws-us-gov' without region
207+
# Invalid/Edge cases
208+
("us-east-1", "unknown-partition", None), # Unknown partition
209+
("some-region", "invalid-partition", None), # Invalid partition
210+
(None, "aws", None), # None region
211+
("us-east-1", None, None), # None partition
212+
(123, "aws", None), # Non-string region
213+
("us-east-1", 456, None), # Non-string partition
214+
("", "", None), # Empty region and partition
215+
("us-east-1", "", None), # Empty partition
216+
(
217+
"invalid-region",
218+
"aws",
219+
"sts.invalid-region.amazonaws.com",
220+
), # Valid format, invalid region name
221+
],
222+
)
223+
def test_get_aws_sts_hostname_valid_and_invalid_inputs(
224+
region, partition, expected_hostname
225+
):
226+
assert get_aws_sts_hostname(region, partition) == expected_hostname
227+
228+
157229
# -- GCP Tests --
158230

159231

@@ -312,6 +384,22 @@ def test_explicit_azure_uses_explicit_entra_resource(fake_azure_metadata_service
312384
assert parsed["aud"] == "api://non-standard"
313385

314386

387+
@pytest.mark.parametrize(
388+
"issuer",
389+
[
390+
"https://sts.windows.net/067802cd-8f92-4c7c-bceb-ea8f15d31cc5",
391+
"https://sts.chinacloudapi.cn/067802cd-8f92-4c7c-bceb-ea8f15d31cc5",
392+
"https://login.microsoftonline.com/067802cd-8f92-4c7c-bceb-ea8f15d31cc5/v2.0",
393+
"https://login.microsoftonline.us/067802cd-8f92-4c7c-bceb-ea8f15d31cc5/v2.0",
394+
"https://login.partner.microsoftonline.cn/067802cd-8f92-4c7c-bceb-ea8f15d31cc5/v2.0",
395+
],
396+
)
397+
def test_azure_issuer_prefixes(issuer):
398+
assert any(
399+
issuer.startswith(issuer_prefix) for issuer_prefix in AZURE_ISSUER_PREFIXES
400+
)
401+
402+
315403
# -- Auto-detect Tests --
316404

317405

0 commit comments

Comments
 (0)