Skip to content

Commit 80bfec9

Browse files
sfc-gh-pbulawasfc-gh-pczajka
authored andcommitted
SNOW-1789751: Add GCP regional and virtual endpoints support (#2233)
1 parent 53fcf4a commit 80bfec9

File tree

8 files changed

+228
-17
lines changed

8 files changed

+228
-17
lines changed

src/snowflake/connector/azure_storage_client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def __init__(
6464
credentials: StorageCredential | None,
6565
chunk_size: int,
6666
stage_info: dict[str, Any],
67-
use_s3_regional_url: bool = False,
6867
unsafe_file_write: bool = False,
6968
) -> None:
7069
super().__init__(

src/snowflake/connector/connection.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,10 @@ def _get_private_bytes_from_file(
312312
None,
313313
(type(None), int),
314314
), # SNOW-1817982: limit iobound TPE sizes when executing PUT/GET
315+
"gcs_use_virtual_endpoints": (
316+
False,
317+
bool,
318+
), # use https://{bucket}.storage.googleapis.com instead of https://storage.googleapis.com/{bucket}
315319
"unsafe_file_write": (
316320
False,
317321
bool,
@@ -395,6 +399,7 @@ class SnowflakeConnection:
395399
before the connector shuts down. Default value is false.
396400
token_file_path: The file path of the token file. If both token and token_file_path are provided, the token in token_file_path will be used.
397401
unsafe_file_write: When true, files downloaded by GET will be saved with 644 permissions. Otherwise, files will be saved with safe - owner-only permissions: 600.
402+
gcs_use_virtual_endpoints: When true, the virtual endpoint url is used, see: https://cloud.google.com/storage/docs/request-endpoints#xml-api
398403
"""
399404

400405
OCSP_ENV_LOCK = Lock()
@@ -783,6 +788,14 @@ def unsafe_file_write(self) -> bool:
783788
def unsafe_file_write(self, value: bool) -> None:
784789
self._unsafe_file_write = value
785790

791+
@property
792+
def gcs_use_virtual_endpoints(self) -> bool:
793+
return self._gcs_use_virtual_endpoints
794+
795+
@gcs_use_virtual_endpoints.setter
796+
def gcs_use_virtual_endpoints(self, value: bool) -> None:
797+
self._gcs_use_virtual_endpoints = value
798+
786799
def connect(self, **kwargs) -> None:
787800
"""Establishes connection to Snowflake."""
788801
logger.debug("connect")

src/snowflake/connector/cursor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,6 +1061,7 @@ def execute(
10611061
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
10621062
iobound_tpe_limit=self._connection.iobound_tpe_limit,
10631063
unsafe_file_write=self._connection.unsafe_file_write,
1064+
gcs_use_virtual_endpoints=self._connection.gcs_use_virtual_endpoints,
10641065
)
10651066
sf_file_transfer_agent.execute()
10661067
data = sf_file_transfer_agent.result()

src/snowflake/connector/file_transfer_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ def __init__(
356356
use_s3_regional_url: bool = False,
357357
iobound_tpe_limit: int | None = None,
358358
unsafe_file_write: bool = False,
359+
gcs_use_virtual_endpoints: bool = False,
359360
) -> None:
360361
self._cursor = cursor
361362
self._command = command
@@ -388,6 +389,7 @@ def __init__(
388389
self._credentials: StorageCredential | None = None
389390
self._iobound_tpe_limit = iobound_tpe_limit
390391
self._unsafe_file_write = unsafe_file_write
392+
self._gcs_use_virtual_endpoints = gcs_use_virtual_endpoints
391393

392394
def execute(self) -> None:
393395
self._parse_command()
@@ -683,7 +685,6 @@ def _create_file_transfer_client(
683685
self._credentials,
684686
AZURE_CHUNK_SIZE,
685687
self._stage_info,
686-
use_s3_regional_url=self._use_s3_regional_url,
687688
unsafe_file_write=self._unsafe_file_write,
688689
)
689690
elif self._stage_location_type == S3_FS:
@@ -703,7 +704,6 @@ def _create_file_transfer_client(
703704
self._stage_info,
704705
self._cursor._connection,
705706
self._command,
706-
use_s3_regional_url=self._use_s3_regional_url,
707707
unsafe_file_write=self._unsafe_file_write,
708708
)
709709
raise Exception(f"{self._stage_location_type} is an unknown stage type")

src/snowflake/connector/gcs_storage_client.py

Lines changed: 77 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,15 @@
3636
GCS_FILE_HEADER_DIGEST = "gcs-file-header-digest"
3737
GCS_FILE_HEADER_CONTENT_LENGTH = "gcs-file-header-content-length"
3838
GCS_FILE_HEADER_ENCRYPTION_METADATA = "gcs-file-header-encryption-metadata"
39+
GCS_REGION_ME_CENTRAL_2 = "me-central2"
3940
CONTENT_CHUNK_SIZE = 10 * kilobyte
4041
ACCESS_TOKEN = "GCS_ACCESS_TOKEN"
4142

4243

4344
class GcsLocation(NamedTuple):
4445
bucket_name: str
4546
path: str
47+
endpoint: str = "https://storage.googleapis.com"
4648

4749

4850
class SnowflakeGCSRestClient(SnowflakeStorageClient):
@@ -53,7 +55,6 @@ def __init__(
5355
stage_info: dict[str, Any],
5456
cnx: SnowflakeConnection,
5557
command: str,
56-
use_s3_regional_url: bool = False,
5758
unsafe_file_write: bool = False,
5859
) -> None:
5960
"""Creates a client object with given stage credentials.
@@ -79,6 +80,15 @@ def __init__(
7980
# presigned_url in meta is for downloading
8081
self.presigned_url: str = meta.presigned_url or stage_info.get("presignedUrl")
8182
self.security_token = credentials.creds.get("GCS_ACCESS_TOKEN")
83+
self.use_regional_url = (
84+
"region" in stage_info
85+
and stage_info["region"].lower() == GCS_REGION_ME_CENTRAL_2
86+
or "useRegionalUrl" in stage_info
87+
and stage_info["useRegionalUrl"]
88+
)
89+
self.endpoint: str | None = (
90+
None if "endPoint" not in stage_info else stage_info["endPoint"]
91+
)
8292

8393
if self.security_token:
8494
logger.debug(f"len(GCS_ACCESS_TOKEN): {len(self.security_token)}")
@@ -91,7 +101,7 @@ def _has_expired_token(self, response: requests.Response) -> bool:
91101

92102
def _has_expired_presigned_url(self, response: requests.Response) -> bool:
93103
# Presigned urls can be generated for any xml-api operation
94-
# offered by GCS. Hence the error codes expected are similar
104+
# offered by GCS. Hence, the error codes expected are similar
95105
# to xml api.
96106
# https://cloud.google.com/storage/docs/xml-api/reference-status
97107

@@ -152,7 +162,14 @@ def generate_url_and_rest_args() -> (
152162
):
153163
if not self.presigned_url:
154164
upload_url = self.generate_file_url(
155-
self.stage_info["location"], meta.dst_file_name.lstrip("/")
165+
self.stage_info["location"],
166+
meta.dst_file_name.lstrip("/"),
167+
self.use_regional_url,
168+
(
169+
None
170+
if "region" not in self.stage_info
171+
else self.stage_info["region"]
172+
),
156173
)
157174
access_token = self.security_token
158175
else:
@@ -182,7 +199,15 @@ def generate_url_and_rest_args() -> (
182199
gcs_headers = {}
183200
if not self.presigned_url:
184201
download_url = self.generate_file_url(
185-
self.stage_info["location"], meta.src_file_name.lstrip("/")
202+
self.stage_info["location"],
203+
meta.src_file_name.lstrip("/"),
204+
self.use_regional_url,
205+
(
206+
None
207+
if "region" not in self.stage_info
208+
else self.stage_info["region"]
209+
),
210+
self.endpoint,
186211
)
187212
access_token = self.security_token
188213
gcs_headers["Authorization"] = f"Bearer {access_token}"
@@ -339,7 +364,14 @@ def get_file_header(self, filename: str) -> FileHeader | None:
339364

340365
def generate_url_and_authenticated_headers():
341366
url = self.generate_file_url(
342-
self.stage_info["location"], filename.lstrip("/")
367+
self.stage_info["location"],
368+
filename.lstrip("/"),
369+
self.use_regional_url,
370+
(
371+
None
372+
if "region" not in self.stage_info
373+
else self.stage_info["region"]
374+
),
343375
)
344376
gcs_headers = {"Authorization": f"Bearer {self.security_token}"}
345377
rest_args = {"headers": gcs_headers}
@@ -383,7 +415,13 @@ def generate_url_and_authenticated_headers():
383415
return None
384416

385417
@staticmethod
386-
def extract_bucket_name_and_path(stage_location: str) -> GcsLocation:
418+
def get_location(
419+
stage_location: str,
420+
use_regional_url: str = False,
421+
region: str = None,
422+
endpoint: str = None,
423+
use_virtual_endpoints: bool = False,
424+
) -> GcsLocation:
387425
container_name = stage_location
388426
path = ""
389427

@@ -393,13 +431,40 @@ def extract_bucket_name_and_path(stage_location: str) -> GcsLocation:
393431
path = stage_location[stage_location.index("/") + 1 :]
394432
if path and not path.endswith("/"):
395433
path += "/"
396-
397-
return GcsLocation(bucket_name=container_name, path=path)
434+
if endpoint:
435+
if endpoint.endswith("/"):
436+
endpoint = endpoint[:-1]
437+
return GcsLocation(bucket_name=container_name, path=path, endpoint=endpoint)
438+
elif use_virtual_endpoints:
439+
return GcsLocation(
440+
bucket_name=container_name,
441+
path=path,
442+
endpoint=f"https://{container_name}.storage.googleapis.com",
443+
)
444+
elif use_regional_url:
445+
return GcsLocation(
446+
bucket_name=container_name,
447+
path=path,
448+
endpoint=f"https://storage.{region.lower()}.rep.googleapis.com",
449+
)
450+
else:
451+
return GcsLocation(bucket_name=container_name, path=path)
398452

399453
@staticmethod
400-
def generate_file_url(stage_location: str, filename: str) -> str:
401-
gcs_location = SnowflakeGCSRestClient.extract_bucket_name_and_path(
402-
stage_location
454+
def generate_file_url(
455+
stage_location: str,
456+
filename: str,
457+
use_regional_url: str = False,
458+
region: str = None,
459+
endpoint: str = None,
460+
use_virtual_endpoints: bool = False,
461+
) -> str:
462+
gcs_location = SnowflakeGCSRestClient.get_location(
463+
stage_location, use_regional_url, region, endpoint
403464
)
404465
full_file_path = f"{gcs_location.path}{filename}"
405-
return f"https://storage.googleapis.com/{gcs_location.bucket_name}/{quote(full_file_path)}"
466+
467+
if use_virtual_endpoints:
468+
return f"{gcs_location.endpoint}/{quote(full_file_path)}"
469+
else:
470+
return f"{gcs_location.endpoint}/{gcs_location.bucket_name}/{quote(full_file_path)}"

src/snowflake/connector/s3_storage_client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,13 @@ def __init__(
8686
self.stage_info["location"]
8787
)
8888
)
89-
self.use_s3_regional_url = use_s3_regional_url
89+
self.use_s3_regional_url = (
90+
use_s3_regional_url
91+
or "useS3RegionalUrl" in stage_info
92+
and stage_info["useS3RegionalUrl"]
93+
or "useRegionalUrl" in stage_info
94+
and stage_info["useRegionalUrl"]
95+
)
9096
self.location_type = stage_info.get("locationType")
9197

9298
# if GS sends us an endpoint, it's likely for FIPS. Use it.

test/integ/test_connection.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,6 +1369,34 @@ def test_server_session_keep_alive(conn_cnx):
13691369
mock_delete_session.assert_called_once()
13701370

13711371

1372+
@pytest.mark.skipolddriver
1373+
@pytest.mark.parametrize(
1374+
"value",
1375+
[
1376+
True,
1377+
False,
1378+
],
1379+
)
1380+
def test_gcs_use_virtual_endpoints(conn_cnx, value):
1381+
with mock.patch(
1382+
"snowflake.connector.network.SnowflakeRestful.fetch",
1383+
return_value={"data": {"token": None, "masterToken": None}, "success": True},
1384+
):
1385+
with snowflake.connector.connect(
1386+
user="test-user",
1387+
password="test-password",
1388+
host="test-host",
1389+
port="443",
1390+
account="test-account",
1391+
gcs_use_virtual_endpoints=value,
1392+
) as cnx:
1393+
assert cnx
1394+
cnx.commit = cnx.rollback = (
1395+
lambda: None
1396+
) # Skip tear down, there's only a mocked rest api
1397+
assert cnx.gcs_use_virtual_endpoints == value
1398+
1399+
13721400
@pytest.mark.skipolddriver
13731401
def test_ocsp_mode_disable_ocsp_checks(
13741402
conn_cnx, is_public_test, is_local_dev_setup, caplog

0 commit comments

Comments
 (0)