Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions google/cloud/sql/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,43 @@ async def _get_metadata(
"database_version": ret_dict["databaseVersion"],
}

async def resolve_connect_settings(
self,
dns_name: str,
location: str,
) -> dict[str, Any]:
"""Asynchronously calls the resolveConnectSettings endpoint to resolve a
PSC DNS name to a connection name.

Args:
dns_name (str): The DNS name of the Cloud SQL instance.
location (str): The region/location of the instance.

Returns:
A dictionary containing the resolve response (e.g. connectionName).
"""
headers = {
"Authorization": f"Bearer {self._credentials.token}",
}

url = f"{self._sqladmin_api_endpoint}/sql/{API_VERSION}/dns/{dns_name}/locations/{location}:resolveConnectSettings"

resp = await self._client.get(url, headers=headers)
if resp.status >= 500:
resp = await retry_50x(self._client.get, url, headers=headers)
try:
ret_dict = await resp.json()
if resp.status >= 400:
message = ret_dict.get("error", {}).get("message")
if message:
resp.reason = message
except Exception:
pass
finally:
resp.raise_for_status()

return ret_dict

async def _get_ephemeral(
self,
project: str,
Expand Down
32 changes: 30 additions & 2 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import asyncio
from functools import partial
import ipaddress
import logging
import os
import socket
Expand Down Expand Up @@ -49,6 +50,23 @@

logger = logging.getLogger(name=__name__)


def _is_ip_address(ip: str) -> bool:
try:
ipaddress.ip_address(ip)
return True
except ValueError:
return False


def _get_fallback_ip(current_ip: str, ip_addresses: dict[str, str]) -> str:
if _is_ip_address(current_ip):
return current_ip
fallback = ip_addresses.get("PRIVATE")
if fallback is None:
fallback = ip_addresses.get("PRIMARY")
return fallback if fallback else current_ip

ASYNC_DRIVERS = ["asyncpg"]
SERVER_PROXY_PORT = 3307
_DEFAULT_SCHEME = "https://"
Expand Down Expand Up @@ -316,6 +334,8 @@ async def connect_async(
user_agent=self._user_agent,
driver=driver,
)
if hasattr(self._resolver, "set_client"):
self._resolver.set_client(self._client)
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)

conn_name = await self._resolver.resolve(instance_connection_string)
Expand Down Expand Up @@ -405,17 +425,25 @@ async def connect_async(
"using it to connect"
)
else:
fallback_ip = _get_fallback_ip(
ip_address, conn_info.ip_addrs
)
logger.debug(
f"['{instance_connection_string}']: Custom DNS name "
f"'{conn_info.conn_name.domain_name}' resolved but returned no "
f"entries, using '{ip_address}' from instance metadata"
f"entries, using '{fallback_ip}' from instance metadata"
)
ip_address = fallback_ip
except Exception as e:
fallback_ip = _get_fallback_ip(
ip_address, conn_info.ip_addrs
)
logger.debug(
f"['{instance_connection_string}']: Custom DNS name "
f"'{conn_info.conn_name.domain_name}' did not resolve to an IP "
f"address: {e}, using '{ip_address}' from instance metadata"
f"address: {e}, using '{fallback_ip}' from instance metadata"
)
ip_address = fallback_ip

logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307")
# format `user` param for automatic IAM database authn
Expand Down
136 changes: 101 additions & 35 deletions google/cloud/sql/connector/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
import re
from typing import Any, List

import dns.asyncresolver

Expand All @@ -24,6 +25,10 @@
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.exceptions import DnsResolutionError

PSC_DNS_PATTERN = re.compile(
r"^([a-f0-9]{12})\.([^.]+)\.([a-z0-9]+-[a-z0-9]+)\.(sql|sql-psa|sql-psc)\.goog\.?$"
)


class DefaultResolver:
"""DefaultResolver simply validates and parses instance connection name."""
Expand All @@ -38,54 +43,115 @@ class DnsResolver(dns.asyncresolver.Resolver):
TXT records in DNS.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._client: Any = None

def set_client(self, client: Any) -> None:
self._client = client

async def resolve(self, dns: str) -> ConnectionName: # type: ignore
try:
conn_name = _parse_connection_name(dns)
except ValueError:
# The connection name was not project:region:instance format.
# Check if connection name is a valid DNS domain name
if _is_valid_domain(dns):
# Attempt to query a TXT record to get connection name.
conn_name = await self.query_dns(dns)
else:
current = dns
visited = {current}

# Max 10 iterations to prevent infinite CNAME loops
for _ in range(10):
try:
domain_val = dns if current != dns else ""
conn_name = _parse_connection_name_with_domain_name(
current, domain_val
)
return conn_name
except ValueError:
pass

dns_normalized = current.rstrip(".")
match = PSC_DNS_PATTERN.match(dns_normalized.lower())
if match:
region = match.group(3)
if self._client is None:
raise ValueError(
"SQLAdmin client is not configured in the resolver."
)

dns_name_with_dot = dns_normalized + "."
resp = await self._client.resolve_connect_settings(
dns_name_with_dot, region
)
resolved_conn_name = resp["connectionName"]
return _parse_connection_name_with_domain_name(
resolved_conn_name, dns
)

