Skip to content

Commit 114750b

Browse files
SNOW-1944208 add unsafe write flag (#2184)
1 parent 900a676 commit 114750b

File tree

11 files changed

+96
-9
lines changed

11 files changed

+96
-9
lines changed

DESCRIPTION.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
1616
- Added support for iceberg tables to `write_pandas`.
1717
- Fixed base64 encoded private key tests.
1818
- Added Wiremock tests.
19-
- Fixed a bug where file permission check happened on Windows
19+
- Fixed a bug where file permission check happened on Windows.
2020
- Added support for File types.
21+
- Added `unsafe_file_write` connection parameter that restores the previous behaviour of saving files downloaded with GET with 644 permissions.
2122

2223
- v3.13.2(January 29, 2025)
2324
- Changed not to use scoped temporary objects.

src/snowflake/connector/azure_storage_client.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,15 @@ def __init__(
6565
chunk_size: int,
6666
stage_info: dict[str, Any],
6767
use_s3_regional_url: bool = False,
68+
unsafe_file_write: bool = False,
6869
) -> None:
69-
super().__init__(meta, stage_info, chunk_size, credentials=credentials)
70+
super().__init__(
71+
meta,
72+
stage_info,
73+
chunk_size,
74+
credentials=credentials,
75+
unsafe_file_write=unsafe_file_write,
76+
)
7077
end_point: str = stage_info["endPoint"]
7178
if end_point.startswith("blob."):
7279
end_point = end_point[len("blob.") :]

src/snowflake/connector/connection.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ class SnowflakeConnection:
379379
server_session_keep_alive: When true, the connector does not destroy the session on the Snowflake server side
380380
before the connector shuts down. Default value is false.
381381
token_file_path: The file path of the token file. If both token and token_file_path are provided, the token in token_file_path will be used.
382+
unsafe_file_write: When true, files downloaded by GET will be saved with 644 permissions. Otherwise, files will be saved with safe - owner-only permissions: 600.
382383
"""
383384

384385
OCSP_ENV_LOCK = Lock()
@@ -736,6 +737,14 @@ def is_query_context_cache_disabled(self) -> bool:
736737
def iobound_tpe_limit(self) -> int | None:
737738
return self._iobound_tpe_limit
738739

740+
@property
741+
def unsafe_file_write(self) -> bool:
742+
return self._unsafe_file_write
743+
744+
@unsafe_file_write.setter
745+
def unsafe_file_write(self, value: bool) -> None:
746+
self._unsafe_file_write = value
747+
739748
def connect(self, **kwargs) -> None:
740749
"""Establishes connection to Snowflake."""
741750
logger.debug("connect")
@@ -1207,6 +1216,11 @@ def __config(self, **kwargs):
12071216
if "protocol" not in kwargs:
12081217
self._protocol = "https"
12091218

1219+
if "unsafe_file_write" in kwargs:
1220+
self._unsafe_file_write = kwargs["unsafe_file_write"]
1221+
else:
1222+
self._unsafe_file_write = False
1223+
12101224
logger.info(
12111225
f"Connecting to {_DOMAIN_NAME_MAP.get(extract_top_level_domain_from_hostname(self._host), 'GLOBAL')} Snowflake domain"
12121226
)

src/snowflake/connector/cursor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,7 @@ def execute(
10601060
multipart_threshold=data.get("threshold"),
10611061
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
10621062
iobound_tpe_limit=self._connection.iobound_tpe_limit,
1063+
unsafe_file_write=self._connection.unsafe_file_write,
10631064
)
10641065
sf_file_transfer_agent.execute()
10651066
data = sf_file_transfer_agent.result()

src/snowflake/connector/encryption_util.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def decrypt_file(
195195
in_filename: str,
196196
chunk_size: int = 64 * kilobyte,
197197
tmp_dir: str | None = None,
198+
unsafe_file_write: bool = False,
198199
) -> str:
199200
"""Decrypts a file and stores the output in the temporary directory.
200201
@@ -213,8 +214,10 @@ def decrypt_file(
213214
temp_output_file = os.path.join(tmp_dir, temp_output_file)
214215

215216
logger.debug("encrypted file: %s, tmp file: %s", in_filename, temp_output_file)
217+
218+
file_opener = None if unsafe_file_write else owner_rw_opener
216219
with open(in_filename, "rb") as infile:
217-
with open(temp_output_file, "wb", opener=owner_rw_opener) as outfile:
220+
with open(temp_output_file, "wb", opener=file_opener) as outfile:
218221
SnowflakeEncryptionUtil.decrypt_stream(
219222
metadata, encryption_material, infile, outfile, chunk_size
220223
)

src/snowflake/connector/file_transfer_agent.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def __init__(
355355
source_from_stream: IO[bytes] | None = None,
356356
use_s3_regional_url: bool = False,
357357
iobound_tpe_limit: int | None = None,
358+
unsafe_file_write: bool = False,
358359
) -> None:
359360
self._cursor = cursor
360361
self._command = command
@@ -386,6 +387,7 @@ def __init__(
386387
self._use_s3_regional_url = use_s3_regional_url
387388
self._credentials: StorageCredential | None = None
388389
self._iobound_tpe_limit = iobound_tpe_limit
390+
self._unsafe_file_write = unsafe_file_write
389391

390392
def execute(self) -> None:
391393
self._parse_command()
@@ -673,6 +675,7 @@ def _create_file_transfer_client(
673675
meta,
674676
self._stage_info,
675677
4 * megabyte,
678+
unsafe_file_write=self._unsafe_file_write,
676679
)
677680
elif self._stage_location_type == AZURE_FS:
678681
return SnowflakeAzureRestClient(
@@ -681,6 +684,7 @@ def _create_file_transfer_client(
681684
AZURE_CHUNK_SIZE,
682685
self._stage_info,
683686
use_s3_regional_url=self._use_s3_regional_url,
687+
unsafe_file_write=self._unsafe_file_write,
684688
)
685689
elif self._stage_location_type == S3_FS:
686690
return SnowflakeS3RestClient(
@@ -690,6 +694,7 @@ def _create_file_transfer_client(
690694
_chunk_size_calculator(meta.src_file_size),
691695
use_accelerate_endpoint=self._use_accelerate_endpoint,
692696
use_s3_regional_url=self._use_s3_regional_url,
697+
unsafe_file_write=self._unsafe_file_write,
693698
)
694699
elif self._stage_location_type == GCS_FS:
695700
return SnowflakeGCSRestClient(
@@ -699,6 +704,7 @@ def _create_file_transfer_client(
699704
self._cursor._connection,
700705
self._command,
701706
use_s3_regional_url=self._use_s3_regional_url,
707+
unsafe_file_write=self._unsafe_file_write,
702708
)
703709
raise Exception(f"{self._stage_location_type} is an unknown stage type")
704710

src/snowflake/connector/gcs_storage_client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(
5454
cnx: SnowflakeConnection,
5555
command: str,
5656
use_s3_regional_url: bool = False,
57+
unsafe_file_write: bool = False,
5758
) -> None:
5859
"""Creates a client object with given stage credentials.
5960
@@ -64,7 +65,12 @@ def __init__(
6465
The client to communicate with GCS.
6566
"""
6667
super().__init__(
67-
meta, stage_info, -1, credentials=credentials, chunked_transfer=False
68+
meta,
69+
stage_info,
70+
-1,
71+
credentials=credentials,
72+
chunked_transfer=False,
73+
unsafe_file_write=unsafe_file_write,
6874
)
6975
self.stage_info = stage_info
7076
self._command = command

src/snowflake/connector/local_storage_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@ def __init__(
2626
meta: SnowflakeFileMeta,
2727
stage_info: dict[str, Any],
2828
chunk_size: int,
29+
unsafe_file_write: bool = False,
2930
) -> None:
30-
super().__init__(meta, stage_info, chunk_size)
31+
super().__init__(
32+
meta, stage_info, chunk_size, unsafe_file_write=unsafe_file_write
33+
)
3134
self.data_file = meta.src_file_name
3235
self.full_dst_file_name: str = os.path.join(
3336
stage_info["location"], os.path.basename(meta.dst_file_name)

src/snowflake/connector/s3_storage_client.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,20 @@ def __init__(
6161
chunk_size: int,
6262
use_accelerate_endpoint: bool | None = None,
6363
use_s3_regional_url: bool = False,
64+
unsafe_file_write: bool = False,
6465
) -> None:
6566
"""Rest client for S3 storage.
6667
6768
Args:
6869
stage_info:
6970
"""
70-
super().__init__(meta, stage_info, chunk_size, credentials=credentials)
71+
super().__init__(
72+
meta,
73+
stage_info,
74+
chunk_size,
75+
credentials=credentials,
76+
unsafe_file_write=unsafe_file_write,
77+
)
7178
# Signature version V4
7279
# Addressing style Virtual Host
7380
self.region_name: str = stage_info["region"]

src/snowflake/connector/storage_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
chunked_transfer: bool | None = True,
7878
credentials: StorageCredential | None = None,
7979
max_retry: int = 5,
80+
unsafe_file_write: bool = False,
8081
) -> None:
8182
self.meta = meta
8283
self.stage_info = stage_info
@@ -115,6 +116,7 @@ def __init__(
115116
self.failed_transfers: int = 0
116117
# only used when PRESIGNED_URL expires
117118
self.last_err_is_presigned_url = False
119+
self.unsafe_file_write = unsafe_file_write
118120

119121
def compress(self) -> None:
120122
if self.meta.require_compress:
@@ -376,7 +378,7 @@ def finish_download(self) -> None:
376378
# For storage utils that do not have the privilege of
377379
# getting the metadata early, both object and metadata
378380
# are downloaded at once. In which case, the file meta will
379-
# be updated with all the metadata that we need and
381+
# be updated with all the metadata that we need, and
380382
# then we can call get_file_header to get just that and also
381383
# preserve the idea of getting metadata in the first place.
382384
# One example of this is the utils that use presigned url
@@ -390,6 +392,7 @@ def finish_download(self) -> None:
390392
meta.encryption_material,
391393
str(self.intermediate_dst_path),
392394
tmp_dir=self.tmp_dir,
395+
unsafe_file_write=self.unsafe_file_write,
393396
)
394397
shutil.move(tmp_dst_file_name, self.full_dst_file_name)
395398
self.intermediate_dst_path.unlink()

0 commit comments

Comments
 (0)