|
9 | 9 | import stat |
10 | 10 | import tempfile |
11 | 11 | import threading |
| 12 | +import time |
12 | 13 | import warnings |
13 | 14 | import weakref |
14 | 15 | from unittest import mock |
| 16 | +from unittest.mock import MagicMock, PropertyMock, patch |
15 | 17 | from uuid import uuid4 |
16 | 18 |
|
17 | 19 | import pytest |
@@ -1487,6 +1489,165 @@ def test_ocsp_mode_insecure_mode_and_disable_ocsp_checks_mismatch_ocsp_enabled( |
1487 | 1489 | assert "snowflake.connector.ocsp_snowflake" not in caplog.text |
1488 | 1490 |
|
1489 | 1491 |
|
| 1492 | +@pytest.mark.skipolddriver |
| 1493 | +def test_root_certs_dict_lock_timeout_fail_open(db_parameters, caplog): |
| 1494 | + def mock_acquire_times_out(timeout=None): |
| 1495 | + """Mock acquire method that always times out after the specified timeout.""" |
| 1496 | + if timeout is not None and timeout > 0: |
| 1497 | + time.sleep(timeout) |
| 1498 | + return False |
| 1499 | + |
| 1500 | + config = { |
| 1501 | + "user": db_parameters["user"], |
| 1502 | + "password": db_parameters["password"], |
| 1503 | + "host": db_parameters["host"], |
| 1504 | + "port": db_parameters["port"], |
| 1505 | + "account": db_parameters["account"], |
| 1506 | + "schema": db_parameters["schema"], |
| 1507 | + "database": db_parameters["database"], |
| 1508 | + "protocol": db_parameters["protocol"], |
| 1509 | + "timezone": "UTC", |
| 1510 | + "ocsp_fail_open": True, |
| 1511 | + "ocsp_root_certs_dict_lock_timeout": 0.1, |
| 1512 | + } |
| 1513 | + |
| 1514 | + caplog.set_level(logging.INFO, "snowflake.connector.ocsp_snowflake") |
| 1515 | + |
| 1516 | + with patch( |
| 1517 | + "snowflake.connector.ocsp_snowflake.SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK" |
| 1518 | + ) as mock_lock: |
| 1519 | + snowflake.connector.ocsp_snowflake.SnowflakeOCSP.ROOT_CERTIFICATES_DICT = {} |
| 1520 | + |
| 1521 | + mock_lock.acquire = MagicMock(side_effect=mock_acquire_times_out) |
| 1522 | + mock_lock.release = MagicMock() |
| 1523 | + |
| 1524 | + conn = snowflake.connector.connect(**config) |
| 1525 | + |
| 1526 | + try: |
| 1527 | + with conn.cursor() as cur: |
| 1528 | + assert cur.execute("select 1").fetchall() == [(1,)] |
| 1529 | + |
| 1530 | + if mock_lock.acquire.called: |
| 1531 | + mock_lock.acquire.assert_called_with(timeout=0.1) |
| 1532 | + assert conn._ocsp_root_certs_dict_lock_timeout == 0.1 |
| 1533 | + finally: |
| 1534 | + conn.close() |
| 1535 | + |
| 1536 | + |
| 1537 | +@pytest.mark.skipolddriver |
| 1538 | +def test_root_certs_dict_lock_timeout(db_parameters, caplog): |
| 1539 | + config_fail_close = { |
| 1540 | + "user": db_parameters["user"], |
| 1541 | + "password": db_parameters["password"], |
| 1542 | + "host": db_parameters["host"], |
| 1543 | + "port": db_parameters["port"], |
| 1544 | + "account": db_parameters["account"], |
| 1545 | + "schema": db_parameters["schema"], |
| 1546 | + "database": db_parameters["database"], |
| 1547 | + "protocol": db_parameters["protocol"], |
| 1548 | + "timezone": "UTC", |
| 1549 | + "ocsp_fail_open": False, |
| 1550 | + "ocsp_root_certs_dict_lock_timeout": 1, |
| 1551 | + } |
| 1552 | + |
| 1553 | + caplog.set_level(logging.INFO, "snowflake.connector.ocsp_snowflake") |
| 1554 | + |
| 1555 | + with patch( |
| 1556 | + "snowflake.connector.ocsp_snowflake.SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK" |
| 1557 | + ) as mock_lock: |
| 1558 | + snowflake.connector.ocsp_snowflake.SnowflakeOCSP.ROOT_CERTIFICATES_DICT = {} |
| 1559 | + |
| 1560 | + type(mock_lock).acquire = PropertyMock(return_value=lambda timeout: False) |
| 1561 | + type(mock_lock).release = PropertyMock(return_value=lambda: None) |
| 1562 | + |
| 1563 | + conn = snowflake.connector.connect(**config_fail_close) |
| 1564 | + with conn.cursor() as cur: |
| 1565 | + assert cur.execute("select 1").fetchall() == [(1,)] |
| 1566 | + |
| 1567 | + assert conn._ocsp_root_certs_dict_lock_timeout == 1 |
| 1568 | + conn.close() |
| 1569 | + |
| 1570 | + caplog.clear() |
| 1571 | + |
| 1572 | + config_fail_open = { |
| 1573 | + "user": db_parameters["user"], |
| 1574 | + "password": db_parameters["password"], |
| 1575 | + "host": db_parameters["host"], |
| 1576 | + "port": db_parameters["port"], |
| 1577 | + "account": db_parameters["account"], |
| 1578 | + "schema": db_parameters["schema"], |
| 1579 | + "database": db_parameters["database"], |
| 1580 | + "protocol": db_parameters["protocol"], |
| 1581 | + "timezone": "UTC", |
| 1582 | + "ocsp_fail_open": True, # fail-open mode |
| 1583 | + "ocsp_root_certs_dict_lock_timeout": 2, # 2 second timeout |
| 1584 | + } |
| 1585 | + |
| 1586 | + caplog.set_level(logging.INFO, "snowflake.connector.ocsp_snowflake") |
| 1587 | + |
| 1588 | + with patch( |
| 1589 | + "snowflake.connector.ocsp_snowflake.SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK" |
| 1590 | + ) as mock_lock: |
| 1591 | + snowflake.connector.ocsp_snowflake.SnowflakeOCSP.ROOT_CERTIFICATES_DICT = {} |
| 1592 | + |
| 1593 | + type(mock_lock).acquire = PropertyMock(return_value=lambda timeout: False) |
| 1594 | + type(mock_lock).release = PropertyMock(return_value=lambda: None) |
| 1595 | + |
| 1596 | + conn = snowflake.connector.connect(**config_fail_open) |
| 1597 | + with conn.cursor() as cur: |
| 1598 | + assert cur.execute("select 1").fetchall() == [(1,)] |
| 1599 | + |
| 1600 | + assert conn._ocsp_root_certs_dict_lock_timeout == 2 |
| 1601 | + conn.close() |
| 1602 | + |
| 1603 | + caplog.clear() |
| 1604 | + |
| 1605 | + config_short_timeout = { |
| 1606 | + "user": db_parameters["user"], |
| 1607 | + "password": db_parameters["password"], |
| 1608 | + "host": db_parameters["host"], |
| 1609 | + "port": db_parameters["port"], |
| 1610 | + "account": db_parameters["account"], |
| 1611 | + "schema": db_parameters["schema"], |
| 1612 | + "database": db_parameters["database"], |
| 1613 | + "protocol": db_parameters["protocol"], |
| 1614 | + "timezone": "UTC", |
| 1615 | + "ocsp_fail_open": True, |
| 1616 | + "ocsp_root_certs_dict_lock_timeout": 0.001, |
| 1617 | + } |
| 1618 | + |
| 1619 | + conn = snowflake.connector.connect(**config_short_timeout) |
| 1620 | + try: |
| 1621 | + with conn.cursor() as cur: |
| 1622 | + assert cur.execute("select 1").fetchall() == [(1,)] |
| 1623 | + |
| 1624 | + assert conn._ocsp_root_certs_dict_lock_timeout == 0.001 |
| 1625 | + finally: |
| 1626 | + conn.close() |
| 1627 | + |
| 1628 | + config_no_timeout = { |
| 1629 | + "user": db_parameters["user"], |
| 1630 | + "password": db_parameters["password"], |
| 1631 | + "host": db_parameters["host"], |
| 1632 | + "port": db_parameters["port"], |
| 1633 | + "account": db_parameters["account"], |
| 1634 | + "schema": db_parameters["schema"], |
| 1635 | + "database": db_parameters["database"], |
| 1636 | + "protocol": db_parameters["protocol"], |
| 1637 | + "timezone": "UTC", |
| 1638 | + "ocsp_fail_open": True, |
| 1639 | + } |
| 1640 | + |
| 1641 | + conn = snowflake.connector.connect(**config_no_timeout) |
| 1642 | + try: |
| 1643 | + with conn.cursor() as cur: |
| 1644 | + assert cur.execute("select 1").fetchall() == [(1,)] |
| 1645 | + |
| 1646 | + assert conn._ocsp_root_certs_dict_lock_timeout == -1 |
| 1647 | + finally: |
| 1648 | + conn.close() |
| 1649 | + |
| 1650 | + |
1490 | 1651 | @pytest.mark.skipolddriver |
1491 | 1652 | def test_ocsp_mode_insecure_mode_deprecation_warning(conn_cnx): |
1492 | 1653 | with warnings.catch_warnings(record=True) as w: |
|
0 commit comments