Skip to content

Commit 7604784

Browse files
Merge branch 'main' into SNOW-2176203-partial-chains
2 parents d2a4d29 + 29b5b7e commit 7604784

File tree

10 files changed

+114
-31
lines changed

10 files changed

+114
-31
lines changed

DESCRIPTION.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
99
# Release Notes
1010
- v3.18.0(TBD)
1111
- Added the `workload_identity_impersonation_path` parameter to support service account impersonation for Workload Identity Federation on GCP workloads only
12+
- Fixed `get_results_from_sfqid` when using `DictCursor` and executing multiple statements at once
1213
- Added support for intermediate cetificates as roots when they are stored in the trust store
1314

1415
- v3.17.3(September 02,2025)
53 Bytes
Binary file not shown.

src/snowflake/connector/connection.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,14 +1358,18 @@ def __open_connection(self):
13581358
)
13591359
if (
13601360
self._workload_identity_impersonation_path
1361-
and self._workload_identity_provider != AttestationProvider.GCP
1361+
and self._workload_identity_provider
1362+
not in (
1363+
AttestationProvider.GCP,
1364+
AttestationProvider.AWS,
1365+
)
13621366
):
13631367
Error.errorhandler_wrapper(
13641368
self,
13651369
None,
13661370
ProgrammingError,
13671371
{
1368-
"msg": "workload_identity_impersonation_path is currently only supported for GCP.",
1372+
"msg": "workload_identity_impersonation_path is currently only supported for GCP and AWS.",
13691373
"errno": ER_INVALID_WIF_SETTINGS,
13701374
},
13711375
)

src/snowflake/connector/cursor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1741,8 +1741,7 @@ def wait_until_ready() -> None:
17411741
self.connection.get_query_status_throw_if_error(
17421742
sfqid
17431743
) # Trigger an exception if query failed
1744-
klass = self.__class__
1745-
self._inner_cursor = klass(self.connection)
1744+
self._inner_cursor = SnowflakeCursor(self.connection)
17461745
self._sfqid = sfqid
17471746
self._prefetch_hook = wait_until_ready
17481747

src/snowflake/connector/wif_util.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,37 @@ def get_aws_sts_hostname(region: str, partition: str) -> str:
145145
)
146146

147147

148+
def get_aws_session(impersonation_path: list[str] | None = None):
149+
"""Creates a boto3 session with the appropriate credentials.
150+
151+
If impersonation_path is provided, this uses the role at the end of the path. Otherwise, this uses the role attached to the current workload.
152+
"""
153+
session = boto3.session.Session()
154+
155+
impersonation_path = impersonation_path or []
156+
for arn in impersonation_path:
157+
response = session.client("sts").assume_role(
158+
RoleArn=arn, RoleSessionName="identity-federation-session"
159+
)
160+
creds = response["Credentials"]
161+
session = boto3.session.Session(
162+
aws_access_key_id=creds["AccessKeyId"],
163+
aws_secret_access_key=creds["SecretAccessKey"],
164+
aws_session_token=creds["SessionToken"],
165+
)
166+
return session
167+
168+
148169
def create_aws_attestation(
149-
session_manager: SessionManager | None = None,
170+
impersonation_path: list[str] | None = None,
150171
) -> WorkloadIdentityAttestation:
151172
"""Tries to create a workload identity attestation for AWS.
152173
153174
If the application isn't running on AWS or no credentials were found, raises an error.
154175
"""
155176
# TODO: SNOW-2223669 Investigate if our adapters - containing settings of http traffic - should be passed here as boto urllib3session. Those requests go to local servers, so they do not need Proxy setup or Headers customization in theory. But we may want to have all the traffic going through one class (e.g. Adapter or mixin).
156-
session = boto3.session.Session()
177+
session = get_aws_session(impersonation_path)
178+
157179
aws_creds = session.get_credentials()
158180
if not aws_creds:
159181
raise ProgrammingError(
@@ -387,7 +409,7 @@ def create_attestation(
387409
)
388410

389411
if provider == AttestationProvider.AWS:
390-
return create_aws_attestation(session_manager)
412+
return create_aws_attestation(impersonation_path)
391413
elif provider == AttestationProvider.AZURE:
392414
return create_azure_attestation(entra_resource, session_manager)
393415
elif provider == AttestationProvider.GCP:

test/csp_helpers.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@ def gen_dummy_id_token(
4040
)
4141

4242

43-
def gen_dummy_access_token(sub="test-subject") -> str:
43+
def gen_dummy_access_token(sub="test-subject", key="secret") -> str:
4444
"""Generates a dummy access token using the given subject."""
45-
key = "secret"
4645
logger.debug(f"Generating dummy access token for subject {sub}")
4746
return (sub + key).encode("utf-8").hex()
4847

@@ -368,6 +367,11 @@ class FakeAwsEnvironment:
368367
def __init__(self):
369368
# Defaults used for generating a token. Can be overriden in individual tests.
370369
self.arn = "arn:aws:sts::123456789:assumed-role/My-Role/i-34afe100cad287fab"
370+
# Path of roles that can be assumed. Empty if no impersonation is allowed.
371+
# Can be overriden in individual tests.
372+
self.assumption_path = []
373+
self.assume_role_call_count = 0
374+
371375
self.caller_identity = {"Arn": self.arn}
372376
self.region = "us-east-1"
373377
self.credentials = Credentials(access_key="ak", secret_key="sk")
@@ -376,6 +380,25 @@ def __init__(self):
376380
)
377381
self.metadata_token = "test-token"
378382

