Skip to content

Commit 116604c

Browse files
use boto method to get partition
1 parent 0e03580 commit 116604c

File tree

3 files changed

+22
-74
lines changed

3 files changed

+22
-74
lines changed

src/snowflake/connector/wif_util.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from base64 import b64encode
77
from dataclasses import dataclass
88
from enum import Enum, unique
9-
from hashlib import sha256
109

1110
import boto3
1211
import jwt
@@ -93,23 +92,6 @@ def get_aws_region() -> str:
9392
return region
9493

9594

96-
def get_aws_partition(region: str) -> str:
97-
"""Get the current AWS partition from region.
98-
99-
Args:
100-
region (str): The AWS region (e.g., 'us-east-1', 'cn-north-1').
101-
102-
Returns:
103-
str: The AWS partition (e.g., 'aws', 'aws-cn', 'aws-us-gov').
104-
"""
105-
if region.startswith("cn-"):
106-
return "aws-cn"
107-
elif region.startswith("us-gov-"):
108-
return "aws-us-gov"
109-
else:
110-
return "aws"
111-
112-
11395
def get_aws_sts_hostname(region: str, partition: str) -> str:
11496
"""Constructs the AWS STS hostname for a given region and partition.
11597
@@ -152,14 +134,15 @@ def create_aws_attestation() -> WorkloadIdentityAttestation:
152134
153135
If the application isn't running on AWS or no credentials were found, raises an error.
154136
"""
155-
aws_creds = boto3.session.Session().get_credentials()
137+
session = boto3.session.Session()
138+
aws_creds = session.get_credentials()
156139
if not aws_creds:
157140
raise ProgrammingError(
158141
msg="No AWS credentials were found. Ensure the application is running on AWS with an IAM role attached.",
159142
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
160143
)
161144
region = get_aws_region()
162-
partition = get_aws_partition(region)
145+
partition = session.get_partition_for_region(region)
163146
sts_hostname = get_aws_sts_hostname(region, partition)
164147
request = AWSRequest(
165148
method="POST",
@@ -178,8 +161,10 @@ def create_aws_attestation() -> WorkloadIdentityAttestation:
178161
"headers": dict(request.headers.items()),
179162
}
180163
credential = b64encode(json.dumps(assertion_dict).encode("utf-8")).decode("utf-8")
164+
# Unlike other providers, for AWS, we only include general identifiers (region and partition)
165+
# rather than specific user identifiers, since we don't actually execute a GetCallerIdentity call.
181166
return WorkloadIdentityAttestation(
182-
AttestationProvider.AWS, credential, {"hashed_attestation": sha256(credential.encode("utf-8")).hexdigest()}
167+
AttestationProvider.AWS, credential, {"region": region, "partition": partition}
183168
)
184169

185170

test/csp_helpers.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -358,9 +358,6 @@ def __init__(self):
358358
def get_region(self):
359359
return self.region
360360

361-
def get_arn(self):
362-
return self.arn
363-
364361
def get_credentials(self):
365362
return self.credentials
366363

