Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,22 +1082,23 @@ def connect(self, **kwargs) -> None:

self._crl_config: CRLConfig = CRLConfig.from_connection(self)

no_proxy_csv_str = (
",".join(str(x) for x in self.no_proxy)
if (
self.no_proxy is not None
and isinstance(self.no_proxy, Iterable)
and not isinstance(self.no_proxy, (str, bytes))
)
else self.no_proxy
)
self._http_config = HttpConfig(
adapter_factory=ProxySupportAdapterFactory(),
use_pooling=(not self.disable_request_pooling),
proxy_host=self.proxy_host,
proxy_port=self.proxy_port,
proxy_user=self.proxy_user,
proxy_password=self.proxy_password,
no_proxy=(
",".join(str(x) for x in self.no_proxy)
if (
self.no_proxy is not None
and isinstance(self.no_proxy, Iterable)
and not isinstance(self.no_proxy, (str, bytes))
)
else self.no_proxy
),
no_proxy=no_proxy_csv_str,
)
self._session_manager = SessionManagerFactory.get_manager(self._http_config)

Expand Down
35 changes: 18 additions & 17 deletions src/snowflake/connector/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def close(self) -> None:
self._idle_sessions.clear()


class _ConfigDirectAccessMixin(abc.ABC):
class _BaseConfigDirectAccessMixin(abc.ABC):
@property
@abc.abstractmethod
def config(self) -> HttpConfig: ...
Expand All @@ -245,14 +245,6 @@ def use_pooling(self) -> bool:
def use_pooling(self, value: bool) -> None:
self.config = self.config.copy_with(use_pooling=value)

@property
def adapter_factory(self) -> Callable[..., HTTPAdapter]:
return self.config.adapter_factory

@adapter_factory.setter
def adapter_factory(self, value: Callable[..., HTTPAdapter]) -> None:
self.config = self.config.copy_with(adapter_factory=value)

@property
def max_retries(self) -> Retry | int:
return self.config.max_retries
Expand All @@ -262,6 +254,16 @@ def max_retries(self, value: Retry | int) -> None:
self.config = self.config.copy_with(max_retries=value)


class _HttpConfigDirectAccessMixin(_BaseConfigDirectAccessMixin):
@property
def adapter_factory(self) -> Callable[..., HTTPAdapter]:
return self.config.adapter_factory

@adapter_factory.setter
def adapter_factory(self, value: Callable[..., HTTPAdapter]) -> None:
self.config = self.config.copy_with(adapter_factory=value)


class _RequestVerbsUsingSessionMixin(abc.ABC):
"""
Mixin that provides HTTP methods (get, post, put, etc.) mirroring requests.Session, maintaining their default argument behavior (e.g., HEAD uses allow_redirects=False).
Expand Down Expand Up @@ -372,7 +374,7 @@ def delete(
return session.delete(url, headers=headers, timeout=timeout, **kwargs)


class SessionManager(_RequestVerbsUsingSessionMixin, _ConfigDirectAccessMixin):
class SessionManager(_RequestVerbsUsingSessionMixin, _HttpConfigDirectAccessMixin):
"""
Central HTTP session manager that handles all external requests from the Snowflake driver.