383+
def assume_role(self, **kwargs):
384+
if (
385+
self.assumption_path
386+
and kwargs["RoleArn"] == self.assumption_path[self.assume_role_call_count]
387+
):
388+
arn = self.assumption_path[self.assume_role_call_count]
389+
self.assume_role_call_count += 1
390+
return {
391+
"Credentials": {
392+
"AccessKeyId": "access_key",
393+
"SecretAccessKey": "secret_key",
394+
"SessionToken": "session_token",
395+
"Expiration": int(time()) + 60 * 60,
396+
},
397+
"AssumedRoleUser": {"AssumedRoleId": hash(arn), "Arn": arn},
398+
"ResponseMetadata": {},
399+
}
400+
return {}
401+
379402
def get_region(self):
380403
return self.region
381404

@@ -399,6 +422,7 @@ def fetcher_fetch_metadata_token(self):
399422
def boto3_client(self, *args, **kwargs):
400423
mock_client = mock.Mock()
401424
mock_client.get_caller_identity.return_value = self.caller_identity
425+
mock_client.assume_role = self.assume_role
402426
return mock_client
403427

404428
def __enter__(self):
@@ -439,6 +463,9 @@ def __enter__(self):
439463
side_effect=self.boto3_client,
440464
)
441465
)
466+
self.patchers.append(
467+
mock.patch("boto3.session.Session.client", side_effect=self.boto3_client)
468+
)
442469
for patcher in self.patchers:
443470
patcher.__enter__()
444471
return self

test/integ/test_async.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88

99
from snowflake.connector import DatabaseError, ProgrammingError
10+
from snowflake.connector.cursor import DictCursor, SnowflakeCursor
1011

1112
# Mark all tests in this file to time out after 2 minutes to prevent hanging forever
1213
pytestmark = [pytest.mark.timeout(120), pytest.mark.skipolddriver]
@@ -17,14 +18,15 @@
1718
QueryStatus = None
1819

1920

20-
def test_simple_async(conn_cnx):
21+
@pytest.mark.parametrize("cursor_class", [SnowflakeCursor, DictCursor])
22+
def test_simple_async(conn_cnx, cursor_class):
2123
"""Simple test to that shows the most simple usage of fire and forget.
2224
2325
This test also makes sure that wait_until_ready function's sleeping is tested and
2426
that some fields are copied over correctly from the original query.
2527
"""
2628
with conn_cnx() as con:
27-
with con.cursor() as cur:
29+
with con.cursor(cursor_class) as cur:
2830
cur.execute_async("select count(*) from table(generator(timeLimit => 5))")
2931
cur.get_results_from_sfqid(cur.sfqid)
3032
assert len(cur.fetchall()) == 1

test/integ/test_multi_statement.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import snowflake.connector.cursor
1616
from snowflake.connector import ProgrammingError, errors
17+
from snowflake.connector.cursor import DictCursor, SnowflakeCursor
1718

1819
try: # pragma: no cover
1920
from snowflake.connector.constants import (
@@ -153,10 +154,11 @@ def test_binding_multi(conn_cnx, style: str, skip_to_last_set: bool):
153154
)
154155

155156

