Skip to content

Commit a3faefd

Browse files
SNOW-2183023: fixed user id
1 parent 5e921d8 commit a3faefd

File tree

2 files changed

+6
-82
lines changed

2 files changed

+6
-82
lines changed

src/snowflake/connector/wif_util.py

Lines changed: 1 addition & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -243,79 +243,6 @@ def get_aws_region(session_manager: SessionManager | None = None) -> str | None:
243243
return None
244244

245245

246-
def get_aws_arn(session_manager: SessionManager | None = None) -> str | None:
247-
"""Get the current AWS workload's ARN by calling GetCallerIdentity.
248-
249-
Note: This function makes a network call to AWS STS and is only used for
250-
assertion content generation (logging and backward compatibility purposes).
251-
The ARN is not required for authentication - it's just used as a unique
252-
identifier for the workload in logs and assertion content.
253-
254-
Returns the ARN of the current AWS identity, or None if it cannot be determined.
255-
"""
256-
credentials = get_aws_credentials(session_manager)
257-
if not credentials:
258-
logger.debug("No AWS credentials available for ARN lookup.")
259-
return None
260-
261-
region = get_aws_region(session_manager)
262-
if not region:
263-
logger.debug("No AWS region available for ARN lookup.")
264-
return None
265-
266-
try:
267-
# Create the GetCallerIdentity request
268-
sts_hostname = get_aws_sts_hostname(region)
269-
url = f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15"
270-
271-
base_headers = {
272-
"Content-Type": "application/x-amz-json-1.1",
273-
}
274-
275-
signed_headers = aws_signature_v4_sign(
276-
credentials=credentials,
277-
method="POST",
278-
url=url,
279-
region=region,
280-
service="sts",
281-
headers=base_headers,
282-
)
283-
284-
# Make the actual request to get caller identity
285-
response = http_request(
286-
method="POST",
287-
url=url,
288-
headers=signed_headers,
289-
timeout_sec=10,
290-
session_manager=session_manager,
291-
)
292-
293-
if response and response.ok:
294-
# Parse the XML response to extract the ARN
295-
import xml.etree.ElementTree as ET
296-
297-
# Ensure content is bytes and decode it
298-
content = response.content
299-
if isinstance(content, bytes):
300-
content_str = content.decode("utf-8")
301-
else:
302-
content_str = str(content) if content else ""
303-
304-
if content_str:
305-
root = ET.fromstring(content_str)
306-
307-
# Find the Arn element in the response
308-
for elem in root.iter():
309-
if elem.tag.endswith("Arn") and elem.text:
310-
return elem.text.strip()
311-
312-
logger.debug("Failed to get ARN from GetCallerIdentity response.")
313-
return None
314-
except Exception as e:
315-
logger.debug(f"Error getting AWS ARN: {e}")
316-
return None
317-
318-
319246
def get_aws_sts_hostname(region: str) -> str | None:
320247
"""Constructs the AWS STS hostname for a given region.
321248
@@ -481,11 +408,7 @@ def create_aws_attestation(
481408
credential = b64encode(json.dumps(attestation_request).encode("utf-8")).decode(
482409
"utf-8"
483410
)
484-
485-
# Get the ARN for user identifier components (used only for assertion content - logging and backward compatibility)
486-
# The ARN is not required for authentication, but provides a unique identifier for the workload
487-
arn = get_aws_arn(session_manager)
488-
user_identifier_components = {"arn": arn} if arn else {}
411+
user_identifier_components = {"region": region}
489412

490413
return WorkloadIdentityAttestation(
491414
AttestationProvider.AWS,

test/unit/test_auth_workload_identity.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,11 @@ def test_explicit_aws_generates_unique_assertion_content(
173173
auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AWS)
174174
auth_class.prepare(conn=None)
175175

176-
assert (
177-
'{"_provider":"AWS","arn":"arn:aws:sts::123456789:assumed-role/A-Different-Role/i-34afe100cad287fab"}'
178-
== auth_class.assertion_content
179-
)
176+
expected = {
177+
"_provider": "AWS",
178+
"region": fake_aws_environment.region,
179+
}
180+
assert json.loads(auth_class.assertion_content) == expected
180181

181182

182183
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)