Skip to content

Commit fe6faae

Browse files
sfc-gh-yuwangsfc-gh-aling
authored andcommitted
SNOW-1628850: fix s3 accelerate logic (#2070)
1 parent d8935da commit fe6faae

File tree

4 files changed

+57
-8
lines changed

4 files changed

+57
-8
lines changed

src/snowflake/connector/aio/_file_transfer_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def postprocess_done_cb(
261261
async def _transfer_accelerate_config(self) -> None:
262262
if self._stage_location_type == S3_FS and self._file_metadata:
263263
client = await self._create_file_transfer_client(self._file_metadata[0])
264-
self._use_accelerate_endpoint = client.transfer_accelerate_config()
264+
self._use_accelerate_endpoint = await client.transfer_accelerate_config()
265265

266266
async def _create_file_transfer_client(
267267
self, meta: SnowflakeFileMeta
@@ -289,6 +289,7 @@ async def _create_file_transfer_client(
289289
use_accelerate_endpoint=self._use_accelerate_endpoint,
290290
use_s3_regional_url=self._use_s3_regional_url,
291291
)
292+
await client.transfer_accelerate_config(self._use_accelerate_endpoint)
292293
return client
293294
elif self._stage_location_type == GCS_FS:
294295
client = SnowflakeGCSRestClient(

src/snowflake/connector/aio/_s3_storage_client.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,6 @@ def __init__(
8181
self.endpoint = (
8282
f"https://{self.s3location.bucket_name}." + stage_info["endPoint"]
8383
)
84-
# self.transfer_accelerate_config(use_accelerate_endpoint)
85-
self.transfer_accelerate_config(False)
86-
# TODO: fix accelerate logic SNOW-1628850
8784

8885
async def _send_request_with_authentication_and_retry(
8986
self,
@@ -376,6 +373,41 @@ async def _get_bucket_accelerate_config(self, bucket_name: str) -> bool:
376373
return use_accelerate_endpoint
377374
return False
378375

376+
async def transfer_accelerate_config(
377+
self, use_accelerate_endpoint: bool | None = None
378+
) -> bool:
379+
# accelerate cannot be used in China and us government
380+
if self.region_name and self.region_name.startswith("cn-"):
381+
self.endpoint = (
382+
f"https://{self.s3location.bucket_name}."
383+
f"s3.{self.region_name}.amazonaws.com.cn"
384+
)
385+
return False
386+
# if self.endpoint has been set, e.g. by metadata, no more config is needed.
387+
if self.endpoint is not None:
388+
return self.endpoint.find("s3-accelerate.amazonaws.com") >= 0
389+
if self.use_s3_regional_url:
390+
self.endpoint = (
391+
f"https://{self.s3location.bucket_name}."
392+
f"s3.{self.region_name}.amazonaws.com"
393+
)
394+
return False
395+
else:
396+
if use_accelerate_endpoint is None:
397+
use_accelerate_endpoint = await self._get_bucket_accelerate_config(
398+
self.s3location.bucket_name
399+
)
400+
401+
if use_accelerate_endpoint:
402+
self.endpoint = (
403+
f"https://{self.s3location.bucket_name}.s3-accelerate.amazonaws.com"
404+
)
405+
else:
406+
self.endpoint = (
407+
f"https://{self.s3location.bucket_name}.s3.amazonaws.com"
408+
)
409+
return use_accelerate_endpoint
410+
379411
async def _has_expired_token(self, response: aiohttp.ClientResponse) -> bool:
380412
"""Extract error code and error message from the S3's error response.
381413
Expected format:

test/integ/aio/test_put_get_with_aws_token_async.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ async def test_put_with_invalid_token(tmpdir, aio_connection):
118118
)
119119

120120
client = SnowflakeS3RestClient(meta, creds, stage_info, 8388608)
121+
await client.transfer_accelerate_config(None)
121122
await client.get_file_header(meta.name) # positive case
122123

123124
# negative case, no aws token

test/unit/aio/test_s3_util_async.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ async def test_upload_file_with_s3_upload_failed_error(tmp_path):
105105
)
106106
exc = Exception("Stop executing")
107107

108-
def mock_transfer_accelerate_config(
108+
async def mock_transfer_accelerate_config(
109109
self: SnowflakeS3RestClient,
110110
use_accelerate_endpoint: bool | None = None,
111111
) -> bool:
@@ -117,7 +117,7 @@ def mock_transfer_accelerate_config(
117117
return_value=True,
118118
):
119119
with mock.patch(
120-
"snowflake.connector.s3_storage_client.SnowflakeS3RestClient.transfer_accelerate_config",
120+
"snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient.transfer_accelerate_config",
121121
mock_transfer_accelerate_config,
122122
):
123123
with mock.patch(
@@ -160,6 +160,7 @@ async def test_get_header_expiry_error():
160160
},
161161
8 * megabyte,
162162
)
163+
await rest_client.transfer_accelerate_config(None)
163164

164165
with mock.patch(
165166
"snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient._has_expired_token",
@@ -241,6 +242,7 @@ async def test_upload_expiry_error():
241242
},
242243
8 * megabyte,
243244
)
245+
await rest_client.transfer_accelerate_config(None)
244246

245247
with mock.patch(
246248
"snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient._has_expired_token",
@@ -332,6 +334,7 @@ async def test_download_expiry_error():
332334
},
333335
8 * megabyte,
334336
)
337+
await rest_client.transfer_accelerate_config(None)
335338

336339
with mock.patch(
337340
"snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient._has_expired_token",
@@ -373,12 +376,23 @@ async def test_download_unknown_error(caplog):
373376
message="No, just chuck testing...",
374377
headers={},
375378
)
379+
380+
async def mock_transfer_accelerate_config(
381+
self: SnowflakeS3RestClient,
382+
use_accelerate_endpoint: bool | None = None,
383+
) -> bool:
384+
self.endpoint = f"https://{self.s3location.bucket_name}.s3.awsamazon.com"
385+
return False
386+
376387
with mock.patch(
377388
"snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient._send_request_with_authentication_and_retry",
378389
side_effect=error,
379390
), mock.patch(
380391
"snowflake.connector.aio._file_transfer_agent.SnowflakeFileTransferAgent._transfer_accelerate_config",
381392
side_effect=None,
393+
), mock.patch(
394+
"snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient.transfer_accelerate_config",
395+
mock_transfer_accelerate_config,
382396
):
383397
await agent.execute()
384398
assert agent._file_metadata[0].error_details.status == 400
@@ -422,6 +436,7 @@ async def test_download_retry_exceeded_error():
422436
},
423437
8 * megabyte,
424438
)
439+
await rest_client.transfer_accelerate_config()
425440
rest_client.SLEEP_UNIT = 0
426441

427442
with mock.patch(
@@ -466,7 +481,7 @@ async def test_accelerate_in_china_endpoint():
466481
},
467482
8 * megabyte,
468483
)
469-
assert not rest_client.transfer_accelerate_config()
484+
assert not await rest_client.transfer_accelerate_config()
470485

471486
rest_client = SnowflakeS3RestClient(
472487
meta,
@@ -484,4 +499,4 @@ async def test_accelerate_in_china_endpoint():
484499
},
485500
8 * megabyte,
486501
)
487-
assert not rest_client.transfer_accelerate_config()
502+
assert not await rest_client.transfer_accelerate_config()

0 commit comments

Comments
 (0)