Skip to content

Commit cf2a731

Browse files
[async] Applied #2429 to async code - part 2:
Storage client ResultBatch WIF replacing aiohttp with session manager and propagating it down partially - ocsp and session_manager file merge with _ssl_connector - analogically to ProxySupportAdapter
1 parent 01ada93 commit cf2a731

File tree

13 files changed

+320
-177
lines changed

13 files changed

+320
-177
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ repos:
5757
exclude: |
5858
(?x)^(
5959
src/snowflake/connector/session_manager\.py|
60+
src/snowflake/connector/aio/_session_manager\.py|
6061
src/snowflake/connector/vendored/.*
6162
)$
6263
args: [--show-fixes]

src/snowflake/connector/aio/_connection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ async def __open_connection(self):
200200
protocol=self._protocol,
201201
inject_client_pause=self._inject_client_pause,
202202
connection=self,
203-
session_manager=self._session_manager,
203+
session_manager=self._session_manager, # connection shares the session pool used for making Backend related requests
204204
)
205205
logger.debug("REST API object was created: %s:%s", self.host, self.port)
206206

@@ -592,6 +592,7 @@ def _init_connection_parameters(
592592
PLATFORM,
593593
)
594594

595+
# Placeholder attributes; will be initialized in connect()
595596
self._http_config: AioHttpConfig | None = None
596597
self._session_manager: SessionManager | None = None
597598
self._rest = None

src/snowflake/connector/aio/_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def add_retry_params(self, full_url: str) -> str:
567567
include_retry_reason = self._connection._enable_retry_reason_in_query_response
568568
include_retry_params = kwargs.pop("_include_retry_params", False)
569569

