Skip to content

Commit 3956ff1

Browse files
[async] Applied #2451 to async code - part 3 - test passing
1 parent c0d3327 commit 3956ff1

File tree

6 files changed

+254
-20
lines changed

6 files changed

+254
-20
lines changed

src/snowflake/connector/aio/_result_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
raise_failed_request_error,
1414
raise_okta_unauthorized_error,
1515
)
16+
from snowflake.connector.aio._session_manager import SessionManagerFactory
1617
from snowflake.connector.aio._time_util import TimerContextManager
1718
from snowflake.connector.arrow_context import ArrowConverterContext
1819
from snowflake.connector.backoff_policies import exponential_backoff
@@ -33,7 +34,6 @@
3334
from snowflake.connector.result_batch import ResultBatch as ResultBatchSync
3435
from snowflake.connector.result_batch import _create_nanoarrow_iterator
3536
from snowflake.connector.secret_detector import SecretDetector
36-
from src.snowflake.connector.aio._session_manager import SessionManagerFactory
3737

3838
if TYPE_CHECKING:
3939
from pandas import DataFrame

src/snowflake/connector/aio/_session_manager.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def clone(
441441
if connector_factory is not None:
442442
overrides["connector_factory"] = connector_factory
443443

444-
return SessionManager.from_config(self._cfg, **overrides)
444+
return self.from_config(self._cfg, **overrides)
445445

446446

447447
async def request(
@@ -480,7 +480,7 @@ async def request(
480480
url: StrOrURL,
481481
**kwargs: Unpack[_RequestOptions],
482482
):
483-
# Apply proxy per request and inject Host header when proxying
483+
# Inject Host header when proxying
484484
try:
485485
# respect caller-provided proxy and proxy_headers if any
486486
provided_proxy = kwargs.get("proxy") or self._default_proxy
@@ -515,21 +515,6 @@ def make_session(self) -> aiohttp.ClientSession:
515515
proxy=self.proxy_url,
516516
)
517517

518-
def clone(
519-
self,
520-
*,
521-
use_pooling: bool | None = None,
522-
connector_factory: ConnectorFactory | None = None,
523-
) -> SessionManager:
524-
"""Return a new async SessionManager sharing this instance's config."""
525-
overrides: dict[str, Any] = {}
526-
if use_pooling is not None:
527-
overrides["use_pooling"] = use_pooling
528-
if connector_factory is not None:
529-
overrides["connector_factory"] = connector_factory
530-
531-
return SessionManager.from_config(self._cfg, **overrides)
532-
533518

534519
class SessionManagerFactory:
535520
@staticmethod

test/data/wiremock/mappings/queries/select_large_request_successful.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
},
1212
"response": {
1313
"status": 200,
14+
"headers": { "Content-Type": "application/json" },
1415
"jsonBody": {
1516
"data": {
1617
"parameters": [

test/unit/aio/test_connection_async_unit.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from cryptography.hazmat.primitives.asymmetric import rsa
2828

2929
import snowflake.connector.aio
30+
from snowflake.connector.aio import connect as async_connect
3031
from snowflake.connector.aio._network import SnowflakeRestful
3132
from snowflake.connector.aio.auth import (
3233
AuthByDefault,
@@ -773,3 +774,93 @@ async def test_invalid_authenticator():
773774
)
774775
await conn.connect()
775776
assert "Unknown authenticator: INVALID" in str(excinfo.value)
777+
778+
779+
@pytest.mark.skipolddriver
780+
@pytest.mark.parametrize("proxy_method", ["explicit_args", "env_vars"])
781+
async def test_large_query_through_proxy_async(
782+
wiremock_generic_mappings_dir,
783+
wiremock_target_proxy_pair,
784+
wiremock_mapping_dir,
785+
proxy_env_vars,
786+
proxy_method,
787+
):
788+
target_wm, proxy_wm = wiremock_target_proxy_pair
789+
790+
password_mapping = wiremock_mapping_dir / "auth/password/successful_flow.json"
791+
multi_chunk_request_mapping = (
792+
wiremock_mapping_dir / "queries/select_large_request_successful.json"
793+
)
794+
disconnect_mapping = (
795+
wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json"
796+
)
797+
telemetry_mapping = wiremock_generic_mappings_dir / "telemetry.json"
798+
chunk_1_mapping = wiremock_mapping_dir / "queries/chunk_1.json"
799+
chunk_2_mapping = wiremock_mapping_dir / "queries/chunk_2.json"
800+
801+
expected_headers = {"Via": {"contains": "wiremock"}}
802+
803+
target_wm.import_mapping(password_mapping, expected_headers=expected_headers)
804+
target_wm.add_mapping_with_default_placeholders(
805+
multi_chunk_request_mapping, expected_headers
806+
)
807+
target_wm.add_mapping(disconnect_mapping, expected_headers=expected_headers)
808+
target_wm.add_mapping(telemetry_mapping, expected_headers=expected_headers)
809+
target_wm.add_mapping_with_default_placeholders(chunk_1_mapping, expected_headers)
810+
target_wm.add_mapping_with_default_placeholders(chunk_2_mapping, expected_headers)
811+
812+
set_proxy_env_vars, clear_proxy_env_vars = proxy_env_vars
813+
connect_kwargs = {
814+
"user": "testUser",
815+
"password": "testPassword",
816+
"account": "testAccount",
817+
"host": target_wm.wiremock_host,
818+
"port": target_wm.wiremock_http_port,
819+
"protocol": "http",
820+
"warehouse": "TEST_WH",
821+
}
822+
823+
if proxy_method == "explicit_args":
824+
connect_kwargs.update(
825+
{
826+
"proxy_host": proxy_wm.wiremock_host,
827+
"proxy_port": str(proxy_wm.wiremock_http_port),
828+
"proxy_user": "proxyUser",
829+
"proxy_password": "proxyPass",
830+
}
831+
)
832+
clear_proxy_env_vars()
833+
else:
834+
proxy_url = f"http://proxyUser:proxyPass@{proxy_wm.wiremock_host}:{proxy_wm.wiremock_http_port}"
835+
set_proxy_env_vars(proxy_url)
836+
837+
row_count = 50_000
838+
conn = await async_connect(**connect_kwargs)
839+
try:
840+
cur = conn.cursor()
841+
await cur.execute(
842+
f"select seq4() as n from table(generator(rowcount => {row_count}));"
843+
)
844+
assert len(cur._result_set.batches) > 1
845+
_ = [r async for r in cur]
846+
finally:
847+
await conn.close()
848+
849+
async with aiohttp.ClientSession() as session:
850+
async with session.get(
851+
f"{proxy_wm.http_host_with_port}/__admin/requests"
852+
) as resp:
853+
proxy_reqs = await resp.json()
854+
assert any(
855+
"/queries/v1/query-request" in r["request"]["url"]
856+
for r in proxy_reqs["requests"]
857+
)
858+
859+
async with session.get(
860+
f"{target_wm.http_host_with_port}/__admin/requests"
861+
) as resp:
862+
target_reqs = await resp.json()
863+
assert any(
864+
"/queries/v1/query-request" in r["request"]["url"]
865+
for r in target_reqs["requests"]
866+
)

test/unit/aio/test_oauth_token_async.py

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from unittest import mock
88
from unittest.mock import Mock, patch
99

10+
import aiohttp
1011
import pytest
1112

1213
try:
@@ -18,7 +19,7 @@
1819
import snowflake.connector.errors
1920
from snowflake.connector.token_cache import TokenCache, TokenKey, TokenType
2021

21-
from ...wiremock.wiremock_utils import WiremockClient
22+
from ...test_utils.wiremock.wiremock_utils import WiremockClient
2223
from ..test_oauth_token import omit_oauth_urls_check # noqa: F401
2324

2425
logger = logging.getLogger(__name__)
@@ -699,3 +700,159 @@ async def test_client_creds_expired_refresh_token_flow_async(
699700
cached_refresh_token = temp_cache_async.retrieve(refresh_token_key)
700701
assert cached_access_token == "expired-access-token-123"
701702
assert cached_refresh_token == "expired-refresh-token-123"
703+
704+
705+
@pytest.mark.skipolddriver
706+
@pytest.mark.parametrize("proxy_method", ["explicit_args", "env_vars"])
707+
async def test_client_credentials_flow_through_proxy_async(
708+
wiremock_oauth_client_creds_dir,
709+
wiremock_generic_mappings_dir,
710+
wiremock_target_proxy_pair,
711+
temp_cache_async,
712+
proxy_env_vars,
713+
proxy_method,
714+
):
715+
"""Run OAuth Client-Credentials flow and ensure it goes through proxy (async)."""
716+
from snowflake.connector.aio import SnowflakeConnection
717+
718+
target_wm, proxy_wm = wiremock_target_proxy_pair
719+
720+
expected_headers = {"Via": {"contains": "wiremock"}}
721+
722+
target_wm.import_mapping_with_default_placeholders(
723+
wiremock_oauth_client_creds_dir / "successful_flow.json", expected_headers
724+
)
725+
target_wm.add_mapping_with_default_placeholders(
726+
wiremock_generic_mappings_dir / "snowflake_login_successful.json",
727+
expected_headers,
728+
)
729+
target_wm.add_mapping(
730+
wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json",
731+
expected_headers=expected_headers,
732+
)
733+
734+
token_request_url = f"http://{target_wm.wiremock_host}:{target_wm.wiremock_http_port}/oauth/token-request"
735+
736+
set_proxy_env_vars, clear_proxy_env_vars = proxy_env_vars
737+
connect_kwargs = {
738+
"user": "testUser",
739+
"authenticator": "OAUTH_CLIENT_CREDENTIALS",
740+
"oauth_client_id": "cid",
741+
"oauth_client_secret": "secret",
742+
"account": "testAccount",
743+
"protocol": "http",
744+
"role": "ANALYST",
745+
"oauth_token_request_url": token_request_url,
746+
"host": target_wm.wiremock_host,
747+
"port": target_wm.wiremock_http_port,
748+
"oauth_enable_refresh_tokens": True,
749+
"client_store_temporary_credential": True,
750+
"token_cache": temp_cache_async,
751+
}
752+
753+
if proxy_method == "explicit_args":
754+
connect_kwargs.update(
755+
{
756+
"proxy_host": proxy_wm.wiremock_host,
757+
"proxy_port": str(proxy_wm.wiremock_http_port),
758+
"proxy_user": "proxyUser",
759+
"proxy_password": "proxyPass",
760+
}
761+
)
762+
clear_proxy_env_vars()
763+
else:
764+
proxy_url = f"http://proxyUser:proxyPass@{proxy_wm.wiremock_host}:{proxy_wm.wiremock_http_port}"
765+
set_proxy_env_vars(proxy_url)
766+
767+
with mock.patch("secrets.token_urlsafe", return_value="abc123"):
768+
cnx = SnowflakeConnection(**connect_kwargs)
769+
await cnx.connect()
770+
await cnx.close()
771+
772+
async with aiohttp.ClientSession() as session:
773+
async with session.get(
774+
f"{proxy_wm.http_host_with_port}/__admin/requests"
775+
) as resp:
776+
proxy_requests = await resp.json()
777+
assert any(
778+
req["request"]["url"].endswith("/oauth/token-request")
779+
for req in proxy_requests["requests"]
780+
)
781+
782+
async with session.get(
783+
f"{target_wm.http_host_with_port}/__admin/requests"
784+
) as resp:
785+
target_requests = await resp.json()
786+
assert any(
787+
req["request"]["url"].endswith("/oauth/token-request")
788+
for req in target_requests["requests"]
789+
)
790+
791+
792+
@pytest.mark.skipolddriver
793+
@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30)
794+
async def test_oauth_code_successful_flow_through_proxy_async(
795+
wiremock_oauth_authorization_code_dir,
796+
wiremock_generic_mappings_dir,
797+
wiremock_target_proxy_pair,
798+
webbrowser_mock_sync,
799+
monkeypatch,
800+
omit_oauth_urls_check, # noqa: F811
801+
) -> None:
802+
from snowflake.connector.aio import SnowflakeConnection
803+
804+
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
805+
target_wm, proxy_wm = wiremock_target_proxy_pair
806+
807+
target_wm.import_mapping_with_default_placeholders(
808+
wiremock_oauth_authorization_code_dir / "successful_flow.json",
809+
)
810+
target_wm.add_mapping_with_default_placeholders(
811+
wiremock_generic_mappings_dir / "snowflake_login_successful.json",
812+
)
813+
target_wm.add_mapping(
814+
wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json",
815+
)
816+
817+
with mock.patch("webbrowser.open", new=webbrowser_mock_sync.open):
818+
with mock.patch("secrets.token_urlsafe", return_value="abc123"):
819+
cnx = SnowflakeConnection(
820+
user="testUser",
821+
authenticator="OAUTH_AUTHORIZATION_CODE",
822+
oauth_client_id="123",
823+
account="testAccount",
824+
protocol="http",
825+
role="ANALYST",
826+
proxy_host=proxy_wm.wiremock_host,
827+
proxy_port=str(proxy_wm.wiremock_http_port),
828+
proxy_user="proxyUser",
829+
proxy_password="proxyPass",
830+
oauth_client_secret="testClientSecret",
831+
oauth_token_request_url=f"http://{target_wm.wiremock_host}:{target_wm.wiremock_http_port}/oauth/token-request",
832+
oauth_authorization_url=f"http://{target_wm.wiremock_host}:{target_wm.wiremock_http_port}/oauth/authorize",
833+
oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect",
834+
host=target_wm.wiremock_host,
835+
port=target_wm.wiremock_http_port,
836+
)
837+
838+
await cnx.connect()
839+
await cnx.close()
840+
841+
async with aiohttp.ClientSession() as session:
842+
async with session.get(
843+
f"http://{proxy_wm.wiremock_host}:{proxy_wm.wiremock_http_port}/__admin/requests"
844+
) as resp:
845+
proxy_requests = await resp.json()
846+
assert any(
847+
req["request"]["url"].endswith("/oauth/token-request")
848+
for req in proxy_requests["requests"]
849+
), "Proxy did not record token-request"
850+
851+
async with session.get(
852+
f"http://{target_wm.wiremock_host}:{target_wm.wiremock_http_port}/__admin/requests"
853+
) as resp:
854+
target_requests = await resp.json()
855+
assert any(
856+
req["request"]["url"].endswith("/oauth/token-request")
857+
for req in target_requests["requests"]
858+
), "Target did not receive token-request forwarded by proxy"

test/unit/aio/test_programmatic_access_token_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import snowflake.connector.errors
1818

19-
from ...wiremock.wiremock_utils import WiremockClient
19+
from ...test_utils.wiremock import WiremockClient
2020

2121

2222
@pytest.mark.skipolddriver

0 commit comments

Comments
 (0)