if not _is_valid_domain(current):
raise ValueError(
"Arg `instance_connection_string` must have "
"format: PROJECT:REGION:INSTANCE or be a valid DNS domain "
f"name, got {dns}."
)
return conn_name

async def resolve_a_record(self, dns: str) -> List[str]:
try:
# Attempt to query the A records.
records = await super().resolve(dns, "A", raise_on_no_answer=True)
# return IP addresses as strings
return [record.to_text() for record in records]
except Exception:
# On any error, return empty list
return []
cname_found = False
try:
cname = await self.resolve_cname(current)
cname_found = True
except DnsResolutionError:
pass

if cname_found:
if cname in visited:
raise DnsResolutionError(f"CNAME loop detected for `{dns}`")
visited.add(cname)
current = cname
continue

try:
rdata = await self.resolve_txt(current)
except DnsResolutionError as e:
raise DnsResolutionError(
f"Unable to resolve TXT record for `{dns}`"
) from e

async def query_dns(self, dns: str) -> ConnectionName:
try:
# Attempt to query the TXT records.
records = await super().resolve(dns, "TXT", raise_on_no_answer=True)
# Sort the TXT record values alphabetically, strip quotes as record
# values can be returned as raw strings
rdata = [record.to_text().strip('"') for record in records]
rdata.sort()
# Attempt to parse records, returning the first valid record.
for record in rdata:
try:
conn_name = _parse_connection_name_with_domain_name(record, dns)
conn_name = _parse_connection_name_with_domain_name(
record, dns
)
return conn_name
except Exception:
continue
# If all records failed to parse, throw error

raise DnsResolutionError(
f"Unable to parse TXT record for `{dns}` -> `{rdata[0]}`"
f"Unable to parse TXT record for `{current}` -> `{rdata[0]}`"
if rdata
else f"Unable to resolve TXT record for `{current}`"
)
# Don't override above DnsResolutionError
except DnsResolutionError:
raise

raise DnsResolutionError(
f"CNAME loop detected or max resolution depth reached for `{dns}`"
)

async def resolve_cname(self, dns: str) -> str:
try:
answers = await super().resolve(dns, "CNAME", raise_on_no_answer=True)
return str(answers[0].target).rstrip(".")
except Exception as e:
raise DnsResolutionError(
f"Unable to resolve CNAME record for `{dns}`"
) from e

async def resolve_txt(self, dns: str) -> List[str]:
try:
answers = await super().resolve(dns, "TXT", raise_on_no_answer=True)
return [record.to_text().strip('"') for record in answers]
except Exception as e:
raise DnsResolutionError(f"Unable to resolve TXT record for `{dns}`") from e
raise DnsResolutionError(
f"Unable to resolve TXT record for `{dns}`"
) from e

async def resolve_a_record(self, dns: str) -> List[str]:
try:
records = await super().resolve(dns, "A", raise_on_no_answer=True)
return [record.to_text() for record in records]
except Exception:
return []
60 changes: 60 additions & 0 deletions tests/unit/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,3 +659,63 @@ async def test_Connector_connect_async_custom_dns_resolver_fallback(
fake_client.instance.ip_addrs = original_ips


@pytest.mark.asyncio
async def test_Connector_connect_async_custom_dns_resolver_fallback_psc_to_private_ip(
fake_credentials: Credentials, fake_client: CloudSQLClient
) -> None:
"""Test that Connector.connect_async falls back to Private IP if CNAME/PSC DNS resolution fails."""

with patch(
"google.cloud.sql.connector.resolver.DnsResolver.resolve_a_record"
) as mock_resolve_a:
# DNS resolution fails
mock_resolve_a.return_value = []

with patch(
"google.cloud.sql.connector.resolver.DnsResolver.resolve"
) as mock_resolve:
conn_name_with_domain = ConnectionName(
"test-project", "test-region", "test-instance", "db.example.com"
)
mock_resolve.return_value = conn_name_with_domain

async with Connector(
credentials=fake_credentials,
loop=asyncio.get_running_loop(),
resolver=DnsResolver,
ip_type="PSC", # Use PSC IP type
) as connector:
connector._client = fake_client

original_ips = fake_client.instance.ip_addrs
# Configure instance to be PSC enabled, but also have a PRIVATE IP fallback!
fake_client.instance.psc_enabled = True
fake_client.instance.ip_addrs = {
"PSC": "1ad3b5d73f10.3oxon2yfo9tob.us-east1.sql.goog",
"PRIVATE": "10.0.0.1",
}

try:
with patch(
"google.cloud.sql.connector.asyncpg.connect"
) as mock_connect:
mock_connect.return_value = True

connection = await connector.connect_async(
"db.example.com",
"asyncpg",
user="my-user",
password="my-pass",
db="my-db",
)

# Verify mock_connect fell back to PRIVATE IP "10.0.0.1"!
args, _ = mock_connect.call_args
assert args[0] == "10.0.0.1"
assert connection is True
finally:
# Restore original IPs
fake_client.instance.ip_addrs = original_ips
fake_client.instance.psc_enabled = False


Loading
Loading