2020from functools import partial
2121import logging
2222import os
23+ import socket
2324from threading import Thread
2425from types import TracebackType
25- from typing import Any , Optional , Union
26+ from typing import Any , Callable , Optional , Union
2627
2728import google .auth
2829from google .auth .credentials import Credentials
3536from google .cloud .sql .connector .enums import RefreshStrategy
3637from google .cloud .sql .connector .instance import RefreshAheadCache
3738from google .cloud .sql .connector .lazy import LazyRefreshCache
39+ from google .cloud .sql .connector .monitored_cache import MonitoredCache
3840import google .cloud .sql .connector .pg8000 as pg8000
3941import google .cloud .sql .connector .pymysql as pymysql
4042import google .cloud .sql .connector .pytds as pytds
4648logger = logging .getLogger (name = __name__ )
4749
4850ASYNC_DRIVERS = ["asyncpg" ]
51+ SERVER_PROXY_PORT = 3307
4952_DEFAULT_SCHEME = "https://"
5053_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"
5154_SQLADMIN_HOST_TEMPLATE = "sqladmin.{universe_domain}"
@@ -67,6 +70,7 @@ def __init__(
6770 universe_domain : Optional [str ] = None ,
6871 refresh_strategy : str | RefreshStrategy = RefreshStrategy .BACKGROUND ,
6972 resolver : type [DefaultResolver ] | type [DnsResolver ] = DefaultResolver ,
73+ failover_period : int = 30 ,
7074 ) -> None :
7175 """Initializes a Connector instance.
7276
@@ -114,6 +118,11 @@ def __init__(
114118 name. To resolve a DNS record to an instance connection name, use
115119 DnsResolver.
116120 Default: DefaultResolver
121+
122+ failover_period (int): The time interval in seconds between each
123+ attempt to check if a failover has occured for a given instance.
124+ Must be used with `resolver=DnsResolver` to have any effect.
125+ Default: 30
117126 """
118127 # if refresh_strategy is str, convert to RefreshStrategy enum
119128 if isinstance (refresh_strategy , str ):
@@ -143,9 +152,7 @@ def __init__(
143152 )
144153 # initialize dict to store caches, key is a tuple consisting of instance
145154 # connection name string and enable_iam_auth boolean flag
146- self ._cache : dict [
147- tuple [str , bool ], Union [RefreshAheadCache , LazyRefreshCache ]
148- ] = {}
155+ self ._cache : dict [tuple [str , bool ], MonitoredCache ] = {}
149156 self ._client : Optional [CloudSQLClient ] = None
150157
151158 # initialize credentials
@@ -167,6 +174,7 @@ def __init__(
167174 self ._enable_iam_auth = enable_iam_auth
168175 self ._user_agent = user_agent
169176 self ._resolver = resolver ()
177+ self ._failover_period = failover_period
170178 # if ip_type is str, convert to IPTypes enum
171179 if isinstance (ip_type , str ):
172180 ip_type = IPTypes ._from_str (ip_type )
@@ -285,15 +293,19 @@ async def connect_async(
285293 driver = driver ,
286294 )
287295 enable_iam_auth = kwargs .pop ("enable_iam_auth" , self ._enable_iam_auth )
288- if (instance_connection_string , enable_iam_auth ) in self ._cache :
289- cache = self ._cache [(instance_connection_string , enable_iam_auth )]
296+
297+ conn_name = await self ._resolver .resolve (instance_connection_string )
298+ # Cache entry must exist and not be closed
299+ if (str (conn_name ), enable_iam_auth ) in self ._cache and not self ._cache [
300+ (str (conn_name ), enable_iam_auth )
301+ ].closed :
302+ monitored_cache = self ._cache [(str (conn_name ), enable_iam_auth )]
290303 else :
291- conn_name = await self ._resolver .resolve (instance_connection_string )
292304 if self ._refresh_strategy == RefreshStrategy .LAZY :
293305 logger .debug (
294306 f"['{ conn_name } ']: Refresh strategy is set to lazy refresh"
295307 )
296- cache = LazyRefreshCache (
308+ cache : Union [ LazyRefreshCache , RefreshAheadCache ] = LazyRefreshCache (
297309 conn_name ,
298310 self ._client ,
299311 self ._keys ,
@@ -309,8 +321,14 @@ async def connect_async(
309321 self ._keys ,
310322 enable_iam_auth ,
311323 )
324+ # wrap cache as a MonitoredCache
325+ monitored_cache = MonitoredCache (
326+ cache ,
327+ self ._failover_period ,
328+ self ._resolver ,
329+ )
312330 logger .debug (f"['{ conn_name } ']: Connection info added to cache" )
313- self ._cache [(instance_connection_string , enable_iam_auth )] = cache
331+ self ._cache [(str ( conn_name ) , enable_iam_auth )] = monitored_cache
314332
315333 connect_func = {
316334 "pymysql" : pymysql .connect ,
@@ -321,7 +339,7 @@ async def connect_async(
321339
322340 # only accept supported database drivers
323341 try :
324- connector = connect_func [driver ]
342+ connector : Callable = connect_func [driver ] # type: ignore
325343 except KeyError :
326344 raise KeyError (f"Driver '{ driver } ' is not supported." )
327345
@@ -339,14 +357,14 @@ async def connect_async(
339357
340358 # attempt to get connection info for Cloud SQL instance
341359 try :
342- conn_info = await cache .connect_info ()
360+ conn_info = await monitored_cache .connect_info ()
343361 # validate driver matches intended database engine
344362 DriverMapping .validate_engine (driver , conn_info .database_version )
345363 ip_address = conn_info .get_preferred_ip (ip_type )
346364 except Exception :
347365 # with an error from Cloud SQL Admin API call or IP type, invalidate
348366 # the cache and re-raise the error
349- await self ._remove_cached (instance_connection_string , enable_iam_auth )
367+ await self ._remove_cached (str ( conn_name ) , enable_iam_auth )
350368 raise
351369 logger .debug (f"['{ conn_info .conn_name } ']: Connecting to { ip_address } :3307" )
352370 # format `user` param for automatic IAM database authn
@@ -367,18 +385,28 @@ async def connect_async(
367385 await conn_info .create_ssl_context (enable_iam_auth ),
368386 ** kwargs ,
369387 )
370- # synchronous drivers are blocking and run using executor
388+ # Create socket with SSLContext for sync drivers
389+ ctx = await conn_info .create_ssl_context (enable_iam_auth )
390+ sock = ctx .wrap_socket (
391+ socket .create_connection ((ip_address , SERVER_PROXY_PORT )),
392+ server_hostname = ip_address ,
393+ )
394+ # If this connection was opened using a domain name, then store it
395+ # for later in case we need to forcibly close it on failover.
396+ if conn_info .conn_name .domain_name :
397+ monitored_cache .sockets .append (sock )
398+ # Synchronous drivers are blocking and run using executor
371399 connect_partial = partial (
372400 connector ,
373401 ip_address ,
374- await conn_info . create_ssl_context ( enable_iam_auth ) ,
402+ sock ,
375403 ** kwargs ,
376404 )
377405 return await self ._loop .run_in_executor (None , connect_partial )
378406
379407 except Exception :
380408 # with any exception, we attempt a force refresh, then throw the error
381- await cache .force_refresh ()
409+ await monitored_cache .force_refresh ()
382410 raise
383411
384412 async def _remove_cached (
@@ -456,6 +484,7 @@ async def create_async_connector(
456484 universe_domain : Optional [str ] = None ,
457485 refresh_strategy : str | RefreshStrategy = RefreshStrategy .BACKGROUND ,
458486 resolver : type [DefaultResolver ] | type [DnsResolver ] = DefaultResolver ,
487+ failover_period : int = 30 ,
459488) -> Connector :
460489 """Helper function to create Connector object for asyncio connections.
461490
@@ -507,6 +536,11 @@ async def create_async_connector(
507536 DnsResolver.
508537 Default: DefaultResolver
509538
539+ failover_period (int): The time interval in seconds between each
540+ attempt to check if a failover has occured for a given instance.
541+ Must be used with `resolver=DnsResolver` to have any effect.
542+ Default: 30
543+
510544 Returns:
511545 A Connector instance configured with running event loop.
512546 """
@@ -525,4 +559,5 @@ async def create_async_connector(
525559 universe_domain = universe_domain ,
526560 refresh_strategy = refresh_strategy ,
527561 resolver = resolver ,
562+ failover_period = failover_period ,
528563 )
0 commit comments