1616
1717import asyncio
1818import os
19- from typing import Any
19+ from typing import Any , Union
2020
2121import asyncpg
2222import sqlalchemy
2323import sqlalchemy .ext .asyncio
2424
2525from google .cloud .sql .connector import Connector
26+ from google .cloud .sql .connector import DefaultResolver
27+ from google .cloud .sql .connector import DnsResolver
2628
2729
2830async def create_sqlalchemy_engine (
@@ -31,6 +33,7 @@ async def create_sqlalchemy_engine(
3133 password : str ,
3234 db : str ,
3335 refresh_strategy : str = "background" ,
36+ resolver : Union [type [DefaultResolver ], type [DnsResolver ]] = DefaultResolver ,
3437) -> tuple [sqlalchemy .ext .asyncio .engine .AsyncEngine , Connector ]:
3538 """Creates a connection pool for a Cloud SQL instance and returns the pool
3639 and the connector. Callers are responsible for closing the pool and the
@@ -64,9 +67,16 @@ async def create_sqlalchemy_engine(
6467 Refresh strategy for the Cloud SQL Connector. Can be one of "lazy"
6568 or "background". For serverless environments use "lazy" to avoid
6669 errors resulting from CPU being throttled.
70+ resolver (Optional[google.cloud.sql.connector.DefaultResolver]):
71+ Resolver class for resolving instance connection name. Use
72+ google.cloud.sql.connector.DnsResolver when resolving DNS domain
73+ names or google.cloud.sql.connector.DefaultResolver for regular
74+ instance connection names ("my-project:my-region:my-instance").
6775 """
6876 loop = asyncio .get_running_loop ()
69- connector = Connector (loop = loop , refresh_strategy = refresh_strategy )
77+ connector = Connector (
78+ loop = loop , refresh_strategy = refresh_strategy , resolver = resolver
79+ )
7080
7181 async def getconn () -> asyncpg .Connection :
7282 conn : asyncpg .Connection = await connector .connect_async (
@@ -183,6 +193,24 @@ async def test_lazy_sqlalchemy_connection_with_asyncpg() -> None:
183193 await connector .close_async ()
184194
185195
196+ async def test_custom_SAN_with_dns_sqlalchemy_connection_with_asyncpg () -> None :
197+ """Basic test to get time from database."""
198+ inst_conn_name = os .environ ["POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME" ]
199+ user = os .environ ["POSTGRES_USER" ]
200+ password = os .environ ["POSTGRES_CUSTOMER_CAS_PASS" ]
201+ db = os .environ ["POSTGRES_DB" ]
202+
203+ pool , connector = await create_sqlalchemy_engine (
204+ inst_conn_name , user , password , db , resolver = DnsResolver
205+ )
206+
207+ async with pool .connect () as conn :
208+ res = (await conn .execute (sqlalchemy .text ("SELECT 1" ))).fetchone ()
209+ assert res [0 ] == 1
210+
211+ await connector .close_async ()
212+
213+
186214async def test_connection_with_asyncpg () -> None :
187215 """Basic test to get time from database."""
188216 inst_conn_name = os .environ ["POSTGRES_CONNECTION_NAME" ]
0 commit comments