Skip to content

Commit d0d9587

Browse files
Fix bug in AWS sovereign partition support (#2459)
1 parent 08dbe0e commit d0d9587

File tree

3 files changed

+22
-92
lines changed

3 files changed

+22
-92
lines changed

src/snowflake/connector/wif_util.py

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -92,41 +92,6 @@ def get_aws_region() -> str:
9292
return region
9393

9494

95-
def get_aws_arn() -> str:
96-
"""Get the current AWS workload's ARN."""
97-
caller_identity = boto3.client("sts").get_caller_identity()
98-
if not caller_identity or "Arn" not in caller_identity:
99-
raise ProgrammingError(
100-
msg="No AWS identity was found. Ensure the application is running on AWS with an IAM role attached.",
101-
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
102-
)
103-
return caller_identity["Arn"]
104-
105-
106-
def get_aws_partition(arn: str) -> str:
107-
"""Get the current AWS partition from ARN.
108-
109-
Args:
110-
arn (str): The Amazon Resource Name (ARN) string.
111-
112-
Returns:
113-
str: The AWS partition (e.g., 'aws', 'aws-cn', 'aws-us-gov').
114-
115-
Raises:
116-
ProgrammingError: If the ARN is invalid or does not contain a valid partition.
117-
118-
Reference: https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html.
119-
"""
120-
parts = arn.split(":")
121-
if len(parts) > 1 and parts[0] == "arn" and parts[1]:
122-
return parts[1]
123-
124-
raise ProgrammingError(
125-
msg=f"Invalid AWS ARN: '{arn}'.",
126-
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
127-
)
128-
129-
13095
def get_aws_sts_hostname(region: str, partition: str) -> str:
13196
"""Constructs the AWS STS hostname for a given region and partition.
13297
@@ -169,15 +134,15 @@ def create_aws_attestation() -> WorkloadIdentityAttestation:
169134
170135
If the application isn't running on AWS or no credentials were found, raises an error.
171136
"""
172-
aws_creds = boto3.session.Session().get_credentials()
137+
session = boto3.session.Session()
138+
aws_creds = session.get_credentials()
173139
if not aws_creds:
174140
raise ProgrammingError(
175141
msg="No AWS credentials were found. Ensure the application is running on AWS with an IAM role attached.",
176142
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
177143
)
178144
region = get_aws_region()
179-
arn = get_aws_arn()
180-
partition = get_aws_partition(arn)
145+
partition = session.get_partition_for_region(region)
181146
sts_hostname = get_aws_sts_hostname(region, partition)
182147
request = AWSRequest(
183148
method="POST",
@@ -196,8 +161,10 @@ def create_aws_attestation() -> WorkloadIdentityAttestation:
196161
"headers": dict(request.headers.items()),
197162
}
198163
credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode("utf-8")
164+
# Unlike other providers, for AWS, we only include general identifiers (region and partition)
165+
# rather than specific user identifiers, since we don't actually execute a GetCallerIdentity call.
199166
return WorkloadIdentityAttestation(
200-
AttestationProvider.AWS, credential, {"arn": arn}
167+
AttestationProvider.AWS, credential, {"region": region, "partition": partition}
201168
)
202169

203170

test/csp_helpers.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -358,9 +358,6 @@ def __init__(self):
358358
def get_region(self):
359359
return self.region
360360

361-
def get_arn(self):
362-
return self.arn
363-
364361
def get_credentials(self):
365362
return self.credentials
366363

@@ -403,11 +400,6 @@ def __enter__(self):
403400
side_effect=self.get_region,
404401
)
405402
)
406-
self.patchers.append(
407-
mock.patch(
408-
"snowflake.connector.wif_util.get_aws_arn", side_effect=self.get_arn
409-
)
410-
)
411403
self.patchers.append(
412404
mock.patch(
413405
"snowflake.connector.platform_detection.IMDSFetcher._get_request",

test/unit/test_auth_workload_identity.py

Lines changed: 16 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,7 @@
1515
HTTPError,
1616
Timeout,
1717
)
18-
from snowflake.connector.wif_util import (
19-
AttestationProvider,
20-
get_aws_partition,
21-
get_aws_sts_hostname,
22-
)
18+
from snowflake.connector.wif_util import AttestationProvider, get_aws_sts_hostname
2319

2420
from ..csp_helpers import FakeAwsEnvironment, FakeGceMetadataService, gen_dummy_id_token
2521

@@ -129,8 +125,19 @@ def test_explicit_aws_encodes_audience_host_signature_to_api(
129125
verify_aws_token(data["TOKEN"], fake_aws_environment.region)
130126

131127

132-
def test_explicit_aws_uses_regional_hostname(fake_aws_environment: FakeAwsEnvironment):
133-
fake_aws_environment.region = "antarctica-northeast-3"
128+
@pytest.mark.parametrize(
129+
"region,expected_hostname",
130+
[
131+
("us-east-1", "sts.us-east-1.amazonaws.com"),
132+
("af-south-1", "sts.af-south-1.amazonaws.com"),
133+
("us-gov-west-1", "sts.us-gov-west-1.amazonaws.com"),
134+
("cn-north-1", "sts.cn-north-1.amazonaws.com.cn"),
135+
],
136+
)
137+
def test_explicit_aws_uses_regional_hostnames(
138+
fake_aws_environment: FakeAwsEnvironment, region: str, expected_hostname: str
139+
):
140+
fake_aws_environment.region = region
134141

135142
auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS)
136143
auth_class.prepare()
@@ -140,59 +147,23 @@ def test_explicit_aws_uses_regional_hostname(fake_aws_environment: FakeAwsEnviro
140147
hostname_from_url = urlparse(decoded_token["url"]).hostname
141148
hostname_from_header = decoded_token["headers"]["Host"]
142149

143-
expected_hostname = "sts.antarctica-northeast-3.amazonaws.com"
144150
assert expected_hostname == hostname_from_url
145151
assert expected_hostname == hostname_from_header
146152

147153

148154
def test_explicit_aws_generates_unique_assertion_content(
149155
fake_aws_environment: FakeAwsEnvironment,
150156
):
151-
fake_aws_environment.arn = (
152-
"arn:aws:sts::123456789:assumed-role/A-Different-Role/i-34afe100cad287fab"
153-
)
157+
fake_aws_environment.region = "us-east-1"
154158
auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS)
155159
auth_class.prepare()
156160

157161
assert (
158-
'{"_provider":"AWS","arn":"arn:aws:sts::123456789:assumed-role/A-Different-Role/i-34afe100cad287fab"}'
162+
'{"_provider":"AWS","partition":"aws","region":"us-east-1"}'
159163
== auth_class.assertion_content
160164
)
161165

162166

163-
@pytest.mark.parametrize(
164-
"arn, expected_partition",
165-
[
166-
("arn:aws:iam::123456789012:role/MyTestRole", "aws"),
167-
(
168-
"arn:aws-cn:ec2:cn-north-1:987654321098:instance/i-1234567890abcdef0",
169-
"aws-cn",
170-
),
171-
("arn:aws-us-gov:s3:::my-gov-bucket", "aws-us-gov"),
172-
("arn:aws:s3:::my-bucket/my/key", "aws"),
173-
("arn:aws:lambda:us-east-1:123456789012:function:my-function", "aws"),
174-
("arn:aws:sns:eu-west-1:111122223333:my-topic", "aws"),
175-
("arn:aws:iam:", "aws"), # Incomplete ARN, but partition is present
176-
],
177-
)
178-
def test_get_aws_partition_valid_arns(arn, expected_partition):
179-
assert get_aws_partition(arn) == expected_partition
180-
181-
182-
@pytest.mark.parametrize(
183-
"arn",
184-
[
185-
"invalid-arn",
186-
"arn::service:region:account:resource", # Missing partition
187-
"", # Empty string
188-
],
189-
)
190-
def test_get_aws_partition_invalid_arns(arn):
191-
with pytest.raises(ProgrammingError) as excinfo:
192-
get_aws_partition(arn)
193-
assert "Invalid AWS ARN" in str(excinfo.value)
194-
195-
196167
@pytest.mark.parametrize(
197168
"region, partition, expected_hostname",
198169
[

0 commit comments

Comments
 (0)