@@ -403,11 +400,6 @@ def __enter__(self):
403400
side_effect=self.get_region,
404401
)
405402
)
406-
self.patchers.append(
407-
mock.patch(
408-
"snowflake.connector.wif_util.get_aws_arn", side_effect=self.get_arn
409-
)
410-
)
411403
self.patchers.append(
412404
mock.patch(
413405
"snowflake.connector.platform_detection.IMDSFetcher._get_request",

test/unit/test_auth_workload_identity.py

Lines changed: 16 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,7 @@
1515
HTTPError,
1616
Timeout,
1717
)
18-
from snowflake.connector.wif_util import (
19-
AttestationProvider,
20-
get_aws_partition,
21-
get_aws_sts_hostname,
22-
)
18+
from snowflake.connector.wif_util import AttestationProvider, get_aws_sts_hostname
2319

2420
from ..csp_helpers import FakeAwsEnvironment, FakeGceMetadataService, gen_dummy_id_token
2521

@@ -129,8 +125,19 @@ def test_explicit_aws_encodes_audience_host_signature_to_api(
129125
verify_aws_token(data["TOKEN"], fake_aws_environment.region)
130126

131127

132-
def test_explicit_aws_uses_regional_hostname(fake_aws_environment: FakeAwsEnvironment):
133-
fake_aws_environment.region = "antarctica-northeast-3"
128+
@pytest.mark.parametrize(
129+
"region,expected_hostname",
130+
[
131+
("us-east-1", "sts.us-east-1.amazonaws.com"),
132+
("af-south-1", "sts.af-south-1.amazonaws.com"),
133+
("us-gov-west-1", "sts.us-gov-west-1.amazonaws.com"),
134+
("cn-north-1", "sts.cn-north-1.amazonaws.com.cn"),
135+
],
136+
)
137+
def test_explicit_aws_uses_regional_hostnames(
138+
fake_aws_environment: FakeAwsEnvironment, region: str, expected_hostname: str
139+
):
140+
fake_aws_environment.region = region
134141

135142
auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS)
136143
auth_class.prepare()
@@ -140,59 +147,23 @@ def test_explicit_aws_uses_regional_hostname(fake_aws_environment: FakeAwsEnviro
140147
hostname_from_url = urlparse(decoded_token["url"]).hostname
141148
hostname_from_header = decoded_token["headers"]["Host"]
142149

143-
expected_hostname = "sts.antarctica-northeast-3.amazonaws.com"
144150
assert expected_hostname == hostname_from_url
145151
assert expected_hostname == hostname_from_header
146152

147153

148154
def test_explicit_aws_generates_unique_assertion_content(
149155
fake_aws_environment: FakeAwsEnvironment,
150156
):
151-
fake_aws_environment.arn = (
152-
"arn:aws:sts::123456789:assumed-role/A-Different-Role/i-34afe100cad287fab"
153-
)
157+
fake_aws_environment.region = "us-east-1"
154158
auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS)
155159
auth_class.prepare()
156160

157161
assert (
158-
'{"_provider":"AWS","arn":"arn:aws:sts::123456789:assumed-role/A-Different-Role/i-34afe100cad287fab"}'
162+
'{"_provider":"AWS","partition":"aws","region":"us-east-1"}'
159163
== auth_class.assertion_content
160164
)
161165

162166

163-
@pytest.mark.parametrize(
164-
"arn, expected_partition",
165-
[
166-
("arn:aws:iam::123456789012:role/MyTestRole", "aws"),
167-
(
168-
"arn:aws-cn:ec2:cn-north-1:987654321098:instance/i-1234567890abcdef0",
169-
"aws-cn",
170-
),
171-
("arn:aws-us-gov:s3:::my-gov-bucket", "aws-us-gov"),
172-
("arn:aws:s3:::my-bucket/my/key", "aws"),
173-
("arn:aws:lambda:us-east-1:123456789012:function:my-function", "aws"),
174-
("arn:aws:sns:eu-west-1:111122223333:my-topic", "aws"),
175-
("arn:aws:iam:", "aws"), # Incomplete ARN, but partition is present
176-
],
177-
)
178-
def test_get_aws_partition_valid_arns(arn, expected_partition):
179-
assert get_aws_partition(arn) == expected_partition
180-
181-
182-
@pytest.mark.parametrize(
183-
"arn",
184-
[
185-
"invalid-arn",
186-
"arn::service:region:account:resource", # Missing partition
187-
"", # Empty string
188-
],
189-
)
190-
def test_get_aws_partition_invalid_arns(arn):
191-
with pytest.raises(ProgrammingError) as excinfo:
192-
get_aws_partition(arn)
193-
assert "Invalid AWS ARN" in str(excinfo.value)
194-
195-
196167
@pytest.mark.parametrize(
197168
"region, partition, expected_hostname",
198169
[

0 commit comments

Comments
 (0)