Skip to content

Commit 9ab4bba

Browse files
Apply #2184+#2413 to async code
1 parent a16f77d commit 9ab4bba

File tree

6 files changed

+32
-16
lines changed

6 files changed

+32
-16
lines changed

src/snowflake/connector/aio/_azure_storage_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,15 @@ def __init__(
5050
chunk_size: int,
5151
stage_info: dict[str, Any],
5252
use_s3_regional_url: bool = False,
53+
unsafe_file_write: bool = False,
5354
) -> None:
5455
SnowflakeAzureRestClientSync.__init__(
5556
self,
5657
meta=meta,
5758
stage_info=stage_info,
5859
chunk_size=chunk_size,
5960
credentials=credentials,
61+
unsafe_file_write=unsafe_file_write,
6062
)
6163

6264
async def _has_expired_token(self, response: aiohttp.ClientResponse) -> bool:

src/snowflake/connector/aio/_cursor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ async def execute(
662662
source_from_stream=file_stream,
663663
multipart_threshold=data.get("threshold"),
664664
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
665+
unsafe_file_write=self._connection.unsafe_file_write,
665666
)
666667
await sf_file_transfer_agent.execute()
667668
data = sf_file_transfer_agent.result()

src/snowflake/connector/aio/_file_transfer_agent.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -62,24 +62,26 @@ def __init__(
6262
multipart_threshold: int | None = None,
6363
source_from_stream: IO[bytes] | None = None,
6464
use_s3_regional_url: bool = False,
65+
unsafe_file_write: bool = False,
6566
) -> None:
6667
super().__init__(
67-
cursor,
68-
command,
69-
ret,
70-
put_callback,
71-
put_azure_callback,
72-
put_callback_output_stream,
73-
get_callback,
74-
get_azure_callback,
75-
get_callback_output_stream,
76-
show_progress_bar,
77-
raise_put_get_error,
78-
force_put_overwrite,
79-
skip_upload_on_content_match,
80-
multipart_threshold,
81-
source_from_stream,
82-
use_s3_regional_url,
68+
cursor=cursor,
69+
command=command,
70+
ret=ret,
71+
put_callback=put_callback,
72+
put_azure_callback=put_azure_callback,
73+
put_callback_output_stream=put_callback_output_stream,
74+
get_callback=get_callback,
75+
get_azure_callback=get_azure_callback,
76+
get_callback_output_stream=get_callback_output_stream,
77+
show_progress_bar=show_progress_bar,
78+
raise_put_get_error=raise_put_get_error,
79+
force_put_overwrite=force_put_overwrite,
80+
skip_upload_on_content_match=skip_upload_on_content_match,
81+
multipart_threshold=multipart_threshold,
82+
source_from_stream=source_from_stream,
83+
use_s3_regional_url=use_s3_regional_url,
84+
unsafe_file_write=unsafe_file_write,
8385
)
8486

8587
async def execute(self) -> None:
@@ -271,6 +273,7 @@ async def _create_file_transfer_client(
271273
meta,
272274
self._stage_info,
273275
4 * megabyte,
276+
unsafe_file_write=self._unsafe_file_write,
274277
)
275278
elif self._stage_location_type == AZURE_FS:
276279
return SnowflakeAzureRestClient(
@@ -279,6 +282,7 @@ async def _create_file_transfer_client(
279282
AZURE_CHUNK_SIZE,
280283
self._stage_info,
281284
use_s3_regional_url=self._use_s3_regional_url,
285+
unsafe_file_write=self._unsafe_file_write,
282286
)
283287
elif self._stage_location_type == S3_FS:
284288
client = SnowflakeS3RestClient(
@@ -288,6 +292,7 @@ async def _create_file_transfer_client(
288292
chunk_size=_chunk_size_calculator(meta.src_file_size),
289293
use_accelerate_endpoint=self._use_accelerate_endpoint,
290294
use_s3_regional_url=self._use_s3_regional_url,
295+
unsafe_file_write=self._unsafe_file_write,
291296
)
292297
await client.transfer_accelerate_config(self._use_accelerate_endpoint)
293298
return client
@@ -299,6 +304,7 @@ async def _create_file_transfer_client(
299304
self._cursor._connection,
300305
self._command,
301306
use_s3_regional_url=self._use_s3_regional_url,
307+
unsafe_file_write=self._unsafe_file_write,
302308
)
303309
if client.security_token:
304310
logger.debug(f"len(GCS_ACCESS_TOKEN): {len(client.security_token)}")

src/snowflake/connector/aio/_gcs_storage_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
cnx: SnowflakeConnection,
4040
command: str,
4141
use_s3_regional_url: bool = False,
42+
unsafe_file_write: bool = False,
4243
) -> None:
4344
"""Creates a client object with given stage credentials.
4445
@@ -55,6 +56,7 @@ def __init__(
5556
chunk_size=-1,
5657
credentials=credentials,
5758
chunked_transfer=False,
59+
unsafe_file_write=unsafe_file_write,
5860
)
5961
self.stage_info = stage_info
6062
self._command = command

src/snowflake/connector/aio/_s3_storage_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
chunk_size: int,
4949
use_accelerate_endpoint: bool | None = None,
5050
use_s3_regional_url: bool = False,
51+
unsafe_file_write: bool = False,
5152
) -> None:
5253
"""Rest client for S3 storage.
5354
@@ -60,6 +61,7 @@ def __init__(
6061
stage_info=stage_info,
6162
chunk_size=chunk_size,
6263
credentials=credentials,
64+
unsafe_file_write=unsafe_file_write,
6365
)
6466
# Signature version V4
6567
# Addressing style Virtual Host

src/snowflake/connector/aio/_storage_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
chunked_transfer: bool | None = True,
3838
credentials: StorageCredential | None = None,
3939
max_retry: int = 5,
40+
unsafe_file_write: bool = False,
4041
) -> None:
4142
SnowflakeStorageClientSync.__init__(
4243
self,
@@ -46,6 +47,7 @@ def __init__(
4647
chunked_transfer=chunked_transfer,
4748
credentials=credentials,
4849
max_retry=max_retry,
50+
unsafe_file_write=unsafe_file_write,
4951
)
5052

5153
@abstractmethod
@@ -162,6 +164,7 @@ async def finish_download(self) -> None:
162164
meta.encryption_material,
163165
str(self.intermediate_dst_path),
164166
tmp_dir=self.tmp_dir,
167+
unsafe_file_write=self.unsafe_file_write,
165168
)
166169
shutil.move(tmp_dst_file_name, self.full_dst_file_name)
167170
self.intermediate_dst_path.unlink()

0 commit comments

Comments
 (0)