66from __future__ import annotations
77
88import codecs
9+ import importlib
910import json
1011import os
1112import platform
3031from asn1crypto .x509 import Certificate
3132from OpenSSL .SSL import Connection
3233
34+ from snowflake .connector import SNOWFLAKE_CONNECTOR_VERSION
3335from snowflake .connector .compat import OK , urlsplit , urlunparse
3436from snowflake .connector .constants import HTTP_HEADER_USER_AGENT
3537from snowflake .connector .errorcode import (
5860
5961from . import constants
6062from .backoff_policies import exponential_backoff
61- from .cache import SFDictCache , SFDictFileCache
63+ from .cache import CacheEntry , SFDictCache , SFDictFileCache
6264from .telemetry import TelemetryField , generate_telemetry_data_dict
6365from .url_util import extract_top_level_domain_from_hostname , url_encode_str
66+ from .util_text import _base64_bytes_to_str
6467
6568
6669class OCSPResponseValidationResult (NamedTuple ):
@@ -72,27 +75,180 @@ class OCSPResponseValidationResult(NamedTuple):
7275 ts : int | None = None
7376 validated : bool = False
7477
78+ def _serialize (self ):
79+ def serialize_exception (exc ):
80+ # serialization exception is not supported for all exceptions
81+ # in the ocsp_snowflake.py, most exceptions are RevocationCheckError which is easy to serialize.
82+ # however, it would require non-trivial effort to serialize other exceptions especially 3rd part errors
83+ # as there can be un-serializable members and nondeterministic constructor arguments.
84+ # here we do a general best efforts serialization for other exceptions recording only the error message.
85+ if not exc :
86+ return None
87+
88+ exc_type = type (exc )
89+ ret = {"class" : exc_type .__name__ , "module" : exc_type .__module__ }
90+ if isinstance (exc , RevocationCheckError ):
91+ ret .update ({"errno" : exc .errno , "msg" : exc .raw_msg })
92+ else :
93+ ret .update ({"msg" : str (exc )})
94+ return ret
95+
96+ return json .dumps (
97+ {
98+ "exception" : serialize_exception (self .exception ),
99+ "issuer" : (
100+ _base64_bytes_to_str (self .issuer .dump ()) if self .issuer else None
101+ ),
102+ "subject" : (
103+ _base64_bytes_to_str (self .subject .dump ()) if self .subject else None
104+ ),
105+ "cert_id" : (
106+ _base64_bytes_to_str (self .cert_id .dump ()) if self .cert_id else None
107+ ),
108+ "ocsp_response" : _base64_bytes_to_str (self .ocsp_response ),
109+ "ts" : self .ts ,
110+ "validated" : self .validated ,
111+ }
112+ )
113+
114+ @classmethod
115+ def _deserialize (cls , json_str : str ) -> OCSPResponseValidationResult :
116+ json_obj = json .loads (json_str )
117+
118+ def deserialize_exception (exception_dict : dict | None ) -> Exception | None :
119+ # as pointed out in the serialization method, here we do the best effort deserialization
120+ # for non-RevocationCheckError exceptions. If we can not deserialize the exception, we will
121+ # return a RevocationCheckError with a message indicating the failure.
122+ if not exception_dict :
123+ return
124+ exc_class = exception_dict .get ("class" )
125+ exc_module = exception_dict .get ("module" )
126+ try :
127+ if (
128+ exc_class == "RevocationCheckError"
129+ and exc_module == "snowflake.connector.errors"
130+ ):
131+ return RevocationCheckError (
132+ msg = exception_dict ["msg" ],
133+ errno = exception_dict ["errno" ],
134+ )
135+ else :
136+ module = importlib .import_module (exc_module )
137+ exc_cls = getattr (module , exc_class )
138+ return exc_cls (exception_dict ["msg" ])
139+ except Exception as deserialize_exc :
140+ logger .debug (
141+ f"hitting error { str (deserialize_exc )} while deserializing exception,"
142+ f" the original error error class and message are { exc_class } and { exception_dict ['msg' ]} "
143+ )
144+ return RevocationCheckError (
145+ f"Got error { str (deserialize_exc )} while deserializing ocsp cache, please try "
146+ f"cleaning up the "
147+ f"OCSP cache under directory { OCSP_RESPONSE_VALIDATION_CACHE .file_path } " ,
148+ errno = ER_OCSP_RESPONSE_LOAD_FAILURE ,
149+ )
150+
151+ return OCSPResponseValidationResult (
152+ exception = deserialize_exception (json_obj .get ("exception" )),
153+ issuer = (
154+ Certificate .load (b64decode (json_obj .get ("issuer" )))
155+ if json_obj .get ("issuer" )
156+ else None
157+ ),
158+ subject = (
159+ Certificate .load (b64decode (json_obj .get ("subject" )))
160+ if json_obj .get ("subject" )
161+ else None
162+ ),
163+ cert_id = (
164+ CertId .load (b64decode (json_obj .get ("cert_id" )))
165+ if json_obj .get ("cert_id" )
166+ else None
167+ ),
168+ ocsp_response = (
169+ b64decode (json_obj .get ("ocsp_response" ))
170+ if json_obj .get ("ocsp_response" )
171+ else None
172+ ),
173+ ts = json_obj .get ("ts" ),
174+ validated = json_obj .get ("validated" ),
175+ )
176+
177+
178+ class _OCSPResponseValidationResultCache (SFDictFileCache ):
179+ def _serialize (self ) -> bytes :
180+ entries = {
181+ (
182+ _base64_bytes_to_str (k [0 ]),
183+ _base64_bytes_to_str (k [1 ]),
184+ _base64_bytes_to_str (k [2 ]),
185+ ): (v .expiry .isoformat (), v .entry ._serialize ())
186+ for k , v in self ._cache .items ()
187+ }
188+
189+ return json .dumps (
190+ {
191+ "cache_keys" : list (entries .keys ()),
192+ "cache_items" : list (entries .values ()),
193+ "entry_lifetime" : self ._entry_lifetime .total_seconds (),
194+ "file_path" : str (self .file_path ),
195+ "file_timeout" : self .file_timeout ,
196+ "last_loaded" : (
197+ self .last_loaded .isoformat () if self .last_loaded else None
198+ ),
199+ "telemetry" : self .telemetry ,
200+ "connector_version" : SNOWFLAKE_CONNECTOR_VERSION , # reserved for schema version control
201+ }
202+ ).encode ()
203+
204+ @classmethod
205+ def _deserialize (cls , opened_fd ) -> _OCSPResponseValidationResultCache :
206+ data = json .loads (opened_fd .read ().decode ())
207+ cache_instance = cls (
208+ file_path = data ["file_path" ],
209+ entry_lifetime = int (data ["entry_lifetime" ]),
210+ file_timeout = data ["file_timeout" ],
211+ load_if_file_exists = False ,
212+ )
213+ cache_instance .file_path = os .path .expanduser (data ["file_path" ])
214+ cache_instance .telemetry = data ["telemetry" ]
215+ cache_instance .last_loaded = (
216+ datetime .fromisoformat (data ["last_loaded" ]) if data ["last_loaded" ] else None
217+ )
218+ for k , v in zip (data ["cache_keys" ], data ["cache_items" ]):
219+ cache_instance ._cache [
220+ (b64decode (k [0 ]), b64decode (k [1 ]), b64decode (k [2 ]))
221+ ] = CacheEntry (
222+ datetime .fromisoformat (v [0 ]),
223+ OCSPResponseValidationResult ._deserialize (v [1 ]),
224+ )
225+ return cache_instance
226+
75227
76228try :
77229 OCSP_RESPONSE_VALIDATION_CACHE : SFDictFileCache [
78230 tuple [bytes , bytes , bytes ],
79231 OCSPResponseValidationResult ,
80- ] = SFDictFileCache (
232+ ] = _OCSPResponseValidationResultCache (
81233 entry_lifetime = constants .DAY_IN_SECONDS ,
82234 file_path = {
83235 "linux" : os .path .join (
84- "~" , ".cache" , "snowflake" , "ocsp_response_validation_cache"
236+ "~" , ".cache" , "snowflake" , "ocsp_response_validation_cache.json "
85237 ),
86238 "darwin" : os .path .join (
87- "~" , "Library" , "Caches" , "Snowflake" , "ocsp_response_validation_cache"
239+ "~" ,
240+ "Library" ,
241+ "Caches" ,
242+ "Snowflake" ,
243+ "ocsp_response_validation_cache.json" ,
88244 ),
89245 "windows" : os .path .join (
90246 "~" ,
91247 "AppData" ,
92248 "Local" ,
93249 "Snowflake" ,
94250 "Caches" ,
95- "ocsp_response_validation_cache" ,
251+ "ocsp_response_validation_cache.json " ,
96252 ),
97253 },
98254 )
0 commit comments