Skip to content

Commit 2a0f804

Browse files
[ASYNC] Apply #2233 to async code
1 parent 80bfec9 commit 2a0f804

File tree

8 files changed

+220
-13
lines changed

8 files changed

+220
-13
lines changed

src/snowflake/connector/aio/_azure_storage_client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def __init__(
4949
credentials: StorageCredential | None,
5050
chunk_size: int,
5151
stage_info: dict[str, Any],
52-
use_s3_regional_url: bool = False,
5352
unsafe_file_write: bool = False,
5453
) -> None:
5554
SnowflakeAzureRestClientSync.__init__(

src/snowflake/connector/aio/_cursor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,7 @@ async def execute(
663663
multipart_threshold=data.get("threshold"),
664664
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
665665
unsafe_file_write=self._connection.unsafe_file_write,
666+
gcs_use_virtual_endpoints=self._connection.gcs_use_virtual_endpoints,
666667
)
667668
await sf_file_transfer_agent.execute()
668669
data = sf_file_transfer_agent.result()

src/snowflake/connector/aio/_file_transfer_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(
6363
source_from_stream: IO[bytes] | None = None,
6464
use_s3_regional_url: bool = False,
6565
unsafe_file_write: bool = False,
66+
gcs_use_virtual_endpoints: bool = False,
6667
) -> None:
6768
super().__init__(
6869
cursor=cursor,
@@ -82,6 +83,7 @@ def __init__(
8283
source_from_stream=source_from_stream,
8384
use_s3_regional_url=use_s3_regional_url,
8485
unsafe_file_write=unsafe_file_write,
86+
gcs_use_virtual_endpoints=gcs_use_virtual_endpoints,
8587
)
8688

8789
async def execute(self) -> None:
@@ -281,7 +283,6 @@ async def _create_file_transfer_client(
281283
self._credentials,
282284
AZURE_CHUNK_SIZE,
283285
self._stage_info,
284-
use_s3_regional_url=self._use_s3_regional_url,
285286
unsafe_file_write=self._unsafe_file_write,
286287
)
287288
elif self._stage_location_type == S3_FS:
@@ -303,7 +304,6 @@ async def _create_file_transfer_client(
303304
self._stage_info,
304305
self._cursor._connection,
305306
self._command,
306-
use_s3_regional_url=self._use_s3_regional_url,
307307
unsafe_file_write=self._unsafe_file_write,
308308
)
309309
if client.security_token:

src/snowflake/connector/aio/_gcs_storage_client.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
GCS_METADATA_ENCRYPTIONDATAPROP,
2828
GCS_METADATA_MATDESC_KEY,
2929
GCS_METADATA_SFC_DIGEST,
30+
GCS_REGION_ME_CENTRAL_2,
3031
)
3132

3233

@@ -38,7 +39,6 @@ def __init__(
3839
stage_info: dict[str, Any],
3940
cnx: SnowflakeConnection,
4041
command: str,
41-
use_s3_regional_url: bool = False,
4242
unsafe_file_write: bool = False,
4343
) -> None:
4444
"""Creates a client object with given stage credentials.
@@ -65,6 +65,15 @@ def __init__(
6565
# presigned_url in meta is for downloading
6666
self.presigned_url: str = meta.presigned_url or stage_info.get("presignedUrl")
6767
self.security_token = credentials.creds.get("GCS_ACCESS_TOKEN")
68+
self.use_regional_url = (
69+
"region" in stage_info
70+
and stage_info["region"].lower() == GCS_REGION_ME_CENTRAL_2
71+
or "useRegionalUrl" in stage_info
72+
and stage_info["useRegionalUrl"]
73+
)
74+
self.endpoint: str | None = (
75+
None if "endPoint" not in stage_info else stage_info["endPoint"]
76+
)
6877

6978
async def _has_expired_token(self, response: aiohttp.ClientResponse) -> bool:
7079
return self.security_token and response.status == 401
@@ -73,7 +82,7 @@ async def _has_expired_presigned_url(
7382
self, response: aiohttp.ClientResponse
7483
) -> bool:
7584
# Presigned urls can be generated for any xml-api operation
76-
# offered by GCS. Hence the error codes expected are similar
85+
# offered by GCS. Hence, the error codes expected are similar
7786
# to xml api.
7887
# https://cloud.google.com/storage/docs/xml-api/reference-status
7988

@@ -132,7 +141,14 @@ def generate_url_and_rest_args() -> (
132141
):
133142
if not self.presigned_url:
134143
upload_url = self.generate_file_url(
135-
self.stage_info["location"], meta.dst_file_name.lstrip("/")
144+
self.stage_info["location"],
145+
meta.dst_file_name.lstrip("/"),
146+
self.use_regional_url,
147+
(
148+
None
149+
if "region" not in self.stage_info
150+
else self.stage_info["region"]
151+
),
136152
)
137153
access_token = self.security_token
138154
else:
@@ -162,7 +178,15 @@ def generate_url_and_rest_args() -> (
162178
gcs_headers = {}
163179
if not self.presigned_url:
164180
download_url = self.generate_file_url(
165-
self.stage_info["location"], meta.src_file_name.lstrip("/")
181+
self.stage_info["location"],
182+
meta.src_file_name.lstrip("/"),
183+
self.use_regional_url,
184+
(
185+
None
186+
if "region" not in self.stage_info
187+
else self.stage_info["region"]
188+
),
189+
self.endpoint,
166190
)
167191
access_token = self.security_token
168192
gcs_headers["Authorization"] = f"Bearer {access_token}"
@@ -279,7 +303,14 @@ async def get_file_header(self, filename: str) -> FileHeader | None:
279303

280304
def generate_url_and_authenticated_headers():
281305
url = self.generate_file_url(
282-
self.stage_info["location"], filename.lstrip("/")
306+
self.stage_info["location"],
307+
filename.lstrip("/"),
308+
self.use_regional_url,
309+
(
310+
None
311+
if "region" not in self.stage_info
312+
else self.stage_info["region"]
313+
),
283314
)
284315
gcs_headers = {"Authorization": f"Bearer {self.security_token}"}
285316
rest_args = {"headers": gcs_headers}

src/snowflake/connector/aio/_s3_storage_client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,13 @@ def __init__(
7474
self.stage_info["location"]
7575
)
7676
)
77-
self.use_s3_regional_url = use_s3_regional_url
77+
self.use_s3_regional_url = (
78+
use_s3_regional_url
79+
or "useS3RegionalUrl" in stage_info
80+
and stage_info["useS3RegionalUrl"]
81+
or "useRegionalUrl" in stage_info
82+
and stage_info["useRegionalUrl"]
83+
)
7884
self.location_type = stage_info.get("locationType")
7985

8086
# if GS sends us an endpoint, it's likely for FIPS. Use it.

test/integ/aio/test_connection_async.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,3 +1686,34 @@ async def test_no_auth_connection_negative_case():
16861686
await conn.execute_string("select 1")
16871687

16881688
await conn.close()
1689+
1690+
1691+
@pytest.mark.skipolddriver
1692+
@pytest.mark.parametrize(
1693+
"value",
1694+
[
1695+
True,
1696+
False,
1697+
],
1698+
)
1699+
async def test_gcs_use_virtual_endpoints(value):
1700+
with mock.patch(
1701+
"snowflake.connector.aio._network.SnowflakeRestful.fetch",
1702+
return_value={"data": {"token": None, "masterToken": None}, "success": True},
1703+
):
1704+
cnx = snowflake.connector.aio.SnowflakeConnection(
1705+
user="test-user",
1706+
password="test-password",
1707+
host="test-host",
1708+
port="443",
1709+
account="test-account",
1710+
gcs_use_virtual_endpoints=value,
1711+
)
1712+
try:
1713+
await cnx.connect()
1714+
cnx.commit = cnx.rollback = (
1715+
lambda: None
1716+
) # Skip tear down, there's only a mocked rest api
1717+
assert cnx.gcs_use_virtual_endpoints == value
1718+
finally:
1719+
await cnx.close()

test/unit/aio/test_gcs_client_async.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ async def test_get_file_header_none_with_presigned_url(tmp_path):
330330
)
331331
storage_credentials = Mock()
332332
storage_credentials.creds = {}
333-
stage_info = Mock()
333+
stage_info: dict[str, any] = dict()
334334
connection = Mock()
335335
client = SnowflakeGCSRestClient(
336336
meta, storage_credentials, stage_info, connection, ""
@@ -339,3 +339,102 @@ async def test_get_file_header_none_with_presigned_url(tmp_path):
339339
await client._update_presigned_url()
340340
file_header = await client.get_file_header(meta.name)
341341
assert file_header is None
342+
343+
344+
@pytest.mark.parametrize(
345+
"region,return_url,use_regional_url,endpoint,gcs_use_virtual_endpoints",
346+
[
347+
(
348+
"US-CENTRAL1",
349+
"https://storage.us-central1.rep.googleapis.com",
350+
True,
351+
None,
352+
False,
353+
),
354+
(
355+
"ME-CENTRAL2",
356+
"https://storage.me-central2.rep.googleapis.com",
357+
True,
358+
None,
359+
False,
360+
),
361+
("US-CENTRAL1", "https://storage.googleapis.com", False, None, False),
362+
("US-CENTRAL1", "https://storage.googleapis.com", False, None, False),
363+
("US-CENTRAL1", "https://location.storage.googleapis.com", False, None, True),
364+
("US-CENTRAL1", "https://location.storage.googleapis.com", True, None, True),
365+
(
366+
"US-CENTRAL1",
367+
"https://overriddenurl.com",
368+
False,
369+
"https://overriddenurl.com",
370+
False,
371+
),
372+
(
373+
"US-CENTRAL1",
374+
"https://overriddenurl.com",
375+
True,
376+
"https://overriddenurl.com",
377+
False,
378+
),
379+
(
380+
"US-CENTRAL1",
381+
"https://overriddenurl.com",
382+
True,
383+
"https://overriddenurl.com",
384+
True,
385+
),
386+
(
387+
"US-CENTRAL1",
388+
"https://overriddenurl.com",
389+
False,
390+
"https://overriddenurl.com",
391+
False,
392+
),
393+
(
394+
"US-CENTRAL1",
395+
"https://overriddenurl.com",
396+
False,
397+
"https://overriddenurl.com",
398+
True,
399+
),
400+
],
401+
)
402+
def test_url(region, return_url, use_regional_url, endpoint, gcs_use_virtual_endpoints):
403+
gcs_location = SnowflakeGCSRestClient.get_location(
404+
stage_location="location",
405+
use_regional_url=use_regional_url,
406+
region=region,
407+
endpoint=endpoint,
408+
use_virtual_endpoints=gcs_use_virtual_endpoints,
409+
)
410+
assert gcs_location.endpoint == return_url
411+
412+
413+
@pytest.mark.parametrize(
414+
"region,use_regional_url,return_value",
415+
[
416+
("ME-CENTRAL2", False, True),
417+
("ME-CENTRAL2", True, True),
418+
("US-CENTRAL1", False, False),
419+
("US-CENTRAL1", True, True),
420+
],
421+
)
422+
def test_use_regional_url(region, use_regional_url, return_value):
423+
meta = SnowflakeFileMeta(
424+
name="path/some_file",
425+
src_file_name="path/some_file",
426+
stage_location_type="GCS",
427+
presigned_url="www.example.com",
428+
)
429+
storage_credentials = Mock()
430+
storage_credentials.creds = {}
431+
stage_info: dict[str, any] = dict()
432+
stage_info["region"] = region
433+
stage_info["useRegionalUrl"] = use_regional_url
434+
connection = Mock()
435+
436+
client = SnowflakeGCSRestClient(
437+
meta, storage_credentials, stage_info, connection, ""
438+
)
439+
440+
assert client.use_regional_url == return_value

test/unit/aio/test_s3_util_async.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,11 @@
2929
SnowflakeFileMeta,
3030
StorageCredential,
3131
)
32-
from snowflake.connector.s3_storage_client import ERRORNO_WSAECONNABORTED
3332
from snowflake.connector.vendored.requests import HTTPError
3433
except ImportError:
3534
# Compatibility for olddriver tests
3635
from requests import HTTPError
3736

38-
from snowflake.connector.s3_util import ERRORNO_WSAECONNABORTED # NOQA
39-
4037
SnowflakeFileMeta = dict
4138
SnowflakeS3RestClient = None
4239
RequestExceedMaxRetryError = None
@@ -500,3 +497,46 @@ async def test_accelerate_in_china_endpoint():
500497
8 * megabyte,
501498
)
502499
assert not await rest_client.transfer_accelerate_config()
500+
501+
502+
@pytest.mark.parametrize(
503+
"use_s3_regional_url,stage_info_flags,expected",
504+
[
505+
(False, {}, False),
506+
(True, {}, True),
507+
(False, {"useS3RegionalUrl": True}, True),
508+
(False, {"useRegionalUrl": True}, True),
509+
(True, {"useS3RegionalUrl": False}, True),
510+
(False, {"useS3RegionalUrl": True, "useRegionalUrl": False}, True),
511+
(False, {"useS3RegionalUrl": False, "useRegionalUrl": True}, True),
512+
(False, {"useS3RegionalUrl": False, "useRegionalUrl": False}, False),
513+
],
514+
)
515+
def test_s3_regional_url_logic_async(use_s3_regional_url, stage_info_flags, expected):
516+
"""Tests that the async S3 storage client correctly handles regional URL flags from stage_info."""
517+
if SnowflakeS3RestClient is None:
518+
pytest.skip("S3 storage client not available")
519+
520+
meta = SnowflakeFileMeta(
521+
name="path/some_file",
522+
src_file_name="path/some_file",
523+
stage_location_type="S3",
524+
)
525+
storage_credentials = StorageCredential({}, mock.Mock(), "test")
526+
527+
stage_info = {
528+
"region": "us-west-2",
529+
"location": "test-bucket",
530+
"endPoint": None,
531+
}
532+
stage_info.update(stage_info_flags)
533+
534+
client = SnowflakeS3RestClient(
535+
meta=meta,
536+
credentials=storage_credentials,
537+
stage_info=stage_info,
538+
chunk_size=1024,
539+
use_s3_regional_url=use_s3_regional_url,
540+
)
541+
542+
assert client.use_s3_regional_url == expected

0 commit comments

Comments
 (0)