Skip to content

Commit b29daf1

Browse files
sfc-gh-mkubiksfc-gh-pczajka
authored andcommitted
SNOW-1944208 add unsafe write flag (#2184)
1 parent 8ab954c commit b29daf1

File tree

10 files changed

+94
-8
lines changed

10 files changed

+94
-8
lines changed

src/snowflake/connector/azure_storage_client.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,15 @@ def __init__(
6565
chunk_size: int,
6666
stage_info: dict[str, Any],
6767
use_s3_regional_url: bool = False,
68+
unsafe_file_write: bool = False,
6869
) -> None:
69-
super().__init__(meta, stage_info, chunk_size, credentials=credentials)
70+
super().__init__(
71+
meta,
72+
stage_info,
73+
chunk_size,
74+
credentials=credentials,
75+
unsafe_file_write=unsafe_file_write,
76+
)
7077
end_point: str = stage_info["endPoint"]
7178
if end_point.startswith("blob."):
7279
end_point = end_point[len("blob.") :]

src/snowflake/connector/connection.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ class SnowflakeConnection:
383383
server_session_keep_alive: When true, the connector does not destroy the session on the Snowflake server side
384384
before the connector shuts down. Default value is false.
385385
token_file_path: The file path of the token file. If both token and token_file_path are provided, the token in token_file_path will be used.
386+
unsafe_file_write: When true, files downloaded by GET will be saved with 644 permissions. Otherwise, files will be saved with safe - owner-only permissions: 600.
386387
"""
387388

388389
OCSP_ENV_LOCK = Lock()
@@ -763,6 +764,14 @@ def is_query_context_cache_disabled(self) -> bool:
763764
def iobound_tpe_limit(self) -> int | None:
764765
return self._iobound_tpe_limit
765766

767+
@property
768+
def unsafe_file_write(self) -> bool:
769+
return self._unsafe_file_write
770+
771+
@unsafe_file_write.setter
772+
def unsafe_file_write(self, value: bool) -> None:
773+
self._unsafe_file_write = value
774+
766775
def connect(self, **kwargs) -> None:
767776
"""Establishes connection to Snowflake."""
768777
logger.debug("connect")
@@ -1234,6 +1243,11 @@ def __config(self, **kwargs):
12341243
if "protocol" not in kwargs:
12351244
self._protocol = "https"
12361245

1246+
if "unsafe_file_write" in kwargs:
1247+
self._unsafe_file_write = kwargs["unsafe_file_write"]
1248+
else:
1249+
self._unsafe_file_write = False
1250+
12371251
logger.info(
12381252
f"Connecting to {_DOMAIN_NAME_MAP.get(extract_top_level_domain_from_hostname(self._host), 'GLOBAL')} Snowflake domain"
12391253
)

src/snowflake/connector/cursor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,7 @@ def execute(
10601060
multipart_threshold=data.get("threshold"),
10611061
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
10621062
iobound_tpe_limit=self._connection.iobound_tpe_limit,
1063+
unsafe_file_write=self._connection.unsafe_file_write,
10631064
)
10641065
sf_file_transfer_agent.execute()
10651066
data = sf_file_transfer_agent.result()

src/snowflake/connector/encryption_util.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def decrypt_file(
195195
in_filename: str,
196196
chunk_size: int = 64 * kilobyte,
197197
tmp_dir: str | None = None,
198+
unsafe_file_write: bool = False,
198199
) -> str:
199200
"""Decrypts a file and stores the output in the temporary directory.
200201
@@ -213,8 +214,10 @@ def decrypt_file(
213214
temp_output_file = os.path.join(tmp_dir, temp_output_file)
214215

215216
logger.debug("encrypted file: %s, tmp file: %s", in_filename, temp_output_file)
217+
218+
file_opener = None if unsafe_file_write else owner_rw_opener
216219
with open(in_filename, "rb") as infile:
217-
with open(temp_output_file, "wb", opener=owner_rw_opener) as outfile:
220+
with open(temp_output_file, "wb", opener=file_opener) as outfile:
218221
SnowflakeEncryptionUtil.decrypt_stream(
219222
metadata, encryption_material, infile, outfile, chunk_size
220223
)

src/snowflake/connector/file_transfer_agent.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def __init__(
355355
source_from_stream: IO[bytes] | None = None,
356356
use_s3_regional_url: bool = False,
357357
iobound_tpe_limit: int | None = None,
358+
unsafe_file_write: bool = False,
358359
) -> None:
359360
self._cursor = cursor
360361
self._command = command
@@ -386,6 +387,7 @@ def __init__(
386387
self._use_s3_regional_url = use_s3_regional_url
387388
self._credentials: StorageCredential | None = None
388389
self._iobound_tpe_limit = iobound_tpe_limit
390+
self._unsafe_file_write = unsafe_file_write
389391

390392
def execute(self) -> None:
391393
self._parse_command()
@@ -673,6 +675,7 @@ def _create_file_transfer_client(
673675
meta,
674676
self._stage_info,
675677
4 * megabyte,
678+
unsafe_file_write=self._unsafe_file_write,
676679
)
677680
elif self._stage_location_type == AZURE_FS:
678681
return SnowflakeAzureRestClient(
@@ -681,6 +684,7 @@ def _create_file_transfer_client(
681684
AZURE_CHUNK_SIZE,
682685
self._stage_info,
683686
use_s3_regional_url=self._use_s3_regional_url,
687+
unsafe_file_write=self._unsafe_file_write,
684688
)
685689
elif self._stage_location_type == S3_FS:
686690
return SnowflakeS3RestClient(
@@ -690,6 +694,7 @@ def _create_file_transfer_client(
690694
_chunk_size_calculator(meta.src_file_size),
691695
use_accelerate_endpoint=self._use_accelerate_endpoint,
692696
use_s3_regional_url=self._use_s3_regional_url,
697+
unsafe_file_write=self._unsafe_file_write,
693698
)
694699
elif self._stage_location_type == GCS_FS:
695700
return SnowflakeGCSRestClient(
@@ -699,6 +704,7 @@ def _create_file_transfer_client(
699704
self._cursor._connection,
700705
self._command,
701706
use_s3_regional_url=self._use_s3_regional_url,
707+
unsafe_file_write=self._unsafe_file_write,
702708
)
703709
raise Exception(f"{self._stage_location_type} is an unknown stage type")
704710

src/snowflake/connector/gcs_storage_client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(
5454
cnx: SnowflakeConnection,
5555
command: str,
5656
use_s3_regional_url: bool = False,
57+
unsafe_file_write: bool = False,
5758
) -> None:
5859
"""Creates a client object with given stage credentials.
5960
@@ -64,7 +65,12 @@ def __init__(
6465
The client to communicate with GCS.
6566
"""
6667
super().__init__(
67-
meta, stage_info, -1, credentials=credentials, chunked_transfer=False
68+
meta,
69+
stage_info,
70+
-1,
71+
credentials=credentials,
72+
chunked_transfer=False,
73+
unsafe_file_write=unsafe_file_write,
6874
)
6975
self.stage_info = stage_info
7076
self._command = command

src/snowflake/connector/local_storage_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@ def __init__(
2626
meta: SnowflakeFileMeta,
2727
stage_info: dict[str, Any],
2828
chunk_size: int,
29+
unsafe_file_write: bool = False,
2930
) -> None:
30-
super().__init__(meta, stage_info, chunk_size)
31+
super().__init__(
32+
meta, stage_info, chunk_size, unsafe_file_write=unsafe_file_write
33+
)
3134
self.data_file = meta.src_file_name
3235
self.full_dst_file_name: str = os.path.join(
3336
stage_info["location"], os.path.basename(meta.dst_file_name)

src/snowflake/connector/s3_storage_client.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,20 @@ def __init__(
6161
chunk_size: int,
6262
use_accelerate_endpoint: bool | None = None,
6363
use_s3_regional_url: bool = False,
64+
unsafe_file_write: bool = False,
6465
) -> None:
6566
"""Rest client for S3 storage.
6667
6768
Args:
6869
stage_info:
6970
"""
70-
super().__init__(meta, stage_info, chunk_size, credentials=credentials)
71+
super().__init__(
72+
meta,
73+
stage_info,
74+
chunk_size,
75+
credentials=credentials,
76+
unsafe_file_write=unsafe_file_write,
77+
)
7178
# Signature version V4
7279
# Addressing style Virtual Host
7380
self.region_name: str = stage_info["region"]

src/snowflake/connector/storage_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
chunked_transfer: bool | None = True,
7878
credentials: StorageCredential | None = None,
7979
max_retry: int = 5,
80+
unsafe_file_write: bool = False,
8081
) -> None:
8182
self.meta = meta
8283
self.stage_info = stage_info
@@ -115,6 +116,7 @@ def __init__(
115116
self.failed_transfers: int = 0
116117
# only used when PRESIGNED_URL expires
117118
self.last_err_is_presigned_url = False
119+
self.unsafe_file_write = unsafe_file_write
118120

119121
def compress(self) -> None:
120122
if self.meta.require_compress:
@@ -376,7 +378,7 @@ def finish_download(self) -> None:
376378
# For storage utils that do not have the privilege of
377379
# getting the metadata early, both object and metadata
378380
# are downloaded at once. In which case, the file meta will
379-
# be updated with all the metadata that we need and
381+
# be updated with all the metadata that we need, and
380382
# then we can call get_file_header to get just that and also
381383
# preserve the idea of getting metadata in the first place.
382384
# One example of this is the utils that use presigned url
@@ -390,6 +392,7 @@ def finish_download(self) -> None:
390392
meta.encryption_material,
391393
str(self.intermediate_dst_path),
392394
tmp_dir=self.tmp_dir,
395+
unsafe_file_write=self.unsafe_file_write,
393396
)
394397
shutil.move(tmp_dst_file_name, self.full_dst_file_name)
395398
self.intermediate_dst_path.unlink()

test/integ/test_put_get.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020

2121
from snowflake.connector import OperationalError
2222

23+
try:
24+
from src.snowflake.connector.compat import IS_WINDOWS
25+
except ImportError:
26+
import platform
27+
28+
IS_WINDOWS = platform.system() == "Windows"
29+
2330
try:
2431
from snowflake.connector.util_text import random_string
2532
except ImportError:
@@ -740,16 +747,44 @@ def test_get_empty_file(tmp_path, conn_cnx):
740747

741748

742749
@pytest.mark.skipolddriver
750+
@pytest.mark.skipif(IS_WINDOWS, reason="not supported on Windows")
743751
def test_get_file_permission(tmp_path, conn_cnx, caplog):
744752
test_file = tmp_path / "data.csv"
745753
test_file.write_text("1,2,3\n")
746-
stage_name = random_string(5, "test_get_empty_file_")
754+
stage_name = random_string(5, "test_get_file_permission_")
747755
with conn_cnx() as cnx:
748756
with cnx.cursor() as cur:
749757
cur.execute(f"create temporary stage {stage_name}")
750758
filename_in_put = str(test_file).replace("\\", "/")
751759
cur.execute(
752-
f"PUT 'file://{filename_in_put}' @{stage_name}",
760+
f"PUT 'file://{filename_in_put}' @{stage_name} AUTO_COMPRESS=FALSE",
761+
)
762+
763+
with caplog.at_level(logging.ERROR):
764+
cur.execute(f"GET @{stage_name}/data.csv file://{tmp_path}")
765+
assert "FileNotFoundError" not in caplog.text
766+
767+
default_mask = os.umask(0)
768+
os.umask(default_mask)
769+
770+
assert (
771+
oct(os.stat(test_file).st_mode)[-3:] == oct(0o600 & ~default_mask)[-3:]
772+
)
773+
774+
775+
@pytest.mark.skipolddriver
776+
@pytest.mark.skipif(IS_WINDOWS, reason="not supported on Windows")
777+
def test_get_unsafe_file_permission_when_flag_set(tmp_path, conn_cnx, caplog):
778+
test_file = tmp_path / "data.csv"
779+
test_file.write_text("1,2,3\n")
780+
stage_name = random_string(5, "test_get_file_permission_")
781+
with conn_cnx() as cnx:
782+
cnx.unsafe_file_write = True
783+
with cnx.cursor() as cur:
784+
cur.execute(f"create temporary stage {stage_name}")
785+
filename_in_put = str(test_file).replace("\\", "/")
786+
cur.execute(
787+
f"PUT 'file://{filename_in_put}' @{stage_name} AUTO_COMPRESS=FALSE",
753788
)
754789

755790
with caplog.at_level(logging.ERROR):
@@ -764,6 +799,7 @@ def test_get_file_permission(tmp_path, conn_cnx, caplog):
764799
assert (
765800
oct(os.stat(test_file).st_mode)[-3:] == oct(0o666 & ~default_mask)[-3:]
766801
)
802+
cnx.unsafe_file_write = False
767803

768804

769805
@pytest.mark.skipolddriver

0 commit comments

Comments
 (0)