Expand Down Expand Up @@ -562,7 +564,8 @@ def clone(
Optional kwargs (e.g. *use_pooling* / *adapter_factory* / max_retries etc.) - overrides to create a modified
copy of the HttpConfig before instantiation.
"""
return SessionManager.from_config(self._cfg, **http_config_overrides)
# return SessionManager.from_config(self._cfg, **http_config_overrides)
return self.from_config(self._cfg, **http_config_overrides)

def __getstate__(self):
state = self.__dict__.copy()
Expand Down Expand Up @@ -612,6 +615,10 @@ class ProxySessionManager(SessionManager):
def make_session(self, *, url: str | None = None) -> Session:
session = requests.Session()
self._mount_adapters(session)

# TODO: adding this makes all proxies "not see the requests" - but they actually go since we can see MATCHED-stub id
# session.trust_env = False

proxies = (
{
"no_proxy": self._cfg.no_proxy,
Expand All @@ -626,12 +633,6 @@ def make_session(self, *, url: str | None = None) -> Session:
session.proxies = proxies
return session

def clone(
self,
**http_config_overrides,
) -> SessionManager:
return ProxySessionManager.from_config(self._cfg, **http_config_overrides)


class SessionManagerFactory:
@staticmethod
Expand Down
21 changes: 21 additions & 0 deletions test/test_utils/cross_module_fixtures/wiremock_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
WiremockClient,
get_clients_for_proxy_and_target,
get_clients_for_proxy_target_and_storage,
get_clients_for_two_proxies_and_target,
)


Expand Down Expand Up @@ -102,3 +103,23 @@ def wiremock_backend_storage_proxy(wiremock_generic_mappings_dir):
proxy_mapping_template=wiremock_proxy_mapping_path
) as triple:
yield triple


@pytest.fixture
def wiremock_two_proxies_backend(wiremock_generic_mappings_dir):
"""Starts backend (DB) and two proxy Wiremocks.

Returns a tuple ``(backend_wm, proxy1_wm, proxy2_wm)`` to make roles explicit.
- proxy1_wm: Configured to forward to backend
- proxy2_wm: Configured to forward to backend

Use when you need to test proxy selection logic with simple setup,
such as connection parameters taking precedence over environment variables.
"""
wiremock_proxy_mapping_path = (
wiremock_generic_mappings_dir / "proxy_forward_all.json"
)
with get_clients_for_two_proxies_and_target(
proxy_mapping_template=wiremock_proxy_mapping_path
) as triple:
yield triple
52 changes: 52 additions & 0 deletions test/test_utils/wiremock/wiremock_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,55 @@ def get_clients_for_proxy_target_and_storage(
forbidden = [target_wm.wiremock_http_port, proxy_wm.wiremock_http_port]
with WiremockClient(forbidden_ports=forbidden) as storage_wm:
yield target_wm, storage_wm, proxy_wm


@contextmanager
def get_clients_for_two_proxies_and_target(
proxy_mapping_template: Union[str, dict, pathlib.Path, None] = None,
additional_proxy_placeholders: Optional[dict[str, object]] = None,
additional_proxy_args: Optional[Iterable[str]] = None,
):
"""Context manager that starts three Wiremock instances – one *target* (DB) and two *proxies*.

Both proxies are configured to forward all traffic to *target* using the same
mapping mechanism. This allows the test to verify which proxy was actually used
by checking the request history.

Yields a tuple ``(target_wm, proxy1_wm, proxy2_wm)`` where:
- target_wm: The backend/DB Wiremock
- proxy1_wm: First proxy configured to forward to target
- proxy2_wm: Second proxy configured to forward to target

All processes are shut down automatically on context exit.

Note:
Use this helper for tests that need to verify proxy selection logic,
such as connection parameters taking precedence over environment variables.
"""
# Reuse existing helper to set up target+proxy1
with get_clients_for_proxy_and_target(
proxy_mapping_template=proxy_mapping_template,
additional_proxy_placeholders=additional_proxy_placeholders,
additional_proxy_args=additional_proxy_args,
) as (target_wm, proxy1_wm):
# Start second proxy and configure it to forward to target as well
forbidden = [target_wm.wiremock_http_port, proxy1_wm.wiremock_http_port]
with WiremockClient(forbidden_ports=forbidden) as proxy2_wm:
# Configure proxy2 to forward to target with the same mapping
if proxy_mapping_template is None:
proxy_mapping_template = (
pathlib.Path(__file__).parent.parent.parent.parent
/ "test"
/ "data"
/ "wiremock"
/ "mappings"
/ "generic"
/ "proxy_forward_all.json"
)
placeholders: dict[str, object] = {
"{{TARGET_HTTP_HOST_WITH_PORT}}": target_wm.http_host_with_port
}
if additional_proxy_placeholders:
placeholders.update(additional_proxy_placeholders)
proxy2_wm.add_mapping(proxy_mapping_template, placeholders=placeholders)
yield target_wm, proxy1_wm, proxy2_wm
Loading
Loading