diff --git a/src/snowflake/connector/aio/__init__.py b/src/snowflake/connector/aio/__init__.py index 0b0410ebaa..d2142bab7b 100644 --- a/src/snowflake/connector/aio/__init__.py +++ b/src/snowflake/connector/aio/__init__.py @@ -1,5 +1,8 @@ from __future__ import annotations +from functools import wraps +from typing import Any, Coroutine, Generator + from ._connection import SnowflakeConnection from ._cursor import DictCursor, SnowflakeCursor @@ -10,7 +13,90 @@ ] -async def connect(**kwargs) -> SnowflakeConnection: - conn = SnowflakeConnection(**kwargs) - await conn.connect() - return conn +class _AsyncConnectContextManager: + """Hybrid wrapper that enables both awaiting and async context manager usage. + + Allows both patterns: + - conn = await connect(...) + - async with connect(...) as conn: + + Implements the full coroutine protocol for maximum compatibility. + """ + + __slots__ = ("_coro", "_conn") + + def __init__(self, coro: Coroutine[Any, Any, SnowflakeConnection]) -> None: + self._coro = coro + self._conn: SnowflakeConnection | None = None + + def send(self, arg: Any) -> Any: + """Send a value into the wrapped coroutine.""" + return self._coro.send(arg) + + def throw(self, *args: Any, **kwargs: Any) -> Any: + """Throw an exception into the wrapped coroutine.""" + return self._coro.throw(*args, **kwargs) + + def close(self) -> None: + """Close the wrapped coroutine.""" + return self._coro.close() + + def __await__(self) -> Generator[Any, None, SnowflakeConnection]: + """Enable await connect(...)""" + return self._coro.__await__() + + def __iter__(self) -> Generator[Any, None, SnowflakeConnection]: + """Make the wrapper iterable like a coroutine.""" + return self.__await__() + + async def __aenter__(self) -> SnowflakeConnection: + """Enable async with connect(...) as conn:""" + self._conn = await self._coro + # Connection is already connected by the coroutine + self._conn._prepare_aenter() + return self._conn + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + """Exit async context manager.""" + if self._conn is not None: + return await self._conn.__aexit__(exc_type, exc, tb) + else: + return None + + +class _AsyncConnectWrapper: + """Preserves SnowflakeConnection.__init__ metadata for async connect function. + + This wrapper enables introspection tools and IDEs to see the same signature + as the synchronous snowflake.connector.connect function. + """ + + def __init__(self) -> None: + self.__wrapped__ = SnowflakeConnection.__init__ + self.__name__ = "connect" + self.__doc__ = SnowflakeConnection.__init__.__doc__ + self.__module__ = __name__ + self.__qualname__ = "connect" + self.__annotations__ = getattr( + SnowflakeConnection.__init__, "__annotations__", {} + ) + + @wraps(SnowflakeConnection.__init__) + def __call__(self, **kwargs: Any) -> _AsyncConnectContextManager: + """Create and connect to a Snowflake connection asynchronously. + + Returns an awaitable that can also be used as an async context manager. + Supports both patterns: + - conn = await connect(...) + - async with connect(...) as conn: + """ + + async def _connect_coro() -> SnowflakeConnection: + conn = SnowflakeConnection(**kwargs) + await conn.connect() + return conn + + return _AsyncConnectContextManager(_connect_coro()) + + +connect = _AsyncConnectWrapper() diff --git a/src/snowflake/connector/aio/_connection.py b/src/snowflake/connector/aio/_connection.py index 479af373ad..6f2ad680d1 100644 --- a/src/snowflake/connector/aio/_connection.py +++ b/src/snowflake/connector/aio/_connection.py @@ -21,7 +21,6 @@ Error, OperationalError, ProgrammingError, - proxy, ) from .._query_context_cache import QueryContextCache @@ -80,6 +79,7 @@ from ._session_manager import ( AioHttpConfig, SessionManager, + SessionManagerFactory, SnowflakeSSLConnectorFactory, ) from ._telemetry import TelemetryClient @@ -128,6 +128,7 @@ def __init__( if "platform_detection_timeout_seconds" not in kwargs: self._platform_detection_timeout_seconds = 0.0 + # TODO: why we have it here if never changed self._connected = False self.expired = False # check SNOW-1218851 for long term improvement plan to refactor ocsp code @@ -165,8 +166,20 @@ def __exit__(self, exc_type, exc_val, exc_tb): "'SnowflakeConnection' object does not support the context manager protocol" ) + def _prepare_aenter(self) -> None: + """ + All connection changes done before entering connection context have to be done here, as we expose the same api through snowflake.connector.aio.connect() and call this function there at __aenter__ as well. + """ + pass + async def __aenter__(self) -> SnowflakeConnection: - """Context manager.""" + """ + Context manager. + + All connection changes done before entering connection context have to be done in the _prepare_aenter() method only. + We expose the same api through snowflake.connector.aio.connect() and call that method there at its __aenter__ as well, so there cannot be any logic executed here, but not there. We cannot just call conn.__aenter__() there as it contains already connected connection. + """ + self._prepare_aenter() await self.connect() return self @@ -191,10 +204,6 @@ async def __open_connection(self): use_numpy=self._numpy, support_negative_year=self._support_negative_year ) - proxy.set_proxies( - self.proxy_host, self.proxy_port, self.proxy_user, self.proxy_password - ) - self._rest = SnowflakeRestful( host=self.host, port=self.port, @@ -1014,13 +1023,17 @@ async def connect(self, **kwargs) -> None: else: self.__config(**self._conn_parameters) - self._http_config = AioHttpConfig( + self._http_config: AioHttpConfig = AioHttpConfig( connector_factory=SnowflakeSSLConnectorFactory(), use_pooling=not self.disable_request_pooling, + proxy_host=self.proxy_host, + proxy_port=self.proxy_port, + proxy_user=self.proxy_user, + proxy_password=self.proxy_password, snowflake_ocsp_mode=self._ocsp_mode(), trust_env=True, # Required for proxy support via environment variables ) - self._session_manager = SessionManager(self._http_config) + self._session_manager = SessionManagerFactory.get_manager(self._http_config) if self.enable_connection_diag: raise NotImplementedError( diff --git a/src/snowflake/connector/aio/_network.py b/src/snowflake/connector/aio/_network.py index 95ba4e97a2..34730ba601 100644 --- a/src/snowflake/connector/aio/_network.py +++ b/src/snowflake/connector/aio/_network.py @@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator import OpenSSL.SSL -from urllib3.util.url import parse_url from ..compat import FORBIDDEN, OK, UNAUTHORIZED, urlencode, urlparse, urlsplit from ..constants import ( @@ -79,7 +78,11 @@ ) from ..time_util import TimeoutBackoffCtx from ._description import CLIENT_NAME -from ._session_manager import SessionManager, SnowflakeSSLConnectorFactory +from ._session_manager import ( + SessionManager, + SessionManagerFactory, + SnowflakeSSLConnectorFactory, +) if TYPE_CHECKING: from snowflake.connector.aio import SnowflakeConnection @@ -145,15 +148,12 @@ def __init__( session_manager = ( connection._session_manager if (connection and connection._session_manager) - else SessionManager(connector_factory=SnowflakeSSLConnectorFactory()) + else SessionManagerFactory.get_manager( + connector_factory=SnowflakeSSLConnectorFactory() + ) ) self._session_manager = session_manager - if self._connection and self._connection.proxy_host: - self._get_proxy_headers = lambda url: {"Host": parse_url(url).hostname} - else: - self._get_proxy_headers = lambda _: None - async def close(self) -> None: if hasattr(self, "_token"): del self._token @@ -737,7 +737,6 @@ async def _request_exec( headers=headers, data=input_data, timeout=aiohttp.ClientTimeout(socket_timeout), - proxy_headers=self._get_proxy_headers(full_url), ) try: if raw_ret.status == OK: diff --git a/src/snowflake/connector/aio/_result_batch.py b/src/snowflake/connector/aio/_result_batch.py index 86c6d3d316..b04f5c49f0 100644 --- a/src/snowflake/connector/aio/_result_batch.py +++ b/src/snowflake/connector/aio/_result_batch.py @@ -13,7 +13,7 @@ raise_failed_request_error, raise_okta_unauthorized_error, ) -from snowflake.connector.aio._session_manager import SessionManager +from snowflake.connector.aio._session_manager import SessionManagerFactory from snowflake.connector.aio._time_util import TimerContextManager from snowflake.connector.arrow_context import ArrowConverterContext from snowflake.connector.backoff_policies import exponential_backoff @@ -261,7 +261,9 @@ async def download_chunk(http_session): logger.debug( f"downloading result batch id: {self.id} with new session through local session manager" ) - local_session_manager = SessionManager(use_pooling=False) + local_session_manager = SessionManagerFactory.get_manager( + use_pooling=False + ) async with local_session_manager.use_session() as session: response, content, encoding = await download_chunk(session) diff --git a/src/snowflake/connector/aio/_session_manager.py b/src/snowflake/connector/aio/_session_manager.py index dcf95c1be9..aba3e0b840 100644 --- a/src/snowflake/connector/aio/_session_manager.py +++ b/src/snowflake/connector/aio/_session_manager.py @@ -4,8 +4,10 @@ from typing import TYPE_CHECKING from aiohttp import ClientRequest, ClientTimeout +from aiohttp.client import _RequestOptions from aiohttp.client_proto import ResponseHandler from aiohttp.connector import Connection +from aiohttp.typedefs import StrOrURL from .. import OperationalError from ..errorcode import ER_OCSP_RESPONSE_CERT_STATUS_REVOKED @@ -14,6 +16,8 @@ if TYPE_CHECKING: from aiohttp.tracing import Trace + from typing import Unpack + from aiohttp.client import _RequestContextManager import abc import collections @@ -44,10 +48,10 @@ def __init__( ): self._snowflake_ocsp_mode = snowflake_ocsp_mode if session_manager is None: - logger.debug( - "SessionManager instance was not passed to SSLConnector - OCSP will use default settings which may be distinct from the customer's specific one. Code should always pass such instance so please verify why it isn't true in the current context" + logger.warning( + "SessionManager instance was not passed to SSLConnector - OCSP will use default settings which may be distinct from the customer's specific one. Code should always pass such instance - verify why it isn't true in the current context" ) - session_manager = SessionManager() + session_manager = SessionManagerFactory.get_manager() self._session_manager = session_manager if self._snowflake_ocsp_mode == OCSPMode.FAIL_OPEN and sys.version_info < ( 3, @@ -345,13 +349,27 @@ def __init__( lambda: SessionPool(self) ) + @classmethod + def from_config(cls, cfg: AioHttpConfig, **overrides: Any) -> SessionManager: + """Build a new manager from *cfg*, optionally overriding fields. + + Example:: + + no_pool_cfg = conn._http_config.copy_with(use_pooling=False) + manager = SessionManager.from_config(no_pool_cfg) + """ + + if overrides: + cfg = cfg.copy_with(**overrides) + return cls(config=cfg) + @property def connector_factory(self) -> Callable[..., aiohttp.BaseConnector]: return self._cfg.connector_factory @connector_factory.setter def connector_factory(self, value: Callable[..., aiohttp.BaseConnector]) -> None: - self._cfg = self._cfg.copy_with(connector_factory=value) + self._cfg: AioHttpConfig = self._cfg.copy_with(connector_factory=value) def make_session(self) -> aiohttp.ClientSession: """Create a new aiohttp.ClientSession with configured connector.""" @@ -359,10 +377,10 @@ def make_session(self) -> aiohttp.ClientSession: session_manager=self.clone(), snowflake_ocsp_mode=self._cfg.snowflake_ocsp_mode, ) - return aiohttp.ClientSession( connector=connector, trust_env=self._cfg.trust_env, + proxy=self.proxy_url, ) @contextlib.asynccontextmanager @@ -425,7 +443,7 @@ def clone( if connector_factory is not None: overrides["connector_factory"] = connector_factory - return SessionManager.from_config(self._cfg, **overrides) + return self.from_config(self._cfg, **overrides) async def request( @@ -454,3 +472,82 @@ async def request( use_pooling=use_pooling, **kwargs, ) + + +class ProxySessionManager(SessionManager): + class SessionWithProxy(aiohttp.ClientSession): + if sys.version_info >= (3, 11) and TYPE_CHECKING: + + def request( + self, + method: str, + url: StrOrURL, + **kwargs: Unpack[_RequestOptions], + ) -> _RequestContextManager: ... + + else: + + def request( + self, method: str, url: StrOrURL, **kwargs: Any + ) -> _RequestContextManager: + """Perform HTTP request.""" + # Inject Host header when proxying + try: + # respect caller-provided proxy and proxy_headers if any + provided_proxy = kwargs.get("proxy") or self._default_proxy + provided_proxy_headers = kwargs.get("proxy_headers") + if provided_proxy is not None: + authority = urlparse(str(url)).netloc + if provided_proxy_headers is None: + kwargs["proxy_headers"] = {"Host": authority} + elif "Host" not in provided_proxy_headers: + provided_proxy_headers["Host"] = authority + else: + logger.debug( + "Host header was already set - not overriding with netloc at the ClientSession.request method level." + ) + except Exception: + logger.warning( + "Failed to compute proxy settings for %s", + urlparse(url).hostname, + exc_info=True, + ) + return super().request(method, url, **kwargs) + + def make_session(self) -> aiohttp.ClientSession: + connector = self._cfg.get_connector( + session_manager=self.clone(), + snowflake_ocsp_mode=self._cfg.snowflake_ocsp_mode, + ) + # Construct session with base proxy set, request() may override per-URL when bypassing + return self.SessionWithProxy( + connector=connector, + trust_env=self._cfg.trust_env, + proxy=self.proxy_url, + ) + + +class SessionManagerFactory: + @staticmethod + def get_manager( + config: AioHttpConfig | None = None, **http_config_kwargs + ) -> SessionManager: + """Return a proxy-aware or plain async SessionManager based on config. + + If any explicit proxy parameters are provided (in config or kwargs), + return ProxySessionManager; otherwise return the base SessionManager. + """ + + def _has_proxy_params(cfg: AioHttpConfig | None, kwargs: dict) -> bool: + cfg_keys = ( + "proxy_host", + "proxy_port", + ) + in_cfg = any(getattr(cfg, k, None) for k in cfg_keys) if cfg else False + in_kwargs = "proxy" in kwargs + return in_cfg or in_kwargs + + if _has_proxy_params(config, http_config_kwargs): + return ProxySessionManager(config, **http_config_kwargs) + else: + return SessionManager(config, **http_config_kwargs) diff --git a/src/snowflake/connector/aio/_storage_client.py b/src/snowflake/connector/aio/_storage_client.py index 01a3d59135..94e5bc92ed 100644 --- a/src/snowflake/connector/aio/_storage_client.py +++ b/src/snowflake/connector/aio/_storage_client.py @@ -15,7 +15,7 @@ from ..encryption_util import SnowflakeEncryptionUtil from ..errors import RequestExceedMaxRetryError from ..storage_client import SnowflakeStorageClient as SnowflakeStorageClientSync -from ._session_manager import SessionManager +from ._session_manager import SessionManagerFactory if TYPE_CHECKING: # pragma: no cover from ..file_transfer_agent import SnowflakeFileMeta, StorageCredential @@ -205,7 +205,9 @@ async def _send_request_with_retry( # SessionManager on the fly, if code ends up here, since we probably do not care about losing # proxy or HTTP setup. logger.debug("storage client request with new session") - session_manager = SessionManager(use_pooling=False) + session_manager = SessionManagerFactory.get_manager( + use_pooling=False + ) response = await session_manager.request(verb, url, **rest_kwargs) if await self._has_expired_presigned_url(response): diff --git a/src/snowflake/connector/aio/_wif_util.py b/src/snowflake/connector/aio/_wif_util.py index 553e8e6309..1f2a62ff5c 100644 --- a/src/snowflake/connector/aio/_wif_util.py +++ b/src/snowflake/connector/aio/_wif_util.py @@ -21,7 +21,7 @@ extract_iss_and_sub_without_signature_verification, get_aws_sts_hostname, ) -from ._session_manager import SessionManager +from ._session_manager import SessionManager, SessionManagerFactory logger = logging.getLogger(__name__) @@ -187,7 +187,9 @@ async def create_attestation( """ entra_resource = entra_resource or DEFAULT_ENTRA_SNOWFLAKE_RESOURCE session_manager = ( - session_manager.clone() if session_manager else SessionManager(use_pooling=True) + session_manager.clone() + if session_manager + else SessionManagerFactory.get_manager(use_pooling=True) ) if provider == AttestationProvider.AWS: diff --git a/src/snowflake/connector/auth/_oauth_base.py b/src/snowflake/connector/auth/_oauth_base.py index 85deaf7f13..2ff1241638 100644 --- a/src/snowflake/connector/auth/_oauth_base.py +++ b/src/snowflake/connector/auth/_oauth_base.py @@ -20,9 +20,12 @@ ) from ..errors import Error, ProgrammingError from ..network import OAUTH_AUTHENTICATOR +from ..proxy import get_proxy_url from ..secret_detector import SecretDetector from ..token_cache import TokenCache, TokenKey, TokenType from ..vendored import urllib3 +from ..vendored.requests.utils import get_environ_proxies, select_proxy +from ..vendored.urllib3.poolmanager import ProxyManager from .by_plugin import AuthByPlugin, AuthType if TYPE_CHECKING: @@ -319,7 +322,13 @@ def _get_refresh_token_response( fields["scope"] = self._scope try: # TODO(SNOW-2229411) Session manager should be used here. It may require additional security validation (since we would transition from PoolManager to requests.Session) and some parameters would be passed implicitly. OAuth token exchange must NOT reuse pooled HTTP sessions. We should create a fresh SessionManager with use_pooling=False for each call. - return urllib3.PoolManager().request_encode_body( + proxy_url = self._resolve_proxy_url(conn, self._token_request_url) + http_client = ( + ProxyManager(proxy_url=proxy_url) + if proxy_url + else urllib3.PoolManager() + ) + return http_client.request_encode_body( "POST", self._token_request_url, encode_multipart=False, @@ -359,7 +368,11 @@ def _get_request_token_response( fields: dict[str, str], ) -> (str | None, str | None): # TODO(SNOW-2229411) Session manager should be used here. It may require additional security validation (since we would transition from PoolManager to requests.Session) and some parameters would be passed implicitly. Token request must bypass HTTP connection pools. - resp = urllib3.PoolManager().request_encode_body( + proxy_url = self._resolve_proxy_url(connection, self._token_request_url) + http_client = ( + ProxyManager(proxy_url=proxy_url) if proxy_url else urllib3.PoolManager() + ) + resp = http_client.request_encode_body( "POST", self._token_request_url, headers=self._create_token_request_headers(), @@ -400,3 +413,25 @@ def _create_token_request_headers(self) -> dict[str, str]: "Accept": "application/json", "Content-Type": "application/x-www-form-urlencoded; charset=UTF-8", } + + @staticmethod + def _resolve_proxy_url( + connection: SnowflakeConnection, request_url: str + ) -> str | None: + # TODO(SNOW-2229411) Session manager should be used instead. It may require additional security validation. + """Resolve proxy URL from explicit config first, then environment variables.""" + # First try explicit proxy configuration from connection parameters + proxy_url = get_proxy_url( + connection.proxy_host, + connection.proxy_port, + connection.proxy_user, + connection.proxy_password, + ) + + if proxy_url: + return proxy_url + + # Fall back to environment variables (HTTP_PROXY, HTTPS_PROXY) + # Use proper proxy selection that considers the URL scheme + proxies = get_environ_proxies(request_url) + return select_proxy(request_url, proxies) diff --git a/src/snowflake/connector/auth/oauth_code.py b/src/snowflake/connector/auth/oauth_code.py index 1c0c41eb6d..a5aaf31fb9 100644 --- a/src/snowflake/connector/auth/oauth_code.py +++ b/src/snowflake/connector/auth/oauth_code.py @@ -269,6 +269,7 @@ def _do_authorization_request( "login. If you can't see it, check existing browser windows, " "or your OS settings. Press CTRL+C to abort and try again..." ) + # TODO(SNOW-2229411) Investigate if Session manager / Http Config should be used here. code, state = ( self._receive_authorization_callback(callback_server, connection) if webbrowser.open(authorization_request) diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index 000907d5c4..38f4e5301d 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -27,7 +27,7 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey -from . import errors, proxy +from . import errors from ._query_context_cache import QueryContextCache from ._utils import ( _DEFAULT_VALUE_SERVER_DOP_CAP_FOR_FILE_TRANSFER, @@ -924,6 +924,10 @@ def connect(self, **kwargs) -> None: self._http_config = HttpConfig( adapter_factory=ProxySupportAdapterFactory(), use_pooling=(not self.disable_request_pooling), + proxy_host=self.proxy_host, + proxy_port=self.proxy_port, + proxy_user=self.proxy_user, + proxy_password=self.proxy_password, ) self._session_manager = SessionManager(self._http_config) @@ -1125,10 +1129,6 @@ def __open_connection(self): use_numpy=self._numpy, support_negative_year=self._support_negative_year ) - proxy.set_proxies( - self.proxy_host, self.proxy_port, self.proxy_user, self.proxy_password - ) - self._rest = SnowflakeRestful( host=self.host, port=self.port, diff --git a/src/snowflake/connector/proxy.py b/src/snowflake/connector/proxy.py index 6b54e29ee5..996fd563ba 100644 --- a/src/snowflake/connector/proxy.py +++ b/src/snowflake/connector/proxy.py @@ -1,43 +1,28 @@ #!/usr/bin/env python from __future__ import annotations -import os - -def set_proxies( +def get_proxy_url( proxy_host: str | None, proxy_port: str | None, proxy_user: str | None = None, proxy_password: str | None = None, -) -> dict[str, str] | None: - """Sets proxy dict for requests.""" - PREFIX_HTTP = "http://" - PREFIX_HTTPS = "https://" - proxies = None +) -> str | None: + http_prefix = "http://" + https_prefix = "https://" + if proxy_host and proxy_port: - if proxy_host.startswith(PREFIX_HTTP): - proxy_host = proxy_host[len(PREFIX_HTTP) :] - elif proxy_host.startswith(PREFIX_HTTPS): - proxy_host = proxy_host[len(PREFIX_HTTPS) :] - if proxy_user or proxy_password: - proxy_auth = "{proxy_user}:{proxy_password}@".format( - proxy_user=proxy_user if proxy_user is not None else "", - proxy_password=proxy_password if proxy_password is not None else "", - ) + if proxy_host.startswith(http_prefix): + host = proxy_host[len(http_prefix) :] + elif proxy_host.startswith(https_prefix): + host = proxy_host[len(https_prefix) :] else: - proxy_auth = "" - proxies = { - "http": "http://{proxy_auth}{proxy_host}:{proxy_port}".format( - proxy_host=proxy_host, - proxy_port=str(proxy_port), - proxy_auth=proxy_auth, - ), - "https": "http://{proxy_auth}{proxy_host}:{proxy_port}".format( - proxy_host=proxy_host, - proxy_port=str(proxy_port), - proxy_auth=proxy_auth, - ), - } - os.environ["HTTP_PROXY"] = proxies["http"] - os.environ["HTTPS_PROXY"] = proxies["https"] - return proxies + host = proxy_host + auth = ( + f"{proxy_user or ''}:{proxy_password or ''}@" + if proxy_user or proxy_password + else "" + ) + return f"{http_prefix}{auth}{host}:{proxy_port}" + + return None diff --git a/src/snowflake/connector/result_batch.py b/src/snowflake/connector/result_batch.py index 742cbbaf13..8225997011 100644 --- a/src/snowflake/connector/result_batch.py +++ b/src/snowflake/connector/result_batch.py @@ -26,7 +26,7 @@ from .options import installed_pandas from .options import pyarrow as pa from .secret_detector import SecretDetector -from .session_manager import SessionManager +from .session_manager import HttpConfig, SessionManager from .time_util import TimerContextManager logger = getLogger(__name__) @@ -261,6 +261,8 @@ def __init__( [s._to_result_metadata_v1() for s in schema] if schema is not None else None ) self._use_dict_result = use_dict_result + # Passed to contain the configured Http behavior in case the connection is no longer active for the download + # Can be overridden with setters if needed. self._session_manager = session_manager self._metrics: dict[str, int] = {} self._data: str | list[tuple[Any, ...]] | None = None @@ -300,6 +302,25 @@ def uncompressed_size(self) -> int | None: def column_names(self) -> list[str]: return [col.name for col in self._schema] + @property + def session_manager(self) -> SessionManager | None: + return self._session_manager + + @session_manager.setter + def session_manager(self, session_manager: SessionManager | None) -> None: + self._session_manager = session_manager + + @property + def http_config(self): + return self._session_manager.config + + @http_config.setter + def http_config(self, config: HttpConfig) -> None: + if self._session_manager: + self._session_manager.config = config + else: + self._session_manager = SessionManager(config=config) + def __iter__( self, ) -> Iterator[dict | Exception] | Iterator[tuple | Exception]: diff --git a/src/snowflake/connector/session_manager.py b/src/snowflake/connector/session_manager.py index 918a4b429d..43eeb87ee4 100644 --- a/src/snowflake/connector/session_manager.py +++ b/src/snowflake/connector/session_manager.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, Callable, Generator, Generic, Mapping, TypeVar from .compat import urlparse +from .proxy import get_proxy_url from .vendored import requests from .vendored.requests import Response, Session from .vendored.requests.adapters import BaseAdapter, HTTPAdapter @@ -79,8 +80,15 @@ def get_connection( proxy_manager = self.proxy_manager_for(proxy) if isinstance(proxy_manager, ProxyManager): - # Add Host to proxy header SNOW-232777 - proxy_manager.proxy_headers["Host"] = parsed_url.hostname + # Add Host to proxy header SNOW-232777 and SNOW-694457 + + # RFC 7230 / 5.4 – a proxy’s Host header must repeat the request authority + # verbatim: [:] with IPv6 still in [brackets]. We take that + # straight from urlparse(url).netloc, which preserves port and brackets (and case-sensitive hostname). + # Note: netloc also keeps user-info (user:pass@host) if present in URL. The driver never sends + # URLs with embedded credentials, so we leave them unhandled — for full support + # we’d need to manually concatenate hostname with optional port and IPv6 brackets. + proxy_manager.proxy_headers["Host"] = parsed_url.netloc else: logger.debug( f"Unable to set 'Host' to proxy manager of type {type(proxy_manager)} as" @@ -112,6 +120,10 @@ class BaseHttpConfig: use_pooling: bool = True max_retries: int | None = REQUESTS_RETRY + proxy_host: str | None = None + proxy_port: str | None = None + proxy_user: str | None = None + proxy_password: str | None = None def copy_with(self, **overrides: Any) -> BaseHttpConfig: """Return a new config with overrides applied.""" @@ -325,13 +337,13 @@ class SessionManager(_RequestVerbsUsingSessionMixin): **Two Operating Modes**: - use_pooling=False: One-shot sessions (create, use, close) - suitable for infrequent requests - use_pooling=True: Per-hostname session pools - reuses TCP connections, avoiding handshake - and SSL/TLS negotiation overhead for repeated requests to the same host + and SSL/TLS negotiation overhead for repeated requests to the same host. **Key Benefits**: - Centralized HTTP configuration management and easy propagation across the codebase - Consistent proxy setup (SNOW-694457) and headers customization (SNOW-2043816) - HTTPAdapter customization for connection-level request manipulation - - Performance optimization through connection reuse for high-traffic scenarios + - Performance optimization through connection reuse for high-traffic scenarios. **Usage**: Create the base session manager, then use clone() for derived managers to ensure proper config propagation. Pre-commit checks enforce usage to prevent code drift back to @@ -347,7 +359,6 @@ def __init__(self, config: HttpConfig | None = None, **http_config_kwargs) -> No logger.debug("Creating a config for the SessionManager") config = HttpConfig(**http_config_kwargs) self._cfg: HttpConfig = config - self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( lambda: SessionPool(self) ) @@ -370,6 +381,19 @@ def from_config(cls, cfg: HttpConfig, **overrides: Any) -> SessionManager: def config(self) -> HttpConfig: return self._cfg + @config.setter + def config(self, cfg: HttpConfig) -> None: + self._cfg = cfg + + @property + def proxy_url(self) -> str: + return get_proxy_url( + self._cfg.proxy_host, + self._cfg.proxy_port, + self._cfg.proxy_user, + self._cfg.proxy_password, + ) + @property def use_pooling(self) -> bool: return self._cfg.use_pooling @@ -427,6 +451,7 @@ def _mount_adapters(self, session: requests.Session) -> None: def make_session(self) -> Session: session = requests.Session() self._mount_adapters(session) + session.proxies = {"http": self.proxy_url, "https": self.proxy_url} return session @contextlib.contextmanager diff --git a/test/conftest.py b/test/conftest.py index e8a8081b20..50b7f287c3 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -5,6 +5,8 @@ from contextlib import contextmanager from logging import getLogger from pathlib import Path +from test.test_utils.cross_module_fixtures.http_fixtures import * # NOQA +from test.test_utils.cross_module_fixtures.wiremock_fixtures import * # NOQA from typing import Generator import pytest diff --git a/test/data/wiremock/mappings/auth/password/successful_flow.json b/test/data/wiremock/mappings/auth/password/successful_flow.json new file mode 100644 index 0000000000..9f2db70eec --- /dev/null +++ b/test/data/wiremock/mappings/auth/password/successful_flow.json @@ -0,0 +1,61 @@ +{ + "mappings": [ + { + "request": { + "urlPathPattern": "/session/v1/login-request.*", + "method": "POST", + "bodyPatterns": [ + { + "equalToJson" : { + "data": { + "LOGIN_NAME": "testUser", + "PASSWORD": "testPassword" + } + }, + "ignoreExtraElements" : true + } + ] + }, + "response": { + "status": 200, + "headers": { "Content-Type": "application/json" }, + "jsonBody": { + "data": { + "masterToken": "master token", + "token": "session token", + "validityInSeconds": 3600, + "masterValidityInSeconds": 14400, + "displayUserName": "TEST_USER", + "serverVersion": "8.48.0 b2024121104444034239f05", + "firstLogin": false, + "remMeToken": null, + "remMeValidityInSeconds": 0, + "healthCheckInterval": 45, + "newClientForUpgrade": "3.12.3", + "sessionId": 1172562260498, + "parameters": [ + { + "name": "CLIENT_PREFETCH_THREADS", + "value": 4 + } + ], + "sessionInfo": { + "databaseName": "TEST_DB", + "schemaName": "TEST_GO", + "warehouseName": "TEST_XSMALL", + "roleName": "ANALYST" + }, + "idToken": null, + "idTokenValidityInSeconds": 0, + "responseData": null, + "mfaToken": null, + "mfaTokenValidityInSeconds": 0 + }, + "code": null, + "message": null, + "success": true + } + } + } + ] +} diff --git a/test/data/wiremock/mappings/generic/proxy_forward_all.json b/test/data/wiremock/mappings/generic/proxy_forward_all.json new file mode 100644 index 0000000000..62ba091bf2 --- /dev/null +++ b/test/data/wiremock/mappings/generic/proxy_forward_all.json @@ -0,0 +1,12 @@ +{ + "request": { + "urlPattern": "/.*", + "method": "ANY" + }, + "response": { + "proxyBaseUrl": "{{TARGET_HTTP_HOST_WITH_PORT}}", + "additionalProxyRequestHeaders": { + "Via": "1.1 wiremock-proxy" + } + } +} diff --git a/test/data/wiremock/mappings/generic/telemetry.json b/test/data/wiremock/mappings/generic/telemetry.json new file mode 100644 index 0000000000..9b734a0cf2 --- /dev/null +++ b/test/data/wiremock/mappings/generic/telemetry.json @@ -0,0 +1,18 @@ +{ + "scenarioName": "Successful telemetry flow", + "request": { + "urlPathPattern": "/telemetry/send", + "method": "POST" + }, + "response": { + "status": 200, + "jsonBody": { + "data": { + "code": null, + "data": "Log Received", + "message": null, + "success": true + } + } + } +} diff --git a/test/data/wiremock/mappings/queries/chunk_1.json b/test/data/wiremock/mappings/queries/chunk_1.json new file mode 100644 index 0000000000..246874d3c4 --- /dev/null +++ b/test/data/wiremock/mappings/queries/chunk_1.json @@ -0,0 +1,14 @@ +{ + "request": { + "method": "GET", + "url": "/amazonaws/test/s3testaccount/stage/results/01bd1448-0100-0001-0000-0000001006f5_0/main/data_0_0_0_1?x-amz-server-side-encryption-customer-algorithm=AES256&response-content-encoding=gzip" + }, + "response": { + "status": 200, + "headers": { + "Content-Encoding": "gzip", + "x-amz-server-side-encryption-customer-algorithm": "AES256" + }, + "base64Body": "H4sIAM7YUGgC/4s21FHAiWIB81FB/x4AAAA=" + } +} diff --git a/test/data/wiremock/mappings/queries/chunk_2.json b/test/data/wiremock/mappings/queries/chunk_2.json new file mode 100644 index 0000000000..60f2756d0e --- /dev/null +++ b/test/data/wiremock/mappings/queries/chunk_2.json @@ -0,0 +1,14 @@ +{ + "request": { + "method": "GET", + "url": "/amazonaws/test/s3testaccount/stage/results/01bd1448-0100-0001-0000-0000001006f5_0/main/data_0_0_0_2?x-amz-server-side-encryption-customer-algorithm=AES256&response-content-encoding=gzip" + }, + "response": { + "status": 200, + "headers": { + "Content-Encoding": "gzip", + "x-amz-server-side-encryption-customer-algorithm": "AES256" + }, + "base64Body": "H4sIAM7YUGgC/4s21FHAiWIB81FB/x4AAAA=" + } +} diff --git a/test/data/wiremock/mappings/queries/select_1_successful.json b/test/data/wiremock/mappings/queries/select_1_successful.json new file mode 100644 index 0000000000..d0d880903d --- /dev/null +++ b/test/data/wiremock/mappings/queries/select_1_successful.json @@ -0,0 +1,200 @@ +{ + "scenarioName": "Successful SELECT 1 flow", + "request": { + "urlPathPattern": "/queries/v1/query-request.*", + "method": "POST", + "headers": { + "Authorization": { + "equalTo": "Snowflake Token=\"session token\"" + } + } + }, + "response": { + "status": 200, + "headers": { "Content-Type": "application/json" }, + "jsonBody": { + "data": { + "parameters": [ + { + "name": "TIMESTAMP_OUTPUT_FORMAT", + "value": "YYYY-MM-DD HH24:MI:SS.FF3 TZHTZM" + }, + { + "name": "CLIENT_PREFETCH_THREADS", + "value": 4 + }, + { + "name": "TIME_OUTPUT_FORMAT", + "value": "HH24:MI:SS" + }, + { + "name": "CLIENT_RESULT_CHUNK_SIZE", + "value": 16 + }, + { + "name": "TIMESTAMP_TZ_OUTPUT_FORMAT", + "value": "" + }, + { + "name": "CLIENT_SESSION_KEEP_ALIVE", + "value": false + }, + { + "name": "QUERY_CONTEXT_CACHE_SIZE", + "value": 5 + }, + { + "name": "CLIENT_METADATA_USE_SESSION_DATABASE", + "value": false + }, + { + "name": "CLIENT_OUT_OF_BAND_TELEMETRY_ENABLED", + "value": false + }, + { + "name": "ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1", + "value": true + }, + { + "name": "TIMESTAMP_NTZ_OUTPUT_FORMAT", + "value": "YYYY-MM-DD HH24:MI:SS.FF3" + }, + { + "name": "CLIENT_RESULT_PREFETCH_THREADS", + "value": 1 + }, + { + "name": "CLIENT_METADATA_REQUEST_USE_CONNECTION_CTX", + "value": false + }, + { + "name": "CLIENT_HONOR_CLIENT_TZ_FOR_TIMESTAMP_NTZ", + "value": true + }, + { + "name": "CLIENT_MEMORY_LIMIT", + "value": 1536 + }, + { + "name": "CLIENT_TIMESTAMP_TYPE_MAPPING", + "value": "TIMESTAMP_LTZ" + }, + { + "name": "TIMEZONE", + "value": "America/Los_Angeles" + }, + { + "name": "SERVICE_NAME", + "value": "" + }, + { + "name": "CLIENT_RESULT_PREFETCH_SLOTS", + "value": 2 + }, + { + "name": "CLIENT_TELEMETRY_ENABLED", + "value": true + }, + { + "name": "CLIENT_DISABLE_INCIDENTS", + "value": true + }, + { + "name": "CLIENT_USE_V1_QUERY_API", + "value": true + }, + { + "name": "CLIENT_RESULT_COLUMN_CASE_INSENSITIVE", + "value": false + }, + { + "name": "CSV_TIMESTAMP_FORMAT", + "value": "" + }, + { + "name": "BINARY_OUTPUT_FORMAT", + "value": "HEX" + }, + { + "name": "CLIENT_ENABLE_LOG_INFO_STATEMENT_PARAMETERS", + "value": false + }, + { + "name": "CLIENT_TELEMETRY_SESSIONLESS_ENABLED", + "value": true + }, + { + "name": "DATE_OUTPUT_FORMAT", + "value": "YYYY-MM-DD" + }, + { + "name": "CLIENT_STAGE_ARRAY_BINDING_THRESHOLD", + "value": 65280 + }, + { + "name": "CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY", + "value": 3600 + }, + { + "name": "CLIENT_SESSION_CLONE", + "value": false + }, + { + "name": "AUTOCOMMIT", + "value": true + }, + { + "name": "TIMESTAMP_LTZ_OUTPUT_FORMAT", + "value": "" + } + ], + "rowtype": [ + { + "name": "1", + "database": "", + "schema": "", + "table": "", + "nullable": false, + "length": null, + "type": "fixed", + "scale": 0, + "precision": 1, + "byteLength": null, + "collation": null + } + ], + "rowset": [ + [ + "1" + ] + ], + "total": 1, + "returned": 1, + "queryId": "01ba13b4-0104-e9fd-0000-0111029ca00e", + "databaseProvider": null, + "finalDatabaseName": null, + "finalSchemaName": null, + "finalWarehouseName": "TEST_XSMALL", + "numberOfBinds": 0, + "arrayBindSupported": false, + "statementTypeId": 4096, + "version": 1, + "sendResultTime": 1738317395581, + "queryResultFormat": "json", + "queryContext": { + "entries": [ + { + "id": 0, + "timestamp": 1738317395574564, + "priority": 0, + "context": "CPbPTg==" + } + ] + } + }, + "code": null, + "message": null, + "success": true + } + } +} diff --git a/test/data/wiremock/mappings/queries/select_large_request_successful.json b/test/data/wiremock/mappings/queries/select_large_request_successful.json new file mode 100644 index 0000000000..7199e2d279 --- /dev/null +++ b/test/data/wiremock/mappings/queries/select_large_request_successful.json @@ -0,0 +1,414 @@ +{ + "scenarioName": "Successful SELECT 1 flow", + "request": { + "urlPathPattern": "/queries/v1/query-request.*", + "method": "POST", + "headers": { + "Authorization": { + "equalTo": "Snowflake Token=\"session token\"" + } + } + }, + "response": { + "status": 200, + "headers": { "Content-Type": "application/json" }, + "jsonBody": { + "data": { + "parameters": [ + { + "name": "CLIENT_PREFETCH_THREADS", + "value": 4 + }, + { + "name": "TIMESTAMP_OUTPUT_FORMAT", + "value": "DY, DD MON YYYY HH24:MI:SS TZHTZM" + }, + { + "name": "PYTHON_SNOWPARK_CLIENT_MIN_VERSION_FOR_AST", + "value": "1.29.0" + }, + { + "name": "TIME_OUTPUT_FORMAT", + "value": "HH24:MI:SS" + }, + { + "name": "CLIENT_RESULT_CHUNK_SIZE", + "value": 160 + }, + { + "name": "TIMESTAMP_TZ_OUTPUT_FORMAT", + "value": "" + }, + { + "name": "CLIENT_SESSION_KEEP_ALIVE", + "value": false + }, + { + "name": "PYTHON_SNOWPARK_USE_CTE_OPTIMIZATION_VERSION", + "value": "1.31.1" + }, + { + "name": "CLIENT_METADATA_USE_SESSION_DATABASE", + "value": false + }, + { + "name": "QUERY_CONTEXT_CACHE_SIZE", + "value": 5 + }, + { + "name": "PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION_VERSION", + "value": "" + }, + { + "name": "ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1", + "value": false + }, + { + "name": "TIMESTAMP_NTZ_OUTPUT_FORMAT", + "value": "" + }, + { + "name": "CLIENT_RESULT_PREFETCH_THREADS", + "value": 1 + }, + { + "name": "CLIENT_METADATA_REQUEST_USE_CONNECTION_CTX", + "value": false + }, + { + "name": "CLIENT_HONOR_CLIENT_TZ_FOR_TIMESTAMP_NTZ", + "value": true + }, + { + "name": "CLIENT_MEMORY_LIMIT", + "value": 1536 + }, + { + "name": "CLIENT_TIMESTAMP_TYPE_MAPPING", + "value": "TIMESTAMP_LTZ" + }, + { + "name": "TIMEZONE", + "value": "UTC" + }, + { + "name": "PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER", + "value": true + }, + { + "name": "SNOWPARK_REQUEST_TIMEOUT_IN_SECONDS", + "value": 86400 + }, + { + "name": "PYTHON_SNOWPARK_USE_AST", + "value": false + }, + { + "name": "SERVICE_NAME", + "value": "" + }, + { + "name": "PYTHON_CONNECTOR_USE_NANOARROW", + "value": true + }, + { + "name": "CLIENT_RESULT_PREFETCH_SLOTS", + "value": 2 + }, + { + "name": "PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND", + "value": 10000000 + }, + { + "name": "PYTHON_SNOWPARK_GENERATE_MULTILINE_QUERIES", + "value": true + }, + { + "name": "CLIENT_DISABLE_INCIDENTS", + "value": true + }, + { + "name": "CSV_TIMESTAMP_FORMAT", + "value": "" + }, + { + "name": "BINARY_OUTPUT_FORMAT", + "value": "HEX" + }, + { + "name": "CLIENT_TELEMETRY_SESSIONLESS_ENABLED", + "value": true + }, + { + "name": "DATE_OUTPUT_FORMAT", + "value": "YYYY-MM-DD" + }, + { + "name": "CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY", + "value": 3600 + }, + { + "name": "PYTHON_SNOWPARK_AUTO_CLEAN_UP_TEMP_TABLE_ENABLED", + "value": false + }, + { + "name": "AUTOCOMMIT", + "value": true + }, + { + "name": "PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED", + "value": false + }, + { + "name": "CLIENT_SESSION_CLONE", + "value": false + }, + { + "name": "TIMESTAMP_LTZ_OUTPUT_FORMAT", + "value": "" + }, + { + "name": "CLIENT_OUT_OF_BAND_TELEMETRY_ENABLED", + "value": false + }, + { + "name": "PYTHON_SNOWPARK_DATAFRAME_JOIN_ALIAS_FIX_VERSION", + "value": "" + }, + { + "name": "PYTHON_SNOWPARK_COLLECT_TELEMETRY_AT_CRITICAL_PATH_VERSION", + "value": "1.28.0" + }, + { + "name": "PYTHON_SNOWPARK_AUTO_CLEAN_UP_TEMP_TABLE_ENABLED_VERSION", + "value": "" + }, + { + "name": "CLIENT_TELEMETRY_ENABLED", + "value": true + }, + { + "name": "PYTHON_SNOWPARK_ELIMINATE_NUMERIC_SQL_VALUE_CAST_ENABLED", + "value": false + }, + { + "name": "CLIENT_USE_V1_QUERY_API", + "value": true + }, + { + "name": "PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION", + "value": true + }, + { + "name": "CLIENT_RESULT_COLUMN_CASE_INSENSITIVE", + "value": false + }, + { + "name": "CLIENT_ENABLE_LOG_INFO_STATEMENT_PARAMETERS", + "value": false + }, + { + "name": "CLIENT_STAGE_ARRAY_BINDING_THRESHOLD", + "value": 65280 + }, + { + "name": "PYTHON_SNOWPARK_COMPILATION_STAGE_ENABLED", + "value": true + }, + { + "name": "PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_UPPER_BOUND", + "value": 12000000 + }, + { + "name": "PYTHON_SNOWPARK_CLIENT_AST_MODE", + "value": 0 + } + ], + "rowtype": [ + { + "name": "C0", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C1", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C2", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C3", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C4", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C5", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C6", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C7", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C8", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C9", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + } + ], + + "rowset": [ + [ + "1" + ] + ], + "qrmk": "+ZSmIj7I0L0BnU3zdVnSaHH5MW6cwY0GmLtz/Un5zSM=", + "chunkHeaders": { + "x-amz-server-side-encryption-customer-key": "+ZSmIj7I0L0BnU3zdVnSaHH5MW6cwY0GmLtz/Un5zSM=", + "x-amz-server-side-encryption-customer-key-md5": "ByrEgrMhjgAEMRr1QA/nGg==" + }, + "chunks": [ + { + "url": "{{WIREMOCK_HTTP_HOST_WITH_PORT}}/amazonaws/test/s3testaccount/stage/results/01bd1448-0100-0001-0000-0000001006f5_0/main/data_0_0_0_1?x-amz-server-side-encryption-customer-algorithm=AES256&response-content-encoding=gzip", + "rowCount": 4096, + "uncompressedSize": 331328, + "compressedSize": 326422 + }, + { + "url": "{{WIREMOCK_HTTP_HOST_WITH_PORT}}/amazonaws/test/s3testaccount/stage/results/01bd1448-0100-0001-0000-0000001006f5_0/main/data_0_0_0_2?x-amz-server-side-encryption-customer-algorithm=AES256&response-content-encoding=gzip", + "rowCount": 4096, + "uncompressedSize": 331328, + "compressedSize": 326176 + } + ], + "total": 50000, + "returned": 50000, + "queryId": "01bd137c-0100-0001-0000-0000001005b1", + "databaseProvider": null, + "finalDatabaseName": "TESTDB", + "finalSchemaName": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "finalWarehouseName": "REGRESS", + "finalRoleName": "ACCOUNTADMIN", + "numberOfBinds": 0, + "arrayBindSupported": false, + "statementTypeId": 4096, + "version": 1, + "sendResultTime": 1750110502822, + "queryResultFormat": "json", + "queryContext": { + "entries": [ + { + "id": 0, + "timestamp": 1748552075465658, + "priority": 0, + "context": "CAQ=" + } + ] + } + }, + "code": null, + "message": null, + "success": true + } + } +} diff --git a/test/integ/aio_it/conftest.py b/test/integ/aio_it/conftest.py index c3949c2424..dba36fba0c 100644 --- a/test/integ/aio_it/conftest.py +++ b/test/integ/aio_it/conftest.py @@ -9,11 +9,12 @@ get_db_parameters, is_public_testaccount, ) -from typing import AsyncContextManager, AsyncGenerator, Callable +from typing import Any, AsyncContextManager, AsyncGenerator, Callable import pytest from snowflake.connector.aio import SnowflakeConnection +from snowflake.connector.aio import connect as async_connect from snowflake.connector.aio._telemetry import TelemetryClient from snowflake.connector.connection import DefaultConverterClass from snowflake.connector.telemetry import TelemetryData @@ -70,13 +71,7 @@ def capture_sf_telemetry_async() -> TelemetryCaptureFixtureAsync: return TelemetryCaptureFixtureAsync() -async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnection: - """Creates a connection using the parameters defined in parameters.py. - - You can select from the different connections by supplying the appropiate - connection_name parameter and then anything else supplied will overwrite the values - from parameters.py. - """ +def fill_conn_kwargs_for_tests(connection_name: str, **kwargs) -> dict[str, Any]: ret = get_db_parameters(connection_name) ret.update(kwargs) @@ -95,9 +90,21 @@ async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnecti ret.pop("private_key", None) ret.pop("private_key_file", None) + return ret + + +async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnection: + """Creates a connection using the parameters defined in parameters.py. + + You can select from the different connections by supplying the appropiate + connection_name parameter and then anything else supplied will overwrite the values + from parameters.py. + """ + ret = fill_conn_kwargs_for_tests(connection_name, **kwargs) connection = SnowflakeConnection(**ret) + conn = await async_connect(**ret) await connection.connect() - return connection + return conn @asynccontextmanager diff --git a/test/integ/aio_it/test_connection_async.py b/test/integ/aio_it/test_connection_async.py index 315f5291f0..3ca3a1370b 100644 --- a/test/integ/aio_it/test_connection_async.py +++ b/test/integ/aio_it/test_connection_async.py @@ -46,7 +46,7 @@ CONNECTION_PARAMETERS_ADMIN = {} from snowflake.connector.aio.auth import AuthByOkta, AuthByPlugin -from .conftest import create_connection +from .conftest import create_connection, fill_conn_kwargs_for_tests try: from snowflake.connector.errorcode import ER_FAILED_PROCESSING_QMARK @@ -415,6 +415,8 @@ async def test_invalid_account_timeout(conn_cnx): @pytest.mark.timeout(15) async def test_invalid_proxy(conn_cnx): + http_proxy = os.environ.get("HTTP_PROXY") + https_proxy = os.environ.get("HTTPS_PROXY") with pytest.raises(OperationalError): async with conn_cnx( protocol="http", @@ -424,9 +426,41 @@ async def test_invalid_proxy(conn_cnx): proxy_port="3333", ): pass - # NOTE environment variable is set if the proxy parameter is specified. - del os.environ["HTTP_PROXY"] - del os.environ["HTTPS_PROXY"] + # NOTE environment variable is set ONLY FOR THE OLD DRIVER if the proxy parameter is specified. + # So this deletion is needed for old driver tests only. + if http_proxy is not None: + os.environ["HTTP_PROXY"] = http_proxy + else: + try: + del os.environ["HTTP_PROXY"] + except KeyError: + pass + if https_proxy is not None: + os.environ["HTTPS_PROXY"] = https_proxy + else: + try: + del os.environ["HTTPS_PROXY"] + except KeyError: + pass + + +@pytest.mark.skipolddriver +@pytest.mark.timeout(15) +async def test_invalid_proxy_not_impacting_env_vars(conn_cnx): + http_proxy = os.environ.get("HTTP_PROXY") + https_proxy = os.environ.get("HTTPS_PROXY") + with pytest.raises(OperationalError): + async with conn_cnx( + protocol="http", + account="testaccount", + login_timeout=5, + proxy_host="localhost", + proxy_port="3333", + ): + pass + # Proxy environment variables should not change + assert os.environ.get("HTTP_PROXY") == http_proxy + assert os.environ.get("HTTPS_PROXY") == https_proxy @pytest.mark.timeout(15) @@ -1437,6 +1471,64 @@ async def test_platform_detection_timeout(conn_cnx): assert cnx.platform_detection_timeout_seconds == 2.5 +@pytest.mark.skipolddriver +async def test_conn_cnx_basic(conn_cnx): + """Tests platform detection timeout. + + Creates a connection with platform_detection_timeout parameter. + """ + async with conn_cnx() as conn: + async with conn.cursor() as cur: + result = await (await cur.execute("select 1")).fetchall() + assert len(result) == 1 + assert result[0][0] == 1 + + +@pytest.mark.skipolddriver +async def test_conn_assigned_method(conn_cnx): + conn = await snowflake.connector.aio.connect( + **fill_conn_kwargs_for_tests("default") + ) + async with conn.cursor() as cur: + result = await (await cur.execute("select 1")).fetchall() + assert len(result) == 1 + assert result[0][0] == 1 + + +@pytest.mark.skipolddriver +async def test_conn_assigned_class(conn_cnx): + conn = snowflake.connector.aio.SnowflakeConnection( + **fill_conn_kwargs_for_tests("default") + ) + await conn.connect() + async with conn.cursor() as cur: + result = await (await cur.execute("select 1")).fetchall() + assert len(result) == 1 + assert result[0][0] == 1 + + +@pytest.mark.skipolddriver +async def test_conn_with_method(conn_cnx): + async with snowflake.connector.aio.connect( + **fill_conn_kwargs_for_tests("default") + ) as conn: + async with conn.cursor() as cur: + result = await (await cur.execute("select 1")).fetchall() + assert len(result) == 1 + assert result[0][0] == 1 + + +@pytest.mark.skipolddriver +async def test_conn_with_class(conn_cnx): + async with snowflake.connector.aio.SnowflakeConnection( + **fill_conn_kwargs_for_tests("default") + ) as conn: + async with conn.cursor() as cur: + result = await (await cur.execute("select 1")).fetchall() + assert len(result) == 1 + assert result[0][0] == 1 + + @pytest.mark.skipolddriver async def test_platform_detection_zero_timeout(conn_cnx): with ( diff --git a/test/integ/test_connection.py b/test/integ/test_connection.py index ee625a2dcd..c2dd3a3470 100644 --- a/test/integ/test_connection.py +++ b/test/integ/test_connection.py @@ -419,6 +419,8 @@ def test_invalid_account_timeout(conn_cnx): @pytest.mark.timeout(15) def test_invalid_proxy(conn_cnx): + http_proxy = os.environ.get("HTTP_PROXY") + https_proxy = os.environ.get("HTTPS_PROXY") with pytest.raises(OperationalError): with conn_cnx( protocol="http", @@ -428,9 +430,41 @@ def test_invalid_proxy(conn_cnx): proxy_port="3333", ): pass - # NOTE environment variable is set if the proxy parameter is specified. - del os.environ["HTTP_PROXY"] - del os.environ["HTTPS_PROXY"] + # NOTE environment variable is set ONLY FOR THE OLD DRIVER if the proxy parameter is specified. + # So this deletion is needed for old driver tests only. + if http_proxy is not None: + os.environ["HTTP_PROXY"] = http_proxy + else: + try: + del os.environ["HTTP_PROXY"] + except KeyError: + pass + if https_proxy is not None: + os.environ["HTTPS_PROXY"] = https_proxy + else: + try: + del os.environ["HTTPS_PROXY"] + except KeyError: + pass + + +@pytest.mark.skipolddriver +@pytest.mark.timeout(15) +def test_invalid_proxy_not_impacting_env_vars(conn_cnx): + http_proxy = os.environ.get("HTTP_PROXY") + https_proxy = os.environ.get("HTTPS_PROXY") + with pytest.raises(OperationalError): + with conn_cnx( + protocol="http", + account="testaccount", + login_timeout=5, + proxy_host="localhost", + proxy_port="3333", + ): + pass + # Proxy environment variables should not change + assert os.environ.get("HTTP_PROXY") == http_proxy + assert os.environ.get("HTTPS_PROXY") == https_proxy @pytest.mark.timeout(15) diff --git a/test/wiremock/__init__.py b/test/test_utils/__init__.py similarity index 100% rename from test/wiremock/__init__.py rename to test/test_utils/__init__.py diff --git a/test/test_utils/cross_module_fixtures/__init__.py b/test/test_utils/cross_module_fixtures/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/test_utils/cross_module_fixtures/http_fixtures.py b/test/test_utils/cross_module_fixtures/http_fixtures.py new file mode 100644 index 0000000000..a34d349be9 --- /dev/null +++ b/test/test_utils/cross_module_fixtures/http_fixtures.py @@ -0,0 +1,36 @@ +import os + +import pytest + + +@pytest.fixture +def proxy_env_vars(): + """Manages HTTP_PROXY and HTTPS_PROXY environment variables for testing.""" + original_http_proxy = os.environ.get("HTTP_PROXY") + original_https_proxy = os.environ.get("HTTPS_PROXY") + + def set_proxy_env_vars(proxy_url: str): + """Set both HTTP_PROXY and HTTPS_PROXY to the given URL.""" + os.environ["HTTP_PROXY"] = proxy_url + os.environ["HTTPS_PROXY"] = proxy_url + + def clear_proxy_env_vars(): + """Clear proxy environment variables.""" + if "HTTP_PROXY" in os.environ: + del os.environ["HTTP_PROXY"] + if "HTTPS_PROXY" in os.environ: + del os.environ["HTTPS_PROXY"] + + # Yield the helper functions + yield set_proxy_env_vars, clear_proxy_env_vars + + # Cleanup: restore original values + if original_http_proxy is not None: + os.environ["HTTP_PROXY"] = original_http_proxy + elif "HTTP_PROXY" in os.environ: + del os.environ["HTTP_PROXY"] + + if original_https_proxy is not None: + os.environ["HTTPS_PROXY"] = original_https_proxy + elif "HTTPS_PROXY" in os.environ: + del os.environ["HTTPS_PROXY"] diff --git a/test/test_utils/cross_module_fixtures/wiremock_fixtures.py b/test/test_utils/cross_module_fixtures/wiremock_fixtures.py new file mode 100644 index 0000000000..ddf7c22d12 --- /dev/null +++ b/test/test_utils/cross_module_fixtures/wiremock_fixtures.py @@ -0,0 +1,83 @@ +import pathlib +import uuid +from contextlib import contextmanager +from functools import partial +from typing import Any, Callable, ContextManager, Generator, Union + +import pytest + +import snowflake.connector + +from ..wiremock.wiremock_utils import WiremockClient, get_clients_for_proxy_and_target + + +@pytest.fixture(scope="session") +def wiremock_mapping_dir() -> pathlib.Path: + return ( + pathlib.Path(__file__).parent.parent.parent / "data" / "wiremock" / "mappings" + ) + + +@pytest.fixture(scope="session") +def wiremock_generic_mappings_dir(wiremock_mapping_dir) -> pathlib.Path: + return wiremock_mapping_dir / "generic" + + +@pytest.fixture(scope="session") +def wiremock_client() -> Generator[Union[WiremockClient, Any], Any, None]: + with WiremockClient() as client: + yield client + + +@pytest.fixture +def default_db_wiremock_parameters(wiremock_client: WiremockClient) -> dict[str, Any]: + db_params = { + "account": "testAccount", + "user": "testUser", + "password": "testPassword", + "host": wiremock_client.wiremock_host, + "port": wiremock_client.wiremock_http_port, + "protocol": "http", + "name": "python_tests_" + str(uuid.uuid4()).replace("-", "_"), + } + return db_params + + +@contextmanager +def db_wiremock( + default_db_wiremock_parameters: dict[str, Any], + **kwargs, +) -> Generator[snowflake.connector.SnowflakeConnection, None, None]: + ret = default_db_wiremock_parameters + ret.update(kwargs) + cnx = snowflake.connector.connect(**ret) + try: + yield cnx + finally: + cnx.close() + + +@pytest.fixture +def conn_cnx_wiremock( + default_db_wiremock_parameters, +) -> Callable[..., ContextManager[snowflake.connector.SnowflakeConnection]]: + return partial( + db_wiremock, default_db_wiremock_parameters=default_db_wiremock_parameters + ) + + +@pytest.fixture +def wiremock_target_proxy_pair(wiremock_generic_mappings_dir): + """Starts a *target* Wiremock and a *proxy* Wiremock pre-configured to forward to it. + + The fixture yields a tuple ``(target_wm, proxy_wm)`` of ``WiremockClient`` + instances. It is a thin wrapper around + ``test.test_utils.wiremock.wiremock_utils.proxy_target_pair``. + """ + wiremock_proxy_mapping_path = ( + wiremock_generic_mappings_dir / "proxy_forward_all.json" + ) + with get_clients_for_proxy_and_target( + proxy_mapping_template=wiremock_proxy_mapping_path + ) as pair: + yield pair diff --git a/test/test_utils/wiremock/__init__.py b/test/test_utils/wiremock/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/test_utils/wiremock/wiremock_utils.py b/test/test_utils/wiremock/wiremock_utils.py new file mode 100644 index 0000000000..7b7d15da54 --- /dev/null +++ b/test/test_utils/wiremock/wiremock_utils.py @@ -0,0 +1,347 @@ +import json +import logging +import pathlib +import socket +import subprocess +from contextlib import contextmanager +from time import sleep +from typing import Iterable, List, Optional, Union + +try: + from snowflake.connector.vendored import requests +except ImportError: + import requests + +WIREMOCK_START_MAX_RETRY_COUNT = 12 +logger = logging.getLogger(__name__) + + +def _get_mapping_str(mapping: Union[str, dict, pathlib.Path]) -> str: + if isinstance(mapping, str): + return mapping + if isinstance(mapping, dict): + return json.dumps(mapping) + if isinstance(mapping, pathlib.Path): + if mapping.is_file(): + with open(mapping) as f: + return f.read() + else: + raise RuntimeError(f"File with mapping: {mapping} does not exist") + + raise RuntimeError(f"Mapping {mapping} is of an invalid type") + + +class WiremockClient: + HTTP_HOST_PLACEHOLDER: str = "{{WIREMOCK_HTTP_HOST_WITH_PORT}}" + + def __init__( + self, + forbidden_ports: Optional[List[int]] = None, + additional_wiremock_process_args: Optional[Iterable[str]] = None, + ) -> None: + self.wiremock_filename = "wiremock-standalone.jar" + self.wiremock_host = "localhost" + self.wiremock_http_port = None + self.wiremock_https_port = None + self.forbidden_ports = forbidden_ports if forbidden_ports is not None else [] + + self.wiremock_dir = ( + pathlib.Path(__file__).parent.parent.parent.parent / ".wiremock" + ) + assert self.wiremock_dir.exists(), f"{self.wiremock_dir} does not exist" + + self.wiremock_jar_path = self.wiremock_dir / self.wiremock_filename + assert ( + self.wiremock_jar_path.exists() + ), f"{self.wiremock_jar_path} does not exist" + self._additional_wiremock_process_args = ( + additional_wiremock_process_args or list() + ) + + @property + def http_host_with_port(self) -> str: + return f"http://{self.wiremock_host}:{self.wiremock_http_port}" + + def get_http_placeholders(self) -> dict[str, str]: + """Placeholder that substitutes the target Wiremock's host:port in JSON.""" + return {self.HTTP_HOST_PLACEHOLDER: self.http_host_with_port} + + def add_expected_headers_to_mapping( + self, + mapping_str: str, + expected_headers: dict, + ) -> str: + """Add expected headers to all request matchers in mapping string.""" + mapping_dict = json.loads(mapping_str) + + def add_headers_to_request(request_dict: dict) -> None: + if "headers" not in request_dict: + request_dict["headers"] = {} + request_dict["headers"].update(expected_headers) + + if "request" in mapping_dict: + add_headers_to_request(mapping_dict["request"]) + + if "mappings" in mapping_dict: + for single_mapping in mapping_dict["mappings"]: + if "request" in single_mapping: + add_headers_to_request(single_mapping["request"]) + + return json.dumps(mapping_dict) + + def get_default_placeholders(self) -> dict[str, str]: + return self.get_http_placeholders() + + def _start_wiremock(self): + self.wiremock_http_port = self._find_free_port( + forbidden_ports=self.forbidden_ports, + ) + self.wiremock_https_port = self._find_free_port( + forbidden_ports=self.forbidden_ports + [self.wiremock_http_port] + ) + self.wiremock_process = subprocess.Popen( + [ + "java", + "-jar", + self.wiremock_jar_path, + "--root-dir", + self.wiremock_dir, + "--enable-browser-proxying", # work as forward proxy + "--proxy-pass-through", + "false", # pass through only matched requests + "--port", + str(self.wiremock_http_port), + "--https-port", + str(self.wiremock_https_port), + "--https-keystore", + self.wiremock_dir / "ca-cert.jks", + "--ca-keystore", + self.wiremock_dir / "ca-cert.jks", + ] + + self._additional_wiremock_process_args + ) + self._wait_for_wiremock() + + def _stop_wiremock(self): + if self.wiremock_process.poll() is not None: + logger.warning("Wiremock process already exited, skipping shutdown") + return + + try: + response = self._wiremock_post( + f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/shutdown" + ) + if response.status_code != 200: + logger.info("Wiremock shutdown failed, the process will be killed") + self.wiremock_process.kill() + else: + logger.debug("Wiremock shutdown gracefully") + except requests.exceptions.RequestException as e: + logger.warning(f"Shutdown request failed: {e}. Killing process directly.") + self.wiremock_process.kill() + + def _wait_for_wiremock(self): + retry_count = 0 + while retry_count < WIREMOCK_START_MAX_RETRY_COUNT: + if self._health_check(): + return + retry_count += 1 + sleep(1) + + raise TimeoutError( + f"WiremockClient did not respond within {WIREMOCK_START_MAX_RETRY_COUNT} seconds" + ) + + def _health_check(self): + mappings_endpoint = ( + f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/health" + ) + try: + response = requests.get(mappings_endpoint) + except requests.exceptions.RequestException as e: + logger.warning(f"Wiremock healthcheck failed with exception: {e}") + return False + + if ( + response.status_code == requests.codes.ok + and response.json()["status"] != "healthy" + ): + logger.warning(f"Wiremock healthcheck failed with response: {response}") + return False + elif response.status_code != requests.codes.ok: + logger.warning( + f"Wiremock healthcheck failed with status code: {response.status_code}" + ) + return False + + return True + + def _reset_wiremock(self): + clean_journal_endpoint = ( + f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/requests" + ) + requests.delete(clean_journal_endpoint) + reset_endpoint = ( + f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/reset" + ) + response = self._wiremock_post(reset_endpoint) + if response.status_code != requests.codes.ok: + raise RuntimeError("Failed to reset WiremockClient") + + def _wiremock_post( + self, endpoint: str, body: Optional[str] = None + ) -> requests.Response: + headers = {"Accept": "application/json", "Content-Type": "application/json"} + return requests.post(endpoint, data=body, headers=headers) + + def _replace_placeholders_in_mapping( + self, mapping_str: str, placeholders: Optional[dict[str, object]] + ) -> str: + if placeholders: + for key, value in placeholders.items(): + mapping_str = mapping_str.replace(str(key), str(value)) + return mapping_str + + def import_mapping( + self, + mapping: Union[str, dict, pathlib.Path], + placeholders: Optional[dict[str, object]] = None, + expected_headers: Optional[dict] = None, + ): + self._reset_wiremock() + import_mapping_endpoint = f"{self.http_host_with_port}/__admin/mappings/import" + + mapping_str = _get_mapping_str(mapping) + if expected_headers is not None: + mapping_str = self.add_expected_headers_to_mapping( + mapping_str, expected_headers + ) + + mapping_str = self._replace_placeholders_in_mapping(mapping_str, placeholders) + response = self._wiremock_post(import_mapping_endpoint, mapping_str) + if response.status_code != requests.codes.ok: + raise RuntimeError("Failed to import mapping") + + def import_mapping_with_default_placeholders( + self, + mapping: Union[str, dict, pathlib.Path], + expected_headers: Optional[dict] = None, + ): + placeholders = self.get_default_placeholders() + return self.import_mapping(mapping, placeholders, expected_headers) + + def add_mapping_with_default_placeholders( + self, + mapping: Union[str, dict, pathlib.Path], + expected_headers: Optional[dict] = None, + ): + placeholders = self.get_default_placeholders() + return self.add_mapping(mapping, placeholders, expected_headers) + + def add_mapping( + self, + mapping: Union[str, dict, pathlib.Path], + placeholders: Optional[dict[str, object]] = None, + expected_headers: Optional[dict] = None, + ): + add_mapping_endpoint = f"{self.http_host_with_port}/__admin/mappings" + + mapping_str = _get_mapping_str(mapping) + if expected_headers is not None: + mapping_str = self.add_expected_headers_to_mapping( + mapping_str, expected_headers + ) + + mapping_str = self._replace_placeholders_in_mapping(mapping_str, placeholders) + response = self._wiremock_post(add_mapping_endpoint, mapping_str) + if response.status_code != requests.codes.created: + raise RuntimeError("Failed to add mapping") + + def _find_free_port(self, forbidden_ports: Union[List[int], None] = None) -> int: + max_retries = 1 if forbidden_ports is None else 3 + if forbidden_ports is None: + forbidden_ports = [] + + retry_count = 0 + while retry_count < max_retries: + retry_count += 1 + with socket.socket() as sock: + sock.bind((self.wiremock_host, 0)) + port = sock.getsockname()[1] + if port not in forbidden_ports: + return port + + raise RuntimeError( + f"Unable to find a free port for wiremock in {max_retries} attempts" + ) + + def __enter__(self): + self._start_wiremock() + logger.debug( + f"Starting wiremock process, listening on {self.wiremock_host}:{self.wiremock_http_port}" + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + logger.debug("Stopping wiremock process") + self._stop_wiremock() + + +@contextmanager +def get_clients_for_proxy_and_target( + proxy_mapping_template: Union[str, dict, pathlib.Path, None] = None, + additional_proxy_placeholders: Optional[dict[str, object]] = None, + additional_proxy_args: Optional[Iterable[str]] = None, +): + """Context manager that starts two Wiremock instances – *target* and *proxy* – and + configures the proxy to forward **all** traffic to the target. + + It yields a tuple ``(target_wm, proxy_wm)`` where both items are fully initialised + ``WiremockClient`` objects ready for use in tests. When the context exits both + Wiremock processes are shut down automatically. + + Parameters + ---------- + proxy_mapping_template + Mapping JSON (str / dict / pathlib.Path) to be used for configuring the proxy + Wiremock. If *None*, the default template at + ``test/data/wiremock/mappings/proxy/forward_all.json`` is used. + additional_proxy_placeholders + Optional placeholders to be replaced in the proxy mapping *in addition* to the + automatically provided ``{{TARGET_HTTP_HOST_WITH_PORT}}``. + additional_proxy_args + Extra command-line arguments passed to the proxy Wiremock instance when it is + launched. Useful for tweaking Wiremock behaviour in specific tests. + """ + + # Resolve default mapping template if none provided + if proxy_mapping_template is None: + proxy_mapping_template = ( + pathlib.Path(__file__).parent.parent.parent.parent + / "test" + / "data" + / "wiremock" + / "mappings" + / "generic" + / "proxy_forward_all.json" + ) + + # Start the *target* Wiremock first – this will emulate Snowflake / IdP backend + with WiremockClient() as target_wm: + # Then start the *proxy* Wiremock and ensure it doesn't try to bind the same port + with WiremockClient( + forbidden_ports=[target_wm.wiremock_http_port], + additional_wiremock_process_args=additional_proxy_args, + ) as proxy_wm: + # Prepare placeholders so that proxy forwards to the *target* + placeholders: dict[str, object] = { + "{{TARGET_HTTP_HOST_WITH_PORT}}": target_wm.http_host_with_port + } + if additional_proxy_placeholders: + placeholders.update(additional_proxy_placeholders) + + # Configure proxy Wiremock to forward everything to target + proxy_wm.add_mapping(proxy_mapping_template, placeholders=placeholders) + + # Yield control back to the caller with both Wiremocks ready + yield target_wm, proxy_wm diff --git a/test/unit/aio/test_connection_async_unit.py b/test/unit/aio/test_connection_async_unit.py index f173f6de87..590a85711b 100644 --- a/test/unit/aio/test_connection_async_unit.py +++ b/test/unit/aio/test_connection_async_unit.py @@ -27,6 +27,7 @@ from cryptography.hazmat.primitives.asymmetric import rsa import snowflake.connector.aio +from snowflake.connector.aio import connect as async_connect from snowflake.connector.aio._network import SnowflakeRestful from snowflake.connector.aio.auth import ( AuthByDefault, @@ -773,3 +774,93 @@ async def test_invalid_authenticator(): ) await conn.connect() assert "Unknown authenticator: INVALID" in str(excinfo.value) + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize("proxy_method", ["explicit_args", "env_vars"]) +async def test_large_query_through_proxy_async( + wiremock_generic_mappings_dir, + wiremock_target_proxy_pair, + wiremock_mapping_dir, + proxy_env_vars, + proxy_method, +): + target_wm, proxy_wm = wiremock_target_proxy_pair + + password_mapping = wiremock_mapping_dir / "auth/password/successful_flow.json" + multi_chunk_request_mapping = ( + wiremock_mapping_dir / "queries/select_large_request_successful.json" + ) + disconnect_mapping = ( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + telemetry_mapping = wiremock_generic_mappings_dir / "telemetry.json" + chunk_1_mapping = wiremock_mapping_dir / "queries/chunk_1.json" + chunk_2_mapping = wiremock_mapping_dir / "queries/chunk_2.json" + + expected_headers = {"Via": {"contains": "wiremock"}} + + target_wm.import_mapping(password_mapping, expected_headers=expected_headers) + target_wm.add_mapping_with_default_placeholders( + multi_chunk_request_mapping, expected_headers + ) + target_wm.add_mapping(disconnect_mapping, expected_headers=expected_headers) + target_wm.add_mapping(telemetry_mapping, expected_headers=expected_headers) + target_wm.add_mapping_with_default_placeholders(chunk_1_mapping, expected_headers) + target_wm.add_mapping_with_default_placeholders(chunk_2_mapping, expected_headers) + + set_proxy_env_vars, clear_proxy_env_vars = proxy_env_vars + connect_kwargs = { + "user": "testUser", + "password": "testPassword", + "account": "testAccount", + "host": target_wm.wiremock_host, + "port": target_wm.wiremock_http_port, + "protocol": "http", + "warehouse": "TEST_WH", + } + + if proxy_method == "explicit_args": + connect_kwargs.update( + { + "proxy_host": proxy_wm.wiremock_host, + "proxy_port": str(proxy_wm.wiremock_http_port), + "proxy_user": "proxyUser", + "proxy_password": "proxyPass", + } + ) + clear_proxy_env_vars() + else: + proxy_url = f"http://proxyUser:proxyPass@{proxy_wm.wiremock_host}:{proxy_wm.wiremock_http_port}" + set_proxy_env_vars(proxy_url) + + row_count = 50_000 + conn = await async_connect(**connect_kwargs) + try: + cur = conn.cursor() + await cur.execute( + f"select seq4() as n from table(generator(rowcount => {row_count}));" + ) + assert len(cur._result_set.batches) > 1 + _ = [r async for r in cur] + finally: + await conn.close() + + async with aiohttp.ClientSession() as session: + async with session.get( + f"{proxy_wm.http_host_with_port}/__admin/requests" + ) as resp: + proxy_reqs = await resp.json() + assert any( + "/queries/v1/query-request" in r["request"]["url"] + for r in proxy_reqs["requests"] + ) + + async with session.get( + f"{target_wm.http_host_with_port}/__admin/requests" + ) as resp: + target_reqs = await resp.json() + assert any( + "/queries/v1/query-request" in r["request"]["url"] + for r in target_reqs["requests"] + ) diff --git a/test/unit/aio/test_oauth_token_async.py b/test/unit/aio/test_oauth_token_async.py index 16bee7dc78..e54fd2dca5 100644 --- a/test/unit/aio/test_oauth_token_async.py +++ b/test/unit/aio/test_oauth_token_async.py @@ -4,10 +4,10 @@ import logging import pathlib -from typing import Any, Generator, Union from unittest import mock from unittest.mock import Mock, patch +import aiohttp import pytest try: @@ -19,18 +19,12 @@ import snowflake.connector.errors from snowflake.connector.token_cache import TokenCache, TokenKey, TokenType -from ...wiremock.wiremock_utils import WiremockClient +from ...test_utils.wiremock.wiremock_utils import WiremockClient from ..test_oauth_token import omit_oauth_urls_check # noqa: F401 logger = logging.getLogger(__name__) -@pytest.fixture(scope="session") -def wiremock_client() -> Generator[Union[WiremockClient, Any], Any, None]: - with WiremockClient() as client: - yield client - - @pytest.fixture(scope="session") def wiremock_oauth_authorization_code_dir() -> pathlib.Path: return ( @@ -57,17 +51,6 @@ def wiremock_oauth_client_creds_dir() -> pathlib.Path: ) -@pytest.fixture(scope="session") -def wiremock_generic_mappings_dir() -> pathlib.Path: - return ( - pathlib.Path(__file__).parent.parent.parent - / "data" - / "wiremock" - / "mappings" - / "generic" - ) - - @pytest.fixture(scope="session") def wiremock_oauth_refresh_token_dir() -> pathlib.Path: return ( @@ -717,3 +700,159 @@ async def test_client_creds_expired_refresh_token_flow_async( cached_refresh_token = temp_cache_async.retrieve(refresh_token_key) assert cached_access_token == "expired-access-token-123" assert cached_refresh_token == "expired-refresh-token-123" + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize("proxy_method", ["explicit_args", "env_vars"]) +async def test_client_credentials_flow_through_proxy_async( + wiremock_oauth_client_creds_dir, + wiremock_generic_mappings_dir, + wiremock_target_proxy_pair, + temp_cache_async, + proxy_env_vars, + proxy_method, +): + """Run OAuth Client-Credentials flow and ensure it goes through proxy (async).""" + from snowflake.connector.aio import SnowflakeConnection + + target_wm, proxy_wm = wiremock_target_proxy_pair + + expected_headers = {"Via": {"contains": "wiremock"}} + + target_wm.import_mapping_with_default_placeholders( + wiremock_oauth_client_creds_dir / "successful_flow.json", expected_headers + ) + target_wm.add_mapping_with_default_placeholders( + wiremock_generic_mappings_dir / "snowflake_login_successful.json", + expected_headers, + ) + target_wm.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json", + expected_headers=expected_headers, + ) + + token_request_url = f"http://{target_wm.wiremock_host}:{target_wm.wiremock_http_port}/oauth/token-request" + + set_proxy_env_vars, clear_proxy_env_vars = proxy_env_vars + connect_kwargs = { + "user": "testUser", + "authenticator": "OAUTH_CLIENT_CREDENTIALS", + "oauth_client_id": "cid", + "oauth_client_secret": "secret", + "account": "testAccount", + "protocol": "http", + "role": "ANALYST", + "oauth_token_request_url": token_request_url, + "host": target_wm.wiremock_host, + "port": target_wm.wiremock_http_port, + "oauth_enable_refresh_tokens": True, + "client_store_temporary_credential": True, + "token_cache": temp_cache_async, + } + + if proxy_method == "explicit_args": + connect_kwargs.update( + { + "proxy_host": proxy_wm.wiremock_host, + "proxy_port": str(proxy_wm.wiremock_http_port), + "proxy_user": "proxyUser", + "proxy_password": "proxyPass", + } + ) + clear_proxy_env_vars() + else: + proxy_url = f"http://proxyUser:proxyPass@{proxy_wm.wiremock_host}:{proxy_wm.wiremock_http_port}" + set_proxy_env_vars(proxy_url) + + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + cnx = SnowflakeConnection(**connect_kwargs) + await cnx.connect() + await cnx.close() + + async with aiohttp.ClientSession() as session: + async with session.get( + f"{proxy_wm.http_host_with_port}/__admin/requests" + ) as resp: + proxy_requests = await resp.json() + assert any( + req["request"]["url"].endswith("/oauth/token-request") + for req in proxy_requests["requests"] + ) + + async with session.get( + f"{target_wm.http_host_with_port}/__admin/requests" + ) as resp: + target_requests = await resp.json() + assert any( + req["request"]["url"].endswith("/oauth/token-request") + for req in target_requests["requests"] + ) + + +@pytest.mark.skipolddriver +@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30) +async def test_client_credentials_flow_via_explicit_proxy( + wiremock_oauth_authorization_code_dir, + wiremock_generic_mappings_dir, + wiremock_target_proxy_pair, + webbrowser_mock_sync, + monkeypatch, + omit_oauth_urls_check, # noqa: F811 +) -> None: + from snowflake.connector.aio import SnowflakeConnection + + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + target_wm, proxy_wm = wiremock_target_proxy_pair + + target_wm.import_mapping_with_default_placeholders( + wiremock_oauth_authorization_code_dir / "successful_flow.json", + ) + target_wm.add_mapping_with_default_placeholders( + wiremock_generic_mappings_dir / "snowflake_login_successful.json", + ) + target_wm.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json", + ) + + with mock.patch("webbrowser.open", new=webbrowser_mock_sync.open): + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + cnx = SnowflakeConnection( + user="testUser", + authenticator="OAUTH_AUTHORIZATION_CODE", + oauth_client_id="123", + account="testAccount", + protocol="http", + role="ANALYST", + proxy_host=proxy_wm.wiremock_host, + proxy_port=str(proxy_wm.wiremock_http_port), + proxy_user="proxyUser", + proxy_password="proxyPass", + oauth_client_secret="testClientSecret", + oauth_token_request_url=f"http://{target_wm.wiremock_host}:{target_wm.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{target_wm.wiremock_host}:{target_wm.wiremock_http_port}/oauth/authorize", + oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect", + host=target_wm.wiremock_host, + port=target_wm.wiremock_http_port, + ) + + await cnx.connect() + await cnx.close() + + async with aiohttp.ClientSession() as session: + async with session.get( + f"http://{proxy_wm.wiremock_host}:{proxy_wm.wiremock_http_port}/__admin/requests" + ) as resp: + proxy_requests = await resp.json() + assert any( + req["request"]["url"].endswith("/oauth/token-request") + for req in proxy_requests["requests"] + ), "Proxy did not record token-request" + + async with session.get( + f"http://{target_wm.wiremock_host}:{target_wm.wiremock_http_port}/__admin/requests" + ) as resp: + target_requests = await resp.json() + assert any( + req["request"]["url"].endswith("/oauth/token-request") + for req in target_requests["requests"] + ), "Target did not receive token-request forwarded by proxy" diff --git a/test/unit/aio/test_programmatic_access_token_async.py b/test/unit/aio/test_programmatic_access_token_async.py index 65c697975c..356ec572c9 100644 --- a/test/unit/aio/test_programmatic_access_token_async.py +++ b/test/unit/aio/test_programmatic_access_token_async.py @@ -5,7 +5,6 @@ from __future__ import annotations import pathlib -from typing import Any, Generator import pytest @@ -17,13 +16,7 @@ import snowflake.connector.errors -from ...wiremock.wiremock_utils import WiremockClient - - -@pytest.fixture(scope="session") -def wiremock_client() -> Generator[WiremockClient | Any, Any, None]: - with WiremockClient() as client: - yield client +from ...test_utils.wiremock.wiremock_utils import WiremockClient @pytest.mark.skipolddriver diff --git a/test/unit/aio/test_proxies_async.py b/test/unit/aio/test_proxies_async.py new file mode 100644 index 0000000000..786972de90 --- /dev/null +++ b/test/unit/aio/test_proxies_async.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import aiohttp +import pytest + +from snowflake.connector.aio import connect + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.timeout(15) +@pytest.mark.parametrize("proxy_method", ["explicit_args", "env_vars"]) +async def test_basic_query_through_proxy_async( + wiremock_generic_mappings_dir, + wiremock_target_proxy_pair, + wiremock_mapping_dir, + proxy_env_vars, + proxy_method, +): + + target_wm, proxy_wm = wiremock_target_proxy_pair + + password_mapping = wiremock_mapping_dir / "auth/password/successful_flow.json" + select_mapping = wiremock_mapping_dir / "queries/select_1_successful.json" + disconnect_mapping = ( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + telemetry_mapping = wiremock_generic_mappings_dir / "telemetry.json" + + expected_headers = {"Via": {"contains": "wiremock"}} + + target_wm.import_mapping_with_default_placeholders( + password_mapping, expected_headers + ) + target_wm.add_mapping_with_default_placeholders(select_mapping, expected_headers) + target_wm.add_mapping(disconnect_mapping) + target_wm.add_mapping(telemetry_mapping) + + set_proxy_env_vars, clear_proxy_env_vars = proxy_env_vars + connect_kwargs = { + "user": "testUser", + "password": "testPassword", + "account": "testAccount", + "host": target_wm.wiremock_host, + "port": target_wm.wiremock_http_port, + "protocol": "http", + "warehouse": "TEST_WH", + } + + if proxy_method == "explicit_args": + connect_kwargs.update( + { + "proxy_host": proxy_wm.wiremock_host, + "proxy_port": str(proxy_wm.wiremock_http_port), + } + ) + clear_proxy_env_vars() + else: + proxy_url = f"http://{proxy_wm.wiremock_host}:{proxy_wm.wiremock_http_port}" + set_proxy_env_vars(proxy_url) + + conn = await connect(**connect_kwargs) + try: + cur = conn.cursor() + await cur.execute("SELECT 1") + row = await cur.fetchone() + assert row[0] == 1 + finally: + await conn.close() + + async with aiohttp.ClientSession() as session: + async with session.get( + f"{proxy_wm.http_host_with_port}/__admin/requests" + ) as resp: + proxy_reqs = await resp.json() + assert any( + "/queries/v1/query-request" in r["request"]["url"] + for r in proxy_reqs["requests"] + ) + + async with session.get( + f"{target_wm.http_host_with_port}/__admin/requests" + ) as resp: + target_reqs = await resp.json() + assert any( + "/queries/v1/query-request" in r["request"]["url"] + for r in target_reqs["requests"] + ) diff --git a/test/unit/test_connection.py b/test/unit/test_connection.py index 9b8edb66de..3ef2fd6e36 100644 --- a/test/unit/test_connection.py +++ b/test/unit/test_connection.py @@ -42,6 +42,7 @@ AuthByDefault = AuthByOkta = AuthByOAuth = AuthByWebBrowser = MagicMock try: # pragma: no cover + import snowflake.connector.vendored.requests as requests from snowflake.connector.auth import AuthByUsrPwdMfa from snowflake.connector.config_manager import CONFIG_MANAGER from snowflake.connector.constants import ( @@ -808,3 +809,88 @@ def test_reraise_error_in_file_transfer_work_function_config( expected_value = bool(reraise_enabled) actual_value = conn._reraise_error_in_file_transfer_work_function assert actual_value == expected_value + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize("proxy_method", ["explicit_args", "env_vars"]) +def test_large_query_through_proxy( + wiremock_generic_mappings_dir, + wiremock_target_proxy_pair, + wiremock_mapping_dir, + proxy_env_vars, + proxy_method, +): + target_wm, proxy_wm = wiremock_target_proxy_pair + + password_mapping = wiremock_mapping_dir / "auth/password/successful_flow.json" + multi_chunk_request_mapping = ( + wiremock_mapping_dir / "queries/select_large_request_successful.json" + ) + disconnect_mapping = ( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + telemetry_mapping = wiremock_generic_mappings_dir / "telemetry.json" + chunk_1_mapping = wiremock_mapping_dir / "queries/chunk_1.json" + chunk_2_mapping = wiremock_mapping_dir / "queries/chunk_2.json" + + # Configure mappings with proxy header verification + expected_headers = {"Via": {"contains": "wiremock"}} + + target_wm.import_mapping(password_mapping, expected_headers=expected_headers) + target_wm.add_mapping_with_default_placeholders( + multi_chunk_request_mapping, expected_headers + ) + target_wm.add_mapping(disconnect_mapping, expected_headers=expected_headers) + target_wm.add_mapping(telemetry_mapping, expected_headers=expected_headers) + target_wm.add_mapping_with_default_placeholders(chunk_1_mapping, expected_headers) + target_wm.add_mapping_with_default_placeholders(chunk_2_mapping, expected_headers) + + # Configure proxy based on test parameter + set_proxy_env_vars, clear_proxy_env_vars = proxy_env_vars + connect_kwargs = { + "user": "testUser", + "password": "testPassword", + "account": "testAccount", + "host": target_wm.wiremock_host, + "port": target_wm.wiremock_http_port, + "protocol": "http", + "warehouse": "TEST_WH", + } + + if proxy_method == "explicit_args": + connect_kwargs.update( + { + "proxy_host": proxy_wm.wiremock_host, + "proxy_port": str(proxy_wm.wiremock_http_port), + "proxy_user": "proxyUser", + "proxy_password": "proxyPass", + } + ) + clear_proxy_env_vars() # Ensure no env vars interfere + else: # env_vars + proxy_url = f"http://proxyUser:proxyPass@{proxy_wm.wiremock_host}:{proxy_wm.wiremock_http_port}" + set_proxy_env_vars(proxy_url) + + row_count = 50_000 + with snowflake.connector.connect(**connect_kwargs) as conn: + cursors = conn.execute_string( + f"select seq4() as n from table(generator(rowcount => {row_count}));" + ) + assert len(cursors[0]._result_set.batches) > 1 # We need to have remote results + assert list(cursors[0]) + + # Ensure proxy saw query + proxy_reqs = requests.get(f"{proxy_wm.http_host_with_port}/__admin/requests").json() + assert any( + "/queries/v1/query-request" in r["request"]["url"] + for r in proxy_reqs["requests"] + ) + + # Ensure backend saw query + target_reqs = requests.get( + f"{target_wm.http_host_with_port}/__admin/requests" + ).json() + assert any( + "/queries/v1/query-request" in r["request"]["url"] + for r in target_reqs["requests"] + ) diff --git a/test/unit/test_oauth_token.py b/test/unit/test_oauth_token.py index bc1e650adb..b19d9415d6 100644 --- a/test/unit/test_oauth_token.py +++ b/test/unit/test_oauth_token.py @@ -5,7 +5,6 @@ import logging import pathlib from threading import Thread -from typing import Any, Generator, Union from unittest import mock from unittest.mock import Mock, patch @@ -16,17 +15,11 @@ from snowflake.connector.auth import AuthByOauthCredentials from snowflake.connector.token_cache import TokenCache, TokenKey, TokenType -from ..wiremock.wiremock_utils import WiremockClient +from ..test_utils.wiremock.wiremock_utils import WiremockClient logger = logging.getLogger(__name__) -@pytest.fixture(scope="session") -def wiremock_client() -> Generator[Union[WiremockClient, Any], Any, None]: - with WiremockClient() as client: - yield client - - @pytest.fixture(scope="session") def wiremock_oauth_authorization_code_dir() -> pathlib.Path: return ( @@ -53,17 +46,6 @@ def wiremock_oauth_client_creds_dir() -> pathlib.Path: ) -@pytest.fixture(scope="session") -def wiremock_generic_mappings_dir() -> pathlib.Path: - return ( - pathlib.Path(__file__).parent.parent - / "data" - / "wiremock" - / "mappings" - / "generic" - ) - - @pytest.fixture(scope="session") def wiremock_oauth_refresh_token_dir() -> pathlib.Path: return ( @@ -701,3 +683,156 @@ def test_client_creds_expired_refresh_token_flow( cached_refresh_token = temp_cache.retrieve(refresh_token_key) assert cached_access_token == "expired-access-token-123" assert cached_refresh_token == "expired-refresh-token-123" + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize("proxy_method", ["explicit_args", "env_vars"]) +def test_client_credentials_flow_via_explicit_proxy( + wiremock_oauth_client_creds_dir, + wiremock_generic_mappings_dir, + wiremock_target_proxy_pair, + temp_cache, + wiremock_mapping_dir, + proxy_env_vars, + proxy_method, +): + """Spin up two Wiremock instances (target & proxy) via shared fixture and run OAuth Client-Credentials flow through the proxy.""" + + target_wm, proxy_wm = wiremock_target_proxy_pair + + # Configure backend (Snowflake + IdP) responses with proxy header verification + expected_headers = {"Via": {"contains": "wiremock"}} + + target_wm.import_mapping_with_default_placeholders( + wiremock_oauth_client_creds_dir / "successful_flow.json", expected_headers + ) + target_wm.add_mapping_with_default_placeholders( + wiremock_generic_mappings_dir / "snowflake_login_successful.json", + expected_headers, + ) + target_wm.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json", + expected_headers=expected_headers, + ) + + token_request_url = f"http://{target_wm.wiremock_host}:{target_wm.wiremock_http_port}/oauth/token-request" + + # Configure proxy based on test parameter + set_proxy_env_vars, clear_proxy_env_vars = proxy_env_vars + connect_kwargs = { + "user": "testUser", + "authenticator": "OAUTH_CLIENT_CREDENTIALS", + "oauth_client_id": "cid", + "oauth_client_secret": "secret", + "account": "testAccount", + "protocol": "http", + "role": "ANALYST", + "oauth_token_request_url": token_request_url, + "host": target_wm.wiremock_host, + "port": target_wm.wiremock_http_port, + "oauth_enable_refresh_tokens": True, + "client_store_temporary_credential": True, + "token_cache": temp_cache, + } + + if proxy_method == "explicit_args": + connect_kwargs.update( + { + "proxy_host": proxy_wm.wiremock_host, + "proxy_port": str(proxy_wm.wiremock_http_port), + "proxy_user": "proxyUser", + "proxy_password": "proxyPass", + } + ) + clear_proxy_env_vars() # Ensure no env vars interfere + else: # env_vars + proxy_url = f"http://proxyUser:proxyPass@{proxy_wm.wiremock_host}:{proxy_wm.wiremock_http_port}" + set_proxy_env_vars(proxy_url) + + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + cnx = snowflake.connector.connect(**connect_kwargs) + assert cnx, "Connection object should be valid" + cnx.close() + + # Verify proxy & backend saw the token request + proxy_requests = requests.get( + f"{proxy_wm.http_host_with_port}/__admin/requests" + ).json() + assert any( + req["request"]["url"].endswith("/oauth/token-request") + for req in proxy_requests["requests"] + ) + + target_requests = requests.get( + f"{target_wm.http_host_with_port}/__admin/requests" + ).json() + assert any( + req["request"]["url"].endswith("/oauth/token-request") + for req in target_requests["requests"] + ) + + +@pytest.mark.skipolddriver +@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30) +def test_oauth_code_successful_flow_through_proxy( + wiremock_oauth_authorization_code_dir, + wiremock_generic_mappings_dir, + wiremock_target_proxy_pair, + webbrowser_mock, + monkeypatch, + omit_oauth_urls_check, +) -> None: + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + target_wm, proxy_wm = wiremock_target_proxy_pair + + target_wm.import_mapping_with_default_placeholders( + wiremock_oauth_authorization_code_dir / "successful_flow.json", + ) + target_wm.add_mapping_with_default_placeholders( + wiremock_generic_mappings_dir / "snowflake_login_successful.json", + ) + target_wm.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json", + ) + + with mock.patch("webbrowser.open", new=webbrowser_mock.open): + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + cnx = snowflake.connector.connect( + user="testUser", + authenticator="OAUTH_AUTHORIZATION_CODE", + oauth_client_id="123", + account="testAccount", + protocol="http", + role="ANALYST", + proxy_host=proxy_wm.wiremock_host, + proxy_port=str(proxy_wm.wiremock_http_port), + proxy_user="proxyUser", + proxy_password="proxyPass", + oauth_client_secret="testClientSecret", + oauth_token_request_url=f"http://{target_wm.wiremock_host}:{target_wm.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{target_wm.wiremock_host}:{target_wm.wiremock_http_port}/oauth/authorize", + oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect", + host=target_wm.wiremock_host, + port=target_wm.wiremock_http_port, + ) + + assert cnx, "invalid cnx" + cnx.close() + + # Verify: proxy Wiremock saw the token request + proxy_requests = requests.get( + f"http://{proxy_wm.wiremock_host}:{proxy_wm.wiremock_http_port}/__admin/requests" + ).json() + assert any( + req["request"]["url"].endswith("/oauth/token-request") + for req in proxy_requests["requests"] + ), "Proxy did not record token-request" + + # Verify: target Wiremock also saw it (because proxy forwarded) + target_requests = requests.get( + f"http://{target_wm.wiremock_host}:{target_wm.wiremock_http_port}/__admin/requests" + ).json() + assert any( + req["request"]["url"].endswith("/oauth/token-request") + for req in target_requests["requests"] + ), "Target did not receive token-request forwarded by proxy" diff --git a/test/unit/test_programmatic_access_token.py b/test/unit/test_programmatic_access_token.py index 7d6ecb175e..fdf5bc0c9d 100644 --- a/test/unit/test_programmatic_access_token.py +++ b/test/unit/test_programmatic_access_token.py @@ -1,5 +1,4 @@ import pathlib -from typing import Any, Generator, Union import pytest @@ -9,13 +8,7 @@ except ImportError: pass -from ..wiremock.wiremock_utils import WiremockClient - - -@pytest.fixture(scope="session") -def wiremock_client() -> Generator[Union[WiremockClient, Any], Any, None]: - with WiremockClient() as client: - yield client +from ..test_utils.wiremock.wiremock_utils import WiremockClient @pytest.mark.skipolddriver diff --git a/test/unit/test_proxies.py b/test/unit/test_proxies.py index fbd2d47268..b32e1dcb09 100644 --- a/test/unit/test_proxies.py +++ b/test/unit/test_proxies.py @@ -8,32 +8,23 @@ import pytest import snowflake.connector +import snowflake.connector.vendored.requests as requests from snowflake.connector.errors import OperationalError -def test_set_proxies(): - from snowflake.connector.proxy import set_proxies +@pytest.mark.skipolddriver +def test_get_proxy_url(): + from snowflake.connector.proxy import get_proxy_url - assert set_proxies("proxyhost", "8080") == { - "http": "http://proxyhost:8080", - "https": "http://proxyhost:8080", - } - assert set_proxies("http://proxyhost", "8080") == { - "http": "http://proxyhost:8080", - "https": "http://proxyhost:8080", - } - assert set_proxies("http://proxyhost", "8080", "testuser", "testpass") == { - "http": "http://testuser:testpass@proxyhost:8080", - "https": "http://testuser:testpass@proxyhost:8080", - } - assert set_proxies("proxyhost", "8080", "testuser", "testpass") == { - "http": "http://testuser:testpass@proxyhost:8080", - "https": "http://testuser:testpass@proxyhost:8080", - } + assert get_proxy_url("host", "port", "user", "password") == ( + "http://user:password@host:port" + ) + assert get_proxy_url("host", "port") == "http://host:port" - # NOTE environment variable is set if the proxy parameter is specified. - del os.environ["HTTP_PROXY"] - del os.environ["HTTPS_PROXY"] + assert get_proxy_url("http://host", "port") == "http://host:port" + assert get_proxy_url("https://host", "port", "user", "password") == ( + "http://user:password@host:port" + ) @pytest.mark.skipolddriver @@ -91,3 +82,81 @@ def mock_proxy_manager_for_url_wiht_header(*args, **kwargs): assert "Unable to set 'Host' to proxy manager of type" not in caplog.text del os.environ["HTTPS_PROXY"] + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize("proxy_method", ["explicit_args", "env_vars"]) +def test_basic_query_through_proxy( + wiremock_generic_mappings_dir, + wiremock_target_proxy_pair, + wiremock_mapping_dir, + proxy_env_vars, + proxy_method, +): + target_wm, proxy_wm = wiremock_target_proxy_pair + + password_mapping = wiremock_mapping_dir / "auth/password/successful_flow.json" + select_mapping = wiremock_mapping_dir / "queries/select_1_successful.json" + disconnect_mapping = ( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + telemetry_mapping = wiremock_generic_mappings_dir / "telemetry.json" + + # Use expected headers to ensure requests go through proxy + expected_headers = {"Via": {"contains": "wiremock"}} + + target_wm.import_mapping_with_default_placeholders( + password_mapping, expected_headers + ) + target_wm.add_mapping_with_default_placeholders(select_mapping, expected_headers) + target_wm.add_mapping(disconnect_mapping) + target_wm.add_mapping(telemetry_mapping) + + # Configure proxy based on test parameter + set_proxy_env_vars, clear_proxy_env_vars = proxy_env_vars + connect_kwargs = { + "user": "testUser", + "password": "testPassword", + "account": "testAccount", + "host": target_wm.wiremock_host, + "port": target_wm.wiremock_http_port, + "protocol": "http", + "warehouse": "TEST_WH", + } + + if proxy_method == "explicit_args": + connect_kwargs.update( + { + "proxy_host": proxy_wm.wiremock_host, + "proxy_port": str(proxy_wm.wiremock_http_port), + } + ) + clear_proxy_env_vars() # Ensure no env vars interfere + else: # env_vars + proxy_url = f"http://{proxy_wm.wiremock_host}:{proxy_wm.wiremock_http_port}" + set_proxy_env_vars(proxy_url) + + # Make connection via proxy + cnx = snowflake.connector.connect(**connect_kwargs) + cur = cnx.cursor() + cur.execute("SELECT 1") + result = cur.fetchone() + assert result[0] == 1 + cur.close() + cnx.close() + + # Ensure proxy saw query + proxy_reqs = requests.get(f"{proxy_wm.http_host_with_port}/__admin/requests").json() + assert any( + "/queries/v1/query-request" in r["request"]["url"] + for r in proxy_reqs["requests"] + ) + + # Ensure backend saw query + target_reqs = requests.get( + f"{target_wm.http_host_with_port}/__admin/requests" + ).json() + assert any( + "/queries/v1/query-request" in r["request"]["url"] + for r in target_reqs["requests"] + ) diff --git a/test/unit/test_wiremock_client.py b/test/unit/test_wiremock_client.py index b471f39df7..19625c42c0 100644 --- a/test/unit/test_wiremock_client.py +++ b/test/unit/test_wiremock_client.py @@ -1,7 +1,3 @@ -from typing import Any, Generator - -import pytest - # old driver support try: from snowflake.connector.vendored import requests @@ -9,16 +5,6 @@ import requests -from ..wiremock.wiremock_utils import WiremockClient - - -@pytest.mark.skipolddriver -@pytest.fixture(scope="session") -def wiremock_client() -> Generator[WiremockClient, Any, None]: - with WiremockClient() as client: - yield client - - def test_wiremock(wiremock_client): connection_reset_by_peer_mapping = { "mappings": [ diff --git a/test/wiremock/wiremock_utils.py b/test/wiremock/wiremock_utils.py deleted file mode 100644 index 1d036a8023..0000000000 --- a/test/wiremock/wiremock_utils.py +++ /dev/null @@ -1,186 +0,0 @@ -import json -import logging -import pathlib -import socket -import subprocess -from time import sleep -from typing import List, Optional, Union - -try: - from snowflake.connector.vendored import requests -except ImportError: - import requests - -WIREMOCK_START_MAX_RETRY_COUNT = 12 -logger = logging.getLogger(__name__) - - -def _get_mapping_str(mapping: Union[str, dict, pathlib.Path]) -> str: - if isinstance(mapping, str): - return mapping - if isinstance(mapping, dict): - return json.dumps(mapping) - if isinstance(mapping, pathlib.Path): - if mapping.is_file(): - with open(mapping) as f: - return f.read() - else: - raise RuntimeError(f"File with mapping: {mapping} does not exist") - - raise RuntimeError(f"Mapping {mapping} is of an invalid type") - - -class WiremockClient: - def __init__(self, forbidden_ports: Optional[List[int]] = None) -> None: - self.wiremock_filename = "wiremock-standalone.jar" - self.wiremock_host = "localhost" - self.wiremock_http_port = None - self.wiremock_https_port = None - self.forbidden_ports = forbidden_ports if forbidden_ports is not None else [] - - self.wiremock_dir = pathlib.Path(__file__).parent.parent.parent / ".wiremock" - assert self.wiremock_dir.exists(), f"{self.wiremock_dir} does not exist" - - self.wiremock_jar_path = self.wiremock_dir / self.wiremock_filename - assert ( - self.wiremock_jar_path.exists() - ), f"{self.wiremock_jar_path} does not exist" - - def _start_wiremock(self): - self.wiremock_http_port = self._find_free_port( - forbidden_ports=self.forbidden_ports, - ) - self.wiremock_https_port = self._find_free_port( - forbidden_ports=self.forbidden_ports + [self.wiremock_http_port] - ) - self.wiremock_process = subprocess.Popen( - [ - "java", - "-jar", - self.wiremock_jar_path, - "--root-dir", - self.wiremock_dir, - "--enable-browser-proxying", # work as forward proxy - "--proxy-pass-through", - "false", # pass through only matched requests - "--port", - str(self.wiremock_http_port), - "--https-port", - str(self.wiremock_https_port), - "--https-keystore", - self.wiremock_dir / "ca-cert.jks", - "--ca-keystore", - self.wiremock_dir / "ca-cert.jks", - ] - ) - self._wait_for_wiremock() - - def _stop_wiremock(self): - response = self._wiremock_post( - f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/shutdown" - ) - if response.status_code != 200: - logger.info("Wiremock shutdown failed, the process will be killed") - self.wiremock_process.kill() - else: - logger.debug("Wiremock shutdown gracefully") - - def _wait_for_wiremock(self): - retry_count = 0 - while retry_count < WIREMOCK_START_MAX_RETRY_COUNT: - if self._health_check(): - return - retry_count += 1 - sleep(1) - - raise TimeoutError( - f"WiremockClient did not respond within {WIREMOCK_START_MAX_RETRY_COUNT} seconds" - ) - - def _health_check(self): - mappings_endpoint = ( - f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/health" - ) - try: - response = requests.get(mappings_endpoint) - except requests.exceptions.RequestException as e: - logger.warning(f"Wiremock healthcheck failed with exception: {e}") - return False - - if ( - response.status_code == requests.codes.ok - and response.json()["status"] != "healthy" - ): - logger.warning(f"Wiremock healthcheck failed with response: {response}") - return False - elif response.status_code != requests.codes.ok: - logger.warning( - f"Wiremock healthcheck failed with status code: {response.status_code}" - ) - return False - - return True - - def _reset_wiremock(self): - clean_journal_endpoint = ( - f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/requests" - ) - requests.delete(clean_journal_endpoint) - reset_endpoint = ( - f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/reset" - ) - response = self._wiremock_post(reset_endpoint) - if response.status_code != requests.codes.ok: - raise RuntimeError("Failed to reset WiremockClient") - - def _wiremock_post( - self, endpoint: str, body: Optional[str] = None - ) -> requests.Response: - headers = {"Accept": "application/json", "Content-Type": "application/json"} - return requests.post(endpoint, data=body, headers=headers) - - def import_mapping(self, mapping: Union[str, dict, pathlib.Path]): - self._reset_wiremock() - import_mapping_endpoint = f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/mappings/import" - mapping_str = _get_mapping_str(mapping) - response = self._wiremock_post(import_mapping_endpoint, mapping_str) - if response.status_code != requests.codes.ok: - raise RuntimeError("Failed to import mapping") - - def add_mapping(self, mapping: Union[str, dict, pathlib.Path]): - add_mapping_endpoint = ( - f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/mappings" - ) - mapping_str = _get_mapping_str(mapping) - response = self._wiremock_post(add_mapping_endpoint, mapping_str) - if response.status_code != requests.codes.created: - raise RuntimeError("Failed to add mapping") - - def _find_free_port(self, forbidden_ports: Union[List[int], None] = None) -> int: - max_retries = 1 if forbidden_ports is None else 3 - if forbidden_ports is None: - forbidden_ports = [] - - retry_count = 0 - while retry_count < max_retries: - retry_count += 1 - with socket.socket() as sock: - sock.bind((self.wiremock_host, 0)) - port = sock.getsockname()[1] - if port not in forbidden_ports: - return port - - raise RuntimeError( - f"Unable to find a free port for wiremock in {max_retries} attempts" - ) - - def __enter__(self): - self._start_wiremock() - logger.debug( - f"Starting wiremock process, listening on {self.wiremock_host}:{self.wiremock_http_port}" - ) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - logger.debug("Stopping wiremock process") - self._stop_wiremock()