@@ -695,14 +695,21 @@ def is_cache_fresh(current_time, ts):
695695        return  current_time  -  OCSPCache .CACHE_EXPIRATION  <=  ts 
696696
697697    @staticmethod  
698-     def  find_cache (ocsp , cert_id , subject , ** kwargs ):
698+     def  find_cache (
699+         ocsp : SnowflakeOCSP , cert_id : CertId , subject : Certificate  |  None , ** kwargs : Any 
700+     ) ->  tuple [bool , bytes  |  None ]:
699701        subject_name  =  ocsp .subject_name (subject ) if  subject  else  None 
700702        current_time  =  int (time .time ())
701703        cache_key : tuple [bytes , bytes , bytes ] =  kwargs .get (
702704            "cache_key" , ocsp .decode_cert_id_key (cert_id )
703705        )
704-         if  cache_key  in  OCSP_RESPONSE_VALIDATION_CACHE :
705-             ocsp_response_validation_result  =  OCSP_RESPONSE_VALIDATION_CACHE [cache_key ]
706+         lock_cache : bool  =  kwargs .get ("lock_cache" , True )
707+         try :
708+             ocsp_response_validation_result  =  (
709+                 OCSP_RESPONSE_VALIDATION_CACHE [cache_key ]
710+                 if  lock_cache 
711+                 else  OCSP_RESPONSE_VALIDATION_CACHE ._getitem (cache_key )
712+             )
706713            try :
707714                # is_valid_time can raise exception if the cache 
708715                # entry is a SSD. 
@@ -715,52 +722,28 @@ def find_cache(ocsp, cert_id, subject, **kwargs):
715722                        logger .debug ("hit cache for subject: %s" , subject_name )
716723                    return  True , ocsp_response_validation_result .ocsp_response 
717724                else :
718-                     OCSPCache .delete_cache (ocsp , cert_id , cache_key = cache_key )
725+                     OCSPCache .delete_cache (
726+                         ocsp , cert_id , cache_key = cache_key , lock_cache = lock_cache 
727+                     )
719728            except  Exception  as  ex :
720729                logger .debug (f"Could not validate cache entry { cert_id }   { ex }  " )
721730            OCSPCache .CACHE_UPDATED  =  True 
722-         if  subject_name :
723-             logger .debug ("not hit cache for subject: %s" , subject_name )
731+         except  KeyError :
732+             if  subject_name :
733+                 logger .debug (f"cache miss for subject: '{ subject_name }  '" )
724734        return  False , None 
725735
726-     @staticmethod  
727-     def  update_or_delete_cache (ocsp , cert_id , ocsp_response , ts ):
728-         try :
729-             current_time  =  int (time .time ())
730-             found , _  =  OCSPCache .find_cache (ocsp , cert_id , None )
731-             if  current_time  -  OCSPCache .CACHE_EXPIRATION  <=  ts :
732-                 # creation time must be new enough 
733-                 OCSPCache .update_cache (ocsp , cert_id , ocsp_response )
734-             elif  found :
735-                 # invalidate the cache if exists 
736-                 OCSPCache .delete_cache (ocsp , cert_id )
737-         except  Exception  as  ex :
738-             logger .debug ("Caught here > %s" , ex )
739-             raise  ex 
740- 
741-     @staticmethod  
742-     def  update_cache (
743-         ocsp : SnowflakeOCSP , cert_id : CertId , ocsp_response , ** kwargs : Any 
744-     ):
745-         # Every time this is called the in memory cache will 
746-         # be updated and written to disk. 
747-         cache_key : tuple [bytes , bytes , bytes ] =  kwargs .get (
748-             "cache_key" , ocsp .decode_cert_id_key (cert_id )
749-         )
750-         OCSP_RESPONSE_VALIDATION_CACHE [cache_key ] =  OCSPResponseValidationResult (
751-             ocsp_response = ocsp_response ,
752-             ts = int (time .time ()),
753-             validated = False ,
754-         )
755-         OCSPCache .CACHE_UPDATED  =  True 
756- 
757736    @staticmethod  
758737    def  delete_cache (ocsp : SnowflakeOCSP , cert_id : CertId , ** kwargs : Any ):
759738        cache_key : tuple [bytes , bytes , bytes ] =  kwargs .get (
760739            "cache_key" , ocsp .decode_cert_id_key (cert_id )
761740        )
741+         lock_cache : bool  =  kwargs .get ("lock_cache" , True )
762742        try :
763-             del  OCSP_RESPONSE_VALIDATION_CACHE [cache_key ]
743+             if  lock_cache :
744+                 del  OCSP_RESPONSE_VALIDATION_CACHE [cache_key ]
745+             else :
746+                 OCSP_RESPONSE_VALIDATION_CACHE ._delitem (cache_key )
764747            OCSPCache .CACHE_UPDATED  =  True 
765748        except  KeyError :
766749            pass 
@@ -1597,13 +1580,36 @@ def _process_unknown_status(self, cert_id):
15971580    def  decode_ocsp_response_cache (self , ocsp_response_cache_json ):
15981581        """Decodes OCSP response cache from JSON.""" 
15991582        try :
1600-             for  cert_id_base64 , (ts , ocsp_response ) in  ocsp_response_cache_json .items ():
1601-                 cert_id  =  self .decode_cert_id_base64 (cert_id_base64 )
1602-                 if  not  self .is_valid_time (cert_id , b64decode (ocsp_response )):
1603-                     continue 
1604-                 SnowflakeOCSP .OCSP_CACHE .update_or_delete_cache (
1605-                     self , cert_id , b64decode (ocsp_response ), ts 
1606-                 )
1583+             with  OCSP_RESPONSE_VALIDATION_CACHE ._lock :
1584+                 new_cache_dict  =  {}
1585+                 for  cert_id_base64 , (
1586+                     ts ,
1587+                     ocsp_response ,
1588+                 ) in  ocsp_response_cache_json .items ():
1589+                     cert_id  =  self .decode_cert_id_base64 (cert_id_base64 )
1590+                     b64decoded_ocsp_response  =  b64decode (ocsp_response )
1591+                     if  not  self .is_valid_time (cert_id , b64decoded_ocsp_response ):
1592+                         continue 
1593+                     current_time  =  int (time .time ())
1594+                     cache_key : tuple [bytes , bytes , bytes ] =  self .decode_cert_id_key (
1595+                         cert_id 
1596+                     )
1597+                     found , _  =  OCSPCache .find_cache (
1598+                         self , cert_id , None , cache_key = cache_key , lock_cache = False 
1599+                     )
1600+                     if  OCSPCache .is_cache_fresh (current_time , ts ):
1601+                         new_cache_dict [cache_key ] =  OCSPResponseValidationResult (
1602+                             ocsp_response = b64decoded_ocsp_response ,
1603+                             ts = current_time ,
1604+                             validated = False ,
1605+                         )
1606+                     elif  found :
1607+                         OCSPCache .delete_cache (
1608+                             self , cert_id , cache_key = cache_key , lock_cache = False 
1609+                         )
1610+             if  new_cache_dict :
1611+                 OCSP_RESPONSE_VALIDATION_CACHE ._update (new_cache_dict )
1612+                 OCSPCache .CACHE_UPDATED  =  True 
16071613        except  Exception  as  ex :
16081614            logger .debug ("Caught here - %s" , ex )
16091615            ermsg  =  "Exception raised while decoding OCSP Response Cache {}" .format (
0 commit comments