55import os
66import time
77from logging import getLogger
8- from typing import Any
8+ from typing import TYPE_CHECKING , Any
99
10- import aiohttp
1110from aiohttp .client_proto import ResponseHandler
1211from asn1crypto .ocsp import CertId
1312from asn1crypto .x509 import Certificate
3231from snowflake .connector .ocsp_snowflake import SnowflakeOCSP as SnowflakeOCSPSync
3332from snowflake .connector .url_util import extract_top_level_domain_from_hostname
3433
34+ if TYPE_CHECKING :
35+ from snowflake .connector .aio ._session_manager import SessionManager
36+
3537logger = getLogger (__name__ )
3638
3739
3840class OCSPServer (OCSPServerSync ):
39- async def download_cache_from_server (self , ocsp ):
41+ async def download_cache_from_server (
42+ self , ocsp , * , session_manager : SessionManager
43+ ):
4044 if self .CACHE_SERVER_ENABLED :
4145 # if any of them is not cache, download the cache file from
4246 # OCSP response cache server.
4347 try :
4448 retval = await OCSPServer ._download_ocsp_response_cache (
45- ocsp , self .CACHE_SERVER_URL
49+ ocsp , self .CACHE_SERVER_URL , session_manager = session_manager
4650 )
4751 if not retval :
4852 raise RevocationCheckError (
@@ -69,7 +73,9 @@ async def download_cache_from_server(self, ocsp):
6973 raise
7074
7175 @staticmethod
72- async def _download_ocsp_response_cache (ocsp , url , do_retry : bool = True ) -> bool :
76+ async def _download_ocsp_response_cache (
77+ ocsp , url , * , session_manager : SessionManager , do_retry : bool = True
78+ ) -> bool :
7379 """Downloads OCSP response cache from the cache server."""
7480 headers = {HTTP_HEADER_USER_AGENT : PYTHON_CONNECTOR_USER_AGENT }
7581 sf_timeout = SnowflakeOCSP .OCSP_CACHE_SERVER_CONNECTION_TIMEOUT
@@ -88,7 +94,7 @@ async def _download_ocsp_response_cache(ocsp, url, do_retry: bool = True) -> boo
8894 if sf_cache_server_url is not None :
8995 url = sf_cache_server_url
9096
91- async with aiohttp . ClientSession () as session :
97+ async with session_manager . use_session () as session :
9298 max_retry = SnowflakeOCSP .OCSP_CACHE_SERVER_MAX_RETRY if do_retry else 1
9399 sleep_time = 1
94100 backoff = exponential_backoff ()()
@@ -174,6 +180,8 @@ async def validate(
174180 self ,
175181 hostname : str | None ,
176182 connection : ResponseHandler ,
183+ * ,
184+ session_manager : SessionManager ,
177185 no_exception : bool = False ,
178186 ) -> (
179187 list [
@@ -218,20 +226,31 @@ async def validate(
218226 return None
219227
220228 return await self ._validate (
221- hostname , cert_data , telemetry_data , do_retry , no_exception
229+ hostname ,
230+ cert_data ,
231+ telemetry_data ,
232+ session_manager = session_manager ,
233+ do_retry = do_retry ,
234+ no_exception = no_exception ,
222235 )
223236
224237 async def _validate (
225238 self ,
226239 hostname : str | None ,
227240 cert_data : list [tuple [Certificate , Certificate ]],
228241 telemetry_data : OCSPTelemetryData ,
242+ * ,
243+ session_manager : SessionManager ,
229244 do_retry : bool = True ,
230245 no_exception : bool = False ,
231246 ) -> list [tuple [Exception | None , Certificate , Certificate , CertId , bytes ]]:
232247 """Validate certs sequentially if OCSP response cache server is used."""
233248 results = await self ._validate_certificates_sequential (
234- cert_data , telemetry_data , hostname , do_retry = do_retry
249+ cert_data ,
250+ telemetry_data ,
251+ hostname = hostname ,
252+ do_retry = do_retry ,
253+ session_manager = session_manager ,
235254 )
236255
237256 SnowflakeOCSP .OCSP_CACHE .update_file (self )
@@ -253,6 +272,8 @@ async def _validate_issue_subject(
253272 issuer : Certificate ,
254273 subject : Certificate ,
255274 telemetry_data : OCSPTelemetryData ,
275+ * ,
276+ session_manager : SessionManager ,
256277 hostname : str | None = None ,
257278 do_retry : bool = True ,
258279 ) -> tuple [
@@ -275,7 +296,8 @@ async def _validate_issue_subject(
275296 issuer ,
276297 subject ,
277298 telemetry_data ,
278- hostname ,
299+ hostname = hostname ,
300+ session_manager = session_manager ,
279301 do_retry = do_retry ,
280302 cache_key = cache_key ,
281303 )
@@ -292,6 +314,8 @@ async def _validate_issue_subject(
292314 async def _check_ocsp_response_cache_server (
293315 self ,
294316 cert_data : list [tuple [Certificate , Certificate ]],
317+ * ,
318+ session_manager : SessionManager ,
295319 ) -> None :
296320 """Checks if OCSP response is in cache, and if not it downloads the OCSP response cache from the server.
297321
@@ -308,17 +332,23 @@ async def _check_ocsp_response_cache_server(
308332 break
309333
310334 if not in_cache :
311- await self .OCSP_CACHE_SERVER .download_cache_from_server (self )
335+ await self .OCSP_CACHE_SERVER .download_cache_from_server (
336+ self , session_manager = session_manager
337+ )
312338
313339 async def _validate_certificates_sequential (
314340 self ,
315341 cert_data : list [tuple [Certificate , Certificate ]],
316342 telemetry_data : OCSPTelemetryData ,
343+ * ,
344+ session_manager : SessionManager ,
317345 hostname : str | None = None ,
318346 do_retry : bool = True ,
319347 ) -> list [tuple [Exception | None , Certificate , Certificate , CertId , bytes ]]:
320348 try :
321- await self ._check_ocsp_response_cache_server (cert_data )
349+ await self ._check_ocsp_response_cache_server (
350+ cert_data , session_manager = session_manager
351+ )
322352 except RevocationCheckError as rce :
323353 telemetry_data .set_event_sub_type (
324354 OCSPTelemetryData .ERROR_CODE_MAP [rce .errno ]
@@ -339,6 +369,7 @@ async def _validate_certificates_sequential(
339369 hostname = hostname ,
340370 telemetry_data = telemetry_data ,
341371 do_retry = do_retry ,
372+ session_manager = session_manager ,
342373 )
343374 for issuer , subject in cert_data
344375 ]
@@ -363,6 +394,8 @@ async def validate_by_direct_connection(
363394 issuer : Certificate ,
364395 subject : Certificate ,
365396 telemetry_data : OCSPTelemetryData ,
397+ * ,
398+ session_manager : SessionManager ,
366399 hostname : str = None ,
367400 do_retry : bool = True ,
368401 ** kwargs : Any ,
@@ -377,7 +410,13 @@ async def validate_by_direct_connection(
377410 telemetry_data .set_cache_hit (False )
378411 logger .debug ("getting OCSP response from CA's OCSP server" )
379412 ocsp_response = await self ._fetch_ocsp_response (
380- req , subject , cert_id , telemetry_data , hostname , do_retry
413+ req ,
414+ subject ,
415+ cert_id ,
416+ telemetry_data ,
417+ session_manager = session_manager ,
418+ hostname = hostname ,
419+ do_retry = do_retry ,
381420 )
382421 else :
383422 ocsp_url = self .extract_ocsp_url (subject )
@@ -428,6 +467,8 @@ async def _fetch_ocsp_response(
428467 subject ,
429468 cert_id ,
430469 telemetry_data ,
470+ * ,
471+ session_manager : SessionManager ,
431472 hostname = None ,
432473 do_retry : bool = True ,
433474 ):
@@ -497,7 +538,7 @@ async def _fetch_ocsp_response(
497538 if not self .is_enabled_fail_open ():
498539 sf_max_retry = SnowflakeOCSP .CA_OCSP_RESPONDER_MAX_RETRY_FC
499540
500- async with aiohttp . ClientSession () as session :
541+ async with session_manager . use_session () as session :
501542 max_retry = sf_max_retry if do_retry else 1
502543 sleep_time = 1
503544 backoff = exponential_backoff ()()
0 commit comments