Skip to content

Commit a5ae8ed

Browse files
authored
SNOW-1506571: make connector domain agnostic (#1994)
1 parent 01b2c17 commit a5ae8ed

File tree

14 files changed

+224
-20
lines changed

14 files changed

+224
-20
lines changed

DESCRIPTION.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
1111
- v3.12.0(TBD)
1212
- Set default connection timeout of 10 seconds and socket read timeout of 10 minutes for HTTP calls in file transfer.
1313
- Optimized `to_pandas()` performance by fully parallel downloading logic.
14-
- Fixed a bug that specifying client_session_keep_alive_heartbeat_frequency in snowflake-sqlalchemy could crash the connector
14+
- Fixed a bug that specifying client_session_keep_alive_heartbeat_frequency in snowflake-sqlalchemy could crash the connector.
15+
- Added support for connectivity to multiple domains.
1516

1617
- v3.11.0(June 17,2024)
1718
- Added support for `token_file_path` connection parameter to read an OAuth token from a file when connecting to Snowflake.

src/snowflake/connector/connection.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from .config_manager import CONFIG_MANAGER, _get_default_connection_params
5353
from .connection_diagnostic import ConnectionDiagnostic
5454
from .constants import (
55+
_DOMAIN_NAME_MAP,
5556
ENV_VAR_PARTNER,
5657
PARAMETER_AUTOCOMMIT,
5758
PARAMETER_CLIENT_PREFETCH_THREADS,
@@ -107,6 +108,7 @@
107108
from .telemetry import TelemetryClient, TelemetryData, TelemetryField
108109
from .telemetry_oob import TelemetryService
109110
from .time_util import HeartBeatTimer, get_time_millis
111+
from .url_util import extract_top_level_domain_from_hostname
110112
from .util_text import construct_hostname, parse_account, split_statements
111113

112114
DEFAULT_CLIENT_PREFETCH_THREADS = 4
@@ -935,7 +937,7 @@ def __open_connection(self):
935937
os.environ["SF_OCSP_RESPONSE_CACHE_SERVER_URL"],
936938
)
937939

938-
if self.host.endswith(".privatelink.snowflakecomputing.com"):
940+
if ".privatelink.snowflakecomputing." in self.host:
939941
SnowflakeConnection.setup_ocsp_privatelink(self.application, self.host)
940942
else:
941943
if "SF_OCSP_RESPONSE_CACHE_SERVER_URL" in os.environ:
@@ -1182,6 +1184,10 @@ def __config(self, **kwargs):
11821184
if "protocol" not in kwargs:
11831185
self._protocol = "https"
11841186

1187+
logger.info(
1188+
f"Connecting to {_DOMAIN_NAME_MAP.get(extract_top_level_domain_from_hostname(self._host), 'GLOBAL')} Snowflake domain"
1189+
)
1190+
11851191
# If using a custom auth class, we should set the authenticator
11861192
# type to be the same as the custom auth class
11871193
if self._auth_class:

src/snowflake/connector/connection_diagnostic.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from .compat import IS_WINDOWS, urlparse
2525
from .cursor import SnowflakeCursor
26+
from .url_util import extract_top_level_domain_from_hostname
2627
from .vendored import urllib3
2728

2829
logger = getLogger(__name__)
@@ -89,15 +90,21 @@ def __init__(
8990
self.__append_message(
9091
host_type, f"Host based on specified account: {self.host}"
9192
)
92-
if ".com.snowflakecomputing.com" in self.host:
93-
self.host = host.split(".com.snow", 1)[0] + ".com"
93+
94+
top_level_domain = extract_top_level_domain_from_hostname(host)
95+
if (
96+
f".{top_level_domain}.snowflakecomputing.{top_level_domain}" in self.host
97+
): # repeated domain name pattern
98+
self.host = (
99+
host.split(f".{top_level_domain}.snow", 1)[0] + f".{top_level_domain}"
100+
)
94101
logger.warning(
95-
f"Account should not have snowflakecomputing.com in it. You provided {host}. "
102+
f"Account should not have snowflakecomputing.{top_level_domain} in it. You provided {host}. "
96103
f"Continuing with fixed host."
97104
)
98105
self.__append_message(
99106
host_type,
100-
f"We removed extra .snowflakecomputing.com and will continue with host: "
107+
f"We removed extra .snowflakecomputing.{top_level_domain} and will continue with host: "
101108
f"{self.host}",
102109
)
103110
else:
@@ -183,7 +190,7 @@ def __init__(
183190
self.ocsp_urls.append(f"ocsp.{self.host}")
184191
self.allowlist_sql = "select system$allowlist_privatelink();"
185192
else:
186-
self.ocsp_urls.append("ocsp.snowflakecomputing.com")
193+
self.ocsp_urls.append(f"ocsp.snowflakecomputing.{top_level_domain}")
187194

188195
self.allowlist_retrieval_success: bool = False
189196
self.cursor: SnowflakeCursor | None = None

src/snowflake/connector/constants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@
3636
DBAPI_TYPE_NUMBER = 2
3737
DBAPI_TYPE_TIMESTAMP = 3
3838

39+
_DEFAULT_HOSTNAME_TLD = "com"
40+
_CHINA_HOSTNAME_TLD = "cn"
41+
_TOP_LEVEL_DOMAIN_REGEX = r"\.[a-zA-Z]{1,63}$"
42+
_SNOWFLAKE_HOST_SUFFIX_REGEX = r"snowflakecomputing(\.[a-zA-Z]{1,63}){1,2}$"
43+
3944

4045
class FieldType(NamedTuple):
4146
name: str
@@ -420,3 +425,6 @@ class IterUnit(Enum):
420425
# TODO: all env variables definitions should be here
421426
ENV_VAR_PARTNER = "SF_PARTNER"
422427
ENV_VAR_TEST_MODE = "SNOWFLAKE_TEST_MODE"
428+
429+
430+
_DOMAIN_NAME_MAP = {_DEFAULT_HOSTNAME_TLD: "GLOBAL", _CHINA_HOSTNAME_TLD: "CHINA"}

src/snowflake/connector/network.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import itertools
1212
import json
1313
import logging
14+
import re
1415
import time
1516
import traceback
1617
import uuid
@@ -46,6 +47,7 @@
4647
urlparse,
4748
)
4849
from .constants import (
50+
_SNOWFLAKE_HOST_SUFFIX_REGEX,
4951
HTTP_HEADER_ACCEPT,
5052
HTTP_HEADER_CONTENT_TYPE,
5153
HTTP_HEADER_SERVICE_NAME,
@@ -156,6 +158,7 @@
156158
REQUEST_GUID = "request_guid"
157159
SNOWFLAKE_HOST_SUFFIX = ".snowflakecomputing.com"
158160

161+
159162
SNOWFLAKE_CONNECTOR_VERSION = SNOWFLAKE_CONNECTOR_VERSION
160163
PYTHON_VERSION = PYTHON_VERSION
161164
OPERATING_SYSTEM = OPERATING_SYSTEM
@@ -856,7 +859,7 @@ def add_retry_params(self, full_url: str) -> str:
856859
def add_request_guid(full_url: str) -> str:
857860
"""Adds request_guid parameter for HTTP request tracing."""
858861
parsed_url = urlparse(full_url)
859-
if not parsed_url.hostname.endswith(SNOWFLAKE_HOST_SUFFIX):
862+
if not re.search(_SNOWFLAKE_HOST_SUFFIX_REGEX, parsed_url.hostname):
860863
return full_url
861864
request_guid = str(uuid.uuid4())
862865
suffix = urlencode({REQUEST_GUID: request_guid})

src/snowflake/connector/ocsp_snowflake.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
from .backoff_policies import exponential_backoff
6363
from .cache import SFDictCache, SFDictFileCache
6464
from .telemetry import TelemetryField, generate_telemetry_data_dict
65-
from .url_util import url_encode_str
65+
from .url_util import extract_top_level_domain_from_hostname, url_encode_str
6666

6767

6868
class OCSPResponseValidationResult(NamedTuple):
@@ -268,15 +268,20 @@ def generate_telemetry_data(
268268
class OCSPServer:
269269
MAX_RETRY = int(os.getenv("OCSP_MAX_RETRY", "3"))
270270

271-
def __init__(self) -> None:
272-
self.DEFAULT_CACHE_SERVER_URL = "http://ocsp.snowflakecomputing.com"
271+
def __init__(self, **kwargs) -> None:
272+
top_level_domain = kwargs.pop(
273+
"top_level_domain", constants._DEFAULT_HOSTNAME_TLD
274+
)
275+
self.DEFAULT_CACHE_SERVER_URL = (
276+
f"http://ocsp.snowflakecomputing.{top_level_domain}"
277+
)
273278
"""
274279
The following will change to something like
275280
http://ocspssd.snowflakecomputing.com/ocsp/
276281
once the endpoint is up in the backend
277282
"""
278283
self.NEW_DEFAULT_CACHE_SERVER_BASE_URL = (
279-
"https://ocspssd.snowflakecomputing.com/ocsp/"
284+
f"https://ocspssd.snowflakecomputing.{top_level_domain}/ocsp/"
280285
)
281286
if not OCSPServer.is_enabled_new_ocsp_endpoint():
282287
self.CACHE_SERVER_URL = os.getenv(
@@ -307,12 +312,13 @@ def reset_ocsp_endpoint(self, hname) -> None:
307312
on the hostname the customer is trying to connect to. The deployment or in case of client failover, the
308313
replication ID is copied from the hostname.
309314
"""
310-
if hname.endswith("privatelink.snowflakecomputing.com"):
315+
top_level_domain = extract_top_level_domain_from_hostname(hname)
316+
if "privatelink.snowflakecomputing." in hname:
311317
temp_ocsp_endpoint = "".join(["https://ocspssd.", hname, "/ocsp/"])
312-
elif hname.endswith("global.snowflakecomputing.com"):
318+
elif "global.snowflakecomputing." in hname:
313319
rep_id_begin = hname[hname.find("-") :]
314320
temp_ocsp_endpoint = "".join(["https://ocspssd", rep_id_begin, "/ocsp/"])
315-
elif not hname.endswith("snowflakecomputing.com"):
321+
elif not hname.endswith(f"snowflakecomputing.{top_level_domain}"):
316322
temp_ocsp_endpoint = self.NEW_DEFAULT_CACHE_SERVER_BASE_URL
317323
else:
318324
hname_wo_acc = hname[hname.find(".") :]
@@ -832,8 +838,8 @@ class SnowflakeOCSP:
832838

833839
OCSP_WHITELIST = re.compile(
834840
r"^"
835-
r"(.*\.snowflakecomputing\.com$"
836-
r"|(?:|.*\.)s3.*\.amazonaws\.com$" # start with s3 or .s3 in the middle
841+
r"(.*\.snowflakecomputing(\.[a-zA-Z]{1,63}){1,2}$"
842+
r"|(?:|.*\.)s3.*\.amazonaws(\.[a-zA-Z]{1,63}){1,2}$" # start with s3 or .s3 in the middle
837843
r"|.*\.okta\.com$"
838844
r"|(?:|.*\.)storage\.googleapis\.com$"
839845
r"|.*\.blob\.core\.windows\.net$"
@@ -881,14 +887,19 @@ def __init__(
881887
use_ocsp_cache_server=None,
882888
use_post_method: bool = True,
883889
use_fail_open: bool = True,
890+
**kwargs,
884891
) -> None:
885892
self.test_mode = os.getenv("SF_OCSP_TEST_MODE", None)
886893

887894
if self.test_mode == "true":
888895
logger.debug("WARNING - DRIVER CONFIGURED IN TEST MODE")
889896

890897
self._use_post_method = use_post_method
891-
self.OCSP_CACHE_SERVER = OCSPServer()
898+
self.OCSP_CACHE_SERVER = OCSPServer(
899+
top_level_domain=extract_top_level_domain_from_hostname(
900+
kwargs.pop("hostname", None)
901+
)
902+
)
892903

893904
self.debug_ocsp_failure_url = None
894905

src/snowflake/connector/ssl_wrap_socket.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def ssl_wrap_socket_with_ocsp(*args: Any, **kwargs: Any) -> WrappedSocket:
8787
v = SFOCSP(
8888
ocsp_response_cache_uri=FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME,
8989
use_fail_open=FEATURE_OCSP_MODE == OCSPMode.FAIL_OPEN,
90+
hostname=server_hostname,
9091
).validate(server_hostname, ret.connection)
9192
if not v:
9293
raise OperationalError(

src/snowflake/connector/url_util.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import urllib.parse
99
from logging import getLogger
1010

11+
from .constants import _TOP_LEVEL_DOMAIN_REGEX
12+
1113
logger = getLogger(__name__)
1214

1315

@@ -41,3 +43,11 @@ def url_encode_str(target: str | None) -> str:
4143
logger.debug("The string to be URL encoded is None")
4244
return ""
4345
return urllib.parse.quote_plus(target, safe="")
46+
47+
48+
def extract_top_level_domain_from_hostname(hostname: str | None = None) -> str:
49+
if not hostname:
50+
return "com"
51+
# RFC1034 for TLD spec, and https://data.iana.org/TLD/tlds-alpha-by-domain.txt for full TLD list
52+
match = re.search(_TOP_LEVEL_DOMAIN_REGEX, hostname)
53+
return (match.group(0)[1:] if match else "com").lower()

src/snowflake/connector/util_text.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,12 @@ def construct_hostname(region: str | None, account: str) -> str:
240240
if region:
241241
if account.find(".") > 0:
242242
account = account[0 : account.find(".")]
243-
host = f"{account}.{region}.snowflakecomputing.com"
243+
top_level_domain = (
244+
"com"
245+
if not any(substring in region for substring in ["cn-", "CN-"])
246+
else "cn"
247+
)
248+
host = f"{account}.{region}.snowflakecomputing.{top_level_domain}"
244249
else:
245250
host = f"{account}.snowflakecomputing.com"
246251
return host

test/integ/test_connection.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pathlib
1212
import queue
1313
import stat
14+
import tempfile
1415
import threading
1516
import warnings
1617
import weakref
@@ -39,6 +40,7 @@
3940
from snowflake.connector.telemetry import TelemetryField
4041

4142
from ..randomize import random_string
43+
from .conftest import RUNNING_ON_GH
4244

4345
try: # pragma: no cover
4446
from ..parameters import CONNECTION_PARAMETERS_ADMIN
@@ -1370,3 +1372,35 @@ def test_token_file_path(tmp_path, db_parameters):
13701372
assert conn._token == fake_token
13711373
conn = snowflake.connector.connect(**db_parameters, token_file_path=token_file_path)
13721374
assert conn._token == fake_token
1375+
1376+
1377+
@pytest.mark.skipolddriver
1378+
@pytest.mark.skipif(not RUNNING_ON_GH, reason="no ocsp in the environment")
1379+
def test_mock_non_existing_server(conn_cnx, caplog):
1380+
from snowflake.connector.cache import SFDictCache
1381+
1382+
# disabling local cache and pointing ocsp cache server to a non-existing url
1383+
# connection should still work as it will directly validate the certs against CA servers
1384+
with tempfile.NamedTemporaryFile() as tmp, caplog.at_level(logging.DEBUG):
1385+
with mock.patch(
1386+
"snowflake.connector.url_util.extract_top_level_domain_from_hostname",
1387+
return_value="nonexistingtopleveldomain",
1388+
):
1389+
with mock.patch(
1390+
"snowflake.connector.ocsp_snowflake.OCSP_RESPONSE_VALIDATION_CACHE",
1391+
SFDictCache(),
1392+
):
1393+
with mock.patch(
1394+
"snowflake.connector.ocsp_snowflake.OCSPCache.OCSP_RESPONSE_CACHE_FILE_NAME",
1395+
tmp.name,
1396+
):
1397+
with conn_cnx():
1398+
pass
1399+
assert all(
1400+
s in caplog.text
1401+
for s in [
1402+
"Failed to read OCSP response cache file",
1403+
"It will validate with OCSP server.",
1404+
"writing OCSP response cache file to",
1405+
]
1406+
)

0 commit comments

Comments
 (0)