Skip to content

Commit e687be4

Browse files
sfc-gh-yuwangsfc-gh-aling
authored andcommitted
Asyncio support for aws file transfer (#2031)
1 parent a753daa commit e687be4

12 files changed

+2852
-4
lines changed

src/snowflake/connector/aio/_cursor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ async def execute(
615615
)
616616
logger.debug("PUT OR GET: %s", self.is_file_transfer)
617617
if self.is_file_transfer:
618-
from ..file_transfer_agent import SnowflakeFileTransferAgent
618+
from ._file_transfer_agent import SnowflakeFileTransferAgent
619619

620620
# Decide whether to use the old, or new code path
621621
sf_file_transfer_agent = SnowflakeFileTransferAgent(
@@ -637,7 +637,7 @@ async def execute(
637637
multipart_threshold=data.get("threshold"),
638638
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
639639
)
640-
sf_file_transfer_agent.execute()
640+
await sf_file_transfer_agent.execute()
641641
data = sf_file_transfer_agent.result()
642642
self._total_rowcount = len(data["rowset"]) if "rowset" in data else -1
643643

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
#
2+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
from __future__ import annotations
6+
7+
import asyncio
8+
import os
9+
import sys
10+
from logging import getLogger
11+
from typing import IO, TYPE_CHECKING, Any
12+
13+
from ..azure_storage_client import SnowflakeAzureRestClient
14+
from ..constants import (
15+
AZURE_CHUNK_SIZE,
16+
AZURE_FS,
17+
CMD_TYPE_DOWNLOAD,
18+
CMD_TYPE_UPLOAD,
19+
GCS_FS,
20+
LOCAL_FS,
21+
S3_FS,
22+
ResultStatus,
23+
megabyte,
24+
)
25+
from ..errorcode import ER_FILE_NOT_EXISTS
26+
from ..errors import Error, OperationalError
27+
from ..file_transfer_agent import SnowflakeFileMeta
28+
from ..file_transfer_agent import (
29+
SnowflakeFileTransferAgent as SnowflakeFileTransferAgentSync,
30+
)
31+
from ..file_transfer_agent import SnowflakeProgressPercentage, _chunk_size_calculator
32+
from ..gcs_storage_client import SnowflakeGCSRestClient
33+
from ..local_storage_client import SnowflakeLocalStorageClient
34+
from ._s3_storage_client import SnowflakeS3RestClient
35+
from ._storage_client import SnowflakeStorageClient
36+
37+
if TYPE_CHECKING: # pragma: no cover
38+
from ._cursor import SnowflakeCursor
39+
40+
41+
logger = getLogger(__name__)
42+
43+
44+
class SnowflakeFileTransferAgent(SnowflakeFileTransferAgentSync):
45+
"""Snowflake File Transfer Agent provides cloud provider independent implementation for putting/getting files."""
46+
47+
def __init__(
48+
self,
49+
cursor: SnowflakeCursor,
50+
command: str,
51+
ret: dict[str, Any],
52+
put_callback: type[SnowflakeProgressPercentage] | None = None,
53+
put_azure_callback: type[SnowflakeProgressPercentage] | None = None,
54+
put_callback_output_stream: IO[str] = sys.stdout,
55+
get_callback: type[SnowflakeProgressPercentage] | None = None,
56+
get_azure_callback: type[SnowflakeProgressPercentage] | None = None,
57+
get_callback_output_stream: IO[str] = sys.stdout,
58+
show_progress_bar: bool = True,
59+
raise_put_get_error: bool = True,
60+
force_put_overwrite: bool = True,
61+
skip_upload_on_content_match: bool = False,
62+
multipart_threshold: int | None = None,
63+
source_from_stream: IO[bytes] | None = None,
64+
use_s3_regional_url: bool = False,
65+
) -> None:
66+
super().__init__(
67+
cursor,
68+
command,
69+
ret,
70+
put_callback,
71+
put_azure_callback,
72+
put_callback_output_stream,
73+
get_callback,
74+
get_azure_callback,
75+
get_callback_output_stream,
76+
show_progress_bar,
77+
raise_put_get_error,
78+
force_put_overwrite,
79+
skip_upload_on_content_match,
80+
multipart_threshold,
81+
source_from_stream,
82+
use_s3_regional_url,
83+
)
84+
85+
async def execute(self) -> None:
86+
self._parse_command()
87+
self._init_file_metadata()
88+
89+
if self._command_type == CMD_TYPE_UPLOAD:
90+
self._process_file_compression_type()
91+
92+
for m in self._file_metadata:
93+
m.sfagent = self
94+
95+
self._transfer_accelerate_config()
96+
97+
if self._command_type == CMD_TYPE_DOWNLOAD:
98+
if not os.path.isdir(self._local_location):
99+
os.makedirs(self._local_location)
100+
101+
if self._stage_location_type == LOCAL_FS:
102+
if not os.path.isdir(self._stage_info["location"]):
103+
os.makedirs(self._stage_info["location"])
104+
105+
for m in self._file_metadata:
106+
m.overwrite = self._overwrite
107+
m.skip_upload_on_content_match = self._skip_upload_on_content_match
108+
m.sfagent = self
109+
if self._stage_location_type != LOCAL_FS:
110+
m.put_callback = self._put_callback
111+
m.put_azure_callback = self._put_azure_callback
112+
m.put_callback_output_stream = self._put_callback_output_stream
113+
m.get_callback = self._get_callback
114+
m.get_azure_callback = self._get_azure_callback
115+
m.get_callback_output_stream = self._get_callback_output_stream
116+
m.show_progress_bar = self._show_progress_bar
117+
118+
# multichunk threshold
119+
m.multipart_threshold = self._multipart_threshold
120+
121+
# TODO: https://snowflakecomputing.atlassian.net/browse/SNOW-1625364
122+
logger.debug(f"parallel=[{self._parallel}]")
123+
if self._raise_put_get_error and not self._file_metadata:
124+
Error.errorhandler_wrapper(
125+
self._cursor.connection,
126+
self._cursor,
127+
OperationalError,
128+
{
129+
"msg": "While getting file(s) there was an error: "
130+
"the file does not exist.",
131+
"errno": ER_FILE_NOT_EXISTS,
132+
},
133+
)
134+
await self.transfer(self._file_metadata)
135+
136+
# turn enum to string, in order to have backward compatible interface
137+
138+
for result in self._results:
139+
result.result_status = result.result_status.value
140+
141+
async def transfer(self, metas: list[SnowflakeFileMeta]) -> None:
142+
files = [self._create_file_transfer_client(m) for m in metas]
143+
is_upload = self._command_type == CMD_TYPE_UPLOAD
144+
finish_download_upload_tasks = []
145+
146+
async def preprocess_done_cb(
147+
success: bool,
148+
result: Any,
149+
done_client: SnowflakeStorageClient,
150+
) -> None:
151+
if not success:
152+
logger.debug(f"Failed to prepare {done_client.meta.name}.")
153+
try:
154+
if is_upload:
155+
await done_client.finish_upload()
156+
done_client.delete_client_data()
157+
else:
158+
await done_client.finish_download()
159+
except Exception as error:
160+
done_client.meta.error_details = error
161+
elif done_client.meta.result_status == ResultStatus.SKIPPED:
162+
# this case applies to upload only
163+
return
164+
else:
165+
try:
166+
logger.debug(f"Finished preparing file {done_client.meta.name}")
167+
tasks = []
168+
for _chunk_id in range(done_client.num_of_chunks):
169+
task = (
170+
asyncio.create_task(done_client.upload_chunk(_chunk_id))
171+
if is_upload
172+
else asyncio.create_task(
173+
done_client.download_chunk(_chunk_id)
174+
)
175+
)
176+
task.add_done_callback(
177+
lambda t, dc=done_client, _chunk_id=_chunk_id: transfer_done_cb(
178+
t, dc, _chunk_id
179+
)
180+
)
181+
tasks.append(task)
182+
await asyncio.gather(*tasks)
183+
await asyncio.gather(*finish_download_upload_tasks)
184+
except Exception as error:
185+
done_client.meta.error_details = error
186+
187+
def transfer_done_cb(
188+
task: asyncio.Task,
189+
done_client: SnowflakeStorageClient,
190+
chunk_id: int,
191+
) -> None:
192+
# Note: chunk_id is 0 based while num_of_chunks is count
193+
logger.debug(
194+
f"Chunk {chunk_id}/{done_client.num_of_chunks} of file {done_client.meta.name} reached callback"
195+
)
196+
if task.exception():
197+
done_client.failed_transfers += 1
198+
logger.debug(
199+
f"Chunk {chunk_id} of file {done_client.meta.name} failed to transfer for unexpected exception {task.exception()}"
200+
)
201+
else:
202+
done_client.successful_transfers += 1
203+
logger.debug(
204+
f"Chunk progress: {done_client.meta.name}: completed: {done_client.successful_transfers} failed: {done_client.failed_transfers} total: {done_client.num_of_chunks}"
205+
)
206+
if (
207+
done_client.successful_transfers + done_client.failed_transfers
208+
== done_client.num_of_chunks
209+
):
210+
if is_upload:
211+
finish_upload_task = asyncio.create_task(
212+
done_client.finish_upload()
213+
)
214+
finish_download_upload_tasks.append(finish_upload_task)
215+
done_client.delete_client_data()
216+
else:
217+
finish_download_task = asyncio.create_task(
218+
done_client.finish_download()
219+
)
220+
finish_download_task.add_done_callback(
221+
lambda t, dc=done_client: postprocess_done_cb(t, dc)
222+
)
223+
finish_download_upload_tasks.append(finish_download_task)
224+
225+
def postprocess_done_cb(
226+
task: asyncio.Task,
227+
done_client: SnowflakeStorageClient,
228+
) -> None:
229+
logger.debug(f"File {done_client.meta.name} reached postprocess callback")
230+
231+
if task.exception():
232+
done_client.failed_transfers += 1
233+
logger.debug(
234+
f"File {done_client.meta.name} failed to transfer for unexpected exception {task.exception()}"
235+
)
236+
# Whether there was an exception or not, we're done the file.
237+
238+
task_of_files = []
239+
for file_client in files:
240+
try:
241+
# TODO: https://snowflakecomputing.atlassian.net/browse/SNOW-1708819
242+
res = (
243+
await file_client.prepare_upload()
244+
if is_upload
245+
else await file_client.prepare_download()
246+
)
247+
is_successful = True
248+
except Exception as e:
249+
res = e
250+
file_client.meta.error_details = e
251+
is_successful = False
252+
253+
task = asyncio.create_task(
254+
preprocess_done_cb(is_successful, res, done_client=file_client)
255+
)
256+
task_of_files.append(task)
257+
await asyncio.gather(*task_of_files)
258+
259+
self._results = metas
260+
261+
def _create_file_transfer_client(
262+
self, meta: SnowflakeFileMeta
263+
) -> SnowflakeStorageClient:
264+
if self._stage_location_type == LOCAL_FS:
265+
return SnowflakeLocalStorageClient(
266+
meta,
267+
self._stage_info,
268+
4 * megabyte,
269+
)
270+
elif self._stage_location_type == AZURE_FS:
271+
return SnowflakeAzureRestClient(
272+
meta,
273+
self._credentials,
274+
AZURE_CHUNK_SIZE,
275+
self._stage_info,
276+
use_s3_regional_url=self._use_s3_regional_url,
277+
)
278+
elif self._stage_location_type == S3_FS:
279+
return SnowflakeS3RestClient(
280+
meta=meta,
281+
credentials=self._credentials,
282+
stage_info=self._stage_info,
283+
chunk_size=_chunk_size_calculator(meta.src_file_size),
284+
use_accelerate_endpoint=self._use_accelerate_endpoint,
285+
use_s3_regional_url=self._use_s3_regional_url,
286+
)
287+
elif self._stage_location_type == GCS_FS:
288+
return SnowflakeGCSRestClient(
289+
meta,
290+
self._credentials,
291+
self._stage_info,
292+
self._cursor._connection,
293+
self._command,
294+
use_s3_regional_url=self._use_s3_regional_url,
295+
)
296+
raise Exception(f"{self._stage_location_type} is an unknown stage type")

0 commit comments

Comments
 (0)