|
62 | 62 | from .backoff_policies import exponential_backoff |
63 | 63 | from .cache import SFDictCache, SFDictFileCache |
64 | 64 | from .telemetry import TelemetryField, generate_telemetry_data_dict |
65 | | -from .url_util import url_encode_str |
| 65 | +from .url_util import extract_top_level_domain_from_hostname, url_encode_str |
66 | 66 |
|
67 | 67 |
|
68 | 68 | class OCSPResponseValidationResult(NamedTuple): |
@@ -268,15 +268,20 @@ def generate_telemetry_data( |
268 | 268 | class OCSPServer: |
269 | 269 | MAX_RETRY = int(os.getenv("OCSP_MAX_RETRY", "3")) |
270 | 270 |
|
271 | | - def __init__(self) -> None: |
272 | | - self.DEFAULT_CACHE_SERVER_URL = "http://ocsp.snowflakecomputing.com" |
| 271 | + def __init__(self, **kwargs) -> None: |
| 272 | + top_level_domain = kwargs.pop( |
| 273 | + "top_level_domain", constants._DEFAULT_HOSTNAME_TLD |
| 274 | + ) |
| 275 | + self.DEFAULT_CACHE_SERVER_URL = ( |
| 276 | + f"http://ocsp.snowflakecomputing.{top_level_domain}" |
| 277 | + ) |
273 | 278 | """ |
274 | 279 | The following will change to something like |
275 | 280 | http://ocspssd.snowflakecomputing.com/ocsp/ |
276 | 281 | once the endpoint is up in the backend |
277 | 282 | """ |
278 | 283 | self.NEW_DEFAULT_CACHE_SERVER_BASE_URL = ( |
279 | | - "https://ocspssd.snowflakecomputing.com/ocsp/" |
| 284 | + f"https://ocspssd.snowflakecomputing.{top_level_domain}/ocsp/" |
280 | 285 | ) |
281 | 286 | if not OCSPServer.is_enabled_new_ocsp_endpoint(): |
282 | 287 | self.CACHE_SERVER_URL = os.getenv( |
@@ -307,12 +312,13 @@ def reset_ocsp_endpoint(self, hname) -> None: |
307 | 312 | on the hostname the customer is trying to connect to. The deployment or in case of client failover, the |
308 | 313 | replication ID is copied from the hostname. |
309 | 314 | """ |
310 | | - if hname.endswith("privatelink.snowflakecomputing.com"): |
| 315 | + top_level_domain = extract_top_level_domain_from_hostname(hname) |
| 316 | + if "privatelink.snowflakecomputing." in hname: |
311 | 317 | temp_ocsp_endpoint = "".join(["https://ocspssd.", hname, "/ocsp/"]) |
312 | | - elif hname.endswith("global.snowflakecomputing.com"): |
| 318 | + elif "global.snowflakecomputing." in hname: |
313 | 319 | rep_id_begin = hname[hname.find("-") :] |
314 | 320 | temp_ocsp_endpoint = "".join(["https://ocspssd", rep_id_begin, "/ocsp/"]) |
315 | | - elif not hname.endswith("snowflakecomputing.com"): |
| 321 | + elif not hname.endswith(f"snowflakecomputing.{top_level_domain}"): |
316 | 322 | temp_ocsp_endpoint = self.NEW_DEFAULT_CACHE_SERVER_BASE_URL |
317 | 323 | else: |
318 | 324 | hname_wo_acc = hname[hname.find(".") :] |
@@ -832,8 +838,8 @@ class SnowflakeOCSP: |
832 | 838 |
|
833 | 839 | OCSP_WHITELIST = re.compile( |
834 | 840 | r"^" |
835 | | - r"(.*\.snowflakecomputing\.com$" |
836 | | - r"|(?:|.*\.)s3.*\.amazonaws\.com$" # start with s3 or .s3 in the middle |
| 841 | + r"(.*\.snowflakecomputing(\.[a-zA-Z]{1,63}){1,2}$" |
| 842 | + r"|(?:|.*\.)s3.*\.amazonaws(\.[a-zA-Z]{1,63}){1,2}$" # start with s3 or .s3 in the middle |
837 | 843 | r"|.*\.okta\.com$" |
838 | 844 | r"|(?:|.*\.)storage\.googleapis\.com$" |
839 | 845 | r"|.*\.blob\.core\.windows\.net$" |
@@ -881,14 +887,19 @@ def __init__( |
881 | 887 | use_ocsp_cache_server=None, |
882 | 888 | use_post_method: bool = True, |
883 | 889 | use_fail_open: bool = True, |
| 890 | + **kwargs, |
884 | 891 | ) -> None: |
885 | 892 | self.test_mode = os.getenv("SF_OCSP_TEST_MODE", None) |
886 | 893 |
|
887 | 894 | if self.test_mode == "true": |
888 | 895 | logger.debug("WARNING - DRIVER CONFIGURED IN TEST MODE") |
889 | 896 |
|
890 | 897 | self._use_post_method = use_post_method |
891 | | - self.OCSP_CACHE_SERVER = OCSPServer() |
| 898 | + self.OCSP_CACHE_SERVER = OCSPServer( |
| 899 | + top_level_domain=extract_top_level_domain_from_hostname( |
| 900 | + kwargs.pop("hostname", None) |
| 901 | + ) |
| 902 | + ) |
892 | 903 |
|
893 | 904 | self.debug_ocsp_failure_url = None |
894 | 905 |
|
|
0 commit comments