Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
- v3.16(TBD)
- Bumped numpy dependency from <2.1.0 to <=2.2.4
- Added Windows support for Python 3.13.
- Add `ocsp_root_certs_dict_lock_timeout` connection parameter to set the timeout (in seconds) for acquiring the lock on the OCSP root certs dictionary. Default value for this parameter is -1 which indicates no timeout.

- v3.15.1(May 20, 2025)
- Added basic arrow support for Interval types.
Expand Down
5 changes: 5 additions & 0 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ def _get_private_bytes_from_file(
True,
bool,
), # SNOW-XXXXX: remove the check_arrow_conversion_error_on_every_column flag
"ocsp_root_certs_dict_lock_timeout": (
-1,
int,
),
}

APPLICATION_RE = re.compile(r"[\w\d_]+")
Expand Down Expand Up @@ -445,6 +449,7 @@ class SnowflakeConnection:
token_file_path: The file path of the token file. If both token and token_file_path are provided, the token in token_file_path will be used.
unsafe_file_write: When true, files downloaded by GET will be saved with 644 permissions. Otherwise, files will be saved with safe - owner-only permissions: 600.
check_arrow_conversion_error_on_every_column: When true, the error check after the conversion from arrow to python types will happen for every column in the row. This is a new behaviour which fixes the bug that caused the type errors to trigger silently when occurring at any place other than last column in a row. To revert the previous (faulty) behaviour, please set this flag to false.
ocsp_root_certs_dict_lock_timeout: Timeout for the OCSP root certs dict lock in seconds. Default value is -1, which means no timeout.
"""

OCSP_ENV_LOCK = Lock()
Expand Down
6 changes: 6 additions & 0 deletions src/snowflake/connector/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,12 @@ def __init__(
ssl_wrap_socket.FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME = (
self._connection._ocsp_response_cache_filename if self._connection else None
)
# OCSP root timeout
ssl_wrap_socket.FEATURE_ROOT_CERTS_DICT_LOCK_TIMEOUT = (
self._connection._ocsp_root_certs_dict_lock_timeout
if self._connection
else -1
)

# This is to address the issue where requests hangs
_ = "dummy".encode("idna").decode("utf-8")
Expand Down
9 changes: 8 additions & 1 deletion src/snowflake/connector/ocsp_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,7 @@ def __init__(
use_ocsp_cache_server=None,
use_post_method: bool = True,
use_fail_open: bool = True,
root_certs_dict_lock_timeout: int = -1,
**kwargs,
) -> None:
self.test_mode = os.getenv("SF_OCSP_TEST_MODE", None)
Expand All @@ -1040,6 +1041,7 @@ def __init__(
logger.debug("WARNING - DRIVER CONFIGURED IN TEST MODE")

self._use_post_method = use_post_method
self._root_certs_dict_lock_timeout = root_certs_dict_lock_timeout
self.OCSP_CACHE_SERVER = OCSPServer(
top_level_domain=extract_top_level_domain_from_hostname(
kwargs.pop("hostname", None)
Expand Down Expand Up @@ -1410,7 +1412,10 @@ def _check_ocsp_response_cache_server(

def _lazy_read_ca_bundle(self) -> None:
"""Reads the local cabundle file and cache it in memory."""
with SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK:
SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK.acquire(
timeout=self._root_certs_dict_lock_timeout
)
try:
if SnowflakeOCSP.ROOT_CERTIFICATES_DICT:
# return if already loaded
return
Expand Down Expand Up @@ -1471,6 +1476,8 @@ def _lazy_read_ca_bundle(self) -> None:
"No CA bundle file is found in the system. "
"Set REQUESTS_CA_BUNDLE to the file."
)
finally:
SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK.release()

@staticmethod
def _calculate_tolerable_validity(this_update: float, next_update: float) -> int:
Expand Down
2 changes: 2 additions & 0 deletions src/snowflake/connector/ssl_wrap_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

DEFAULT_OCSP_MODE: OCSPMode = OCSPMode.FAIL_OPEN
FEATURE_OCSP_MODE: OCSPMode = DEFAULT_OCSP_MODE
FEATURE_ROOT_CERTS_DICT_LOCK_TIMEOUT: int = -1

"""
OCSP Response cache file name
Expand Down Expand Up @@ -84,6 +85,7 @@ def ssl_wrap_socket_with_ocsp(*args: Any, **kwargs: Any) -> WrappedSocket:
ocsp_response_cache_uri=FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME,
use_fail_open=FEATURE_OCSP_MODE == OCSPMode.FAIL_OPEN,
hostname=server_hostname,
root_certs_dict_lock_timeout=FEATURE_ROOT_CERTS_DICT_LOCK_TIMEOUT,
).validate(server_hostname, ret.connection)
if not v:
raise OperationalError(
Expand Down
90 changes: 47 additions & 43 deletions test/unit/test_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@ class OCSPMode(Enum):
url_3 = f"https://{hostname_2}/rgm1-s-sfctst0/stages/another-url"


mock_conn = mock.Mock()
mock_conn.disable_request_pooling = False
mock_conn._ocsp_mode = lambda: DEFAULT_OCSP_MODE


def close_sessions(rest: SnowflakeRestful, num_session_pools: int) -> None:
"""Helper function to call SnowflakeRestful.close(). Asserts close was called on all SessionPools."""
with mock.patch("snowflake.connector.network.SessionPool.close") as close_mock:
Expand All @@ -50,59 +45,68 @@ def create_session(

@mock.patch("snowflake.connector.network.SnowflakeRestful.make_requests_session")
def test_no_url_multiple_sessions(make_session_mock):
rest = SnowflakeRestful(connection=mock_conn)
with mock.patch("snowflake.connector.SnowflakeConnection") as mock_conn:
mock_conn.disable_request_pooling = False
mock_conn._ocsp_mode = lambda: DEFAULT_OCSP_MODE
rest = SnowflakeRestful(connection=mock_conn)

create_session(rest, 2)
create_session(rest, 2)

assert make_session_mock.call_count == 2
assert make_session_mock.call_count == 2

assert list(rest._sessions_map.keys()) == [None]
assert list(rest._sessions_map.keys()) == [None]

session_pool = rest._sessions_map[None]
assert len(session_pool._idle_sessions) == 2
assert len(session_pool._active_sessions) == 0
session_pool = rest._sessions_map[None]
assert len(session_pool._idle_sessions) == 2
assert len(session_pool._active_sessions) == 0

close_sessions(rest, 1)
close_sessions(rest, 1)


@mock.patch("snowflake.connector.network.SnowflakeRestful.make_requests_session")
def test_multiple_urls_multiple_sessions(make_session_mock):
rest = SnowflakeRestful(connection=mock_conn)
with mock.patch("snowflake.connector.SnowflakeConnection") as mock_conn:
mock_conn.disable_request_pooling = False
mock_conn._ocsp_mode = lambda: DEFAULT_OCSP_MODE
rest = SnowflakeRestful(connection=mock_conn)

for url in [url_1, url_2, None]:
create_session(rest, num_sessions=2, url=url)
for url in [url_1, url_2, None]:
create_session(rest, num_sessions=2, url=url)

assert make_session_mock.call_count == 6
assert make_session_mock.call_count == 6

hostnames = list(rest._sessions_map.keys())
for hostname in [hostname_1, hostname_2, None]:
assert hostname in hostnames
hostnames = list(rest._sessions_map.keys())
for hostname in [hostname_1, hostname_2, None]:
assert hostname in hostnames

for pool in rest._sessions_map.values():
assert len(pool._idle_sessions) == 2
assert len(pool._active_sessions) == 0
for pool in rest._sessions_map.values():
assert len(pool._idle_sessions) == 2
assert len(pool._active_sessions) == 0

close_sessions(rest, 3)
close_sessions(rest, 3)


@mock.patch("snowflake.connector.network.SnowflakeRestful.make_requests_session")
def test_multiple_urls_reuse_sessions(make_session_mock):
rest = SnowflakeRestful(connection=mock_conn)
for url in [url_1, url_2, url_3, None]:
# create 10 sessions, one after another
for _ in range(10):
create_session(rest, url=url)

# only one session is created and reused thereafter
assert make_session_mock.call_count == 3

hostnames = list(rest._sessions_map.keys())
assert len(hostnames) == 3
for hostname in [hostname_1, hostname_2, None]:
assert hostname in hostnames

for pool in rest._sessions_map.values():
assert len(pool._idle_sessions) == 1
assert len(pool._active_sessions) == 0

close_sessions(rest, 3)
with mock.patch("snowflake.connector.SnowflakeConnection") as mock_conn:
mock_conn.disable_request_pooling = False
mock_conn._ocsp_mode = lambda: DEFAULT_OCSP_MODE
rest = SnowflakeRestful(connection=mock_conn)
for url in [url_1, url_2, url_3, None]:
# create 10 sessions, one after another
for _ in range(10):
create_session(rest, url=url)

# only one session is created and reused thereafter
assert make_session_mock.call_count == 3

hostnames = list(rest._sessions_map.keys())
assert len(hostnames) == 3
for hostname in [hostname_1, hostname_2, None]:
assert hostname in hostnames

for pool in rest._sessions_map.values():
assert len(pool._idle_sessions) == 1
assert len(pool._active_sessions) == 0

close_sessions(rest, 3)
1 change: 0 additions & 1 deletion test/unit/test_wiremock_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
except ImportError:
import requests


from ..wiremock.wiremock_utils import WiremockClient


Expand Down
Loading