diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 3be367c8ca..60afc6a0ea 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -28,6 +28,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne - Fix case-sensitivity of `Oauth` and `programmatic_access_token` authenticator values. - Relaxed `pyarrow` version constraint, versions >= 19 can now be used. - Populate type_code in ResultMetadata for interval types. + - Proxy setup with connection parameters added. - v3.16.0(July 04,2025) - Bumped numpy dependency from <2.1.0 to <=2.2.4. diff --git a/src/snowflake/connector/auth/_oauth_base.py b/src/snowflake/connector/auth/_oauth_base.py index 85deaf7f13..2ff1241638 100644 --- a/src/snowflake/connector/auth/_oauth_base.py +++ b/src/snowflake/connector/auth/_oauth_base.py @@ -20,9 +20,12 @@ ) from ..errors import Error, ProgrammingError from ..network import OAUTH_AUTHENTICATOR +from ..proxy import get_proxy_url from ..secret_detector import SecretDetector from ..token_cache import TokenCache, TokenKey, TokenType from ..vendored import urllib3 +from ..vendored.requests.utils import get_environ_proxies, select_proxy +from ..vendored.urllib3.poolmanager import ProxyManager from .by_plugin import AuthByPlugin, AuthType if TYPE_CHECKING: @@ -319,7 +322,13 @@ def _get_refresh_token_response( fields["scope"] = self._scope try: # TODO(SNOW-2229411) Session manager should be used here. It may require additional security validation (since we would transition from PoolManager to requests.Session) and some parameters would be passed implicitly. OAuth token exchange must NOT reuse pooled HTTP sessions. We should create a fresh SessionManager with use_pooling=False for each call. - return urllib3.PoolManager().request_encode_body( + proxy_url = self._resolve_proxy_url(conn, self._token_request_url) + http_client = ( + ProxyManager(proxy_url=proxy_url) + if proxy_url + else urllib3.PoolManager() + ) + return http_client.request_encode_body( "POST", self._token_request_url, encode_multipart=False, @@ -359,7 +368,11 @@ def _get_request_token_response( fields: dict[str, str], ) -> (str | None, str | None): # TODO(SNOW-2229411) Session manager should be used here. It may require additional security validation (since we would transition from PoolManager to requests.Session) and some parameters would be passed implicitly. Token request must bypass HTTP connection pools. - resp = urllib3.PoolManager().request_encode_body( + proxy_url = self._resolve_proxy_url(connection, self._token_request_url) + http_client = ( + ProxyManager(proxy_url=proxy_url) if proxy_url else urllib3.PoolManager() + ) + resp = http_client.request_encode_body( "POST", self._token_request_url, headers=self._create_token_request_headers(), @@ -400,3 +413,25 @@ def _create_token_request_headers(self) -> dict[str, str]: "Accept": "application/json", "Content-Type": "application/x-www-form-urlencoded; charset=UTF-8", } + + @staticmethod + def _resolve_proxy_url( + connection: SnowflakeConnection, request_url: str + ) -> str | None: + # TODO(SNOW-2229411) Session manager should be used instead. It may require additional security validation. + """Resolve proxy URL from explicit config first, then environment variables.""" + # First try explicit proxy configuration from connection parameters + proxy_url = get_proxy_url( + connection.proxy_host, + connection.proxy_port, + connection.proxy_user, + connection.proxy_password, + ) + + if proxy_url: + return proxy_url + + # Fall back to environment variables (HTTP_PROXY, HTTPS_PROXY) + # Use proper proxy selection that considers the URL scheme + proxies = get_environ_proxies(request_url) + return select_proxy(request_url, proxies) diff --git a/src/snowflake/connector/auth/oauth_code.py b/src/snowflake/connector/auth/oauth_code.py index 1c0c41eb6d..a5aaf31fb9 100644 --- a/src/snowflake/connector/auth/oauth_code.py +++ b/src/snowflake/connector/auth/oauth_code.py @@ -269,6 +269,7 @@ def _do_authorization_request( "login. If you can't see it, check existing browser windows, " "or your OS settings. Press CTRL+C to abort and try again..." ) + # TODO(SNOW-2229411) Investigate if Session manager / Http Config should be used here. code, state = ( self._receive_authorization_callback(callback_server, connection) if webbrowser.open(authorization_request) diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index 000907d5c4..38f4e5301d 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -27,7 +27,7 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey -from . import errors, proxy +from . import errors from ._query_context_cache import QueryContextCache from ._utils import ( _DEFAULT_VALUE_SERVER_DOP_CAP_FOR_FILE_TRANSFER, @@ -924,6 +924,10 @@ def connect(self, **kwargs) -> None: self._http_config = HttpConfig( adapter_factory=ProxySupportAdapterFactory(), use_pooling=(not self.disable_request_pooling), + proxy_host=self.proxy_host, + proxy_port=self.proxy_port, + proxy_user=self.proxy_user, + proxy_password=self.proxy_password, ) self._session_manager = SessionManager(self._http_config) @@ -1125,10 +1129,6 @@ def __open_connection(self): use_numpy=self._numpy, support_negative_year=self._support_negative_year ) - proxy.set_proxies( - self.proxy_host, self.proxy_port, self.proxy_user, self.proxy_password - ) - self._rest = SnowflakeRestful( host=self.host, port=self.port, diff --git a/src/snowflake/connector/proxy.py b/src/snowflake/connector/proxy.py index 6b54e29ee5..996fd563ba 100644 --- a/src/snowflake/connector/proxy.py +++ b/src/snowflake/connector/proxy.py @@ -1,43 +1,28 @@ #!/usr/bin/env python from __future__ import annotations -import os - -def set_proxies( +def get_proxy_url( proxy_host: str | None, proxy_port: str | None, proxy_user: str | None = None, proxy_password: str | None = None, -) -> dict[str, str] | None: - """Sets proxy dict for requests.""" - PREFIX_HTTP = "http://" - PREFIX_HTTPS = "https://" - proxies = None +) -> str | None: + http_prefix = "http://" + https_prefix = "https://" + if proxy_host and proxy_port: - if proxy_host.startswith(PREFIX_HTTP): - proxy_host = proxy_host[len(PREFIX_HTTP) :] - elif proxy_host.startswith(PREFIX_HTTPS): - proxy_host = proxy_host[len(PREFIX_HTTPS) :] - if proxy_user or proxy_password: - proxy_auth = "{proxy_user}:{proxy_password}@".format( - proxy_user=proxy_user if proxy_user is not None else "", - proxy_password=proxy_password if proxy_password is not None else "", - ) + if proxy_host.startswith(http_prefix): + host = proxy_host[len(http_prefix) :] + elif proxy_host.startswith(https_prefix): + host = proxy_host[len(https_prefix) :] else: - proxy_auth = "" - proxies = { - "http": "http://{proxy_auth}{proxy_host}:{proxy_port}".format( - proxy_host=proxy_host, - proxy_port=str(proxy_port), - proxy_auth=proxy_auth, - ), - "https": "http://{proxy_auth}{proxy_host}:{proxy_port}".format( - proxy_host=proxy_host, - proxy_port=str(proxy_port), - proxy_auth=proxy_auth, - ), - } - os.environ["HTTP_PROXY"] = proxies["http"] - os.environ["HTTPS_PROXY"] = proxies["https"] - return proxies + host = proxy_host + auth = ( + f"{proxy_user or ''}:{proxy_password or ''}@" + if proxy_user or proxy_password + else "" + ) + return f"{http_prefix}{auth}{host}:{proxy_port}" + + return None diff --git a/src/snowflake/connector/result_batch.py b/src/snowflake/connector/result_batch.py index 2ff29128ca..d6e7c4a128 100644 --- a/src/snowflake/connector/result_batch.py +++ b/src/snowflake/connector/result_batch.py @@ -26,7 +26,7 @@ from .options import installed_pandas from .options import pyarrow as pa from .secret_detector import SecretDetector -from .session_manager import SessionManager +from .session_manager import HttpConfig, SessionManager from .time_util import TimerContextManager logger = getLogger(__name__) @@ -261,6 +261,8 @@ def __init__( [s._to_result_metadata_v1() for s in schema] if schema is not None else None ) self._use_dict_result = use_dict_result + # Passed to contain the configured Http behavior in case the connectio is no longer active for the download + # Can be overridden with setters if needed. self._session_manager = session_manager self._metrics: dict[str, int] = {} self._data: str | list[tuple[Any, ...]] | None = None @@ -300,6 +302,25 @@ def uncompressed_size(self) -> int | None: def column_names(self) -> list[str]: return [col.name for col in self._schema] + @property + def session_manager(self) -> SessionManager | None: + return self._session_manager + + @session_manager.setter + def session_manager(self, session_manager: SessionManager | None) -> None: + self._session_manager = session_manager + + @property + def http_config(self): + return self._session_manager.config + + @http_config.setter + def http_config(self, config: HttpConfig) -> None: + if self._session_manager: + self._session_manager.config = config + else: + self._session_manager = SessionManager(config=config) + def __iter__( self, ) -> Iterator[dict | Exception] | Iterator[tuple | Exception]: diff --git a/src/snowflake/connector/session_manager.py b/src/snowflake/connector/session_manager.py index 770f0167f1..5a6faac88a 100644 --- a/src/snowflake/connector/session_manager.py +++ b/src/snowflake/connector/session_manager.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, Callable, Generator, Mapping from .compat import urlparse +from .proxy import get_proxy_url from .vendored import requests from .vendored.requests import Response, Session from .vendored.requests.adapters import BaseAdapter, HTTPAdapter @@ -76,8 +77,15 @@ def get_connection( proxy_manager = self.proxy_manager_for(proxy) if isinstance(proxy_manager, ProxyManager): - # Add Host to proxy header SNOW-232777 - proxy_manager.proxy_headers["Host"] = parsed_url.hostname + # Add Host to proxy header SNOW-232777 and SNOW-694457 + + # RFC 7230 / 5.4 – a proxy’s Host header must repeat the request authority + # verbatim: [:] with IPv6 still in [brackets]. We take that + # straight from urlparse(url).netloc, which preserves port and brackets (and case-sensitive hostname). + # Note: netloc also keeps user-info (user:pass@host) if present in URL. The driver never sends + # URLs with embedded credentials, so we leave them unhandled — for full support + # we’d need to manually concatenate hostname with optional port and IPv6 brackets. + proxy_manager.proxy_headers["Host"] = parsed_url.netloc else: logger.debug( f"Unable to set 'Host' to proxy manager of type {type(proxy_manager)} as" @@ -112,6 +120,10 @@ class HttpConfig: ) use_pooling: bool = True max_retries: int | None = REQUESTS_RETRY + proxy_host: str | None = None + proxy_port: str | None = None + proxy_user: str | None = None + proxy_password: str | None = None def copy_with(self, **overrides: Any) -> HttpConfig: """Return a new HttpConfig with overrides applied.""" @@ -293,13 +305,13 @@ class SessionManager(_RequestVerbsUsingSessionMixin): **Two Operating Modes**: - use_pooling=False: One-shot sessions (create, use, close) - suitable for infrequent requests - use_pooling=True: Per-hostname session pools - reuses TCP connections, avoiding handshake - and SSL/TLS negotiation overhead for repeated requests to the same host + and SSL/TLS negotiation overhead for repeated requests to the same host. **Key Benefits**: - Centralized HTTP configuration management and easy propagation across the codebase - Consistent proxy setup (SNOW-694457) and headers customization (SNOW-2043816) - HTTPAdapter customization for connection-level request manipulation - - Performance optimization through connection reuse for high-traffic scenarios + - Performance optimization through connection reuse for high-traffic scenarios. **Usage**: Create the base session manager, then use clone() for derived managers to ensure proper config propagation. Pre-commit checks enforce usage to prevent code drift back to @@ -315,7 +327,6 @@ def __init__(self, config: HttpConfig | None = None, **http_config_kwargs) -> No logger.debug("Creating a config for the SessionManager") config = HttpConfig(**http_config_kwargs) self._cfg: HttpConfig = config - self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( lambda: SessionPool(self) ) @@ -338,6 +349,19 @@ def from_config(cls, cfg: HttpConfig, **overrides: Any) -> SessionManager: def config(self) -> HttpConfig: return self._cfg + @config.setter + def config(self, cfg: HttpConfig) -> None: + self._cfg = cfg + + @property + def proxy_url(self) -> str: + return get_proxy_url( + self._cfg.proxy_host, + self._cfg.proxy_port, + self._cfg.proxy_user, + self._cfg.proxy_password, + ) + @property def use_pooling(self) -> bool: return self._cfg.use_pooling @@ -395,6 +419,7 @@ def _mount_adapters(self, session: requests.Session) -> None: def make_session(self) -> Session: session = requests.Session() self._mount_adapters(session) + session.proxies = {"http": self.proxy_url, "https": self.proxy_url} return session @contextlib.contextmanager diff --git a/test/conftest.py b/test/conftest.py index e8a8081b20..50b7f287c3 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -5,6 +5,8 @@ from contextlib import contextmanager from logging import getLogger from pathlib import Path +from test.test_utils.cross_module_fixtures.http_fixtures import * # NOQA +from test.test_utils.cross_module_fixtures.wiremock_fixtures import * # NOQA from typing import Generator import pytest diff --git a/test/data/wiremock/mappings/auth/password/successful_flow.json b/test/data/wiremock/mappings/auth/password/successful_flow.json new file mode 100644 index 0000000000..58045d6fe3 --- /dev/null +++ b/test/data/wiremock/mappings/auth/password/successful_flow.json @@ -0,0 +1,60 @@ +{ + "mappings": [ + { + "request": { + "urlPathPattern": "/session/v1/login-request.*", + "method": "POST", + "bodyPatterns": [ + { + "equalToJson" : { + "data": { + "LOGIN_NAME": "testUser", + "PASSWORD": "testPassword" + } + }, + "ignoreExtraElements" : true + } + ] + }, + "response": { + "status": 200, + "jsonBody": { + "data": { + "masterToken": "master token", + "token": "session token", + "validityInSeconds": 3600, + "masterValidityInSeconds": 14400, + "displayUserName": "TEST_USER", + "serverVersion": "8.48.0 b2024121104444034239f05", + "firstLogin": false, + "remMeToken": null, + "remMeValidityInSeconds": 0, + "healthCheckInterval": 45, + "newClientForUpgrade": "3.12.3", + "sessionId": 1172562260498, + "parameters": [ + { + "name": "CLIENT_PREFETCH_THREADS", + "value": 4 + } + ], + "sessionInfo": { + "databaseName": "TEST_DB", + "schemaName": "TEST_GO", + "warehouseName": "TEST_XSMALL", + "roleName": "ANALYST" + }, + "idToken": null, + "idTokenValidityInSeconds": 0, + "responseData": null, + "mfaToken": null, + "mfaTokenValidityInSeconds": 0 + }, + "code": null, + "message": null, + "success": true + } + } + } + ] +} diff --git a/test/data/wiremock/mappings/generic/proxy_forward_all.json b/test/data/wiremock/mappings/generic/proxy_forward_all.json new file mode 100644 index 0000000000..62ba091bf2 --- /dev/null +++ b/test/data/wiremock/mappings/generic/proxy_forward_all.json @@ -0,0 +1,12 @@ +{ + "request": { + "urlPattern": "/.*", + "method": "ANY" + }, + "response": { + "proxyBaseUrl": "{{TARGET_HTTP_HOST_WITH_PORT}}", + "additionalProxyRequestHeaders": { + "Via": "1.1 wiremock-proxy" + } + } +} diff --git a/test/data/wiremock/mappings/generic/telemetry.json b/test/data/wiremock/mappings/generic/telemetry.json new file mode 100644 index 0000000000..9b734a0cf2 --- /dev/null +++ b/test/data/wiremock/mappings/generic/telemetry.json @@ -0,0 +1,18 @@ +{ + "scenarioName": "Successful telemetry flow", + "request": { + "urlPathPattern": "/telemetry/send", + "method": "POST" + }, + "response": { + "status": 200, + "jsonBody": { + "data": { + "code": null, + "data": "Log Received", + "message": null, + "success": true + } + } + } +} diff --git a/test/data/wiremock/mappings/queries/chunk_1.json b/test/data/wiremock/mappings/queries/chunk_1.json new file mode 100644 index 0000000000..246874d3c4 --- /dev/null +++ b/test/data/wiremock/mappings/queries/chunk_1.json @@ -0,0 +1,14 @@ +{ + "request": { + "method": "GET", + "url": "/amazonaws/test/s3testaccount/stage/results/01bd1448-0100-0001-0000-0000001006f5_0/main/data_0_0_0_1?x-amz-server-side-encryption-customer-algorithm=AES256&response-content-encoding=gzip" + }, + "response": { + "status": 200, + "headers": { + "Content-Encoding": "gzip", + "x-amz-server-side-encryption-customer-algorithm": "AES256" + }, + "base64Body": "H4sIAM7YUGgC/4s21FHAiWIB81FB/x4AAAA=" + } +} diff --git a/test/data/wiremock/mappings/queries/chunk_2.json b/test/data/wiremock/mappings/queries/chunk_2.json new file mode 100644 index 0000000000..60f2756d0e --- /dev/null +++ b/test/data/wiremock/mappings/queries/chunk_2.json @@ -0,0 +1,14 @@ +{ + "request": { + "method": "GET", + "url": "/amazonaws/test/s3testaccount/stage/results/01bd1448-0100-0001-0000-0000001006f5_0/main/data_0_0_0_2?x-amz-server-side-encryption-customer-algorithm=AES256&response-content-encoding=gzip" + }, + "response": { + "status": 200, + "headers": { + "Content-Encoding": "gzip", + "x-amz-server-side-encryption-customer-algorithm": "AES256" + }, + "base64Body": "H4sIAM7YUGgC/4s21FHAiWIB81FB/x4AAAA=" + } +} diff --git a/test/data/wiremock/mappings/queries/select_1_successful.json b/test/data/wiremock/mappings/queries/select_1_successful.json new file mode 100644 index 0000000000..99fdcb7103 --- /dev/null +++ b/test/data/wiremock/mappings/queries/select_1_successful.json @@ -0,0 +1,199 @@ +{ + "scenarioName": "Successful SELECT 1 flow", + "request": { + "urlPathPattern": "/queries/v1/query-request.*", + "method": "POST", + "headers": { + "Authorization": { + "equalTo": "Snowflake Token=\"session token\"" + } + } + }, + "response": { + "status": 200, + "jsonBody": { + "data": { + "parameters": [ + { + "name": "TIMESTAMP_OUTPUT_FORMAT", + "value": "YYYY-MM-DD HH24:MI:SS.FF3 TZHTZM" + }, + { + "name": "CLIENT_PREFETCH_THREADS", + "value": 4 + }, + { + "name": "TIME_OUTPUT_FORMAT", + "value": "HH24:MI:SS" + }, + { + "name": "CLIENT_RESULT_CHUNK_SIZE", + "value": 16 + }, + { + "name": "TIMESTAMP_TZ_OUTPUT_FORMAT", + "value": "" + }, + { + "name": "CLIENT_SESSION_KEEP_ALIVE", + "value": false + }, + { + "name": "QUERY_CONTEXT_CACHE_SIZE", + "value": 5 + }, + { + "name": "CLIENT_METADATA_USE_SESSION_DATABASE", + "value": false + }, + { + "name": "CLIENT_OUT_OF_BAND_TELEMETRY_ENABLED", + "value": false + }, + { + "name": "ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1", + "value": true + }, + { + "name": "TIMESTAMP_NTZ_OUTPUT_FORMAT", + "value": "YYYY-MM-DD HH24:MI:SS.FF3" + }, + { + "name": "CLIENT_RESULT_PREFETCH_THREADS", + "value": 1 + }, + { + "name": "CLIENT_METADATA_REQUEST_USE_CONNECTION_CTX", + "value": false + }, + { + "name": "CLIENT_HONOR_CLIENT_TZ_FOR_TIMESTAMP_NTZ", + "value": true + }, + { + "name": "CLIENT_MEMORY_LIMIT", + "value": 1536 + }, + { + "name": "CLIENT_TIMESTAMP_TYPE_MAPPING", + "value": "TIMESTAMP_LTZ" + }, + { + "name": "TIMEZONE", + "value": "America/Los_Angeles" + }, + { + "name": "SERVICE_NAME", + "value": "" + }, + { + "name": "CLIENT_RESULT_PREFETCH_SLOTS", + "value": 2 + }, + { + "name": "CLIENT_TELEMETRY_ENABLED", + "value": true + }, + { + "name": "CLIENT_DISABLE_INCIDENTS", + "value": true + }, + { + "name": "CLIENT_USE_V1_QUERY_API", + "value": true + }, + { + "name": "CLIENT_RESULT_COLUMN_CASE_INSENSITIVE", + "value": false + }, + { + "name": "CSV_TIMESTAMP_FORMAT", + "value": "" + }, + { + "name": "BINARY_OUTPUT_FORMAT", + "value": "HEX" + }, + { + "name": "CLIENT_ENABLE_LOG_INFO_STATEMENT_PARAMETERS", + "value": false + }, + { + "name": "CLIENT_TELEMETRY_SESSIONLESS_ENABLED", + "value": true + }, + { + "name": "DATE_OUTPUT_FORMAT", + "value": "YYYY-MM-DD" + }, + { + "name": "CLIENT_STAGE_ARRAY_BINDING_THRESHOLD", + "value": 65280 + }, + { + "name": "CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY", + "value": 3600 + }, + { + "name": "CLIENT_SESSION_CLONE", + "value": false + }, + { + "name": "AUTOCOMMIT", + "value": true + }, + { + "name": "TIMESTAMP_LTZ_OUTPUT_FORMAT", + "value": "" + } + ], + "rowtype": [ + { + "name": "1", + "database": "", + "schema": "", + "table": "", + "nullable": false, + "length": null, + "type": "fixed", + "scale": 0, + "precision": 1, + "byteLength": null, + "collation": null + } + ], + "rowset": [ + [ + "1" + ] + ], + "total": 1, + "returned": 1, + "queryId": "01ba13b4-0104-e9fd-0000-0111029ca00e", + "databaseProvider": null, + "finalDatabaseName": null, + "finalSchemaName": null, + "finalWarehouseName": "TEST_XSMALL", + "numberOfBinds": 0, + "arrayBindSupported": false, + "statementTypeId": 4096, + "version": 1, + "sendResultTime": 1738317395581, + "queryResultFormat": "json", + "queryContext": { + "entries": [ + { + "id": 0, + "timestamp": 1738317395574564, + "priority": 0, + "context": "CPbPTg==" + } + ] + } + }, + "code": null, + "message": null, + "success": true + } + } +} diff --git a/test/data/wiremock/mappings/queries/select_large_request_successful.json b/test/data/wiremock/mappings/queries/select_large_request_successful.json new file mode 100644 index 0000000000..61ee3135a6 --- /dev/null +++ b/test/data/wiremock/mappings/queries/select_large_request_successful.json @@ -0,0 +1,413 @@ +{ + "scenarioName": "Successful SELECT 1 flow", + "request": { + "urlPathPattern": "/queries/v1/query-request.*", + "method": "POST", + "headers": { + "Authorization": { + "equalTo": "Snowflake Token=\"session token\"" + } + } + }, + "response": { + "status": 200, + "jsonBody": { + "data": { + "parameters": [ + { + "name": "CLIENT_PREFETCH_THREADS", + "value": 4 + }, + { + "name": "TIMESTAMP_OUTPUT_FORMAT", + "value": "DY, DD MON YYYY HH24:MI:SS TZHTZM" + }, + { + "name": "PYTHON_SNOWPARK_CLIENT_MIN_VERSION_FOR_AST", + "value": "1.29.0" + }, + { + "name": "TIME_OUTPUT_FORMAT", + "value": "HH24:MI:SS" + }, + { + "name": "CLIENT_RESULT_CHUNK_SIZE", + "value": 160 + }, + { + "name": "TIMESTAMP_TZ_OUTPUT_FORMAT", + "value": "" + }, + { + "name": "CLIENT_SESSION_KEEP_ALIVE", + "value": false + }, + { + "name": "PYTHON_SNOWPARK_USE_CTE_OPTIMIZATION_VERSION", + "value": "1.31.1" + }, + { + "name": "CLIENT_METADATA_USE_SESSION_DATABASE", + "value": false + }, + { + "name": "QUERY_CONTEXT_CACHE_SIZE", + "value": 5 + }, + { + "name": "PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION_VERSION", + "value": "" + }, + { + "name": "ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1", + "value": false + }, + { + "name": "TIMESTAMP_NTZ_OUTPUT_FORMAT", + "value": "" + }, + { + "name": "CLIENT_RESULT_PREFETCH_THREADS", + "value": 1 + }, + { + "name": "CLIENT_METADATA_REQUEST_USE_CONNECTION_CTX", + "value": false + }, + { + "name": "CLIENT_HONOR_CLIENT_TZ_FOR_TIMESTAMP_NTZ", + "value": true + }, + { + "name": "CLIENT_MEMORY_LIMIT", + "value": 1536 + }, + { + "name": "CLIENT_TIMESTAMP_TYPE_MAPPING", + "value": "TIMESTAMP_LTZ" + }, + { + "name": "TIMEZONE", + "value": "UTC" + }, + { + "name": "PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER", + "value": true + }, + { + "name": "SNOWPARK_REQUEST_TIMEOUT_IN_SECONDS", + "value": 86400 + }, + { + "name": "PYTHON_SNOWPARK_USE_AST", + "value": false + }, + { + "name": "SERVICE_NAME", + "value": "" + }, + { + "name": "PYTHON_CONNECTOR_USE_NANOARROW", + "value": true + }, + { + "name": "CLIENT_RESULT_PREFETCH_SLOTS", + "value": 2 + }, + { + "name": "PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND", + "value": 10000000 + }, + { + "name": "PYTHON_SNOWPARK_GENERATE_MULTILINE_QUERIES", + "value": true + }, + { + "name": "CLIENT_DISABLE_INCIDENTS", + "value": true + }, + { + "name": "CSV_TIMESTAMP_FORMAT", + "value": "" + }, + { + "name": "BINARY_OUTPUT_FORMAT", + "value": "HEX" + }, + { + "name": "CLIENT_TELEMETRY_SESSIONLESS_ENABLED", + "value": true + }, + { + "name": "DATE_OUTPUT_FORMAT", + "value": "YYYY-MM-DD" + }, + { + "name": "CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY", + "value": 3600 + }, + { + "name": "PYTHON_SNOWPARK_AUTO_CLEAN_UP_TEMP_TABLE_ENABLED", + "value": false + }, + { + "name": "AUTOCOMMIT", + "value": true + }, + { + "name": "PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED", + "value": false + }, + { + "name": "CLIENT_SESSION_CLONE", + "value": false + }, + { + "name": "TIMESTAMP_LTZ_OUTPUT_FORMAT", + "value": "" + }, + { + "name": "CLIENT_OUT_OF_BAND_TELEMETRY_ENABLED", + "value": false + }, + { + "name": "PYTHON_SNOWPARK_DATAFRAME_JOIN_ALIAS_FIX_VERSION", + "value": "" + }, + { + "name": "PYTHON_SNOWPARK_COLLECT_TELEMETRY_AT_CRITICAL_PATH_VERSION", + "value": "1.28.0" + }, + { + "name": "PYTHON_SNOWPARK_AUTO_CLEAN_UP_TEMP_TABLE_ENABLED_VERSION", + "value": "" + }, + { + "name": "CLIENT_TELEMETRY_ENABLED", + "value": true + }, + { + "name": "PYTHON_SNOWPARK_ELIMINATE_NUMERIC_SQL_VALUE_CAST_ENABLED", + "value": false + }, + { + "name": "CLIENT_USE_V1_QUERY_API", + "value": true + }, + { + "name": "PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION", + "value": true + }, + { + "name": "CLIENT_RESULT_COLUMN_CASE_INSENSITIVE", + "value": false + }, + { + "name": "CLIENT_ENABLE_LOG_INFO_STATEMENT_PARAMETERS", + "value": false + }, + { + "name": "CLIENT_STAGE_ARRAY_BINDING_THRESHOLD", + "value": 65280 + }, + { + "name": "PYTHON_SNOWPARK_COMPILATION_STAGE_ENABLED", + "value": true + }, + { + "name": "PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_UPPER_BOUND", + "value": 12000000 + }, + { + "name": "PYTHON_SNOWPARK_CLIENT_AST_MODE", + "value": 0 + } + ], + "rowtype": [ + { + "name": "C0", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C1", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C2", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C3", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C4", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C5", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C6", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C7", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C8", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + }, + { + "name": "C9", + "database": "TESTDB", + "schema": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "table": "PYTHON_TESTS_DB65D4FD_B0D6_4B1E_A7EB_E95DC198CC2A", + "length": null, + "type": "fixed", + "scale": 0, + "precision": 38, + "nullable": true, + "byteLength": null, + "collation": null + } + ], + + "rowset": [ + [ + "1" + ] + ], + "qrmk": "+ZSmIj7I0L0BnU3zdVnSaHH5MW6cwY0GmLtz/Un5zSM=", + "chunkHeaders": { + "x-amz-server-side-encryption-customer-key": "+ZSmIj7I0L0BnU3zdVnSaHH5MW6cwY0GmLtz/Un5zSM=", + "x-amz-server-side-encryption-customer-key-md5": "ByrEgrMhjgAEMRr1QA/nGg==" + }, + "chunks": [ + { + "url": "{{WIREMOCK_HTTP_HOST_WITH_PORT}}/amazonaws/test/s3testaccount/stage/results/01bd1448-0100-0001-0000-0000001006f5_0/main/data_0_0_0_1?x-amz-server-side-encryption-customer-algorithm=AES256&response-content-encoding=gzip", + "rowCount": 4096, + "uncompressedSize": 331328, + "compressedSize": 326422 + }, + { + "url": "{{WIREMOCK_HTTP_HOST_WITH_PORT}}/amazonaws/test/s3testaccount/stage/results/01bd1448-0100-0001-0000-0000001006f5_0/main/data_0_0_0_2?x-amz-server-side-encryption-customer-algorithm=AES256&response-content-encoding=gzip", + "rowCount": 4096, + "uncompressedSize": 331328, + "compressedSize": 326176 + } + ], + "total": 50000, + "returned": 50000, + "queryId": "01bd137c-0100-0001-0000-0000001005b1", + "databaseProvider": null, + "finalDatabaseName": "TESTDB", + "finalSchemaName": "PYTHON_CONNECTOR_TESTS_8F2FF429_1381_4ED5_83F2_6C0D7B29C1AE", + "finalWarehouseName": "REGRESS", + "finalRoleName": "ACCOUNTADMIN", + "numberOfBinds": 0, + "arrayBindSupported": false, + "statementTypeId": 4096, + "version": 1, + "sendResultTime": 1750110502822, + "queryResultFormat": "json", + "queryContext": { + "entries": [ + { + "id": 0, + "timestamp": 1748552075465658, + "priority": 0, + "context": "CAQ=" + } + ] + } + }, + "code": null, + "message": null, + "success": true + } + } +} diff --git a/test/integ/test_connection.py b/test/integ/test_connection.py index 8e18e3577a..3a85a20b5c 100644 --- a/test/integ/test_connection.py +++ b/test/integ/test_connection.py @@ -389,6 +389,8 @@ def test_invalid_account_timeout(conn_cnx): @pytest.mark.timeout(15) def test_invalid_proxy(conn_cnx): + http_proxy = os.environ.get("HTTP_PROXY") + https_proxy = os.environ.get("HTTPS_PROXY") with pytest.raises(OperationalError): with conn_cnx( protocol="http", @@ -398,9 +400,41 @@ def test_invalid_proxy(conn_cnx): proxy_port="3333", ): pass - # NOTE environment variable is set if the proxy parameter is specified. - del os.environ["HTTP_PROXY"] - del os.environ["HTTPS_PROXY"] + # NOTE environment variable is set ONLY FOR THE OLD DRIVER if the proxy parameter is specified. + # So this deletion is needed for old driver tests only. + if http_proxy is not None: + os.environ["HTTP_PROXY"] = http_proxy + else: + try: + del os.environ["HTTP_PROXY"] + except KeyError: + pass + if https_proxy is not None: + os.environ["HTTPS_PROXY"] = https_proxy + else: + try: + del os.environ["HTTPS_PROXY"] + except KeyError: + pass + + +@pytest.mark.skipolddriver +@pytest.mark.timeout(15) +def test_invalid_proxy_not_impacting_env_vars(conn_cnx): + http_proxy = os.environ.get("HTTP_PROXY") + https_proxy = os.environ.get("HTTPS_PROXY") + with pytest.raises(OperationalError): + with conn_cnx( + protocol="http", + account="testaccount", + login_timeout=5, + proxy_host="localhost", + proxy_port="3333", + ): + pass + # Proxy environment variables should not change + assert os.environ.get("HTTP_PROXY") == http_proxy + assert os.environ.get("HTTPS_PROXY") == https_proxy @pytest.mark.timeout(15) diff --git a/test/wiremock/__init__.py b/test/test_utils/__init__.py similarity index 100% rename from test/wiremock/__init__.py rename to test/test_utils/__init__.py diff --git a/test/test_utils/cross_module_fixtures/__init__.py b/test/test_utils/cross_module_fixtures/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/test_utils/cross_module_fixtures/http_fixtures.py b/test/test_utils/cross_module_fixtures/http_fixtures.py new file mode 100644 index 0000000000..a34d349be9 --- /dev/null +++ b/test/test_utils/cross_module_fixtures/http_fixtures.py @@ -0,0 +1,36 @@ +import os + +import pytest + + +@pytest.fixture +def proxy_env_vars(): + """Manages HTTP_PROXY and HTTPS_PROXY environment variables for testing.""" + original_http_proxy = os.environ.get("HTTP_PROXY") + original_https_proxy = os.environ.get("HTTPS_PROXY") + + def set_proxy_env_vars(proxy_url: str): + """Set both HTTP_PROXY and HTTPS_PROXY to the given URL.""" + os.environ["HTTP_PROXY"] = proxy_url + os.environ["HTTPS_PROXY"] = proxy_url + + def clear_proxy_env_vars(): + """Clear proxy environment variables.""" + if "HTTP_PROXY" in os.environ: + del os.environ["HTTP_PROXY"] + if "HTTPS_PROXY" in os.environ: + del os.environ["HTTPS_PROXY"] + + # Yield the helper functions + yield set_proxy_env_vars, clear_proxy_env_vars + + # Cleanup: restore original values + if original_http_proxy is not None: + os.environ["HTTP_PROXY"] = original_http_proxy + elif "HTTP_PROXY" in os.environ: + del os.environ["HTTP_PROXY"] + + if original_https_proxy is not None: + os.environ["HTTPS_PROXY"] = original_https_proxy + elif "HTTPS_PROXY" in os.environ: + del os.environ["HTTPS_PROXY"] diff --git a/test/test_utils/cross_module_fixtures/wiremock_fixtures.py b/test/test_utils/cross_module_fixtures/wiremock_fixtures.py new file mode 100644 index 0000000000..ddf7c22d12 --- /dev/null +++ b/test/test_utils/cross_module_fixtures/wiremock_fixtures.py @@ -0,0 +1,83 @@ +import pathlib +import uuid +from contextlib import contextmanager +from functools import partial +from typing import Any, Callable, ContextManager, Generator, Union + +import pytest + +import snowflake.connector + +from ..wiremock.wiremock_utils import WiremockClient, get_clients_for_proxy_and_target + + +@pytest.fixture(scope="session") +def wiremock_mapping_dir() -> pathlib.Path: + return ( + pathlib.Path(__file__).parent.parent.parent / "data" / "wiremock" / "mappings" + ) + + +@pytest.fixture(scope="session") +def wiremock_generic_mappings_dir(wiremock_mapping_dir) -> pathlib.Path: + return wiremock_mapping_dir / "generic" + + +@pytest.fixture(scope="session") +def wiremock_client() -> Generator[Union[WiremockClient, Any], Any, None]: + with WiremockClient() as client: + yield client + + +@pytest.fixture +def default_db_wiremock_parameters(wiremock_client: WiremockClient) -> dict[str, Any]: + db_params = { + "account": "testAccount", + "user": "testUser", + "password": "testPassword", + "host": wiremock_client.wiremock_host, + "port": wiremock_client.wiremock_http_port, + "protocol": "http", + "name": "python_tests_" + str(uuid.uuid4()).replace("-", "_"), + } + return db_params + + +@contextmanager +def db_wiremock( + default_db_wiremock_parameters: dict[str, Any], + **kwargs, +) -> Generator[snowflake.connector.SnowflakeConnection, None, None]: + ret = default_db_wiremock_parameters + ret.update(kwargs) + cnx = snowflake.connector.connect(**ret) + try: + yield cnx + finally: + cnx.close() + + +@pytest.fixture +def conn_cnx_wiremock( + default_db_wiremock_parameters, +) -> Callable[..., ContextManager[snowflake.connector.SnowflakeConnection]]: + return partial( + db_wiremock, default_db_wiremock_parameters=default_db_wiremock_parameters + ) + + +@pytest.fixture +def wiremock_target_proxy_pair(wiremock_generic_mappings_dir): + """Starts a *target* Wiremock and a *proxy* Wiremock pre-configured to forward to it. + + The fixture yields a tuple ``(target_wm, proxy_wm)`` of ``WiremockClient`` + instances. It is a thin wrapper around + ``test.test_utils.wiremock.wiremock_utils.proxy_target_pair``. + """ + wiremock_proxy_mapping_path = ( + wiremock_generic_mappings_dir / "proxy_forward_all.json" + ) + with get_clients_for_proxy_and_target( + proxy_mapping_template=wiremock_proxy_mapping_path + ) as pair: + yield pair diff --git a/test/test_utils/wiremock/__init__.py b/test/test_utils/wiremock/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/test_utils/wiremock/wiremock_utils.py b/test/test_utils/wiremock/wiremock_utils.py new file mode 100644 index 0000000000..7b7d15da54 --- /dev/null +++ b/test/test_utils/wiremock/wiremock_utils.py @@ -0,0 +1,347 @@ +import json +import logging +import pathlib +import socket +import subprocess +from contextlib import contextmanager +from time import sleep +from typing import Iterable, List, Optional, Union + +try: + from snowflake.connector.vendored import requests +except ImportError: + import requests + +WIREMOCK_START_MAX_RETRY_COUNT = 12 +logger = logging.getLogger(__name__) + + +def _get_mapping_str(mapping: Union[str, dict, pathlib.Path]) -> str: + if isinstance(mapping, str): + return mapping + if isinstance(mapping, dict): + return json.dumps(mapping) + if isinstance(mapping, pathlib.Path): + if mapping.is_file(): + with open(mapping) as f: + return f.read() + else: + raise RuntimeError(f"File with mapping: {mapping} does not exist") + + raise RuntimeError(f"Mapping {mapping} is of an invalid type") + + +class WiremockClient: + HTTP_HOST_PLACEHOLDER: str = "{{WIREMOCK_HTTP_HOST_WITH_PORT}}" + + def __init__( + self, + forbidden_ports: Optional[List[int]] = None, + additional_wiremock_process_args: Optional[Iterable[str]] = None, + ) -> None: + self.wiremock_filename = "wiremock-standalone.jar" + self.wiremock_host = "localhost" + self.wiremock_http_port = None + self.wiremock_https_port = None + self.forbidden_ports = forbidden_ports if forbidden_ports is not None else [] + + self.wiremock_dir = ( + pathlib.Path(__file__).parent.parent.parent.parent / ".wiremock" + ) + assert self.wiremock_dir.exists(), f"{self.wiremock_dir} does not exist" + + self.wiremock_jar_path = self.wiremock_dir / self.wiremock_filename + assert ( + self.wiremock_jar_path.exists() + ), f"{self.wiremock_jar_path} does not exist" + self._additional_wiremock_process_args = ( + additional_wiremock_process_args or list() + ) + + @property + def http_host_with_port(self) -> str: + return f"http://{self.wiremock_host}:{self.wiremock_http_port}" + + def get_http_placeholders(self) -> dict[str, str]: + """Placeholder that substitutes the target Wiremock's host:port in JSON.""" + return {self.HTTP_HOST_PLACEHOLDER: self.http_host_with_port} + + def add_expected_headers_to_mapping( + self, + mapping_str: str, + expected_headers: dict, + ) -> str: + """Add expected headers to all request matchers in mapping string.""" + mapping_dict = json.loads(mapping_str) + + def add_headers_to_request(request_dict: dict) -> None: + if "headers" not in request_dict: + request_dict["headers"] = {} + request_dict["headers"].update(expected_headers) + + if "request" in mapping_dict: + add_headers_to_request(mapping_dict["request"]) + + if "mappings" in mapping_dict: + for single_mapping in mapping_dict["mappings"]: + if "request" in single_mapping: + add_headers_to_request(single_mapping["request"]) + + return json.dumps(mapping_dict) + + def get_default_placeholders(self) -> dict[str, str]: + return self.get_http_placeholders() + + def _start_wiremock(self): + self.wiremock_http_port = self._find_free_port( + forbidden_ports=self.forbidden_ports, + ) + self.wiremock_https_port = self._find_free_port( + forbidden_ports=self.forbidden_ports + [self.wiremock_http_port] + ) + self.wiremock_process = subprocess.Popen( + [ + "java", + "-jar", + self.wiremock_jar_path, + "--root-dir", + self.wiremock_dir, + "--enable-browser-proxying", # work as forward proxy + "--proxy-pass-through", + "false", # pass through only matched requests + "--port", + str(self.wiremock_http_port), + "--https-port", + str(self.wiremock_https_port), + "--https-keystore", + self.wiremock_dir / "ca-cert.jks", + "--ca-keystore", + self.wiremock_dir / "ca-cert.jks", + ] + + self._additional_wiremock_process_args + ) + self._wait_for_wiremock() + + def _stop_wiremock(self): + if self.wiremock_process.poll() is not None: + logger.warning("Wiremock process already exited, skipping shutdown") + return + + try: + response = self._wiremock_post( + f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/shutdown" + ) + if response.status_code != 200: + logger.info("Wiremock shutdown failed, the process will be killed") + self.wiremock_process.kill() + else: + logger.debug("Wiremock shutdown gracefully") + except requests.exceptions.RequestException as e: + logger.warning(f"Shutdown request failed: {e}. Killing process directly.") + self.wiremock_process.kill() + + def _wait_for_wiremock(self): + retry_count = 0 + while retry_count < WIREMOCK_START_MAX_RETRY_COUNT: + if self._health_check(): + return + retry_count += 1 + sleep(1) + + raise TimeoutError( + f"WiremockClient did not respond within {WIREMOCK_START_MAX_RETRY_COUNT} seconds" + ) + + def _health_check(self): + mappings_endpoint = ( + f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/health" + ) + try: + response = requests.get(mappings_endpoint) + except requests.exceptions.RequestException as e: + logger.warning(f"Wiremock healthcheck failed with exception: {e}") + return False + + if ( + response.status_code == requests.codes.ok + and response.json()["status"] != "healthy" + ): + logger.warning(f"Wiremock healthcheck failed with response: {response}") + return False + elif response.status_code != requests.codes.ok: + logger.warning( + f"Wiremock healthcheck failed with status code: {response.status_code}" + ) + return False + + return True + + def _reset_wiremock(self): + clean_journal_endpoint = ( + f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/requests" + ) + requests.delete(clean_journal_endpoint) + reset_endpoint = ( + f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/reset" + ) + response = self._wiremock_post(reset_endpoint) + if response.status_code != requests.codes.ok: + raise RuntimeError("Failed to reset WiremockClient") + + def _wiremock_post( + self, endpoint: str, body: Optional[str] = None + ) -> requests.Response: + headers = {"Accept": "application/json", "Content-Type": "application/json"} + return requests.post(endpoint, data=body, headers=headers) + + def _replace_placeholders_in_mapping( + self, mapping_str: str, placeholders: Optional[dict[str, object]] + ) -> str: + if placeholders: + for key, value in placeholders.items(): + mapping_str = mapping_str.replace(str(key), str(value)) + return mapping_str + + def import_mapping( + self, + mapping: Union[str, dict, pathlib.Path], + placeholders: Optional[dict[str, object]] = None, + expected_headers: Optional[dict] = None, + ): + self._reset_wiremock() + import_mapping_endpoint = f"{self.http_host_with_port}/__admin/mappings/import" + + mapping_str = _get_mapping_str(mapping) + if expected_headers is not None: + mapping_str = self.add_expected_headers_to_mapping( + mapping_str, expected_headers + ) + + mapping_str = self._replace_placeholders_in_mapping(mapping_str, placeholders) + response = self._wiremock_post(import_mapping_endpoint, mapping_str) + if response.status_code != requests.codes.ok: + raise RuntimeError("Failed to import mapping") + + def import_mapping_with_default_placeholders( + self, + mapping: Union[str, dict, pathlib.Path], + expected_headers: Optional[dict] = None, + ): + placeholders = self.get_default_placeholders() + return self.import_mapping(mapping, placeholders, expected_headers) + + def add_mapping_with_default_placeholders( + self, + mapping: Union[str, dict, pathlib.Path], + expected_headers: Optional[dict] = None, + ): + placeholders = self.get_default_placeholders() + return self.add_mapping(mapping, placeholders, expected_headers) + + def add_mapping( + self, + mapping: Union[str, dict, pathlib.Path], + placeholders: Optional[dict[str, object]] = None, + expected_headers: Optional[dict] = None, + ): + add_mapping_endpoint = f"{self.http_host_with_port}/__admin/mappings" + + mapping_str = _get_mapping_str(mapping) + if expected_headers is not None: + mapping_str = self.add_expected_headers_to_mapping( + mapping_str, expected_headers + ) + + mapping_str = self._replace_placeholders_in_mapping(mapping_str, placeholders) + response = self._wiremock_post(add_mapping_endpoint, mapping_str) + if response.status_code != requests.codes.created: + raise RuntimeError("Failed to add mapping") + + def _find_free_port(self, forbidden_ports: Union[List[int], None] = None) -> int: + max_retries = 1 if forbidden_ports is None else 3 + if forbidden_ports is None: + forbidden_ports = [] + + retry_count = 0 + while retry_count < max_retries: + retry_count += 1 + with socket.socket() as sock: + sock.bind((self.wiremock_host, 0)) + port = sock.getsockname()[1] + if port not in forbidden_ports: + return port + + raise RuntimeError( + f"Unable to find a free port for wiremock in {max_retries} attempts" + ) + + def __enter__(self): + self._start_wiremock() + logger.debug( + f"Starting wiremock process, listening on {self.wiremock_host}:{self.wiremock_http_port}" + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + logger.debug("Stopping wiremock process") + self._stop_wiremock() + + +@contextmanager +def get_clients_for_proxy_and_target( + proxy_mapping_template: Union[str, dict, pathlib.Path, None] = None, + additional_proxy_placeholders: Optional[dict[str, object]] = None, + additional_proxy_args: Optional[Iterable[str]] = None, +): + """Context manager that starts two Wiremock instances – *target* and *proxy* – and + configures the proxy to forward **all** traffic to the target. + + It yields a tuple ``(target_wm, proxy_wm)`` where both items are fully initialised + ``WiremockClient`` objects ready for use in tests. When the context exits both + Wiremock processes are shut down automatically. + + Parameters + ---------- + proxy_mapping_template + Mapping JSON (str / dict / pathlib.Path) to be used for configuring the proxy + Wiremock. If *None*, the default template at + ``test/data/wiremock/mappings/proxy/forward_all.json`` is used. + additional_proxy_placeholders + Optional placeholders to be replaced in the proxy mapping *in addition* to the + automatically provided ``{{TARGET_HTTP_HOST_WITH_PORT}}``. + additional_proxy_args + Extra command-line arguments passed to the proxy Wiremock instance when it is + launched. Useful for tweaking Wiremock behaviour in specific tests. + """ + + # Resolve default mapping template if none provided + if proxy_mapping_template is None: + proxy_mapping_template = ( + pathlib.Path(__file__).parent.parent.parent.parent + / "test" + / "data" + / "wiremock" + / "mappings" + / "generic" + / "proxy_forward_all.json" + ) + + # Start the *target* Wiremock first – this will emulate Snowflake / IdP backend + with WiremockClient() as target_wm: + # Then start the *proxy* Wiremock and ensure it doesn't try to bind the same port + with WiremockClient( + forbidden_ports=[target_wm.wiremock_http_port], + additional_wiremock_process_args=additional_proxy_args, + ) as proxy_wm: + # Prepare placeholders so that proxy forwards to the *target* + placeholders: dict[str, object] = { + "{{TARGET_HTTP_HOST_WITH_PORT}}": target_wm.http_host_with_port + } + if additional_proxy_placeholders: + placeholders.update(additional_proxy_placeholders) + + # Configure proxy Wiremock to forward everything to target + proxy_wm.add_mapping(proxy_mapping_template, placeholders=placeholders) + + # Yield control back to the caller with both Wiremocks ready + yield target_wm, proxy_wm diff --git a/test/unit/test_connection.py b/test/unit/test_connection.py index 8220759aa6..9a4f4218db 100644 --- a/test/unit/test_connection.py +++ b/test/unit/test_connection.py @@ -42,6 +42,7 @@ AuthByDefault = AuthByOkta = AuthByOAuth = AuthByWebBrowser = MagicMock try: # pragma: no cover + import snowflake.connector.vendored.requests as requests from snowflake.connector.auth import AuthByUsrPwdMfa from snowflake.connector.config_manager import CONFIG_MANAGER from snowflake.connector.constants import ( @@ -806,3 +807,88 @@ def test_reraise_error_in_file_transfer_work_function_config( expected_value = bool(reraise_enabled) actual_value = conn._reraise_error_in_file_transfer_work_function assert actual_value == expected_value + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize("proxy_method", ["explicit_args", "env_vars"]) +def test_large_query_through_proxy( + wiremock_generic_mappings_dir, + wiremock_target_proxy_pair, + wiremock_mapping_dir, + proxy_env_vars, + proxy_method, +): + target_wm, proxy_wm = wiremock_target_proxy_pair + + password_mapping = wiremock_mapping_dir / "auth/password/successful_flow.json" + multi_chunk_request_mapping = ( + wiremock_mapping_dir / "queries/select_large_request_successful.json" + ) + disconnect_mapping = ( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + telemetry_mapping = wiremock_generic_mappings_dir / "telemetry.json" + chunk_1_mapping = wiremock_mapping_dir / "queries/chunk_1.json" + chunk_2_mapping = wiremock_mapping_dir / "queries/chunk_2.json" + + # Configure mappings with proxy header verification + expected_headers = {"Via": {"contains": "wiremock"}} + + target_wm.import_mapping(password_mapping, expected_headers=expected_headers) + target_wm.add_mapping_with_default_placeholders( + multi_chunk_request_mapping, expected_headers + ) + target_wm.add_mapping(disconnect_mapping, expected_headers=expected_headers) + target_wm.add_mapping(telemetry_mapping, expected_headers=expected_headers) + target_wm.add_mapping_with_default_placeholders(chunk_1_mapping, expected_headers) + target_wm.add_mapping_with_default_placeholders(chunk_2_mapping, expected_headers) + + # Configure proxy based on test parameter + set_proxy_env_vars, clear_proxy_env_vars = proxy_env_vars + connect_kwargs = { + "user": "testUser", + "password": "testPassword", + "account": "testAccount", + "host": target_wm.wiremock_host, + "port": target_wm.wiremock_http_port, + "protocol": "http", + "warehouse": "TEST_WH", + } + + if proxy_method == "explicit_args": + connect_kwargs.update( + { + "proxy_host": proxy_wm.wiremock_host, + "proxy_port": str(proxy_wm.wiremock_http_port), + "proxy_user": "proxyUser", + "proxy_password": "proxyPass", + } + ) + clear_proxy_env_vars() # Ensure no env vars interfere + else: # env_vars + proxy_url = f"http://proxyUser:proxyPass@{proxy_wm.wiremock_host}:{proxy_wm.wiremock_http_port}" + set_proxy_env_vars(proxy_url) + + row_count = 50_000 + with snowflake.connector.connect(**connect_kwargs) as conn: + cursors = conn.execute_string( + f"select seq4() as n from table(generator(rowcount => {row_count}));" + ) + assert len(cursors[0]._result_set.batches) > 1 # We need to have remote results + assert list(cursors[0]) + + # Ensure proxy saw query + proxy_reqs = requests.get(f"{proxy_wm.http_host_with_port}/__admin/requests").json() + assert any( + "/queries/v1/query-request" in r["request"]["url"] + for r in proxy_reqs["requests"] + ) + + # Ensure backend saw query + target_reqs = requests.get( + f"{target_wm.http_host_with_port}/__admin/requests" + ).json() + assert any( + "/queries/v1/query-request" in r["request"]["url"] + for r in target_reqs["requests"] + ) diff --git a/test/unit/test_oauth_token.py b/test/unit/test_oauth_token.py index bc1e650adb..b19d9415d6 100644 --- a/test/unit/test_oauth_token.py +++ b/test/unit/test_oauth_token.py @@ -5,7 +5,6 @@ import logging import pathlib from threading import Thread -from typing import Any, Generator, Union from unittest import mock from unittest.mock import Mock, patch @@ -16,17 +15,11 @@ from snowflake.connector.auth import AuthByOauthCredentials from snowflake.connector.token_cache import TokenCache, TokenKey, TokenType -from ..wiremock.wiremock_utils import WiremockClient +from ..test_utils.wiremock.wiremock_utils import WiremockClient logger = logging.getLogger(__name__) -@pytest.fixture(scope="session") -def wiremock_client() -> Generator[Union[WiremockClient, Any], Any, None]: - with WiremockClient() as client: - yield client - - @pytest.fixture(scope="session") def wiremock_oauth_authorization_code_dir() -> pathlib.Path: return ( @@ -53,17 +46,6 @@ def wiremock_oauth_client_creds_dir() -> pathlib.Path: ) -@pytest.fixture(scope="session") -def wiremock_generic_mappings_dir() -> pathlib.Path: - return ( - pathlib.Path(__file__).parent.parent - / "data" - / "wiremock" - / "mappings" - / "generic" - ) - - @pytest.fixture(scope="session") def wiremock_oauth_refresh_token_dir() -> pathlib.Path: return ( @@ -701,3 +683,156 @@ def test_client_creds_expired_refresh_token_flow( cached_refresh_token = temp_cache.retrieve(refresh_token_key) assert cached_access_token == "expired-access-token-123" assert cached_refresh_token == "expired-refresh-token-123" + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize("proxy_method", ["explicit_args", "env_vars"]) +def test_client_credentials_flow_via_explicit_proxy( + wiremock_oauth_client_creds_dir, + wiremock_generic_mappings_dir, + wiremock_target_proxy_pair, + temp_cache, + wiremock_mapping_dir, + proxy_env_vars, + proxy_method, +): + """Spin up two Wiremock instances (target & proxy) via shared fixture and run OAuth Client-Credentials flow through the proxy.""" + + target_wm, proxy_wm = wiremock_target_proxy_pair + + # Configure backend (Snowflake + IdP) responses with proxy header verification + expected_headers = {"Via": {"contains": "wiremock"}} + + target_wm.import_mapping_with_default_placeholders( + wiremock_oauth_client_creds_dir / "successful_flow.json", expected_headers + ) + target_wm.add_mapping_with_default_placeholders( + wiremock_generic_mappings_dir / "snowflake_login_successful.json", + expected_headers, + ) + target_wm.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json", + expected_headers=expected_headers, + ) + + token_request_url = f"http://{target_wm.wiremock_host}:{target_wm.wiremock_http_port}/oauth/token-request" + + # Configure proxy based on test parameter + set_proxy_env_vars, clear_proxy_env_vars = proxy_env_vars + connect_kwargs = { + "user": "testUser", + "authenticator": "OAUTH_CLIENT_CREDENTIALS", + "oauth_client_id": "cid", + "oauth_client_secret": "secret", + "account": "testAccount", + "protocol": "http", + "role": "ANALYST", + "oauth_token_request_url": token_request_url, + "host": target_wm.wiremock_host, + "port": target_wm.wiremock_http_port, + "oauth_enable_refresh_tokens": True, + "client_store_temporary_credential": True, + "token_cache": temp_cache, + } + + if proxy_method == "explicit_args": + connect_kwargs.update( + { + "proxy_host": proxy_wm.wiremock_host, + "proxy_port": str(proxy_wm.wiremock_http_port), + "proxy_user": "proxyUser", + "proxy_password": "proxyPass", + } + ) + clear_proxy_env_vars() # Ensure no env vars interfere + else: # env_vars + proxy_url = f"http://proxyUser:proxyPass@{proxy_wm.wiremock_host}:{proxy_wm.wiremock_http_port}" + set_proxy_env_vars(proxy_url) + + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + cnx = snowflake.connector.connect(**connect_kwargs) + assert cnx, "Connection object should be valid" + cnx.close() + + # Verify proxy & backend saw the token request + proxy_requests = requests.get( + f"{proxy_wm.http_host_with_port}/__admin/requests" + ).json() + assert any( + req["request"]["url"].endswith("/oauth/token-request") + for req in proxy_requests["requests"] + ) + + target_requests = requests.get( + f"{target_wm.http_host_with_port}/__admin/requests" + ).json() + assert any( + req["request"]["url"].endswith("/oauth/token-request") + for req in target_requests["requests"] + ) + + +@pytest.mark.skipolddriver +@patch("snowflake.connector.auth._http_server.AuthHttpServer.DEFAULT_TIMEOUT", 30) +def test_oauth_code_successful_flow_through_proxy( + wiremock_oauth_authorization_code_dir, + wiremock_generic_mappings_dir, + wiremock_target_proxy_pair, + webbrowser_mock, + monkeypatch, + omit_oauth_urls_check, +) -> None: + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + target_wm, proxy_wm = wiremock_target_proxy_pair + + target_wm.import_mapping_with_default_placeholders( + wiremock_oauth_authorization_code_dir / "successful_flow.json", + ) + target_wm.add_mapping_with_default_placeholders( + wiremock_generic_mappings_dir / "snowflake_login_successful.json", + ) + target_wm.add_mapping( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json", + ) + + with mock.patch("webbrowser.open", new=webbrowser_mock.open): + with mock.patch("secrets.token_urlsafe", return_value="abc123"): + cnx = snowflake.connector.connect( + user="testUser", + authenticator="OAUTH_AUTHORIZATION_CODE", + oauth_client_id="123", + account="testAccount", + protocol="http", + role="ANALYST", + proxy_host=proxy_wm.wiremock_host, + proxy_port=str(proxy_wm.wiremock_http_port), + proxy_user="proxyUser", + proxy_password="proxyPass", + oauth_client_secret="testClientSecret", + oauth_token_request_url=f"http://{target_wm.wiremock_host}:{target_wm.wiremock_http_port}/oauth/token-request", + oauth_authorization_url=f"http://{target_wm.wiremock_host}:{target_wm.wiremock_http_port}/oauth/authorize", + oauth_redirect_uri="http://localhost:8009/snowflake/oauth-redirect", + host=target_wm.wiremock_host, + port=target_wm.wiremock_http_port, + ) + + assert cnx, "invalid cnx" + cnx.close() + + # Verify: proxy Wiremock saw the token request + proxy_requests = requests.get( + f"http://{proxy_wm.wiremock_host}:{proxy_wm.wiremock_http_port}/__admin/requests" + ).json() + assert any( + req["request"]["url"].endswith("/oauth/token-request") + for req in proxy_requests["requests"] + ), "Proxy did not record token-request" + + # Verify: target Wiremock also saw it (because proxy forwarded) + target_requests = requests.get( + f"http://{target_wm.wiremock_host}:{target_wm.wiremock_http_port}/__admin/requests" + ).json() + assert any( + req["request"]["url"].endswith("/oauth/token-request") + for req in target_requests["requests"] + ), "Target did not receive token-request forwarded by proxy" diff --git a/test/unit/test_programmatic_access_token.py b/test/unit/test_programmatic_access_token.py index 7d6ecb175e..fdf5bc0c9d 100644 --- a/test/unit/test_programmatic_access_token.py +++ b/test/unit/test_programmatic_access_token.py @@ -1,5 +1,4 @@ import pathlib -from typing import Any, Generator, Union import pytest @@ -9,13 +8,7 @@ except ImportError: pass -from ..wiremock.wiremock_utils import WiremockClient - - -@pytest.fixture(scope="session") -def wiremock_client() -> Generator[Union[WiremockClient, Any], Any, None]: - with WiremockClient() as client: - yield client +from ..test_utils.wiremock.wiremock_utils import WiremockClient @pytest.mark.skipolddriver diff --git a/test/unit/test_proxies.py b/test/unit/test_proxies.py index fbd2d47268..b32e1dcb09 100644 --- a/test/unit/test_proxies.py +++ b/test/unit/test_proxies.py @@ -8,32 +8,23 @@ import pytest import snowflake.connector +import snowflake.connector.vendored.requests as requests from snowflake.connector.errors import OperationalError -def test_set_proxies(): - from snowflake.connector.proxy import set_proxies +@pytest.mark.skipolddriver +def test_get_proxy_url(): + from snowflake.connector.proxy import get_proxy_url - assert set_proxies("proxyhost", "8080") == { - "http": "http://proxyhost:8080", - "https": "http://proxyhost:8080", - } - assert set_proxies("http://proxyhost", "8080") == { - "http": "http://proxyhost:8080", - "https": "http://proxyhost:8080", - } - assert set_proxies("http://proxyhost", "8080", "testuser", "testpass") == { - "http": "http://testuser:testpass@proxyhost:8080", - "https": "http://testuser:testpass@proxyhost:8080", - } - assert set_proxies("proxyhost", "8080", "testuser", "testpass") == { - "http": "http://testuser:testpass@proxyhost:8080", - "https": "http://testuser:testpass@proxyhost:8080", - } + assert get_proxy_url("host", "port", "user", "password") == ( + "http://user:password@host:port" + ) + assert get_proxy_url("host", "port") == "http://host:port" - # NOTE environment variable is set if the proxy parameter is specified. - del os.environ["HTTP_PROXY"] - del os.environ["HTTPS_PROXY"] + assert get_proxy_url("http://host", "port") == "http://host:port" + assert get_proxy_url("https://host", "port", "user", "password") == ( + "http://user:password@host:port" + ) @pytest.mark.skipolddriver @@ -91,3 +82,81 @@ def mock_proxy_manager_for_url_wiht_header(*args, **kwargs): assert "Unable to set 'Host' to proxy manager of type" not in caplog.text del os.environ["HTTPS_PROXY"] + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize("proxy_method", ["explicit_args", "env_vars"]) +def test_basic_query_through_proxy( + wiremock_generic_mappings_dir, + wiremock_target_proxy_pair, + wiremock_mapping_dir, + proxy_env_vars, + proxy_method, +): + target_wm, proxy_wm = wiremock_target_proxy_pair + + password_mapping = wiremock_mapping_dir / "auth/password/successful_flow.json" + select_mapping = wiremock_mapping_dir / "queries/select_1_successful.json" + disconnect_mapping = ( + wiremock_generic_mappings_dir / "snowflake_disconnect_successful.json" + ) + telemetry_mapping = wiremock_generic_mappings_dir / "telemetry.json" + + # Use expected headers to ensure requests go through proxy + expected_headers = {"Via": {"contains": "wiremock"}} + + target_wm.import_mapping_with_default_placeholders( + password_mapping, expected_headers + ) + target_wm.add_mapping_with_default_placeholders(select_mapping, expected_headers) + target_wm.add_mapping(disconnect_mapping) + target_wm.add_mapping(telemetry_mapping) + + # Configure proxy based on test parameter + set_proxy_env_vars, clear_proxy_env_vars = proxy_env_vars + connect_kwargs = { + "user": "testUser", + "password": "testPassword", + "account": "testAccount", + "host": target_wm.wiremock_host, + "port": target_wm.wiremock_http_port, + "protocol": "http", + "warehouse": "TEST_WH", + } + + if proxy_method == "explicit_args": + connect_kwargs.update( + { + "proxy_host": proxy_wm.wiremock_host, + "proxy_port": str(proxy_wm.wiremock_http_port), + } + ) + clear_proxy_env_vars() # Ensure no env vars interfere + else: # env_vars + proxy_url = f"http://{proxy_wm.wiremock_host}:{proxy_wm.wiremock_http_port}" + set_proxy_env_vars(proxy_url) + + # Make connection via proxy + cnx = snowflake.connector.connect(**connect_kwargs) + cur = cnx.cursor() + cur.execute("SELECT 1") + result = cur.fetchone() + assert result[0] == 1 + cur.close() + cnx.close() + + # Ensure proxy saw query + proxy_reqs = requests.get(f"{proxy_wm.http_host_with_port}/__admin/requests").json() + assert any( + "/queries/v1/query-request" in r["request"]["url"] + for r in proxy_reqs["requests"] + ) + + # Ensure backend saw query + target_reqs = requests.get( + f"{target_wm.http_host_with_port}/__admin/requests" + ).json() + assert any( + "/queries/v1/query-request" in r["request"]["url"] + for r in target_reqs["requests"] + ) diff --git a/test/unit/test_wiremock_client.py b/test/unit/test_wiremock_client.py index b471f39df7..19625c42c0 100644 --- a/test/unit/test_wiremock_client.py +++ b/test/unit/test_wiremock_client.py @@ -1,7 +1,3 @@ -from typing import Any, Generator - -import pytest - # old driver support try: from snowflake.connector.vendored import requests @@ -9,16 +5,6 @@ import requests -from ..wiremock.wiremock_utils import WiremockClient - - -@pytest.mark.skipolddriver -@pytest.fixture(scope="session") -def wiremock_client() -> Generator[WiremockClient, Any, None]: - with WiremockClient() as client: - yield client - - def test_wiremock(wiremock_client): connection_reset_by_peer_mapping = { "mappings": [ diff --git a/test/wiremock/wiremock_utils.py b/test/wiremock/wiremock_utils.py deleted file mode 100644 index 1d036a8023..0000000000 --- a/test/wiremock/wiremock_utils.py +++ /dev/null @@ -1,186 +0,0 @@ -import json -import logging -import pathlib -import socket -import subprocess -from time import sleep -from typing import List, Optional, Union - -try: - from snowflake.connector.vendored import requests -except ImportError: - import requests - -WIREMOCK_START_MAX_RETRY_COUNT = 12 -logger = logging.getLogger(__name__) - - -def _get_mapping_str(mapping: Union[str, dict, pathlib.Path]) -> str: - if isinstance(mapping, str): - return mapping - if isinstance(mapping, dict): - return json.dumps(mapping) - if isinstance(mapping, pathlib.Path): - if mapping.is_file(): - with open(mapping) as f: - return f.read() - else: - raise RuntimeError(f"File with mapping: {mapping} does not exist") - - raise RuntimeError(f"Mapping {mapping} is of an invalid type") - - -class WiremockClient: - def __init__(self, forbidden_ports: Optional[List[int]] = None) -> None: - self.wiremock_filename = "wiremock-standalone.jar" - self.wiremock_host = "localhost" - self.wiremock_http_port = None - self.wiremock_https_port = None - self.forbidden_ports = forbidden_ports if forbidden_ports is not None else [] - - self.wiremock_dir = pathlib.Path(__file__).parent.parent.parent / ".wiremock" - assert self.wiremock_dir.exists(), f"{self.wiremock_dir} does not exist" - - self.wiremock_jar_path = self.wiremock_dir / self.wiremock_filename - assert ( - self.wiremock_jar_path.exists() - ), f"{self.wiremock_jar_path} does not exist" - - def _start_wiremock(self): - self.wiremock_http_port = self._find_free_port( - forbidden_ports=self.forbidden_ports, - ) - self.wiremock_https_port = self._find_free_port( - forbidden_ports=self.forbidden_ports + [self.wiremock_http_port] - ) - self.wiremock_process = subprocess.Popen( - [ - "java", - "-jar", - self.wiremock_jar_path, - "--root-dir", - self.wiremock_dir, - "--enable-browser-proxying", # work as forward proxy - "--proxy-pass-through", - "false", # pass through only matched requests - "--port", - str(self.wiremock_http_port), - "--https-port", - str(self.wiremock_https_port), - "--https-keystore", - self.wiremock_dir / "ca-cert.jks", - "--ca-keystore", - self.wiremock_dir / "ca-cert.jks", - ] - ) - self._wait_for_wiremock() - - def _stop_wiremock(self): - response = self._wiremock_post( - f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/shutdown" - ) - if response.status_code != 200: - logger.info("Wiremock shutdown failed, the process will be killed") - self.wiremock_process.kill() - else: - logger.debug("Wiremock shutdown gracefully") - - def _wait_for_wiremock(self): - retry_count = 0 - while retry_count < WIREMOCK_START_MAX_RETRY_COUNT: - if self._health_check(): - return - retry_count += 1 - sleep(1) - - raise TimeoutError( - f"WiremockClient did not respond within {WIREMOCK_START_MAX_RETRY_COUNT} seconds" - ) - - def _health_check(self): - mappings_endpoint = ( - f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/health" - ) - try: - response = requests.get(mappings_endpoint) - except requests.exceptions.RequestException as e: - logger.warning(f"Wiremock healthcheck failed with exception: {e}") - return False - - if ( - response.status_code == requests.codes.ok - and response.json()["status"] != "healthy" - ): - logger.warning(f"Wiremock healthcheck failed with response: {response}") - return False - elif response.status_code != requests.codes.ok: - logger.warning( - f"Wiremock healthcheck failed with status code: {response.status_code}" - ) - return False - - return True - - def _reset_wiremock(self): - clean_journal_endpoint = ( - f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/requests" - ) - requests.delete(clean_journal_endpoint) - reset_endpoint = ( - f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/reset" - ) - response = self._wiremock_post(reset_endpoint) - if response.status_code != requests.codes.ok: - raise RuntimeError("Failed to reset WiremockClient") - - def _wiremock_post( - self, endpoint: str, body: Optional[str] = None - ) -> requests.Response: - headers = {"Accept": "application/json", "Content-Type": "application/json"} - return requests.post(endpoint, data=body, headers=headers) - - def import_mapping(self, mapping: Union[str, dict, pathlib.Path]): - self._reset_wiremock() - import_mapping_endpoint = f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/mappings/import" - mapping_str = _get_mapping_str(mapping) - response = self._wiremock_post(import_mapping_endpoint, mapping_str) - if response.status_code != requests.codes.ok: - raise RuntimeError("Failed to import mapping") - - def add_mapping(self, mapping: Union[str, dict, pathlib.Path]): - add_mapping_endpoint = ( - f"http://{self.wiremock_host}:{self.wiremock_http_port}/__admin/mappings" - ) - mapping_str = _get_mapping_str(mapping) - response = self._wiremock_post(add_mapping_endpoint, mapping_str) - if response.status_code != requests.codes.created: - raise RuntimeError("Failed to add mapping") - - def _find_free_port(self, forbidden_ports: Union[List[int], None] = None) -> int: - max_retries = 1 if forbidden_ports is None else 3 - if forbidden_ports is None: - forbidden_ports = [] - - retry_count = 0 - while retry_count < max_retries: - retry_count += 1 - with socket.socket() as sock: - sock.bind((self.wiremock_host, 0)) - port = sock.getsockname()[1] - if port not in forbidden_ports: - return port - - raise RuntimeError( - f"Unable to find a free port for wiremock in {max_retries} attempts" - ) - - def __enter__(self): - self._start_wiremock() - logger.debug( - f"Starting wiremock process, listening on {self.wiremock_host}:{self.wiremock_http_port}" - ) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - logger.debug("Stopping wiremock process") - self._stop_wiremock()