Skip to content

Commit 6ac9d49

Browse files
sfc-gh-mkubiksfc-gh-pczajka
authored andcommitted
SNOW-1944208 add unsafe write flag (#2184)
1 parent b4f5940 commit 6ac9d49

File tree

10 files changed

+94
-8
lines changed

10 files changed

+94
-8
lines changed

src/snowflake/connector/azure_storage_client.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,15 @@ def __init__(
6363
chunk_size: int,
6464
stage_info: dict[str, Any],
6565
use_s3_regional_url: bool = False,
66+
unsafe_file_write: bool = False,
6667
) -> None:
67-
super().__init__(meta, stage_info, chunk_size, credentials=credentials)
68+
super().__init__(
69+
meta,
70+
stage_info,
71+
chunk_size,
72+
credentials=credentials,
73+
unsafe_file_write=unsafe_file_write,
74+
)
6875
end_point: str = stage_info["endPoint"]
6976
if end_point.startswith("blob."):
7077
end_point = end_point[len("blob.") :]

src/snowflake/connector/connection.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ class SnowflakeConnection:
381381
server_session_keep_alive: When true, the connector does not destroy the session on the Snowflake server side
382382
before the connector shuts down. Default value is false.
383383
token_file_path: The file path of the token file. If both token and token_file_path are provided, the token in token_file_path will be used.
384+
unsafe_file_write: When true, files downloaded by GET will be saved with 644 permissions. Otherwise, files will be saved with safe - owner-only permissions: 600.
384385
"""
385386

386387
OCSP_ENV_LOCK = Lock()
@@ -761,6 +762,14 @@ def is_query_context_cache_disabled(self) -> bool:
761762
def iobound_tpe_limit(self) -> int | None:
762763
return self._iobound_tpe_limit
763764

765+
@property
766+
def unsafe_file_write(self) -> bool:
767+
return self._unsafe_file_write
768+
769+
@unsafe_file_write.setter
770+
def unsafe_file_write(self, value: bool) -> None:
771+
self._unsafe_file_write = value
772+
764773
def connect(self, **kwargs) -> None:
765774
"""Establishes connection to Snowflake."""
766775
logger.debug("connect")
@@ -1232,6 +1241,11 @@ def __config(self, **kwargs):
12321241
if "protocol" not in kwargs:
12331242
self._protocol = "https"
12341243

1244+
if "unsafe_file_write" in kwargs:
1245+
self._unsafe_file_write = kwargs["unsafe_file_write"]
1246+
else:
1247+
self._unsafe_file_write = False
1248+
12351249
logger.info(
12361250
f"Connecting to {_DOMAIN_NAME_MAP.get(extract_top_level_domain_from_hostname(self._host), 'GLOBAL')} Snowflake domain"
12371251
)

src/snowflake/connector/cursor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,6 +1058,7 @@ def execute(
10581058
multipart_threshold=data.get("threshold"),
10591059
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
10601060
iobound_tpe_limit=self._connection.iobound_tpe_limit,
1061+
unsafe_file_write=self._connection.unsafe_file_write,
10611062
)
10621063
sf_file_transfer_agent.execute()
10631064
data = sf_file_transfer_agent.result()

src/snowflake/connector/encryption_util.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def decrypt_file(
194194
in_filename: str,
195195
chunk_size: int = 64 * kilobyte,
196196
tmp_dir: str | None = None,
197+
unsafe_file_write: bool = False,
197198
) -> str:
198199
"""Decrypts a file and stores the output in the temporary directory.
199200
@@ -212,8 +213,10 @@ def decrypt_file(
212213
temp_output_file = os.path.join(tmp_dir, temp_output_file)
213214

214215
logger.debug("encrypted file: %s, tmp file: %s", in_filename, temp_output_file)
216+
217+
file_opener = None if unsafe_file_write else owner_rw_opener
215218
with open(in_filename, "rb") as infile:
216-
with open(temp_output_file, "wb") as outfile:
219+
with open(temp_output_file, "wb", opener=file_opener) as outfile:
217220
SnowflakeEncryptionUtil.decrypt_stream(
218221
metadata, encryption_material, infile, outfile, chunk_size
219222
)

src/snowflake/connector/file_transfer_agent.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def __init__(
355355
source_from_stream: IO[bytes] | None = None,
356356
use_s3_regional_url: bool = False,
357357
iobound_tpe_limit: int | None = None,
358+
unsafe_file_write: bool = False,
358359
) -> None:
359360
self._cursor = cursor
360361
self._command = command
@@ -386,6 +387,7 @@ def __init__(
386387
self._use_s3_regional_url = use_s3_regional_url
387388
self._credentials: StorageCredential | None = None
388389
self._iobound_tpe_limit = iobound_tpe_limit
390+
self._unsafe_file_write = unsafe_file_write
389391

390392
def execute(self) -> None:
391393
self._parse_command()
@@ -673,6 +675,7 @@ def _create_file_transfer_client(
673675
meta,
674676
self._stage_info,
675677
4 * megabyte,
678+
unsafe_file_write=self._unsafe_file_write,
676679
)
677680
elif self._stage_location_type == AZURE_FS:
678681
return SnowflakeAzureRestClient(
@@ -681,6 +684,7 @@ def _create_file_transfer_client(
681684
AZURE_CHUNK_SIZE,
682685
self._stage_info,
683686
use_s3_regional_url=self._use_s3_regional_url,
687+
unsafe_file_write=self._unsafe_file_write,
684688
)
685689
elif self._stage_location_type == S3_FS:
686690
return SnowflakeS3RestClient(
@@ -690,6 +694,7 @@ def _create_file_transfer_client(
690694
_chunk_size_calculator(meta.src_file_size),
691695
use_accelerate_endpoint=self._use_accelerate_endpoint,
692696
use_s3_regional_url=self._use_s3_regional_url,
697+
unsafe_file_write=self._unsafe_file_write,
693698
)
694699
elif self._stage_location_type == GCS_FS:
695700
return SnowflakeGCSRestClient(
@@ -699,6 +704,7 @@ def _create_file_transfer_client(
699704
self._cursor._connection,
700705
self._command,
701706
use_s3_regional_url=self._use_s3_regional_url,
707+
unsafe_file_write=self._unsafe_file_write,
702708
)
703709
raise Exception(f"{self._stage_location_type} is an unknown stage type")
704710

src/snowflake/connector/gcs_storage_client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(
5454
cnx: SnowflakeConnection,
5555
command: str,
5656
use_s3_regional_url: bool = False,
57+
unsafe_file_write: bool = False,
5758
) -> None:
5859
"""Creates a client object with given stage credentials.
5960
@@ -64,7 +65,12 @@ def __init__(
6465
The client to communicate with GCS.
6566
"""
6667
super().__init__(
67-
meta, stage_info, -1, credentials=credentials, chunked_transfer=False
68+
meta,
69+
stage_info,
70+
-1,
71+
credentials=credentials,
72+
chunked_transfer=False,
73+
unsafe_file_write=unsafe_file_write,
6874
)
6975
self.stage_info = stage_info
7076
self._command = command

src/snowflake/connector/local_storage_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@ def __init__(
2626
meta: SnowflakeFileMeta,
2727
stage_info: dict[str, Any],
2828
chunk_size: int,
29+
unsafe_file_write: bool = False,
2930
) -> None:
30-
super().__init__(meta, stage_info, chunk_size)
31+
super().__init__(
32+
meta, stage_info, chunk_size, unsafe_file_write=unsafe_file_write
33+
)
3134
self.data_file = meta.src_file_name
3235
self.full_dst_file_name: str = os.path.join(
3336
stage_info["location"], os.path.basename(meta.dst_file_name)

src/snowflake/connector/s3_storage_client.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,20 @@ def __init__(
6161
chunk_size: int,
6262
use_accelerate_endpoint: bool | None = None,
6363
use_s3_regional_url: bool = False,
64+
unsafe_file_write: bool = False,
6465
) -> None:
6566
"""Rest client for S3 storage.
6667
6768
Args:
6869
stage_info:
6970
"""
70-
super().__init__(meta, stage_info, chunk_size, credentials=credentials)
71+
super().__init__(
72+
meta,
73+
stage_info,
74+
chunk_size,
75+
credentials=credentials,
76+
unsafe_file_write=unsafe_file_write,
77+
)
7178
# Signature version V4
7279
# Addressing style Virtual Host
7380
self.region_name: str = stage_info["region"]

src/snowflake/connector/storage_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
chunked_transfer: bool | None = True,
7878
credentials: StorageCredential | None = None,
7979
max_retry: int = 5,
80+
unsafe_file_write: bool = False,
8081
) -> None:
8182
self.meta = meta
8283
self.stage_info = stage_info
@@ -115,6 +116,7 @@ def __init__(
115116
self.failed_transfers: int = 0
116117
# only used when PRESIGNED_URL expires
117118
self.last_err_is_presigned_url = False
119+
self.unsafe_file_write = unsafe_file_write
118120

119121
def compress(self) -> None:
120122
if self.meta.require_compress:
@@ -371,7 +373,7 @@ def finish_download(self) -> None:
371373
# For storage utils that do not have the privilege of
372374
# getting the metadata early, both object and metadata
373375
# are downloaded at once. In which case, the file meta will
374-
# be updated with all the metadata that we need and
376+
# be updated with all the metadata that we need, and
375377
# then we can call get_file_header to get just that and also
376378
# preserve the idea of getting metadata in the first place.
377379
# One example of this is the utils that use presigned url
@@ -385,6 +387,7 @@ def finish_download(self) -> None:
385387
meta.encryption_material,
386388
str(self.intermediate_dst_path),
387389
tmp_dir=self.tmp_dir,
390+
unsafe_file_write=self.unsafe_file_write,
388391
)
389392
shutil.move(tmp_dst_file_name, self.full_dst_file_name)
390393
self.intermediate_dst_path.unlink()

test/integ/test_put_get.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020

2121
from snowflake.connector import OperationalError
2222

23+
try:
24+
from src.snowflake.connector.compat import IS_WINDOWS
25+
except ImportError:
26+
import platform
27+
28+
IS_WINDOWS = platform.system() == "Windows"
29+
2330
try:
2431
from snowflake.connector.util_text import random_string
2532
except ImportError:
@@ -740,16 +747,44 @@ def test_get_empty_file(tmp_path, conn_cnx):
740747

741748

742749
@pytest.mark.skipolddriver
750+
@pytest.mark.skipif(IS_WINDOWS, reason="not supported on Windows")
743751
def test_get_file_permission(tmp_path, conn_cnx, caplog):
744752
test_file = tmp_path / "data.csv"
745753
test_file.write_text("1,2,3\n")
746-
stage_name = random_string(5, "test_get_empty_file_")
754+
stage_name = random_string(5, "test_get_file_permission_")
747755
with conn_cnx() as cnx:
748756
with cnx.cursor() as cur:
749757
cur.execute(f"create temporary stage {stage_name}")
750758
filename_in_put = str(test_file).replace("\\", "/")
751759
cur.execute(
752-
f"PUT 'file://{filename_in_put}' @{stage_name}",
760+
f"PUT 'file://{filename_in_put}' @{stage_name} AUTO_COMPRESS=FALSE",
761+
)
762+
763+
with caplog.at_level(logging.ERROR):
764+
cur.execute(f"GET @{stage_name}/data.csv file://{tmp_path}")
765+
assert "FileNotFoundError" not in caplog.text
766+
767+
default_mask = os.umask(0)
768+
os.umask(default_mask)
769+
770+
assert (
771+
oct(os.stat(test_file).st_mode)[-3:] == oct(0o600 & ~default_mask)[-3:]
772+
)
773+
774+
775+
@pytest.mark.skipolddriver
776+
@pytest.mark.skipif(IS_WINDOWS, reason="not supported on Windows")
777+
def test_get_unsafe_file_permission_when_flag_set(tmp_path, conn_cnx, caplog):
778+
test_file = tmp_path / "data.csv"
779+
test_file.write_text("1,2,3\n")
780+
stage_name = random_string(5, "test_get_file_permission_")
781+
with conn_cnx() as cnx:
782+
cnx.unsafe_file_write = True
783+
with cnx.cursor() as cur:
784+
cur.execute(f"create temporary stage {stage_name}")
785+
filename_in_put = str(test_file).replace("\\", "/")
786+
cur.execute(
787+
f"PUT 'file://{filename_in_put}' @{stage_name} AUTO_COMPRESS=FALSE",
753788
)
754789

755790
with caplog.at_level(logging.ERROR):
@@ -764,6 +799,7 @@ def test_get_file_permission(tmp_path, conn_cnx, caplog):
764799
assert (
765800
oct(os.stat(test_file).st_mode)[-3:] == oct(0o666 & ~default_mask)[-3:]
766801
)
802+
cnx.unsafe_file_write = False
767803

768804

769805
@pytest.mark.skipolddriver

0 commit comments

Comments
 (0)