570-
async with self._use_session(full_url) as session:
570+
async with self.use_session(full_url) as session:
571571
retry_ctx = RetryCtx(
572572
_include_retry_params=include_retry_params,
573573
_include_retry_reason=include_retry_reason,

src/snowflake/connector/aio/_ocsp_snowflake.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
import os
66
import time
77
from logging import getLogger
8-
from typing import Any
8+
from typing import TYPE_CHECKING, Any
99

10-
import aiohttp
1110
from aiohttp.client_proto import ResponseHandler
1211
from asn1crypto.ocsp import CertId
1312
from asn1crypto.x509 import Certificate
@@ -32,17 +31,22 @@
3231
from snowflake.connector.ocsp_snowflake import SnowflakeOCSP as SnowflakeOCSPSync
3332
from snowflake.connector.url_util import extract_top_level_domain_from_hostname
3433

34+
if TYPE_CHECKING:
35+
from snowflake.connector.aio._session_manager import SessionManager
36+
3537
logger = getLogger(__name__)
3638

3739

3840
class OCSPServer(OCSPServerSync):
39-
async def download_cache_from_server(self, ocsp):
41+
async def download_cache_from_server(
42+
self, ocsp, *, session_manager: SessionManager
43+
):
4044
if self.CACHE_SERVER_ENABLED:
4145
# if any of them is not cache, download the cache file from
4246
# OCSP response cache server.
4347
try:
4448
retval = await OCSPServer._download_ocsp_response_cache(
45-
ocsp, self.CACHE_SERVER_URL
49+
ocsp, self.CACHE_SERVER_URL, session_manager=session_manager
4650
)
4751
if not retval:
4852
raise RevocationCheckError(
@@ -69,7 +73,9 @@ async def download_cache_from_server(self, ocsp):
6973
raise
7074

7175
@staticmethod
72-
async def _download_ocsp_response_cache(ocsp, url, do_retry: bool = True) -> bool:
76+
async def _download_ocsp_response_cache(
77+
ocsp, url, *, session_manager: SessionManager, do_retry: bool = True
78+
) -> bool:
7379
"""Downloads OCSP response cache from the cache server."""
7480
headers = {HTTP_HEADER_USER_AGENT: PYTHON_CONNECTOR_USER_AGENT}
7581
sf_timeout = SnowflakeOCSP.OCSP_CACHE_SERVER_CONNECTION_TIMEOUT
@@ -88,7 +94,7 @@ async def _download_ocsp_response_cache(ocsp, url, do_retry: bool = True) -> boo
8894
if sf_cache_server_url is not None:
8995
url = sf_cache_server_url
9096

91-
async with aiohttp.ClientSession() as session:
97+
async with session_manager.use_session() as session:
9298
max_retry = SnowflakeOCSP.OCSP_CACHE_SERVER_MAX_RETRY if do_retry else 1
9399
sleep_time = 1
94100
backoff = exponential_backoff()()
@@ -174,6 +180,8 @@ async def validate(
174180
self,
175181
hostname: str | None,
176182
connection: ResponseHandler,
183+
*,
184+
session_manager: SessionManager,
177185
no_exception: bool = False,
178186
) -> (
179187
list[
@@ -218,20 +226,31 @@ async def validate(
218226
return None
219227

220228
return await self._validate(
221-
hostname, cert_data, telemetry_data, do_retry, no_exception
229+
hostname,
230+
cert_data,
231+
telemetry_data,
232+
session_manager=session_manager,
233+
do_retry=do_retry,
234+
no_exception=no_exception,
222235
)
223236

224237
async def _validate(
225238
self,
226239
hostname: str | None,
227240
cert_data: list[tuple[Certificate, Certificate]],
228241
telemetry_data: OCSPTelemetryData,
242+
*,
243+
session_manager: SessionManager,
229244
do_retry: bool = True,
230245
no_exception: bool = False,
231246
) -> list[tuple[Exception | None, Certificate, Certificate, CertId, bytes]]:
232247
"""Validate certs sequentially if OCSP response cache server is used."""
233248
results = await self._validate_certificates_sequential(
234-
cert_data, telemetry_data, hostname, do_retry=do_retry
249+
cert_data,
250+
telemetry_data,
251+
hostname=hostname,
252+
do_retry=do_retry,
253+
session_manager=session_manager,
235254
)
236255

237256
SnowflakeOCSP.OCSP_CACHE.update_file(self)
@@ -253,6 +272,8 @@ async def _validate_issue_subject(
253272
issuer: Certificate,
254273
subject: Certificate,
255274
telemetry_data: OCSPTelemetryData,
275+
*,
276+
session_manager: SessionManager,
256277
hostname: str | None = None,
257278
do_retry: bool = True,
258279
) -> tuple[
@@ -275,7 +296,8 @@ async def _validate_issue_subject(
275296
issuer,
276297
subject,
277298
telemetry_data,
278-
hostname,
299+
hostname=hostname,
300+
session_manager=session_manager,
279301
do_retry=do_retry,
280302
cache_key=cache_key,
281303
)
@@ -292,6 +314,8 @@ async def _validate_issue_subject(
292314
async def _check_ocsp_response_cache_server(
293315
self,
294316
cert_data: list[tuple[Certificate, Certificate]],
317+
*,
318+
session_manager: SessionManager,
295319
) -> None:
296320
"""Checks if OCSP response is in cache, and if not it downloads the OCSP response cache from the server.
297321
@@ -308,17 +332,23 @@ async def _check_ocsp_response_cache_server(
308332
break
309333

310334
if not in_cache:
311-
await self.OCSP_CACHE_SERVER.download_cache_from_server(self)
335+
await self.OCSP_CACHE_SERVER.download_cache_from_server(
336+
self, session_manager=session_manager
337+
)
312338

313339
async def _validate_certificates_sequential(
314340
self,
315341
cert_data: list[tuple[Certificate, Certificate]],
316342
telemetry_data: OCSPTelemetryData,
343+
*,
344+
session_manager: SessionManager,
317345
hostname: str | None = None,
318346
do_retry: bool = True,
319347
) -> list[tuple[Exception | None, Certificate, Certificate, CertId, bytes]]:
320348
try:
321-
await self._check_ocsp_response_cache_server(cert_data)
349+
await self._check_ocsp_response_cache_server(
350+
cert_data, session_manager=session_manager
351+
)
322352
except RevocationCheckError as rce:
323353
telemetry_data.set_event_sub_type(
324354
OCSPTelemetryData.ERROR_CODE_MAP[rce.errno]
@@ -339,6 +369,7 @@ async def _validate_certificates_sequential(
339369
hostname=hostname,
340370
telemetry_data=telemetry_data,
341371
do_retry=do_retry,
372+
session_manager=session_manager,
342373
)
343374
for issuer, subject in cert_data
344375
]
@@ -363,6 +394,8 @@ async def validate_by_direct_connection(
363394
issuer: Certificate,
364395
subject: Certificate,
365396
telemetry_data: OCSPTelemetryData,
397+
*,
398+
session_manager: SessionManager,
366399
hostname: str = None,
367400
do_retry: bool = True,
368401
**kwargs: Any,
@@ -377,7 +410,13 @@ async def validate_by_direct_connection(
377410
telemetry_data.set_cache_hit(False)
378411
logger.debug("getting OCSP response from CA's OCSP server")
379412
ocsp_response = await self._fetch_ocsp_response(
380-
req, subject, cert_id, telemetry_data, hostname, do_retry
413+
req,
414+
subject,
415+
cert_id,
416+
telemetry_data,
417+
session_manager=session_manager,
418+
hostname=hostname,
419+
do_retry=do_retry,
381420
)
382421
else:
383422
ocsp_url = self.extract_ocsp_url(subject)
@@ -428,6 +467,8 @@ async def _fetch_ocsp_response(
428467
subject,
429468
cert_id,
430469
telemetry_data,
470+
*,
471+
session_manager: SessionManager,
431472
hostname=None,
432473
do_retry: bool = True,
433474
):
@@ -497,7 +538,7 @@ async def _fetch_ocsp_response(
497538
if not self.is_enabled_fail_open():
498539
sf_max_retry = SnowflakeOCSP.CA_OCSP_RESPONDER_MAX_RETRY_FC
499540

500-
async with aiohttp.ClientSession() as session:
541+
async with session_manager.use_session() as session:
501542
max_retry = sf_max_retry if do_retry else 1
502543
sleep_time = 1
503544
backoff = exponential_backoff()()

src/snowflake/connector/aio/_result_batch.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
raise_failed_request_error,
1414
raise_okta_unauthorized_error,
1515
)
16+
from snowflake.connector.aio._session_manager import SessionManager
1617
from snowflake.connector.aio._time_util import TimerContextManager
1718
from snowflake.connector.arrow_context import ArrowConverterContext
1819
from snowflake.connector.backoff_policies import exponential_backoff
@@ -111,6 +112,7 @@ def remote_chunk_info(c: dict[str, Any]) -> RemoteChunkInfo:
111112
column_converters,
112113
cursor._use_dict_result,
113114
json_result_force_utf8_decoding=cursor._connection._json_result_force_utf8_decoding,
115+
session_manager=cursor._connection._session_manager.clone(),
114116
)
115117
for c in chunks
116118
]
@@ -125,6 +127,7 @@ def remote_chunk_info(c: dict[str, Any]) -> RemoteChunkInfo:
125127
cursor._connection._numpy,
126128
schema,
127129
cursor._connection._arrow_number_to_decimal,
130+
session_manager=cursor._connection._session_manager.clone(),
128131
)
129132
for c in chunks
130133
]
@@ -137,6 +140,7 @@ def remote_chunk_info(c: dict[str, Any]) -> RemoteChunkInfo:
137140
schema,
138141
column_converters,
139142
cursor._use_dict_result,
143+
session_manager=cursor._connection._session_manager.clone(),
140144
)
141145
elif rowset_b64 is not None:
142146
first_chunk = ArrowResultBatch.from_data(
@@ -147,6 +151,7 @@ def remote_chunk_info(c: dict[str, Any]) -> RemoteChunkInfo:
147151
cursor._connection._numpy,
148152
schema,
149153
cursor._connection._arrow_number_to_decimal,
154+
session_manager=cursor._connection._session_manager.clone(),
150155
)
151156
else:
152157
logger.error(f"Don't know how to construct ResultBatches from response: {data}")
@@ -158,6 +163,7 @@ def remote_chunk_info(c: dict[str, Any]) -> RemoteChunkInfo:
158163
cursor._connection._numpy,
159164
schema,
160165
cursor._connection._arrow_number_to_decimal,
166+
session_manager=cursor._connection._session_manager.clone(),
161167
)
162168

163169
return [first_chunk] + rest_of_chunks
@@ -204,7 +210,7 @@ async def _download(
204210
async def download_chunk(http_session):
205211
response, content, encoding = None, None, None
206212
logger.debug(
207-
f"downloading result batch id: {self.id} with existing session {http_session}"
213+
f"downloading result batch id: {self.id} with session {http_session}"
208214
)
209215
response = await http_session.get(**request_data)
210216
if response.status == OK:
@@ -234,18 +240,29 @@ async def download_chunk(http_session):
234240
request_data["timeout"] = aiohttp.ClientTimeout(
235241
total=DOWNLOAD_TIMEOUT
236242
)
237-
# Try to reuse a connection if possible
238-
if connection and connection._rest is not None:
239-
async with connection._rest._use_session() as session:
243+
# Use SessionManager with same fallback pattern as sync version
244+
if (
245+
connection
246+
and connection.rest
247+
and connection.rest.session_manager is not None
248+
):
249+
# If connection was explicitly passed and not closed yet - we can reuse SessionManager with session pooling
250+
async with connection.rest.use_session() as session:
240251
logger.debug(
241252
f"downloading result batch id: {self.id} with existing session {session}"
242253
)
243254
response, content, encoding = await download_chunk(session)
255+
elif self._session_manager is not None:
256+
# If connection is not accessible or was already closed, but cursors are now used to fetch the data - we will only reuse the http setup (through cloned SessionManager without session pooling)
257+
async with self._session_manager.use_session() as session:
258+
response, content, encoding = await download_chunk(session)
244259
else:
245-
async with aiohttp.ClientSession() as session:
246-
logger.debug(
247-
f"downloading result batch id: {self.id} with new session"
248-
)
260+
# If there was no session manager cloned, then we are using a default Session Manager setup, since it is very unlikely to enter this part outside of testing
261+
logger.debug(
262+
f"downloading result batch id: {self.id} with new session through local session manager"
263+
)
264+
local_session_manager = SessionManager(use_pooling=False)
265+
async with local_session_manager.use_session() as session:
249266
response, content, encoding = await download_chunk(session)
250267

251268
if response.status == OK:

0 commit comments

Comments
 (0)