Skip to content

Commit 1e8162b

Browse files
committed
skip all ca bundle read if lock wasn't acquired, add tests
1 parent eac7963 commit 1e8162b

File tree

2 files changed

+229
-61
lines changed

2 files changed

+229
-61
lines changed

src/snowflake/connector/ocsp_snowflake.py

Lines changed: 68 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,72 +1412,79 @@ def _check_ocsp_response_cache_server(
14121412

14131413
def _lazy_read_ca_bundle(self) -> None:
14141414
"""Reads the local cabundle file and cache it in memory."""
1415-
SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK.acquire(
1415+
lock_acquired = SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK.acquire(
14161416
timeout=self._root_certs_dict_lock_timeout
14171417
)
1418-
try:
1419-
if SnowflakeOCSP.ROOT_CERTIFICATES_DICT:
1420-
# return if already loaded
1421-
return
1422-
1418+
if lock_acquired:
14231419
try:
1424-
ca_bundle = environ.get("REQUESTS_CA_BUNDLE") or environ.get(
1425-
"CURL_CA_BUNDLE"
1426-
)
1427-
if ca_bundle and path.exists(ca_bundle):
1428-
# if the user/application specifies cabundle.
1429-
self.read_cert_bundle(ca_bundle)
1430-
else:
1431-
import sys
1432-
1433-
# This import that depends on these libraries is to import certificates from them,
1434-
# we would like to have these as up to date as possible.
1435-
from requests import certs
1420+
if SnowflakeOCSP.ROOT_CERTIFICATES_DICT:
1421+
# return if already loaded
1422+
return
14361423

1437-
if (
1438-
hasattr(certs, "__file__")
1439-
and path.exists(certs.__file__)
1440-
and path.exists(
1441-
path.join(path.dirname(certs.__file__), "cacert.pem")
1442-
)
1443-
):
1444-
# if cacert.pem exists next to certs.py in request
1445-
# package.
1446-
ca_bundle = path.join(
1447-
path.dirname(certs.__file__), "cacert.pem"
1448-
)
1424+
try:
1425+
ca_bundle = environ.get("REQUESTS_CA_BUNDLE") or environ.get(
1426+
"CURL_CA_BUNDLE"
1427+
)
1428+
if ca_bundle and path.exists(ca_bundle):
1429+
# if the user/application specifies cabundle.
14491430
self.read_cert_bundle(ca_bundle)
1450-
elif hasattr(sys, "_MEIPASS"):
1451-
# if pyinstaller includes cacert.pem
1452-
cabundle_candidates = [
1453-
["botocore", "vendored", "requests", "cacert.pem"],
1454-
["requests", "cacert.pem"],
1455-
["cacert.pem"],
1456-
]
1457-
for filename in cabundle_candidates:
1458-
ca_bundle = path.join(sys._MEIPASS, *filename)
1459-
if path.exists(ca_bundle):
1460-
self.read_cert_bundle(ca_bundle)
1461-
break
1462-
else:
1463-
logger.error("No cabundle file is found in _MEIPASS")
1464-
try:
1465-
import certifi
1466-
1467-
self.read_cert_bundle(certifi.where())
1468-
except Exception:
1469-
logger.debug("no certifi is installed. ignored.")
1470-
1471-
except Exception as e:
1472-
logger.error("Failed to read ca_bundle: %s", e)
1473-
1474-
if not SnowflakeOCSP.ROOT_CERTIFICATES_DICT:
1475-
logger.error(
1476-
"No CA bundle file is found in the system. "
1477-
"Set REQUESTS_CA_BUNDLE to the file."
1478-
)
1479-
finally:
1480-
SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK.release()
1431+
else:
1432+
import sys
1433+
1434+
# This import that depends on these libraries is to import certificates from them,
1435+
# we would like to have these as up to date as possible.
1436+
from requests import certs
1437+
1438+
if (
1439+
hasattr(certs, "__file__")
1440+
and path.exists(certs.__file__)
1441+
and path.exists(
1442+
path.join(path.dirname(certs.__file__), "cacert.pem")
1443+
)
1444+
):
1445+
# if cacert.pem exists next to certs.py in request
1446+
# package.
1447+
ca_bundle = path.join(
1448+
path.dirname(certs.__file__), "cacert.pem"
1449+
)
1450+
self.read_cert_bundle(ca_bundle)
1451+
elif hasattr(sys, "_MEIPASS"):
1452+
# if pyinstaller includes cacert.pem
1453+
cabundle_candidates = [
1454+
["botocore", "vendored", "requests", "cacert.pem"],
1455+
["requests", "cacert.pem"],
1456+
["cacert.pem"],
1457+
]
1458+
for filename in cabundle_candidates:
1459+
ca_bundle = path.join(sys._MEIPASS, *filename)
1460+
if path.exists(ca_bundle):
1461+
self.read_cert_bundle(ca_bundle)
1462+
break
1463+
else:
1464+
logger.error("No cabundle file is found in _MEIPASS")
1465+
try:
1466+
import certifi
1467+
1468+
self.read_cert_bundle(certifi.where())
1469+
except Exception:
1470+
logger.debug("no certifi is installed. ignored.")
1471+
1472+
except Exception as e:
1473+
logger.error("Failed to read ca_bundle: %s", e)
1474+
1475+
if not SnowflakeOCSP.ROOT_CERTIFICATES_DICT:
1476+
logger.error(
1477+
"No CA bundle file is found in the system. "
1478+
"Set REQUESTS_CA_BUNDLE to the file."
1479+
)
1480+
finally:
1481+
SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK.release()
1482+
else:
1483+
logger.info(
1484+
"Failed to acquire lock for ROOT_CERTIFICATES_DICT_LOCK. "
1485+
"Skipping reading CA bundle."
1486+
)
1487+
return
14811488

14821489
@staticmethod
14831490
def _calculate_tolerable_validity(this_update: float, next_update: float) -> int:

test/integ/test_connection.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
import stat
1010
import tempfile
1111
import threading
12+
import time
1213
import warnings
1314
import weakref
1415
from unittest import mock
16+
from unittest.mock import MagicMock, PropertyMock, patch
1517
from uuid import uuid4
1618

1719
import pytest
@@ -1487,6 +1489,165 @@ def test_ocsp_mode_insecure_mode_and_disable_ocsp_checks_mismatch_ocsp_enabled(
14871489
assert "snowflake.connector.ocsp_snowflake" not in caplog.text
14881490

14891491

1492+
@pytest.mark.skipolddriver
1493+
def test_root_certs_dict_lock_timeout_fail_open(db_parameters, caplog):
1494+
def mock_acquire_times_out(timeout=None):
1495+
"""Mock acquire method that always times out after the specified timeout."""
1496+
if timeout is not None and timeout > 0:
1497+
time.sleep(timeout)
1498+
return False
1499+
1500+
config = {
1501+
"user": db_parameters["user"],
1502+
"password": db_parameters["password"],
1503+
"host": db_parameters["host"],
1504+
"port": db_parameters["port"],
1505+
"account": db_parameters["account"],
1506+
"schema": db_parameters["schema"],
1507+
"database": db_parameters["database"],
1508+
"protocol": db_parameters["protocol"],
1509+
"timezone": "UTC",
1510+
"ocsp_fail_open": True,
1511+
"ocsp_root_certs_dict_lock_timeout": 0.1,
1512+
}
1513+
1514+
caplog.set_level(logging.INFO, "snowflake.connector.ocsp_snowflake")
1515+
1516+
with patch(
1517+
"snowflake.connector.ocsp_snowflake.SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK"
1518+
) as mock_lock:
1519+
snowflake.connector.ocsp_snowflake.SnowflakeOCSP.ROOT_CERTIFICATES_DICT = {}
1520+
1521+
mock_lock.acquire = MagicMock(side_effect=mock_acquire_times_out)
1522+
mock_lock.release = MagicMock()
1523+
1524+
conn = snowflake.connector.connect(**config)
1525+
1526+
try:
1527+
with conn.cursor() as cur:
1528+
assert cur.execute("select 1").fetchall() == [(1,)]
1529+
1530+
if mock_lock.acquire.called:
1531+
mock_lock.acquire.assert_called_with(timeout=0.1)
1532+
assert conn._ocsp_root_certs_dict_lock_timeout == 0.1
1533+
finally:
1534+
conn.close()
1535+
1536+
1537+
@pytest.mark.skipolddriver
1538+
def test_root_certs_dict_lock_timeout(db_parameters, caplog):
1539+
config_fail_close = {
1540+
"user": db_parameters["user"],
1541+
"password": db_parameters["password"],
1542+
"host": db_parameters["host"],
1543+
"port": db_parameters["port"],
1544+
"account": db_parameters["account"],
1545+
"schema": db_parameters["schema"],
1546+
"database": db_parameters["database"],
1547+
"protocol": db_parameters["protocol"],
1548+
"timezone": "UTC",
1549+
"ocsp_fail_open": False,
1550+
"ocsp_root_certs_dict_lock_timeout": 1,
1551+
}
1552+
1553+
caplog.set_level(logging.INFO, "snowflake.connector.ocsp_snowflake")
1554+
1555+
with patch(
1556+
"snowflake.connector.ocsp_snowflake.SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK"
1557+
) as mock_lock:
1558+
snowflake.connector.ocsp_snowflake.SnowflakeOCSP.ROOT_CERTIFICATES_DICT = {}
1559+
1560+
type(mock_lock).acquire = PropertyMock(return_value=lambda timeout: False)
1561+
type(mock_lock).release = PropertyMock(return_value=lambda: None)
1562+
1563+
conn = snowflake.connector.connect(**config_fail_close)
1564+
with conn.cursor() as cur:
1565+
assert cur.execute("select 1").fetchall() == [(1,)]
1566+
1567+
assert conn._ocsp_root_certs_dict_lock_timeout == 1
1568+
conn.close()
1569+
1570+
caplog.clear()
1571+
1572+
config_fail_open = {
1573+
"user": db_parameters["user"],
1574+
"password": db_parameters["password"],
1575+
"host": db_parameters["host"],
1576+
"port": db_parameters["port"],
1577+
"account": db_parameters["account"],
1578+
"schema": db_parameters["schema"],
1579+
"database": db_parameters["database"],
1580+
"protocol": db_parameters["protocol"],
1581+
"timezone": "UTC",
1582+
"ocsp_fail_open": True, # fail-open mode
1583+
"ocsp_root_certs_dict_lock_timeout": 2, # 2 second timeout
1584+
}
1585+
1586+
caplog.set_level(logging.INFO, "snowflake.connector.ocsp_snowflake")
1587+
1588+
with patch(
1589+
"snowflake.connector.ocsp_snowflake.SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK"
1590+
) as mock_lock:
1591+
snowflake.connector.ocsp_snowflake.SnowflakeOCSP.ROOT_CERTIFICATES_DICT = {}
1592+
1593+
type(mock_lock).acquire = PropertyMock(return_value=lambda timeout: False)
1594+
type(mock_lock).release = PropertyMock(return_value=lambda: None)
1595+
1596+
conn = snowflake.connector.connect(**config_fail_open)
1597+
with conn.cursor() as cur:
1598+
assert cur.execute("select 1").fetchall() == [(1,)]
1599+
1600+
assert conn._ocsp_root_certs_dict_lock_timeout == 2
1601+
conn.close()
1602+
1603+
caplog.clear()
1604+
1605+
config_short_timeout = {
1606+
"user": db_parameters["user"],
1607+
"password": db_parameters["password"],
1608+
"host": db_parameters["host"],
1609+
"port": db_parameters["port"],
1610+
"account": db_parameters["account"],
1611+
"schema": db_parameters["schema"],
1612+
"database": db_parameters["database"],
1613+
"protocol": db_parameters["protocol"],
1614+
"timezone": "UTC",
1615+
"ocsp_fail_open": True,
1616+
"ocsp_root_certs_dict_lock_timeout": 0.001,
1617+
}
1618+
1619+
conn = snowflake.connector.connect(**config_short_timeout)
1620+
try:
1621+
with conn.cursor() as cur:
1622+
assert cur.execute("select 1").fetchall() == [(1,)]
1623+
1624+
assert conn._ocsp_root_certs_dict_lock_timeout == 0.001
1625+
finally:
1626+
conn.close()
1627+
1628+
config_no_timeout = {
1629+
"user": db_parameters["user"],
1630+
"password": db_parameters["password"],
1631+
"host": db_parameters["host"],
1632+
"port": db_parameters["port"],
1633+
"account": db_parameters["account"],
1634+
"schema": db_parameters["schema"],
1635+
"database": db_parameters["database"],
1636+
"protocol": db_parameters["protocol"],
1637+
"timezone": "UTC",
1638+
"ocsp_fail_open": True,
1639+
}
1640+
1641+
conn = snowflake.connector.connect(**config_no_timeout)
1642+
try:
1643+
with conn.cursor() as cur:
1644+
assert cur.execute("select 1").fetchall() == [(1,)]
1645+
1646+
assert conn._ocsp_root_certs_dict_lock_timeout == -1
1647+
finally:
1648+
conn.close()
1649+
1650+
14901651
@pytest.mark.skipolddriver
14911652
def test_ocsp_mode_insecure_mode_deprecation_warning(conn_cnx):
14921653
with warnings.catch_warnings(record=True) as w:

0 commit comments

Comments
 (0)