diff --git a/src/snowflake/connector/aio/_cursor.py b/src/snowflake/connector/aio/_cursor.py index 24a9b5da03..81f90d3893 100644 --- a/src/snowflake/connector/aio/_cursor.py +++ b/src/snowflake/connector/aio/_cursor.py @@ -664,11 +664,8 @@ async def execute( ) logger.debug("PUT OR GET: %s", self.is_file_transfer) if self.is_file_transfer: - from ._file_transfer_agent import SnowflakeFileTransferAgent - # Decide whether to use the old, or new code path - sf_file_transfer_agent = SnowflakeFileTransferAgent( - self, + sf_file_transfer_agent = self._create_file_transfer_agent( query, ret, put_callback=_put_callback, @@ -684,9 +681,6 @@ async def execute( skip_upload_on_content_match=_skip_upload_on_content_match, source_from_stream=file_stream, multipart_threshold=data.get("threshold"), - use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1, - unsafe_file_write=self._connection.unsafe_file_write, - reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function, ) await sf_file_transfer_agent.execute() data = sf_file_transfer_agent.result() @@ -1082,8 +1076,6 @@ async def _download( _do_reset (bool, optional): Whether to reset the cursor before downloading, by default we will reset the cursor. """ - from ._file_transfer_agent import SnowflakeFileTransferAgent - if _do_reset: self.reset() @@ -1097,11 +1089,9 @@ async def _download( ) # Execute the file operation based on the interpretation above. - file_transfer_agent = SnowflakeFileTransferAgent( - self, + file_transfer_agent = self._create_file_transfer_agent( "", # empty command because it is triggered by directly calling this util not by a SQL query ret, - reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function, ) await file_transfer_agent.execute() await self._init_result_and_meta(file_transfer_agent.result()) @@ -1122,8 +1112,6 @@ async def _upload( _do_reset (bool, optional): Whether to reset the cursor before uploading, by default we will reset the cursor. """ - from ._file_transfer_agent import SnowflakeFileTransferAgent - if _do_reset: self.reset() @@ -1137,12 +1125,10 @@ async def _upload( ) # Execute the file operation based on the interpretation above. - file_transfer_agent = SnowflakeFileTransferAgent( - self, + file_transfer_agent = self._create_file_transfer_agent( "", # empty command because it is triggered by directly calling this util not by a SQL query ret, force_put_overwrite=False, # _upload should respect user decision on overwriting - reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function, ) await file_transfer_agent.execute() await self._init_result_and_meta(file_transfer_agent.result()) @@ -1191,8 +1177,6 @@ async def _upload_stream( _do_reset (bool, optional): Whether to reset the cursor before uploading, by default we will reset the cursor. """ - from ._file_transfer_agent import SnowflakeFileTransferAgent - if _do_reset: self.reset() @@ -1207,13 +1191,11 @@ async def _upload_stream( ) # Execute the file operation based on the interpretation above. - file_transfer_agent = SnowflakeFileTransferAgent( - self, + file_transfer_agent = self._create_file_transfer_agent( "", # empty command because it is triggered by directly calling this util not by a SQL query ret, source_from_stream=input_stream, force_put_overwrite=False, # _upload should respect user decision on overwriting - reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function, ) await file_transfer_agent.execute() await self._init_result_and_meta(file_transfer_agent.result()) @@ -1320,6 +1302,27 @@ async def query_result(self, qid: str) -> SnowflakeCursor: ) return self + def _create_file_transfer_agent( + self, + command: str, + ret: dict[str, Any], + /, + **kwargs, + ) -> SnowflakeFileTransferAgent: + from snowflake.connector.aio._file_transfer_agent import ( + SnowflakeFileTransferAgent, + ) + + return SnowflakeFileTransferAgent( + self, + command, + ret, + use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1, + unsafe_file_write=self._connection.unsafe_file_write, + reraise_error_in_file_transfer_work_function=self._connection._reraise_error_in_file_transfer_work_function, + **kwargs, + ) + class DictCursor(DictCursorSync, SnowflakeCursor): pass diff --git a/src/snowflake/connector/aio/_session_manager.py b/src/snowflake/connector/aio/_session_manager.py index aba3e0b840..2371fc5539 100644 --- a/src/snowflake/connector/aio/_session_manager.py +++ b/src/snowflake/connector/aio/_session_manager.py @@ -34,6 +34,7 @@ from ..session_manager import BaseHttpConfig from ..session_manager import SessionManager as SessionManagerSync from ..session_manager import SessionPool as SessionPoolSync +from ..session_manager import _ConfigDirectAccessMixin logger = logging.getLogger(__name__) @@ -328,7 +329,29 @@ async def delete( ) -class SessionManager(_RequestVerbsUsingSessionMixin, SessionManagerSync): +class _AsyncHttpConfigDirectAccessMixin(_ConfigDirectAccessMixin, abc.ABC): + @property + @abc.abstractmethod + def config(self) -> AioHttpConfig: ... + + @config.setter + @abc.abstractmethod + def config(self, value) -> AioHttpConfig: ... + + @property + def connector_factory(self) -> Callable[..., aiohttp.BaseConnector]: + return self.config.connector_factory + + @connector_factory.setter + def connector_factory(self, value: Callable[..., aiohttp.BaseConnector]) -> None: + self.config: AioHttpConfig = self.config.copy_with(connector_factory=value) + + +class SessionManager( + _RequestVerbsUsingSessionMixin, + SessionManagerSync, + _AsyncHttpConfigDirectAccessMixin, +): """ Async HTTP session manager for aiohttp.ClientSession instances. @@ -363,14 +386,6 @@ def from_config(cls, cfg: AioHttpConfig, **overrides: Any) -> SessionManager: cfg = cfg.copy_with(**overrides) return cls(config=cfg) - @property - def connector_factory(self) -> Callable[..., aiohttp.BaseConnector]: - return self._cfg.connector_factory - - @connector_factory.setter - def connector_factory(self, value: Callable[..., aiohttp.BaseConnector]) -> None: - self._cfg: AioHttpConfig = self._cfg.copy_with(connector_factory=value) - def make_session(self) -> aiohttp.ClientSession: """Create a new aiohttp.ClientSession with configured connector.""" connector = self._cfg.get_connector( @@ -432,18 +447,18 @@ async def close(self): def clone( self, - *, - use_pooling: bool | None = None, - connector_factory: ConnectorFactory | None = None, + **http_config_overrides, ) -> SessionManager: - """Return a new async SessionManager sharing this instance's config.""" - overrides: dict[str, Any] = {} - if use_pooling is not None: - overrides["use_pooling"] = use_pooling - if connector_factory is not None: - overrides["connector_factory"] = connector_factory - - return self.from_config(self._cfg, **overrides) + """Return a new *stateless* SessionManager sharing this instance’s config. + + "Shallow clone" - the configuration object (HttpConfig) is reused as-is, + while *stateful* aspects such as the per-host SessionPool mapping are + reset, so the two managers do not share live `requests.Session` + objects. + Optional kwargs (e.g. *use_pooling* / *adapter_factory* / max_retries etc.) - overrides to create a modified + copy of the HttpConfig before instantiation. + """ + return self.from_config(self._cfg, **http_config_overrides) async def request( diff --git a/src/snowflake/connector/auth/_auth.py b/src/snowflake/connector/auth/_auth.py index cb3d227fe6..5dca31a361 100644 --- a/src/snowflake/connector/auth/_auth.py +++ b/src/snowflake/connector/auth/_auth.py @@ -138,7 +138,7 @@ def base_auth_data( "SOCKET_TIMEOUT": socket_timeout, "PLATFORM": detect_platforms( platform_detection_timeout_seconds=platform_detection_timeout_seconds, - session_manager=session_manager, + session_manager=session_manager.clone(max_retries=0), ), }, }, diff --git a/src/snowflake/connector/cursor.py b/src/snowflake/connector/cursor.py index a8ec738986..6ade7f3d8e 100644 --- a/src/snowflake/connector/cursor.py +++ b/src/snowflake/connector/cursor.py @@ -79,7 +79,10 @@ from pyarrow import Table from .connection import SnowflakeConnection - from .file_transfer_agent import SnowflakeProgressPercentage + from .file_transfer_agent import ( + SnowflakeFileTransferAgent, + SnowflakeProgressPercentage, + ) from .result_batch import ResultBatch T = TypeVar("T", bound=collections.abc.Sequence) @@ -1064,11 +1067,7 @@ def execute( ) logger.debug("PUT OR GET: %s", self.is_file_transfer) if self.is_file_transfer: - from .file_transfer_agent import SnowflakeFileTransferAgent - - # Decide whether to use the old, or new code path - sf_file_transfer_agent = SnowflakeFileTransferAgent( - self, + sf_file_transfer_agent = self._create_file_transfer_agent( query, ret, put_callback=_put_callback, @@ -1084,13 +1083,6 @@ def execute( skip_upload_on_content_match=_skip_upload_on_content_match, source_from_stream=file_stream, multipart_threshold=data.get("threshold"), - use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1, - iobound_tpe_limit=self._connection.iobound_tpe_limit, - unsafe_file_write=self._connection.unsafe_file_write, - snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer( - self._connection - ), - reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function, ) sf_file_transfer_agent.execute() data = sf_file_transfer_agent.result() @@ -1785,8 +1777,6 @@ def _download( _do_reset (bool, optional): Whether to reset the cursor before downloading, by default we will reset the cursor. """ - from .file_transfer_agent import SnowflakeFileTransferAgent - if _do_reset: self.reset() @@ -1800,14 +1790,9 @@ def _download( ) # Execute the file operation based on the interpretation above. - file_transfer_agent = SnowflakeFileTransferAgent( - self, + file_transfer_agent = self._create_file_transfer_agent( "", # empty command because it is triggered by directly calling this util not by a SQL query ret, - snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer( - self._connection - ), - reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function, ) file_transfer_agent.execute() self._init_result_and_meta(file_transfer_agent.result()) @@ -1828,7 +1813,6 @@ def _upload( _do_reset (bool, optional): Whether to reset the cursor before uploading, by default we will reset the cursor. """ - from .file_transfer_agent import SnowflakeFileTransferAgent if _do_reset: self.reset() @@ -1843,15 +1827,10 @@ def _upload( ) # Execute the file operation based on the interpretation above. - file_transfer_agent = SnowflakeFileTransferAgent( - self, + file_transfer_agent = self._create_file_transfer_agent( "", # empty command because it is triggered by directly calling this util not by a SQL query ret, force_put_overwrite=False, # _upload should respect user decision on overwriting - snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer( - self._connection - ), - reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function, ) file_transfer_agent.execute() self._init_result_and_meta(file_transfer_agent.result()) @@ -1898,7 +1877,6 @@ def _upload_stream( _do_reset (bool, optional): Whether to reset the cursor before uploading, by default we will reset the cursor. """ - from .file_transfer_agent import SnowflakeFileTransferAgent if _do_reset: self.reset() @@ -1914,19 +1892,37 @@ def _upload_stream( ) # Execute the file operation based on the interpretation above. - file_transfer_agent = SnowflakeFileTransferAgent( - self, + file_transfer_agent = self._create_file_transfer_agent( "", # empty command because it is triggered by directly calling this util not by a SQL query ret, source_from_stream=input_stream, force_put_overwrite=False, # _upload_stream should respect user decision on overwriting + ) + file_transfer_agent.execute() + self._init_result_and_meta(file_transfer_agent.result()) + + def _create_file_transfer_agent( + self, + command: str, + ret: dict[str, Any], + /, + **kwargs, + ) -> SnowflakeFileTransferAgent: + from .file_transfer_agent import SnowflakeFileTransferAgent + + return SnowflakeFileTransferAgent( + self, + command, + ret, + use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1, + iobound_tpe_limit=self._connection.iobound_tpe_limit, + unsafe_file_write=self._connection.unsafe_file_write, snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer( self._connection ), - reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function, + reraise_error_in_file_transfer_work_function=self._connection._reraise_error_in_file_transfer_work_function, + **kwargs, ) - file_transfer_agent.execute() - self._init_result_and_meta(file_transfer_agent.result()) class DictCursor(SnowflakeCursor): diff --git a/src/snowflake/connector/platform_detection.py b/src/snowflake/connector/platform_detection.py index ec615be24d..2ad1893501 100644 --- a/src/snowflake/connector/platform_detection.py +++ b/src/snowflake/connector/platform_detection.py @@ -405,7 +405,7 @@ def detect_platforms( logger.debug( "No session manager provided. HTTP settings may not be preserved. Using default." ) - session_manager = SessionManager(use_pooling=False) + session_manager = SessionManager(use_pooling=False, max_retries=0) # Run environment-only checks synchronously (no network calls, no threading overhead) platforms = { diff --git a/src/snowflake/connector/session_manager.py b/src/snowflake/connector/session_manager.py index 43eeb87ee4..fe47190bca 100644 --- a/src/snowflake/connector/session_manager.py +++ b/src/snowflake/connector/session_manager.py @@ -16,7 +16,7 @@ from .vendored.requests.adapters import BaseAdapter, HTTPAdapter from .vendored.requests.exceptions import InvalidProxyURL from .vendored.requests.utils import prepend_scheme_if_needed, select_proxy -from .vendored.urllib3 import PoolManager +from .vendored.urllib3 import PoolManager, Retry from .vendored.urllib3.poolmanager import ProxyManager from .vendored.urllib3.util.url import parse_url @@ -119,7 +119,7 @@ class BaseHttpConfig: """Immutable HTTP configuration shared by SessionManager instances.""" use_pooling: bool = True - max_retries: int | None = REQUESTS_RETRY + max_retries: int | Retry | None = REQUESTS_RETRY proxy_host: str | None = None proxy_port: str | None = None proxy_user: str | None = None @@ -217,6 +217,40 @@ def close(self) -> None: self._idle_sessions.clear() +class _ConfigDirectAccessMixin(abc.ABC): + @property + @abc.abstractmethod + def config(self) -> HttpConfig: ... + + @config.setter + @abc.abstractmethod + def config(self, value) -> HttpConfig: ... + + @property + def use_pooling(self) -> bool: + return self.config.use_pooling + + @use_pooling.setter + def use_pooling(self, value: bool) -> None: + self.config = self.config.copy_with(use_pooling=value) + + @property + def adapter_factory(self) -> Callable[..., HTTPAdapter]: + return self.config.adapter_factory + + @adapter_factory.setter + def adapter_factory(self, value: Callable[..., HTTPAdapter]) -> None: + self.config = self.config.copy_with(adapter_factory=value) + + @property + def max_retries(self) -> Retry | int: + return self.config.max_retries + + @max_retries.setter + def max_retries(self, value: Retry | int) -> None: + self.config = self.config.copy_with(max_retries=value) + + class _RequestVerbsUsingSessionMixin(abc.ABC): """ Mixin that provides HTTP methods (get, post, put, etc.) mirroring requests.Session, maintaining their default argument behavior (e.g., HEAD uses allow_redirects=False). @@ -327,7 +361,7 @@ def delete( return session.delete(url, headers=headers, timeout=timeout, **kwargs) -class SessionManager(_RequestVerbsUsingSessionMixin): +class SessionManager(_RequestVerbsUsingSessionMixin, _ConfigDirectAccessMixin): """ Central HTTP session manager that handles all external requests from the Snowflake driver. @@ -394,22 +428,6 @@ def proxy_url(self) -> str: self._cfg.proxy_password, ) - @property - def use_pooling(self) -> bool: - return self._cfg.use_pooling - - @use_pooling.setter - def use_pooling(self, value: bool) -> None: - self._cfg = self._cfg.copy_with(use_pooling=value) - - @property - def adapter_factory(self) -> Callable[..., HTTPAdapter]: - return self._cfg.adapter_factory - - @adapter_factory.setter - def adapter_factory(self, value: Callable[..., HTTPAdapter]) -> None: - self._cfg = self._cfg.copy_with(adapter_factory=value) - @property def sessions_map(self) -> dict[str, SessionPool]: return self._sessions_map @@ -435,9 +453,7 @@ def get_session_pool_manager(session: Session, url: str) -> PoolManager | None: def _mount_adapters(self, session: requests.Session) -> None: try: # Its important that each separate session manager creates its own adapters - because they are storing internally PoolManagers - which shouldn't be reused if not in scope of the same adapter. - adapter = self._cfg.adapter_factory( - max_retries=self._cfg.max_retries or REQUESTS_RETRY - ) + adapter = self._cfg.get_adapter() if adapter is not None: session.mount("http://", adapter) session.mount("https://", adapter) @@ -505,27 +521,18 @@ def close(self): def clone( self, - *, - use_pooling: bool | None = None, - adapter_factory: AdapterFactory | None = None, + **http_config_overrides, ) -> SessionManager: """Return a new *stateless* SessionManager sharing this instance’s config. - "Shallow" means the configuration object (HttpConfig) is reused as-is, + "Shallow clone" - the configuration object (HttpConfig) is reused as-is, while *stateful* aspects such as the per-host SessionPool mapping are reset, so the two managers do not share live `requests.Session` objects. - Optional *use_pooling* / *adapter_factory* overrides create a modified - copy of the config before instantiation. + Optional kwargs (e.g. *use_pooling* / *adapter_factory* / max_retries etc.) - overrides to create a modified + copy of the HttpConfig before instantiation. """ - - overrides: dict[str, Any] = {} - if use_pooling is not None: - overrides["use_pooling"] = use_pooling - if adapter_factory is not None: - overrides["adapter_factory"] = adapter_factory - - return SessionManager.from_config(self._cfg, **overrides) + return SessionManager.from_config(self._cfg, **http_config_overrides) def __getstate__(self): state = self.__dict__.copy() diff --git a/test/integ/aio_it/test_connection_async.py b/test/integ/aio_it/test_connection_async.py index c6f043f461..2cc68c5fdb 100644 --- a/test/integ/aio_it/test_connection_async.py +++ b/test/integ/aio_it/test_connection_async.py @@ -58,6 +58,13 @@ except ImportError: pass +from test.integ.test_connection import ( + _assert_log_bytes_within_tolerance, + _calculate_log_bytes, + _find_matching_patterns, + _log_pattern_analysis, +) + async def test_basic(conn_testaccount): """Basic Connection test.""" @@ -413,7 +420,7 @@ async def test_invalid_account_timeout(conn_cnx): pass -@pytest.mark.timeout(15) +@pytest.mark.timeout(20) async def test_invalid_proxy(conn_cnx): http_proxy = os.environ.get("HTTP_PROXY") https_proxy = os.environ.get("HTTPS_PROXY") @@ -445,7 +452,7 @@ async def test_invalid_proxy(conn_cnx): @pytest.mark.skipolddriver -@pytest.mark.timeout(15) +@pytest.mark.timeout(20) async def test_invalid_proxy_not_impacting_env_vars(conn_cnx): http_proxy = os.environ.get("HTTP_PROXY") https_proxy = os.environ.get("HTTPS_PROXY") @@ -1614,3 +1621,85 @@ async def test_snowflake_version(): assert re.match( version_pattern, await conn.snowflake_version ), f"snowflake_version should match pattern 'x.y.z', but got '{await conn.snowflake_version}'" + + +@pytest.mark.skipolddriver +async def test_logs_size_during_basic_query_stays_unchanged(conn_cnx, caplog): + """Test that the amount of bytes logged during normal select 1 flow is within acceptable range. Related to: SNOW-2268606""" + caplog.set_level(logging.INFO, "snowflake.connector") + caplog.clear() + + # Test-specific constants + EXPECTED_BYTES = 145 + ACCEPTABLE_DELTA = 0.6 + EXPECTED_PATTERNS = [ + "Snowflake Connector for Python Version: ", # followed by version info + "Connecting to GLOBAL Snowflake domain", + ] + + async with conn_cnx() as conn: + async with conn.cursor() as cur: + await (await cur.execute("select 1")).fetchall() + + actual_messages = [record.getMessage() for record in caplog.records] + total_log_bytes = _calculate_log_bytes(actual_messages) + + if total_log_bytes != EXPECTED_BYTES: + logging.warning( + f"There was a change in a size of the logs produced by the basic Snowflake query. " + f"Expected: {EXPECTED_BYTES}, got: {total_log_bytes}. " + f"We may need to update the test_logs_size_during_basic_query_stays_unchanged - i.e. EXACT_EXPECTED_LOGS_BYTES constant." + ) + + # Check if patterns match to decide whether to show all messages + matched_patterns, missing_patterns, unmatched_messages = ( + _find_matching_patterns(actual_messages, EXPECTED_PATTERNS) + ) + patterns_match_perfectly = ( + len(missing_patterns) == 0 and len(unmatched_messages) == 0 + ) + + _log_pattern_analysis( + actual_messages, + EXPECTED_PATTERNS, + matched_patterns, + missing_patterns, + unmatched_messages, + show_all_messages=patterns_match_perfectly, + ) + + _assert_log_bytes_within_tolerance( + total_log_bytes, EXPECTED_BYTES, ACCEPTABLE_DELTA + ) + + +@pytest.mark.skipolddriver +async def test_no_new_warnings_or_errors_on_successful_basic_select(conn_cnx, caplog): + """Test that the number of warning/error log entries stays the same during successful basic select operations. Related to: SNOW-2268606""" + caplog.set_level(logging.WARNING, "snowflake.connector") + baseline_warning_count = 0 + baseline_error_count = 0 + + # Execute basic select operations and check counts remain the same + caplog.clear() + async with conn_cnx() as conn: + async with conn.cursor() as cur: + # Execute basic select operations + result1 = await (await cur.execute("select 1")).fetchall() + assert result1 == [(1,)] + + # Count warning/error log entries after operations + test_warning_count = len( + [r for r in caplog.records if r.levelno >= logging.WARNING] + ) + test_error_count = len([r for r in caplog.records if r.levelno >= logging.ERROR]) + + # Assert counts stay the same (no new warnings or errors) + assert test_warning_count == baseline_warning_count, ( + f"Warning count increased from {baseline_warning_count} to {test_warning_count}. " + f"New warnings: {[r.getMessage() for r in caplog.records if r.levelno == logging.WARNING]}" + ) + assert test_error_count == baseline_error_count, ( + f"Error count increased from {baseline_error_count} to {test_error_count}. " + f"New errors: {[r.getMessage() for r in caplog.records if r.levelno >= logging.ERROR]}" + ) diff --git a/test/integ/test_connection.py b/test/integ/test_connection.py index c2dd3a3470..154d1a407b 100644 --- a/test/integ/test_connection.py +++ b/test/integ/test_connection.py @@ -60,6 +60,8 @@ except ImportError: pass +logger = logging.getLogger(__name__) + def test_basic(conn_testaccount): """Basic Connection test.""" @@ -417,7 +419,7 @@ def test_invalid_account_timeout(conn_cnx): pass -@pytest.mark.timeout(15) +@pytest.mark.timeout(20) def test_invalid_proxy(conn_cnx): http_proxy = os.environ.get("HTTP_PROXY") https_proxy = os.environ.get("HTTPS_PROXY") @@ -449,7 +451,7 @@ def test_invalid_proxy(conn_cnx): @pytest.mark.skipolddriver -@pytest.mark.timeout(15) +@pytest.mark.timeout(20) def test_invalid_proxy_not_impacting_env_vars(conn_cnx): http_proxy = os.environ.get("HTTP_PROXY") https_proxy = os.environ.get("HTTPS_PROXY") @@ -1398,6 +1400,176 @@ def test_ocsp_mode_insecure_mode_and_disable_ocsp_checks_mismatch_ocsp_disabled( assert "This connection does not perform OCSP checks." in caplog.text +def _message_matches_pattern(message, pattern): + """Check if a log message matches a pattern (exact match or starts with pattern).""" + return message == pattern or message.startswith(pattern) + + +def _find_matching_patterns(messages, patterns): + """Find which patterns match the given messages. + + Returns: + tuple: (matched_patterns, missing_patterns, unmatched_messages) + """ + matched_patterns = set() + unmatched_messages = [] + + for message in messages: + found_match = False + for pattern in patterns: + if _message_matches_pattern(message, pattern): + matched_patterns.add(pattern) + found_match = True + break + if not found_match: + unmatched_messages.append(message) + + missing_patterns = set(patterns) - matched_patterns + return matched_patterns, missing_patterns, unmatched_messages + + +def _calculate_log_bytes(messages): + """Calculate total byte size of log messages.""" + return sum(len(message.encode("utf-8")) for message in messages) + + +def _log_pattern_analysis( + actual_messages, + expected_patterns, + matched_patterns, + missing_patterns, + unmatched_messages, + show_all_messages=False, +): + """Log detailed analysis of pattern differences. + + Args: + actual_messages: List of actual log messages + expected_patterns: List of expected log patterns + matched_patterns: Set of patterns that were found + missing_patterns: Set of patterns that were not found + unmatched_messages: List of messages that didn't match any pattern + show_all_messages: If True, log all actual messages for debugging + """ + + if missing_patterns: + logger.warning(f"Missing expected log patterns ({len(missing_patterns)}):") + for pattern in sorted(missing_patterns): + logger.warning(f" - MISSING: '{pattern}'") + + if unmatched_messages: + logger.warning(f"New/unexpected log messages ({len(unmatched_messages)}):") + for message in unmatched_messages: + message_bytes = len(message.encode("utf-8")) + logger.warning(f" + NEW: '{message}' ({message_bytes} bytes)") + + # Log summary + logger.warning("Log analysis summary:") + logger.warning(f" - Expected patterns: {len(expected_patterns)}") + logger.warning(f" - Matched patterns: {len(matched_patterns)}") + logger.warning(f" - Missing patterns: {len(missing_patterns)}") + logger.warning(f" - Actual messages: {len(actual_messages)}") + logger.warning(f" - Unmatched messages: {len(unmatched_messages)}") + + # Show all messages if requested (useful when patterns match but bytes don't) + if show_all_messages: + logger.warning("All actual log messages:") + for i, message in enumerate(actual_messages): + message_bytes = len(message.encode("utf-8")) + logger.warning(f" [{i:2d}] '{message}' ({message_bytes} bytes)") + + +def _assert_log_bytes_within_tolerance(actual_bytes, expected_bytes, tolerance): + """Assert that log bytes are within acceptable tolerance.""" + assert actual_bytes == pytest.approx(expected_bytes, rel=tolerance), ( + f"Log bytes {actual_bytes} is not approximately equal to expected {expected_bytes} " + f"within {tolerance*100}% tolerance. " + f"This may indicate unwanted logs being produced or changes in logging behavior." + ) + + +@pytest.mark.skipolddriver +def test_logs_size_during_basic_query_stays_unchanged(conn_cnx, caplog): + """Test that the amount of bytes logged during normal select 1 flow is within acceptable range. Related to: SNOW-2268606""" + caplog.set_level(logging.INFO, "snowflake.connector") + caplog.clear() + + # Test-specific constants + EXPECTED_BYTES = 145 + ACCEPTABLE_DELTA = 0.6 + EXPECTED_PATTERNS = [ + "Snowflake Connector for Python Version: ", # followed by version info + "Connecting to GLOBAL Snowflake domain", + ] + + with conn_cnx() as conn: + with conn.cursor() as cur: + cur.execute("select 1").fetchall() + + actual_messages = [record.getMessage() for record in caplog.records] + total_log_bytes = _calculate_log_bytes(actual_messages) + + if total_log_bytes != EXPECTED_BYTES: + logger.warning( + f"There was a change in a size of the logs produced by the basic Snowflake query. " + f"Expected: {EXPECTED_BYTES}, got: {total_log_bytes}. " + f"We may need to update the test_logs_size_during_basic_query_stays_unchanged - i.e. EXACT_EXPECTED_LOGS_BYTES constant." + ) + + # Check if patterns match to decide whether to show all messages + matched_patterns, missing_patterns, unmatched_messages = ( + _find_matching_patterns(actual_messages, EXPECTED_PATTERNS) + ) + patterns_match_perfectly = ( + len(missing_patterns) == 0 and len(unmatched_messages) == 0 + ) + + _log_pattern_analysis( + actual_messages, + EXPECTED_PATTERNS, + matched_patterns, + missing_patterns, + unmatched_messages, + show_all_messages=patterns_match_perfectly, + ) + + _assert_log_bytes_within_tolerance( + total_log_bytes, EXPECTED_BYTES, ACCEPTABLE_DELTA + ) + + +@pytest.mark.skipolddriver +def test_no_new_warnings_or_errors_on_successful_basic_select(conn_cnx, caplog): + """Test that the number of warning/error log entries stays the same during successful basic select operations. Related to: SNOW-2268606""" + caplog.set_level(logging.WARNING, "snowflake.connector") + baseline_warning_count = 0 + baseline_error_count = 0 + + # Execute basic select operations and check counts remain the same + caplog.clear() + with conn_cnx() as conn: + with conn.cursor() as cur: + # Execute basic select operations + result1 = cur.execute("select 1").fetchall() + assert result1 == [(1,)] + + # Count warning/error log entries after operations + test_warning_count = len( + [r for r in caplog.records if r.levelno >= logging.WARNING] + ) + test_error_count = len([r for r in caplog.records if r.levelno >= logging.ERROR]) + + # Assert counts stay the same (no new warnings or errors) + assert test_warning_count == baseline_warning_count, ( + f"Warning count increased from {baseline_warning_count} to {test_warning_count}. " + f"New warnings: {[r.getMessage() for r in caplog.records if r.levelno == logging.WARNING]}" + ) + assert test_error_count == baseline_error_count, ( + f"Error count increased from {baseline_error_count} to {test_error_count}. " + f"New errors: {[r.getMessage() for r in caplog.records if r.levelno >= logging.ERROR]}" + ) + + @pytest.mark.skipolddriver def test_ocsp_mode_insecure_mode_and_disable_ocsp_checks_mismatch_ocsp_enabled( conn_cnx, is_public_test, is_local_dev_setup, caplog diff --git a/test/unit/aio/test_auth_workload_identity_async.py b/test/unit/aio/test_auth_workload_identity_async.py index bb563d6591..013f4af6f8 100644 --- a/test/unit/aio/test_auth_workload_identity_async.py +++ b/test/unit/aio/test_auth_workload_identity_async.py @@ -5,7 +5,6 @@ import asyncio import json import logging -import os from base64 import b64decode from unittest import mock from urllib.parse import parse_qs, urlparse @@ -409,12 +408,10 @@ async def test_explicit_azure_omits_client_id_if_not_set(fake_azure_metadata_ser async def test_explicit_azure_uses_explicit_client_id_if_set( - fake_azure_metadata_service, + fake_azure_metadata_service, monkeypatch ): - with mock.patch.dict( - os.environ, {"MANAGED_IDENTITY_CLIENT_ID": "custom-client-id"} - ): - auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) - await auth_class.prepare(conn=None) + monkeypatch.setenv("MANAGED_IDENTITY_CLIENT_ID", "custom-client-id") + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) + await auth_class.prepare(conn=None) assert fake_azure_metadata_service.requested_client_id == "custom-client-id" diff --git a/test/unit/aio/test_connection_async_unit.py b/test/unit/aio/test_connection_async_unit.py index 590a85711b..f75f905a7b 100644 --- a/test/unit/aio/test_connection_async_unit.py +++ b/test/unit/aio/test_connection_async_unit.py @@ -7,7 +7,6 @@ import json import logging -import os import stat import sys from contextlib import asynccontextmanager @@ -193,12 +192,12 @@ def test_is_still_running(): ) -async def test_partner_env_var(mock_post_requests): +async def test_partner_env_var(mock_post_requests, monkeypatch): PARTNER_NAME = "Amanda" - with patch.dict(os.environ, {ENV_VAR_PARTNER: PARTNER_NAME}): - async with fake_db_conn() as conn: - assert conn.application == PARTNER_NAME + monkeypatch.setenv(ENV_VAR_PARTNER, PARTNER_NAME) + async with fake_db_conn() as conn: + assert conn.application == PARTNER_NAME assert ( mock_post_requests["data"]["CLIENT_ENVIRONMENT"]["APPLICATION"] == PARTNER_NAME diff --git a/test/unit/aio/test_cursor_async_unit.py b/test/unit/aio/test_cursor_async_unit.py index 39894c3bad..019e1b4cc1 100644 --- a/test/unit/aio/test_cursor_async_unit.py +++ b/test/unit/aio/test_cursor_async_unit.py @@ -29,6 +29,8 @@ def __init__(self): self._log_max_query_length = 0 self._reuse_results = None self._reraise_error_in_file_transfer_work_function = False + self._enable_stage_s3_privatelink_for_us_east_1 = False + self._unsafe_file_write = False @pytest.mark.parametrize( @@ -125,6 +127,8 @@ async def test_download(self, MockFileTransferAgent): # - download_as_stream of connection._stream_downloader fake_conn._file_operation_parser.parse_file_operation.assert_called_once() fake_conn._stream_downloader.download_as_stream.assert_not_called() + MockFileTransferAgent.assert_called_once() + assert MockFileTransferAgent.call_args.kwargs.get("use_s3_regional_url", False) mock_file_transfer_agent_instance.execute.assert_called_once() @patch("snowflake.connector.aio._file_transfer_agent.SnowflakeFileTransferAgent") @@ -143,6 +147,8 @@ async def test_upload(self, MockFileTransferAgent): # - download_as_stream of connection._stream_downloader fake_conn._file_operation_parser.parse_file_operation.assert_called_once() fake_conn._stream_downloader.download_as_stream.assert_not_called() + MockFileTransferAgent.assert_called_once() + assert MockFileTransferAgent.call_args.kwargs.get("use_s3_regional_url", False) mock_file_transfer_agent_instance.execute.assert_called_once() @patch("snowflake.connector.aio._file_transfer_agent.SnowflakeFileTransferAgent") @@ -161,6 +167,7 @@ async def test_download_stream(self, MockFileTransferAgent): # - execute in SnowflakeFileTransferAgent fake_conn._file_operation_parser.parse_file_operation.assert_called_once() fake_conn._stream_downloader.download_as_stream.assert_called_once() + MockFileTransferAgent.assert_not_called() mock_file_transfer_agent_instance.execute.assert_not_called() @patch("snowflake.connector.aio._file_transfer_agent.SnowflakeFileTransferAgent") @@ -180,6 +187,8 @@ async def test_upload_stream(self, MockFileTransferAgent): # - download_as_stream of connection._stream_downloader fake_conn._file_operation_parser.parse_file_operation.assert_called_once() fake_conn._stream_downloader.download_as_stream.assert_not_called() + MockFileTransferAgent.assert_called_once() + assert MockFileTransferAgent.call_args.kwargs.get("use_s3_regional_url", False) mock_file_transfer_agent_instance.execute.assert_called_once() def _setup_mocks(self, MockFileTransferAgent): @@ -191,6 +200,9 @@ def _setup_mocks(self, MockFileTransferAgent): fake_conn._file_operation_parser.parse_file_operation = AsyncMock() fake_conn._stream_downloader = MagicMock() fake_conn._stream_downloader.download_as_stream = AsyncMock() + # this should be true on all new AWS deployments to use regional endpoints for staging operations + fake_conn._enable_stage_s3_privatelink_for_us_east_1 = True + fake_conn._unsafe_file_write = False cursor = SnowflakeCursor(fake_conn) cursor.reset = MagicMock() diff --git a/test/unit/aio/test_ocsp.py b/test/unit/aio/test_ocsp.py index f1adb75134..234d978fa4 100644 --- a/test/unit/aio/test_ocsp.py +++ b/test/unit/aio/test_ocsp.py @@ -207,7 +207,7 @@ async def test_ocsp_wo_cache_file(session_manager): OCSPCache.reset_cache_dir() -async def test_ocsp_fail_open_w_single_endpoint(session_manager): +async def test_ocsp_fail_open_w_single_endpoint(session_manager, monkeypatch): SnowflakeOCSP.clear_cache() try: @@ -216,33 +216,28 @@ async def test_ocsp_fail_open_w_single_endpoint(session_manager): # File doesn't exist, which is fine for this test pass - environ["SF_OCSP_TEST_MODE"] = "true" - environ["SF_TEST_OCSP_URL"] = "http://httpbin.org/delay/10" - environ["SF_TEST_CA_OCSP_RESPONDER_CONNECTION_TIMEOUT"] = "5" + monkeypatch.setenv("SF_OCSP_TEST_MODE", "true") + monkeypatch.setenv("SF_TEST_OCSP_URL", "http://httpbin.org/delay/10") + monkeypatch.setenv("SF_TEST_CA_OCSP_RESPONDER_CONNECTION_TIMEOUT", "5") ocsp = SFOCSP(use_ocsp_cache_server=False) - try: - async with _asyncio_connect("snowflake.okta.com") as connection: - assert await ocsp.validate( - "snowflake.okta.com", connection, session_manager=session_manager - ), "Failed to validate: {}".format("snowflake.okta.com") - finally: - del environ["SF_OCSP_TEST_MODE"] - del environ["SF_TEST_OCSP_URL"] - del environ["SF_TEST_CA_OCSP_RESPONDER_CONNECTION_TIMEOUT"] + async with _asyncio_connect("snowflake.okta.com") as connection: + assert await ocsp.validate( + "snowflake.okta.com", connection, session_manager=session_manager + ), "Failed to validate: {}".format("snowflake.okta.com") @pytest.mark.skipif( ER_OCSP_RESPONSE_CERT_STATUS_REVOKED is None, reason="No ER_OCSP_RESPONSE_CERT_STATUS_REVOKED is available.", ) -async def test_ocsp_fail_close_w_single_endpoint(session_manager): +async def test_ocsp_fail_close_w_single_endpoint(session_manager, monkeypatch): SnowflakeOCSP.clear_cache() - environ["SF_OCSP_TEST_MODE"] = "true" - environ["SF_TEST_OCSP_URL"] = "http://httpbin.org/delay/10" - environ["SF_TEST_CA_OCSP_RESPONDER_CONNECTION_TIMEOUT"] = "5" + monkeypatch.setenv("SF_OCSP_TEST_MODE", "true") + monkeypatch.setenv("SF_TEST_OCSP_URL", "http://httpbin.org/delay/10") + monkeypatch.setenv("SF_TEST_CA_OCSP_RESPONDER_CONNECTION_TIMEOUT", "5") OCSPCache.del_cache_file() @@ -254,21 +249,16 @@ async def test_ocsp_fail_close_w_single_endpoint(session_manager): "snowflake.okta.com", connection, session_manager=session_manager ) - try: - assert ( - ex.value.errno == ER_OCSP_RESPONSE_FETCH_FAILURE - ), "Connection should have failed" - finally: - del environ["SF_OCSP_TEST_MODE"] - del environ["SF_TEST_OCSP_URL"] - del environ["SF_TEST_CA_OCSP_RESPONDER_CONNECTION_TIMEOUT"] + assert ( + ex.value.errno == ER_OCSP_RESPONSE_FETCH_FAILURE + ), "Connection should have failed" -async def test_ocsp_bad_validity(session_manager): +async def test_ocsp_bad_validity(session_manager, monkeypatch): SnowflakeOCSP.clear_cache() - environ["SF_OCSP_TEST_MODE"] = "true" - environ["SF_TEST_OCSP_FORCE_BAD_RESPONSE_VALIDITY"] = "true" + monkeypatch.setenv("SF_OCSP_TEST_MODE", "true") + monkeypatch.setenv("SF_TEST_OCSP_FORCE_BAD_RESPONSE_VALIDITY", "true") try: OCSPCache.del_cache_file() @@ -282,12 +272,10 @@ async def test_ocsp_bad_validity(session_manager): assert await ocsp.validate( "snowflake.okta.com", connection, session_manager=session_manager ), "Connection should have passed with fail open" - del environ["SF_OCSP_TEST_MODE"] - del environ["SF_TEST_OCSP_FORCE_BAD_RESPONSE_VALIDITY"] -async def test_ocsp_single_endpoint(session_manager): - environ["SF_OCSP_ACTIVATE_NEW_ENDPOINT"] = "True" +async def test_ocsp_single_endpoint(session_manager, monkeypatch): + monkeypatch.setenv("SF_OCSP_ACTIVATE_NEW_ENDPOINT", "True") SnowflakeOCSP.clear_cache() ocsp = SFOCSP() ocsp.OCSP_CACHE_SERVER.NEW_DEFAULT_CACHE_SERVER_BASE_URL = "https://snowflake.preprod3.us-west-2-dev.external-zone.snowflakecomputing.com:8085/ocsp/" @@ -296,8 +284,6 @@ async def test_ocsp_single_endpoint(session_manager): "snowflake.okta.com", connection, session_manager=session_manager ), "Failed to validate: {}".format("snowflake.okta.com") - del environ["SF_OCSP_ACTIVATE_NEW_ENDPOINT"] - async def test_ocsp_by_post_method(session_manager): """OCSP tests.""" @@ -327,7 +313,7 @@ async def test_ocsp_with_file_cache(tmpdir, session_manager): async def test_ocsp_with_bogus_cache_files( - tmpdir, random_ocsp_response_validation_cache, session_manager + tmpdir, random_ocsp_response_validation_cache, session_manager, monkeypatch ): with mock.patch( "snowflake.connector.ocsp_snowflake.OCSP_RESPONSE_VALIDATION_CACHE", @@ -337,7 +323,7 @@ async def test_ocsp_with_bogus_cache_files( """Attempts to use bogus OCSP response data.""" cache_file_name, target_hosts = await _store_cache_in_file( - tmpdir, session_manager + tmpdir, session_manager, monkeypatch=monkeypatch ) ocsp = SFOCSP() @@ -369,7 +355,7 @@ async def test_ocsp_with_bogus_cache_files( async def test_ocsp_with_outdated_cache( - tmpdir, random_ocsp_response_validation_cache, session_manager + tmpdir, random_ocsp_response_validation_cache, session_manager, monkeypatch ): with mock.patch( "snowflake.connector.ocsp_snowflake.OCSP_RESPONSE_VALIDATION_CACHE", @@ -379,7 +365,7 @@ async def test_ocsp_with_outdated_cache( """Attempts to use outdated OCSP response cache file.""" cache_file_name, target_hosts = await _store_cache_in_file( - tmpdir, session_manager + tmpdir, session_manager, monkeypatch=monkeypatch ) ocsp = SFOCSP() @@ -410,10 +396,10 @@ async def test_ocsp_with_outdated_cache( ), "must be empty. outdated cache should not be loaded" -async def _store_cache_in_file(tmpdir, session_manager, target_hosts=None): +async def _store_cache_in_file(tmpdir, session_manager, monkeypatch, target_hosts=None): if target_hosts is None: target_hosts = TARGET_HOSTS - os.environ["SF_OCSP_RESPONSE_CACHE_DIR"] = str(tmpdir) + monkeypatch.setenv("SF_OCSP_RESPONSE_CACHE_DIR", str(tmpdir)) OCSPCache.reset_cache_dir() filename = path.join(str(tmpdir), "ocsp_response_cache.json") diff --git a/test/unit/aio/test_session_manager_async.py b/test/unit/aio/test_session_manager_async.py index bcb428fb71..9f54a20506 100644 --- a/test/unit/aio/test_session_manager_async.py +++ b/test/unit/aio/test_session_manager_async.py @@ -3,11 +3,13 @@ from unittest import mock +import aiohttp import pytest from snowflake.connector.aio._session_manager import ( AioHttpConfig, SessionManager, + SnowflakeSSLConnector, SnowflakeSSLConnectorFactory, ) from snowflake.connector.constants import OCSPMode @@ -348,3 +350,87 @@ async def test_pickle_session_manager(): await manager.close() await unpickled.close() + + +@pytest.fixture +def mock_connector_with_factory(): + """Fixture providing a mock connector factory and connector.""" + mock_connector_factory = mock.MagicMock() + mock_connector = mock.MagicMock() + mock_connector_factory.return_value = mock_connector + return mock_connector, mock_connector_factory + + +@pytest.mark.parametrize( + "ocsp_mode,extra_kwargs,expected_kwargs", + [ + # Test with OCSPMode.FAIL_OPEN + extra kwargs (should all appear) + ( + OCSPMode.FAIL_OPEN, + {"timeout": 30, "pool_connections": 10}, + { + "timeout": 30, + "pool_connections": 10, + "snowflake_ocsp_mode": OCSPMode.FAIL_OPEN, + }, + ), + # Test with OCSPMode.FAIL_CLOSED + no extra kwargs + ( + OCSPMode.FAIL_CLOSED, + {}, + {"snowflake_ocsp_mode": OCSPMode.FAIL_CLOSED}, + ), + # Checks that None values also cause kwargs name to occur + ( + None, + {}, + {"snowflake_ocsp_mode": None}, + ), + # Test override by extra kwargs: config has FAIL_OPEN but extra_kwargs override with FAIL_CLOSED + ( + OCSPMode.FAIL_OPEN, + {"snowflake_ocsp_mode": OCSPMode.FAIL_CLOSED}, + {"snowflake_ocsp_mode": OCSPMode.FAIL_CLOSED}, + ), + ], +) +async def test_aio_http_config_get_connector_parametrized( + mock_connector_with_factory, ocsp_mode, extra_kwargs, expected_kwargs +): + """Test that AioHttpConfig.get_connector properly passes kwargs and snowflake_ocsp_mode to connector factory. + + This mirrors the sync test behavior where: + - Config attributes are passed to the factory + - Extra kwargs can override config attributes + - All resulting attributes appear in the factory call + """ + mock_connector, mock_connector_factory = mock_connector_with_factory + + config = AioHttpConfig( + connector_factory=mock_connector_factory, snowflake_ocsp_mode=ocsp_mode + ) + result = config.get_connector(**extra_kwargs) + + # Verify the connector factory was called with correct arguments + mock_connector_factory.assert_called_once_with(**expected_kwargs) + assert result is mock_connector + + +async def test_aio_http_config_get_connector_with_real_connector_factory(): + """Test get_connector with the actual SnowflakeSSLConnectorFactory. + + Verifies that with a real factory, we get a real SnowflakeSSLConnector instance + with the snowflake_ocsp_mode properly set. + """ + config = AioHttpConfig( + connector_factory=SnowflakeSSLConnectorFactory(), + snowflake_ocsp_mode=OCSPMode.FAIL_CLOSED, + ) + + connector = config.get_connector(session_manager=SessionManager()) + + # Verify we get a real SnowflakeSSLConnector instance + assert isinstance(connector, aiohttp.BaseConnector) + assert isinstance(connector, SnowflakeSSLConnector) + # Verify snowflake_ocsp_mode was set correctly + assert connector._snowflake_ocsp_mode == OCSPMode.FAIL_CLOSED diff --git a/test/unit/test_auth_workload_identity.py b/test/unit/test_auth_workload_identity.py index 1880d1b7d1..bdaacd6962 100644 --- a/test/unit/test_auth_workload_identity.py +++ b/test/unit/test_auth_workload_identity.py @@ -1,6 +1,5 @@ import json import logging -import os from base64 import b64decode from unittest import mock from urllib.parse import parse_qs, urlparse @@ -414,11 +413,11 @@ def test_explicit_azure_omits_client_id_if_not_set(fake_azure_metadata_service): assert fake_azure_metadata_service.requested_client_id is None -def test_explicit_azure_uses_explicit_client_id_if_set(fake_azure_metadata_service): - with mock.patch.dict( - os.environ, {"MANAGED_IDENTITY_CLIENT_ID": "custom-client-id"} - ): - auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) - auth_class.prepare(conn=None) +def test_explicit_azure_uses_explicit_client_id_if_set( + fake_azure_metadata_service, monkeypatch +): + monkeypatch.setenv("MANAGED_IDENTITY_CLIENT_ID", "custom-client-id") + auth_class = AuthByWorkloadIdentity(provider=AttestationProvider.AZURE) + auth_class.prepare(conn=None) assert fake_azure_metadata_service.requested_client_id == "custom-client-id" diff --git a/test/unit/test_connection.py b/test/unit/test_connection.py index 3ef2fd6e36..76e9588e8d 100644 --- a/test/unit/test_connection.py +++ b/test/unit/test_connection.py @@ -3,7 +3,6 @@ import json import logging -import os import stat import sys from pathlib import Path @@ -201,11 +200,11 @@ def test_is_still_running(): @pytest.mark.skipolddriver -def test_partner_env_var(mock_post_requests): +def test_partner_env_var(mock_post_requests, monkeypatch): PARTNER_NAME = "Amanda" - with patch.dict(os.environ, {ENV_VAR_PARTNER: PARTNER_NAME}): - assert fake_connector().application == PARTNER_NAME + monkeypatch.setenv(ENV_VAR_PARTNER, PARTNER_NAME) + assert fake_connector().application == PARTNER_NAME assert ( mock_post_requests["data"]["CLIENT_ENVIRONMENT"]["APPLICATION"] == PARTNER_NAME diff --git a/test/unit/test_cursor.py b/test/unit/test_cursor.py index 6970e6acfb..c936a3928e 100644 --- a/test/unit/test_cursor.py +++ b/test/unit/test_cursor.py @@ -25,6 +25,9 @@ def __init__(self): self._log_max_query_length = 0 self._reuse_results = None self._reraise_error_in_file_transfer_work_function = False + self._enable_stage_s3_privatelink_for_us_east_1 = False + self._iobound_tpe_limit = None + self._unsafe_file_write = False @pytest.mark.parametrize( @@ -121,6 +124,8 @@ def test_download(self, MockFileTransferAgent): # - download_as_stream of connection._stream_downloader fake_conn._file_operation_parser.parse_file_operation.assert_called_once() fake_conn._stream_downloader.download_as_stream.assert_not_called() + MockFileTransferAgent.assert_called_once() + assert MockFileTransferAgent.call_args.kwargs.get("use_s3_regional_url", False) mock_file_transfer_agent_instance.execute.assert_called_once() @patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent") @@ -139,6 +144,8 @@ def test_upload(self, MockFileTransferAgent): # - download_as_stream of connection._stream_downloader fake_conn._file_operation_parser.parse_file_operation.assert_called_once() fake_conn._stream_downloader.download_as_stream.assert_not_called() + MockFileTransferAgent.assert_called_once() + assert MockFileTransferAgent.call_args.kwargs.get("use_s3_regional_url", False) mock_file_transfer_agent_instance.execute.assert_called_once() @patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent") @@ -157,6 +164,7 @@ def test_download_stream(self, MockFileTransferAgent): # - execute in SnowflakeFileTransferAgent fake_conn._file_operation_parser.parse_file_operation.assert_called_once() fake_conn._stream_downloader.download_as_stream.assert_called_once() + MockFileTransferAgent.assert_not_called() mock_file_transfer_agent_instance.execute.assert_not_called() @patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent") @@ -176,6 +184,8 @@ def test_upload_stream(self, MockFileTransferAgent): # - download_as_stream of connection._stream_downloader fake_conn._file_operation_parser.parse_file_operation.assert_called_once() fake_conn._stream_downloader.download_as_stream.assert_not_called() + MockFileTransferAgent.assert_called_once() + assert MockFileTransferAgent.call_args.kwargs.get("use_s3_regional_url", False) mock_file_transfer_agent_instance.execute.assert_called_once() def _setup_mocks(self, MockFileTransferAgent): @@ -185,6 +195,10 @@ def _setup_mocks(self, MockFileTransferAgent): fake_conn = FakeConnection() fake_conn._file_operation_parser = MagicMock() fake_conn._stream_downloader = MagicMock() + # this should be true on all new AWS deployments to use regional endpoints for staging operations + fake_conn._enable_stage_s3_privatelink_for_us_east_1 = True + fake_conn._iobound_tpe_limit = 1 + fake_conn._unsafe_file_write = False cursor = SnowflakeCursor(fake_conn) cursor.reset = MagicMock() diff --git a/test/unit/test_ocsp.py b/test/unit/test_ocsp.py index 0b14285ac6..06286ca617 100644 --- a/test/unit/test_ocsp.py +++ b/test/unit/test_ocsp.py @@ -10,7 +10,7 @@ import platform import time from concurrent.futures.thread import ThreadPoolExecutor -from os import environ, path +from os import path from unittest import mock import asn1crypto.x509 @@ -78,7 +78,7 @@ def overwrite_ocsp_cache(tmpdir): @pytest.fixture(autouse=True) -def worker_specific_cache_dir(tmpdir, request): +def worker_specific_cache_dir(tmpdir, request, monkeypatch): """Create worker-specific cache directory to avoid file lock conflicts in parallel execution. Note: Tests that explicitly manage their own cache directories (like test_ocsp_cache_when_server_is_down) @@ -88,13 +88,12 @@ def worker_specific_cache_dir(tmpdir, request): # Get worker ID for parallel execution (pytest-xdist) worker_id = os.environ.get("PYTEST_XDIST_WORKER", "master") - # Store original cache dir environment variable - original_cache_dir = os.environ.get("SF_OCSP_RESPONSE_CACHE_DIR") + # monkeypatch will automatically handle restoration # Set worker-specific cache directory to prevent main cache file conflicts worker_cache_dir = tmpdir.join(f"ocsp_cache_{worker_id}") worker_cache_dir.ensure(dir=True) - os.environ["SF_OCSP_RESPONSE_CACHE_DIR"] = str(worker_cache_dir) + monkeypatch.setenv("SF_OCSP_RESPONSE_CACHE_DIR", str(worker_cache_dir)) # Only handle the OCSP_RESPONSE_VALIDATION_CACHE to prevent conflicts # Let tests manage SF_OCSP_RESPONSE_CACHE_DIR themselves if they need to @@ -131,11 +130,7 @@ def worker_specific_cache_dir(tmpdir, request): # If modules not available, just yield the directory yield str(tmpdir) finally: - # Restore original cache directory environment variable - if original_cache_dir is not None: - os.environ["SF_OCSP_RESPONSE_CACHE_DIR"] = original_cache_dir - else: - os.environ.pop("SF_OCSP_RESPONSE_CACHE_DIR", None) + # monkeypatch will automatically restore the original environment variable # Reset cache dir back to original state try: @@ -235,7 +230,7 @@ def test_ocsp_wo_cache_server(): assert ocsp.validate(url, connection), f"Failed to validate: {url}" -def test_ocsp_wo_cache_file(): +def test_ocsp_wo_cache_file(monkeypatch): """OCSP tests without File cache. Notes: @@ -248,7 +243,7 @@ def test_ocsp_wo_cache_file(): except FileNotFoundError: # File doesn't exist, which is fine for this test pass - environ["SF_OCSP_RESPONSE_CACHE_DIR"] = "/etc" + monkeypatch.setenv("SF_OCSP_RESPONSE_CACHE_DIR", "/etc") OCSPCache.reset_cache_dir() try: @@ -257,11 +252,10 @@ def test_ocsp_wo_cache_file(): connection = _openssl_connect(url) assert ocsp.validate(url, connection), f"Failed to validate: {url}" finally: - del environ["SF_OCSP_RESPONSE_CACHE_DIR"] OCSPCache.reset_cache_dir() -def test_ocsp_fail_open_w_single_endpoint(): +def test_ocsp_fail_open_w_single_endpoint(monkeypatch): SnowflakeOCSP.clear_cache() try: @@ -270,33 +264,28 @@ def test_ocsp_fail_open_w_single_endpoint(): # File doesn't exist, which is fine for this test pass - environ["SF_OCSP_TEST_MODE"] = "true" - environ["SF_TEST_OCSP_URL"] = "http://httpbin.org/delay/10" - environ["SF_TEST_CA_OCSP_RESPONDER_CONNECTION_TIMEOUT"] = "5" + monkeypatch.setenv("SF_OCSP_TEST_MODE", "true") + monkeypatch.setenv("SF_TEST_OCSP_URL", "http://httpbin.org/delay/10") + monkeypatch.setenv("SF_TEST_CA_OCSP_RESPONDER_CONNECTION_TIMEOUT", "5") ocsp = SFOCSP(use_ocsp_cache_server=False) connection = _openssl_connect("snowflake.okta.com") - try: - assert ocsp.validate( - "snowflake.okta.com", connection - ), "Failed to validate: {}".format("snowflake.okta.com") - finally: - del environ["SF_OCSP_TEST_MODE"] - del environ["SF_TEST_OCSP_URL"] - del environ["SF_TEST_CA_OCSP_RESPONDER_CONNECTION_TIMEOUT"] + assert ocsp.validate( + "snowflake.okta.com", connection + ), "Failed to validate: {}".format("snowflake.okta.com") @pytest.mark.skipif( ER_OCSP_RESPONSE_CERT_STATUS_REVOKED is None, reason="No ER_OCSP_RESPONSE_CERT_STATUS_REVOKED is available.", ) -def test_ocsp_fail_close_w_single_endpoint(): +def test_ocsp_fail_close_w_single_endpoint(monkeypatch): SnowflakeOCSP.clear_cache() - environ["SF_OCSP_TEST_MODE"] = "true" - environ["SF_TEST_OCSP_URL"] = "http://httpbin.org/delay/10" - environ["SF_TEST_CA_OCSP_RESPONDER_CONNECTION_TIMEOUT"] = "5" + monkeypatch.setenv("SF_OCSP_TEST_MODE", "true") + monkeypatch.setenv("SF_TEST_OCSP_URL", "http://httpbin.org/delay/10") + monkeypatch.setenv("SF_TEST_CA_OCSP_RESPONDER_CONNECTION_TIMEOUT", "5") OCSPCache.del_cache_file() @@ -306,21 +295,16 @@ def test_ocsp_fail_close_w_single_endpoint(): with pytest.raises(RevocationCheckError) as ex: ocsp.validate("snowflake.okta.com", connection) - try: - assert ( - ex.value.errno == ER_OCSP_RESPONSE_FETCH_FAILURE - ), "Connection should have failed" - finally: - del environ["SF_OCSP_TEST_MODE"] - del environ["SF_TEST_OCSP_URL"] - del environ["SF_TEST_CA_OCSP_RESPONDER_CONNECTION_TIMEOUT"] + assert ( + ex.value.errno == ER_OCSP_RESPONSE_FETCH_FAILURE + ), "Connection should have failed" -def test_ocsp_bad_validity(): +def test_ocsp_bad_validity(monkeypatch): SnowflakeOCSP.clear_cache() - environ["SF_OCSP_TEST_MODE"] = "true" - environ["SF_TEST_OCSP_FORCE_BAD_RESPONSE_VALIDITY"] = "true" + monkeypatch.setenv("SF_OCSP_TEST_MODE", "true") + monkeypatch.setenv("SF_TEST_OCSP_FORCE_BAD_RESPONSE_VALIDITY", "true") try: OCSPCache.del_cache_file() @@ -334,12 +318,10 @@ def test_ocsp_bad_validity(): assert ocsp.validate( "snowflake.okta.com", connection ), "Connection should have passed with fail open" - del environ["SF_OCSP_TEST_MODE"] - del environ["SF_TEST_OCSP_FORCE_BAD_RESPONSE_VALIDITY"] -def test_ocsp_single_endpoint(): - environ["SF_OCSP_ACTIVATE_NEW_ENDPOINT"] = "True" +def test_ocsp_single_endpoint(monkeypatch): + monkeypatch.setenv("SF_OCSP_ACTIVATE_NEW_ENDPOINT", "True") SnowflakeOCSP.clear_cache() ocsp = SFOCSP() ocsp.OCSP_CACHE_SERVER.NEW_DEFAULT_CACHE_SERVER_BASE_URL = "https://snowflake.preprod3.us-west-2-dev.external-zone.snowflakecomputing.com:8085/ocsp/" @@ -348,8 +330,6 @@ def test_ocsp_single_endpoint(): "snowflake.okta.com", connection ), "Failed to validate: {}".format("snowflake.okta.com") - del environ["SF_OCSP_ACTIVATE_NEW_ENDPOINT"] - def test_ocsp_by_post_method(): """OCSP tests.""" @@ -375,7 +355,9 @@ def test_ocsp_with_file_cache(tmpdir): @pytest.mark.skipolddriver -def test_ocsp_with_bogus_cache_files(tmpdir, random_ocsp_response_validation_cache): +def test_ocsp_with_bogus_cache_files( + tmpdir, random_ocsp_response_validation_cache, monkeypatch +): with mock.patch( "snowflake.connector.ocsp_snowflake.OCSP_RESPONSE_VALIDATION_CACHE", random_ocsp_response_validation_cache, @@ -383,7 +365,7 @@ def test_ocsp_with_bogus_cache_files(tmpdir, random_ocsp_response_validation_cac from snowflake.connector.ocsp_snowflake import OCSPResponseValidationResult """Attempts to use bogus OCSP response data.""" - cache_file_name, target_hosts = _store_cache_in_file(tmpdir) + cache_file_name, target_hosts = _store_cache_in_file(monkeypatch, tmpdir) ocsp = SFOCSP() OCSPCache.read_ocsp_response_cache_file(ocsp, cache_file_name) @@ -414,7 +396,9 @@ def test_ocsp_with_bogus_cache_files(tmpdir, random_ocsp_response_validation_cac @pytest.mark.skipolddriver -def test_ocsp_with_outdated_cache(tmpdir, random_ocsp_response_validation_cache): +def test_ocsp_with_outdated_cache( + tmpdir, random_ocsp_response_validation_cache, monkeypatch +): with mock.patch( "snowflake.connector.ocsp_snowflake.OCSP_RESPONSE_VALIDATION_CACHE", random_ocsp_response_validation_cache, @@ -422,7 +406,7 @@ def test_ocsp_with_outdated_cache(tmpdir, random_ocsp_response_validation_cache) from snowflake.connector.ocsp_snowflake import OCSPResponseValidationResult """Attempts to use outdated OCSP response cache file.""" - cache_file_name, target_hosts = _store_cache_in_file(tmpdir) + cache_file_name, target_hosts = _store_cache_in_file(monkeypatch, tmpdir) ocsp = SFOCSP() @@ -452,10 +436,8 @@ def test_ocsp_with_outdated_cache(tmpdir, random_ocsp_response_validation_cache) ), "must be empty. outdated cache should not be loaded" -def _store_cache_in_file(tmpdir, target_hosts=None): - if target_hosts is None: - target_hosts = TARGET_HOSTS - os.environ["SF_OCSP_RESPONSE_CACHE_DIR"] = str(tmpdir) +def _store_cache_in_file(monkeypatch, tmpdir): + monkeypatch.setenv("SF_OCSP_RESPONSE_CACHE_DIR", str(tmpdir)) OCSPCache.reset_cache_dir() filename = path.join(str(tmpdir), "ocsp_response_cache.json") @@ -464,13 +446,13 @@ def _store_cache_in_file(tmpdir, target_hosts=None): ocsp = SFOCSP( ocsp_response_cache_uri="file://" + filename, use_ocsp_cache_server=False ) - for hostname in target_hosts: + for hostname in TARGET_HOSTS: connection = _openssl_connect(hostname) assert ocsp.validate(hostname, connection), "Failed to validate: {}".format( hostname ) assert path.exists(filename), "OCSP response cache file" - return filename, target_hosts + return filename, TARGET_HOSTS def test_ocsp_with_invalid_cache_file(): @@ -658,11 +640,11 @@ def test_building_retry_url(): assert OCSP_SERVER.OCSP_RETRY_URL is None -def test_building_new_retry(): +def test_building_new_retry(monkeypatch): OCSP_SERVER = OCSPServer() OCSP_SERVER.OCSP_RETRY_URL = None hname = "a1.us-east-1.snowflakecomputing.com" - os.environ["SF_OCSP_ACTIVATE_NEW_ENDPOINT"] = "true" + monkeypatch.setenv("SF_OCSP_ACTIVATE_NEW_ENDPOINT", "true") OCSP_SERVER.reset_ocsp_endpoint(hname) assert ( OCSP_SERVER.CACHE_SERVER_URL @@ -698,8 +680,6 @@ def test_building_new_retry(): == "https://ocspssd.snowflakecomputing.com/ocsp/retry" ) - del os.environ["SF_OCSP_ACTIVATE_NEW_ENDPOINT"] - @pytest.mark.parametrize( "hash_algorithm", diff --git a/test/unit/test_proxies.py b/test/unit/test_proxies.py index b32e1dcb09..f7ec07d562 100644 --- a/test/unit/test_proxies.py +++ b/test/unit/test_proxies.py @@ -2,7 +2,6 @@ from __future__ import annotations import logging -import os import unittest.mock import pytest @@ -28,10 +27,10 @@ def test_get_proxy_url(): @pytest.mark.skipolddriver -def test_socks_5_proxy_missing_proxy_header_attribute(caplog): +def test_socks_5_proxy_missing_proxy_header_attribute(caplog, monkeypatch): from snowflake.connector.vendored.urllib3.poolmanager import ProxyManager - os.environ["HTTPS_PROXY"] = "socks5://localhost:8080" + monkeypatch.setenv("HTTPS_PROXY", "socks5://localhost:8080") class MockSOCKSProxyManager: def __init__(self): @@ -81,8 +80,6 @@ def mock_proxy_manager_for_url_wiht_header(*args, **kwargs): ) assert "Unable to set 'Host' to proxy manager of type" not in caplog.text - del os.environ["HTTPS_PROXY"] - @pytest.mark.skipolddriver @pytest.mark.parametrize("proxy_method", ["explicit_args", "env_vars"]) diff --git a/test/unit/test_session_manager.py b/test/unit/test_session_manager.py index 83ae89c8ad..915051f6ce 100644 --- a/test/unit/test_session_manager.py +++ b/test/unit/test_session_manager.py @@ -3,7 +3,15 @@ from unittest import mock -from snowflake.connector.session_manager import ProxySupportAdapter, SessionManager +import pytest + +from snowflake.connector.session_manager import ( + HttpConfig, + ProxySupportAdapter, + ProxySupportAdapterFactory, + SessionManager, +) +from snowflake.connector.vendored.urllib3 import Retry # Module and class path constants for easier refactoring SESSION_MANAGER_MODULE = "snowflake.connector.session_manager" @@ -234,3 +242,86 @@ def test_context_var_weakref_does_not_leak(): reset_current_session_manager(token) assert get_current_session_manager(create_default_if_missing=False) is None + + +@pytest.fixture +def mock_adapter_with_factory(): + """Fixture providing a mock adapter factory and adapter.""" + mock_adapter_factory = mock.MagicMock() + mock_adapter = mock.MagicMock() + mock_adapter_factory.return_value = mock_adapter + return mock_adapter, mock_adapter_factory + + +@pytest.mark.parametrize( + "max_retries,extra_kwargs,expected_kwargs", + [ + # Test with integer max_retries + ( + 5, + {"timeout": 30, "pool_connections": 10}, + {"timeout": 30, "pool_connections": 10, "max_retries": 5}, + ), + # Test with None max_retries + (None, {}, {"max_retries": None}), + # Test with no extra kwargs + (7, {}, {"max_retries": 7}), + # Test override by extra kwargs + (0.2, {"max_retries": 0.7}, {"max_retries": 0.7}), + ], +) +def test_http_config_get_adapter_parametrized( + mock_adapter_with_factory, max_retries, extra_kwargs, expected_kwargs +): + """Test that HttpConfig.get_adapter properly passes kwargs and max_retries to adapter factory.""" + mock_adapter, mock_adapter_factory = mock_adapter_with_factory + + config = HttpConfig(adapter_factory=mock_adapter_factory, max_retries=max_retries) + result = config.get_adapter(**extra_kwargs) + + # Verify the adapter factory was called with correct arguments + mock_adapter_factory.assert_called_once_with(**expected_kwargs) + assert result is mock_adapter + + +def test_http_config_get_adapter_with_retry_object(mock_adapter_with_factory): + """Test get_adapter with Retry object as max_retries.""" + mock_adapter, mock_adapter_factory = mock_adapter_with_factory + + retry_config = Retry(total=3, backoff_factor=0.3) + config = HttpConfig(adapter_factory=mock_adapter_factory, max_retries=retry_config) + + result = config.get_adapter(pool_maxsize=20) + + # Verify the call was made with the Retry object + mock_adapter_factory.assert_called_once() + call_args = mock_adapter_factory.call_args + assert call_args.kwargs["pool_maxsize"] == 20 + assert call_args.kwargs["max_retries"] is retry_config # Same object reference + assert result is mock_adapter + + +def test_http_config_get_adapter_kwargs_override(mock_adapter_with_factory): + """Test that get_adapter config's max_retries takes precedence over kwargs max_retries.""" + mock_adapter, mock_adapter_factory = mock_adapter_with_factory + + config = HttpConfig(adapter_factory=mock_adapter_factory, max_retries=5) + + # The config's max_retries should override any passed in kwargs + result = config.get_adapter(max_retries=10, timeout=30) + + # Verify that config's max_retries (5) takes precedence over kwargs max_retries (10) + mock_adapter_factory.assert_called_once_with(max_retries=10, timeout=30) + assert result is mock_adapter + + +def test_http_config_get_adapter_with_real_factory(): + """Test get_adapter with the actual ProxySupportAdapterFactory.""" + config = HttpConfig(adapter_factory=ProxySupportAdapterFactory(), max_retries=3) + + adapter = config.get_adapter() + + # Verify we get a real ProxySupportAdapter instance + assert isinstance(adapter, ProxySupportAdapter) + # Verify max_retries was set correctly + assert adapter.max_retries.total == 3