Skip to content

Commit 0bb9d98

Browse files
authored
SNOW-668660: Inefficient file lock for OCSP cache mechanism (#1319)
1 parent 8b8c51e commit 0bb9d98

File tree

2 files changed

+81
-50
lines changed

2 files changed

+81
-50
lines changed

src/snowflake/connector/cache.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,7 @@ def _getitem(
101101
self._hit(k)
102102
return v
103103

104-
def _setitem(
105-
self,
106-
k: K,
107-
v: V,
108-
) -> None:
104+
def _setitem(self, k: K, v: V) -> None:
109105
"""Non-locking version of __setitem__.
110106
111107
This should only be used by internal functions when already
@@ -407,6 +403,35 @@ def __init__(
407403
if os.path.exists(self.file_path):
408404
self._load()
409405

406+
def _getitem(
407+
self,
408+
k: K,
409+
*,
410+
should_record_hits: bool = True,
411+
) -> V:
412+
"""Non-locking version of __getitem__.
413+
414+
This should only be used by internal functions when already
415+
holding self._lock.
416+
"""
417+
if k not in self._cache:
418+
loaded = self._load_if_should()
419+
if (not loaded) or k not in self._cache:
420+
self._miss(k)
421+
raise KeyError
422+
t, v = self._cache[k]
423+
if is_expired(t):
424+
loaded = self._load_if_should()
425+
expire_item = True
426+
if loaded:
427+
t, v = self._cache[k]
428+
expire_item = is_expired(t)
429+
if expire_item:
430+
# Raises KeyError
431+
self._expire(k)
432+
self._hit(k)
433+
return v
434+
410435
def __getitem__(self, k: K) -> V:
411436
"""Returns an element if it hasn't expired yet in a thread-safe way."""
412437
self._lock.acquire()

src/snowflake/connector/ocsp_snowflake.py

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -695,14 +695,21 @@ def is_cache_fresh(current_time, ts):
695695
return current_time - OCSPCache.CACHE_EXPIRATION <= ts
696696

697697
@staticmethod
698-
def find_cache(ocsp, cert_id, subject, **kwargs):
698+
def find_cache(
699+
ocsp: SnowflakeOCSP, cert_id: CertId, subject: Certificate | None, **kwargs: Any
700+
) -> tuple[bool, bytes | None]:
699701
subject_name = ocsp.subject_name(subject) if subject else None
700702
current_time = int(time.time())
701703
cache_key: tuple[bytes, bytes, bytes] = kwargs.get(
702704
"cache_key", ocsp.decode_cert_id_key(cert_id)
703705
)
704-
if cache_key in OCSP_RESPONSE_VALIDATION_CACHE:
705-
ocsp_response_validation_result = OCSP_RESPONSE_VALIDATION_CACHE[cache_key]
706+
lock_cache: bool = kwargs.get("lock_cache", True)
707+
try:
708+
ocsp_response_validation_result = (
709+
OCSP_RESPONSE_VALIDATION_CACHE[cache_key]
710+
if lock_cache
711+
else OCSP_RESPONSE_VALIDATION_CACHE._getitem(cache_key)
712+
)
706713
try:
707714
# is_valid_time can raise exception if the cache
708715
# entry is a SSD.
@@ -715,52 +722,28 @@ def find_cache(ocsp, cert_id, subject, **kwargs):
715722
logger.debug("hit cache for subject: %s", subject_name)
716723
return True, ocsp_response_validation_result.ocsp_response
717724
else:
718-
OCSPCache.delete_cache(ocsp, cert_id, cache_key=cache_key)
725+
OCSPCache.delete_cache(
726+
ocsp, cert_id, cache_key=cache_key, lock_cache=lock_cache
727+
)
719728
except Exception as ex:
720729
logger.debug(f"Could not validate cache entry {cert_id} {ex}")
721730
OCSPCache.CACHE_UPDATED = True
722-
if subject_name:
723-
logger.debug("not hit cache for subject: %s", subject_name)
731+
except KeyError:
732+
if subject_name:
733+
logger.debug(f"cache miss for subject: '{subject_name}'")
724734
return False, None
725735

726-
@staticmethod
727-
def update_or_delete_cache(ocsp, cert_id, ocsp_response, ts):
728-
try:
729-
current_time = int(time.time())
730-
found, _ = OCSPCache.find_cache(ocsp, cert_id, None)
731-
if current_time - OCSPCache.CACHE_EXPIRATION <= ts:
732-
# creation time must be new enough
733-
OCSPCache.update_cache(ocsp, cert_id, ocsp_response)
734-
elif found:
735-
# invalidate the cache if exists
736-
OCSPCache.delete_cache(ocsp, cert_id)
737-
except Exception as ex:
738-
logger.debug("Caught here > %s", ex)
739-
raise ex
740-
741-
@staticmethod
742-
def update_cache(
743-
ocsp: SnowflakeOCSP, cert_id: CertId, ocsp_response, **kwargs: Any
744-
):
745-
# Every time this is called the in memory cache will
746-
# be updated and written to disk.
747-
cache_key: tuple[bytes, bytes, bytes] = kwargs.get(
748-
"cache_key", ocsp.decode_cert_id_key(cert_id)
749-
)
750-
OCSP_RESPONSE_VALIDATION_CACHE[cache_key] = OCSPResponseValidationResult(
751-
ocsp_response=ocsp_response,
752-
ts=int(time.time()),
753-
validated=False,
754-
)
755-
OCSPCache.CACHE_UPDATED = True
756-
757736
@staticmethod
758737
def delete_cache(ocsp: SnowflakeOCSP, cert_id: CertId, **kwargs: Any):
759738
cache_key: tuple[bytes, bytes, bytes] = kwargs.get(
760739
"cache_key", ocsp.decode_cert_id_key(cert_id)
761740
)
741+
lock_cache: bool = kwargs.get("lock_cache", True)
762742
try:
763-
del OCSP_RESPONSE_VALIDATION_CACHE[cache_key]
743+
if lock_cache:
744+
del OCSP_RESPONSE_VALIDATION_CACHE[cache_key]
745+
else:
746+
OCSP_RESPONSE_VALIDATION_CACHE._delitem(cache_key)
764747
OCSPCache.CACHE_UPDATED = True
765748
except KeyError:
766749
pass
@@ -1597,13 +1580,36 @@ def _process_unknown_status(self, cert_id):
15971580
def decode_ocsp_response_cache(self, ocsp_response_cache_json):
15981581
"""Decodes OCSP response cache from JSON."""
15991582
try:
1600-
for cert_id_base64, (ts, ocsp_response) in ocsp_response_cache_json.items():
1601-
cert_id = self.decode_cert_id_base64(cert_id_base64)
1602-
if not self.is_valid_time(cert_id, b64decode(ocsp_response)):
1603-
continue
1604-
SnowflakeOCSP.OCSP_CACHE.update_or_delete_cache(
1605-
self, cert_id, b64decode(ocsp_response), ts
1606-
)
1583+
with OCSP_RESPONSE_VALIDATION_CACHE._lock:
1584+
new_cache_dict = {}
1585+
for cert_id_base64, (
1586+
ts,
1587+
ocsp_response,
1588+
) in ocsp_response_cache_json.items():
1589+
cert_id = self.decode_cert_id_base64(cert_id_base64)
1590+
b64decoded_ocsp_response = b64decode(ocsp_response)
1591+
if not self.is_valid_time(cert_id, b64decoded_ocsp_response):
1592+
continue
1593+
current_time = int(time.time())
1594+
cache_key: tuple[bytes, bytes, bytes] = self.decode_cert_id_key(
1595+
cert_id
1596+
)
1597+
found, _ = OCSPCache.find_cache(
1598+
self, cert_id, None, cache_key=cache_key, lock_cache=False
1599+
)
1600+
if OCSPCache.is_cache_fresh(current_time, ts):
1601+
new_cache_dict[cache_key] = OCSPResponseValidationResult(
1602+
ocsp_response=b64decoded_ocsp_response,
1603+
ts=current_time,
1604+
validated=False,
1605+
)
1606+
elif found:
1607+
OCSPCache.delete_cache(
1608+
self, cert_id, cache_key=cache_key, lock_cache=False
1609+
)
1610+
if new_cache_dict:
1611+
OCSP_RESPONSE_VALIDATION_CACHE._update(new_cache_dict)
1612+
OCSPCache.CACHE_UPDATED = True
16071613
except Exception as ex:
16081614
logger.debug("Caught here - %s", ex)
16091615
ermsg = "Exception raised while decoding OCSP Response Cache {}".format(

0 commit comments

Comments
 (0)