156-
def test_async_exec_multi(conn_cnx, skip_to_last_set: bool):
157+
@pytest.mark.parametrize("cursor_class", [SnowflakeCursor, DictCursor])
158+
def test_async_exec_multi(conn_cnx, cursor_class, skip_to_last_set: bool):
157159
"""Tests whether async execution query works within a multi-statement"""
158160
with conn_cnx() as con:
159-
with con.cursor() as cur:
161+
with con.cursor(cursor_class) as cur:
160162
cur.execute_async(
161163
"select 1; select 2; select count(*) from table(generator(timeLimit => 1)); select 'b';",
162164
num_statements=4,
@@ -165,14 +167,29 @@ def test_async_exec_multi(conn_cnx, skip_to_last_set: bool):
165167
assert con.is_still_running(con.get_query_status(q_id))
166168
_wait_while_query_running(con, q_id, sleep_time=1)
167169
with conn_cnx() as con:
168-
with con.cursor() as cur:
170+
with con.cursor(cursor_class) as cur:
169171
_wait_until_query_success(con, q_id, num_checks=3, sleep_per_check=1)
170172
assert con.get_query_status_throw_if_error(q_id) == QueryStatus.SUCCESS
171173

174+
if cursor_class == SnowflakeCursor:
175+
expected = [
176+
[(1,)],
177+
[(2,)],
178+
lambda x: len(x) == 1 and len(x[0]) == 1 and x[0][0] > 0,
179+
[("b",)],
180+
]
181+
elif cursor_class == DictCursor:
182+
expected = [
183+
[{"1": 1}],
184+
[{"2": 2}],
185+
lambda x: len(x) == 1 and len(x[0]) == 1 and x[0]["COUNT(*)"] > 0,
186+
[{"'B'": "b"}],
187+
]
188+
172189
cur.get_results_from_sfqid(q_id)
173190
_check_multi_statement_results(
174191
cur,
175-
checks=[[(1,)], [(2,)], lambda x: x > [(0,)], [("b",)]],
192+
checks=expected,
176193
skip_to_last_set=skip_to_last_set,
177194
)
178195

test/unit/test_auth_workload_identity.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,22 @@ def test_get_aws_sts_hostname_invalid_inputs(region, partition):
274274
assert "Invalid AWS partition" in str(excinfo.value)
275275

276276

277+
def test_aws_impersonation_calls_correct_apis_for_each_role_in_impersonation_path(
278+
fake_aws_environment: FakeAwsEnvironment,
279+
):
280+
impersonation_path = [
281+
"arn:aws:iam::123456789:role/role2",
282+
"arn:aws:iam::123456789:role/role3",
283+
]
284+
fake_aws_environment.assumption_path = impersonation_path
285+
auth_class = AuthByWorkloadIdentity(
286+
provider=AttestationProvider.AWS, impersonation_path=impersonation_path
287+
)
288+
auth_class.prepare(conn=None)
289+
290+
assert fake_aws_environment.assume_role_call_count == 2
291+
292+
277293
# -- GCP Tests --
278294

279295

test/unit/test_connection.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -682,16 +682,14 @@ def test_workload_identity_provider_is_required_for_wif_authenticator(
682682
"provider_param",
683683
[
684684
# Strongly-typed values.
685-
AttestationProvider.AWS,
686685
AttestationProvider.AZURE,
687686
AttestationProvider.OIDC,
688687
# String values.
689-
"AWS",
690688
"AZURE",
691689
"OIDC",
692690
],
693691
)
694-
def test_workload_identity_impersonation_path_unsupported_for_non_gcp_providers(
692+
def test_workload_identity_impersonation_path_errors_for_unsupported_providers(
695693
monkeypatch, provider_param
696694
):
697695
with monkeypatch.context() as m:
@@ -709,20 +707,22 @@ def test_workload_identity_impersonation_path_unsupported_for_non_gcp_providers(
709707
],
710708
)
711709
assert (
712-
"workload_identity_impersonation_path is currently only supported for GCP."
710+
"workload_identity_impersonation_path is currently only supported for GCP and AWS."
713711
in str(excinfo.value)
714712
)
715713

716714

717715
@pytest.mark.parametrize(
718-
"provider_param",
716+
"provider_param,impersonation_path",
719717
[
720-
AttestationProvider.GCP,
721-
"GCP",
718+
(AttestationProvider.GCP, ["[email protected]"]),
719+
(AttestationProvider.AWS, ["arn:aws:iam::1234567890:role/role2"]),
720+
("GCP", ["[email protected]"]),
721+
("AWS", ["arn:aws:iam::1234567890:role/role2"]),
722722
],
723723
)
724-
def test_workload_identity_impersonation_path_supported_for_gcp_provider(
725-
monkeypatch, provider_param
724+
def test_workload_identity_impersonation_path_populates_auth_class_for_supported_provider(
725+
monkeypatch, provider_param, impersonation_path
726726
):
727727
with monkeypatch.context() as m:
728728
m.setattr(
@@ -733,14 +733,9 @@ def test_workload_identity_impersonation_path_supported_for_gcp_provider(
733733
account="account",
734734
authenticator="WORKLOAD_IDENTITY",
735735
workload_identity_provider=provider_param,
736-
workload_identity_impersonation_path=[
737-
738-
],
736+
workload_identity_impersonation_path=impersonation_path,
739737
)
740-
assert conn.auth_class.provider == AttestationProvider.GCP
741-
assert conn.auth_class.impersonation_path == [
742-
743-
]
738+
assert conn.auth_class.impersonation_path == impersonation_path
744739

745740

746741
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)