60
60
from ..util_text import split_statements
61
61
from ._cursor import SnowflakeCursor
62
62
from ._network import SnowflakeRestful
63
+ from ._time_util import HeartBeatTimer
63
64
from .auth import Auth , AuthByDefault , AuthByPlugin
64
65
65
66
logger = getLogger (__name__ )
@@ -87,7 +88,19 @@ def __init__(
87
88
# get the imported modules from sys.modules
88
89
# self._log_telemetry_imported_packages() # TODO: async telemetry support
89
90
# check SNOW-1218851 for long term improvement plan to refactor ocsp code
90
- # atexit.register(self._close_at_exit) # TODO: async atexit support/test
91
+ atexit .register (self ._close_at_exit )
92
+
93
+ def __enter__ (self ):
94
+ # async connection does not support sync context manager
95
+ raise TypeError (
96
+ "'SnowflakeConnection' object does not support the context manager protocol"
97
+ )
98
+
99
+ def __exit__ (self , exc_type , exc_val , exc_tb ):
100
+ # async connection does not support sync context manager
101
+ raise TypeError (
102
+ "'SnowflakeConnection' object does not support the context manager protocol"
103
+ )
91
104
92
105
async def __aenter__ (self ) -> SnowflakeConnection :
93
106
"""Context manager."""
@@ -135,7 +148,9 @@ async def __open_connection(self):
135
148
)
136
149
137
150
if ".privatelink.snowflakecomputing." in self .host :
138
- SnowflakeConnection .setup_ocsp_privatelink (self .application , self .host )
151
+ await SnowflakeConnection .setup_ocsp_privatelink (
152
+ self .application , self .host
153
+ )
139
154
else :
140
155
if "SF_OCSP_RESPONSE_CACHE_SERVER_URL" in os .environ :
141
156
del os .environ ["SF_OCSP_RESPONSE_CACHE_SERVER_URL" ]
@@ -164,11 +179,10 @@ async def __open_connection(self):
164
179
PARAMETER_CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY
165
180
] = self ._validate_client_session_keep_alive_heartbeat_frequency ()
166
181
167
- # TODO: client_prefetch_threads support
168
- # if self.client_prefetch_threads:
169
- # self._session_parameters[PARAMETER_CLIENT_PREFETCH_THREADS] = (
170
- # self._validate_client_prefetch_threads()
171
- # )
182
+ if self .client_prefetch_threads :
183
+ self ._session_parameters [PARAMETER_CLIENT_PREFETCH_THREADS ] = (
184
+ self ._validate_client_prefetch_threads ()
185
+ )
172
186
173
187
# Setup authenticator
174
188
auth = Auth (self .rest )
@@ -203,7 +217,7 @@ async def __open_connection(self):
203
217
elif self ._authenticator == DEFAULT_AUTHENTICATOR :
204
218
self .auth_class = AuthByDefault (
205
219
password = self ._password ,
206
- timeout = self ._login_timeout ,
220
+ timeout = self .login_timeout ,
207
221
backoff_generator = self ._backoff_generator ,
208
222
)
209
223
else :
@@ -222,10 +236,21 @@ async def __open_connection(self):
222
236
# This will be called after the heartbeat frequency has actually been set.
223
237
# By this point it should have been decided if the heartbeat has to be enabled
224
238
# and what would the heartbeat frequency be
225
- # TODO: implement asyncio heartbeat/timer
226
- raise NotImplementedError (
227
- "asyncio client_session_keep_alive is not supported"
239
+ await self ._add_heartbeat ()
240
+
241
+ async def _add_heartbeat (self ) -> None :
242
+ if not self ._heartbeat_task :
243
+ self ._heartbeat_task = HeartBeatTimer (
244
+ self .client_session_keep_alive_heartbeat_frequency , self ._heartbeat_tick
228
245
)
246
+ await self ._heartbeat_task .start ()
247
+ logger .debug ("started heartbeat" )
248
+
249
+ async def _heartbeat_tick (self ) -> None :
250
+ """Execute a hearbeat if connection isn't closed yet."""
251
+ if not self .is_closed ():
252
+ logger .debug ("heartbeating!" )
253
+ await self .rest ._heartbeat ()
229
254
230
255
async def _all_async_queries_finished (self ) -> bool :
231
256
"""Checks whether all async queries started by this Connection have finished executing."""
@@ -322,6 +347,13 @@ async def _authenticate(self, auth_instance: AuthByPlugin):
322
347
continue
323
348
break
324
349
350
+ async def _cancel_heartbeat (self ) -> None :
351
+ """Cancel a heartbeat thread."""
352
+ if self ._heartbeat_task :
353
+ await self ._heartbeat_task .stop ()
354
+ self ._heartbeat_task = None
355
+ logger .debug ("stopped heartbeat" )
356
+
325
357
def _init_connection_parameters (
326
358
self ,
327
359
connection_init_kwargs : dict ,
@@ -353,7 +385,7 @@ def _init_connection_parameters(
353
385
for name , (value , _ ) in DEFAULT_CONFIGURATION .items ():
354
386
setattr (self , f"_{ name } " , value )
355
387
356
- self .heartbeat_thread = None
388
+ self ._heartbeat_task = None
357
389
is_kwargs_empty = not connection_init_kwargs
358
390
359
391
if "application" not in connection_init_kwargs :
@@ -403,7 +435,7 @@ async def _cancel_query(
403
435
404
436
def _close_at_exit (self ):
405
437
with suppress (Exception ):
406
- asyncio .get_event_loop (). run_until_complete (self .close (retry = False ))
438
+ asyncio .run (self .close (retry = False ))
407
439
408
440
async def _get_query_status (
409
441
self , sf_qid : str
@@ -587,8 +619,7 @@ async def close(self, retry: bool = True) -> None:
587
619
# will hang if the application doesn't close the connection and
588
620
# CLIENT_SESSION_KEEP_ALIVE is set, because the heartbeat runs on
589
621
# a separate thread.
590
- # TODO: async heartbeat support
591
- # self._cancel_heartbeat()
622
+ await self ._cancel_heartbeat ()
592
623
593
624
# close telemetry first, since it needs rest to send remaining data
594
625
logger .info ("closed" )
@@ -600,7 +631,12 @@ async def close(self, retry: bool = True) -> None:
600
631
and not self ._server_session_keep_alive
601
632
):
602
633
logger .info ("No async queries seem to be running, deleting session" )
603
- await self .rest .delete_session (retry = retry )
634
+ try :
635
+ await self .rest .delete_session (retry = retry )
636
+ except Exception as e :
637
+ logger .debug (
638
+ "Exception encountered in deleting session. ignoring...: %s" , e
639
+ )
604
640
else :
605
641
logger .info (
606
642
"There are {} async queries still running, not deleting session" .format (
@@ -837,33 +873,17 @@ async def get_query_status_throw_if_error(self, sf_qid: str) -> QueryStatus:
837
873
"""
838
874
status , status_resp = await self ._get_query_status (sf_qid )
839
875
self ._cache_query_status (sf_qid , status )
840
- queries = status_resp ["data" ]["queries" ]
841
876
if self .is_an_error (status ):
842
- if sf_qid in self ._async_sfqids :
843
- self ._async_sfqids .pop (sf_qid , None )
844
- message = status_resp .get ("message" )
845
- if message is None :
846
- message = ""
847
- code = queries [0 ].get ("errorCode" , - 1 )
848
- sql_state = None
849
- if "data" in status_resp :
850
- message += (
851
- queries [0 ].get ("errorMessage" , "" ) if len (queries ) > 0 else ""
852
- )
853
- sql_state = status_resp ["data" ].get ("sqlState" )
854
- Error .errorhandler_wrapper (
855
- self ,
856
- None ,
857
- ProgrammingError ,
858
- {
859
- "msg" : message ,
860
- "errno" : int (code ),
861
- "sqlstate" : sql_state ,
862
- "sfqid" : sf_qid ,
863
- },
864
- )
877
+ self ._process_error_query_status (sf_qid , status_resp )
865
878
return status
866
879
880
+ @staticmethod
881
+ async def setup_ocsp_privatelink (app , hostname ) -> None :
882
+ async with SnowflakeConnection .OCSP_ENV_LOCK :
883
+ ocsp_cache_server = f"http://ocsp.{ hostname } /ocsp_response_cache.json"
884
+ os .environ ["SF_OCSP_RESPONSE_CACHE_SERVER_URL" ] = ocsp_cache_server
885
+ logger .debug ("OCSP Cache Server is updated: %s" , ocsp_cache_server )
886
+
867
887
async def rollback (self ) -> None :
868
888
"""Rolls back the current transaction."""
869
889
await self .cursor ().execute ("ROLLBACK" )
0 commit comments