15
15
HTTPError ,
16
16
Timeout ,
17
17
)
18
- from snowflake .connector .wif_util import (
19
- AttestationProvider ,
20
- get_aws_partition ,
21
- get_aws_sts_hostname ,
22
- )
18
+ from snowflake .connector .wif_util import AttestationProvider , get_aws_sts_hostname
23
19
24
20
from ..csp_helpers import FakeAwsEnvironment , FakeGceMetadataService , gen_dummy_id_token
25
21
@@ -129,8 +125,19 @@ def test_explicit_aws_encodes_audience_host_signature_to_api(
129
125
verify_aws_token (data ["TOKEN" ], fake_aws_environment .region )
130
126
131
127
132
- def test_explicit_aws_uses_regional_hostname (fake_aws_environment : FakeAwsEnvironment ):
133
- fake_aws_environment .region = "antarctica-northeast-3"
128
+ @pytest .mark .parametrize (
129
+ "region,expected_hostname" ,
130
+ [
131
+ ("us-east-1" , "sts.us-east-1.amazonaws.com" ),
132
+ ("af-south-1" , "sts.af-south-1.amazonaws.com" ),
133
+ ("us-gov-west-1" , "sts.us-gov-west-1.amazonaws.com" ),
134
+ ("cn-north-1" , "sts.cn-north-1.amazonaws.com.cn" ),
135
+ ],
136
+ )
137
+ def test_explicit_aws_uses_regional_hostnames (
138
+ fake_aws_environment : FakeAwsEnvironment , region : str , expected_hostname : str
139
+ ):
140
+ fake_aws_environment .region = region
134
141
135
142
auth_class = AuthByWorkloadIdentity (provider = AttestationProvider .AWS )
136
143
auth_class .prepare ()
@@ -140,59 +147,23 @@ def test_explicit_aws_uses_regional_hostname(fake_aws_environment: FakeAwsEnviro
140
147
hostname_from_url = urlparse (decoded_token ["url" ]).hostname
141
148
hostname_from_header = decoded_token ["headers" ]["Host" ]
142
149
143
- expected_hostname = "sts.antarctica-northeast-3.amazonaws.com"
144
150
assert expected_hostname == hostname_from_url
145
151
assert expected_hostname == hostname_from_header
146
152
147
153
148
154
def test_explicit_aws_generates_unique_assertion_content (
149
155
fake_aws_environment : FakeAwsEnvironment ,
150
156
):
151
- fake_aws_environment .arn = (
152
- "arn:aws:sts::123456789:assumed-role/A-Different-Role/i-34afe100cad287fab"
153
- )
157
+ fake_aws_environment .region = "us-east-1"
154
158
auth_class = AuthByWorkloadIdentity (provider = AttestationProvider .AWS )
155
159
auth_class .prepare ()
156
160
157
161
assert (
158
- '{"_provider":"AWS","arn ":"arn: aws:sts::123456789:assumed-role/A-Different-Role/i-34afe100cad287fab "}'
162
+ '{"_provider":"AWS","partition ":"aws","region":"us-east-1 "}'
159
163
== auth_class .assertion_content
160
164
)
161
165
162
166
163
- @pytest .mark .parametrize (
164
- "arn, expected_partition" ,
165
- [
166
- ("arn:aws:iam::123456789012:role/MyTestRole" , "aws" ),
167
- (
168
- "arn:aws-cn:ec2:cn-north-1:987654321098:instance/i-1234567890abcdef0" ,
169
- "aws-cn" ,
170
- ),
171
- ("arn:aws-us-gov:s3:::my-gov-bucket" , "aws-us-gov" ),
172
- ("arn:aws:s3:::my-bucket/my/key" , "aws" ),
173
- ("arn:aws:lambda:us-east-1:123456789012:function:my-function" , "aws" ),
174
- ("arn:aws:sns:eu-west-1:111122223333:my-topic" , "aws" ),
175
- ("arn:aws:iam:" , "aws" ), # Incomplete ARN, but partition is present
176
- ],
177
- )
178
- def test_get_aws_partition_valid_arns (arn , expected_partition ):
179
- assert get_aws_partition (arn ) == expected_partition
180
-
181
-
182
- @pytest .mark .parametrize (
183
- "arn" ,
184
- [
185
- "invalid-arn" ,
186
- "arn::service:region:account:resource" , # Missing partition
187
- "" , # Empty string
188
- ],
189
- )
190
- def test_get_aws_partition_invalid_arns (arn ):
191
- with pytest .raises (ProgrammingError ) as excinfo :
192
- get_aws_partition (arn )
193
- assert "Invalid AWS ARN" in str (excinfo .value )
194
-
195
-
196
167
@pytest .mark .parametrize (
197
168
"region, partition, expected_hostname" ,
198
169
[
0 commit comments