Skip to content

Commit 266baa2

Browse files
Apply #2184 to async code
1 parent 6ac9d49 commit 266baa2

File tree

8 files changed

+31
-5
lines changed

8 files changed

+31
-5
lines changed

src/snowflake/connector/aio/_azure_storage_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,15 @@ def __init__(
4747
chunk_size: int,
4848
stage_info: dict[str, Any],
4949
use_s3_regional_url: bool = False,
50+
unsafe_file_write: bool = False,
5051
) -> None:
5152
SnowflakeAzureRestClientSync.__init__(
5253
self,
5354
meta=meta,
5455
stage_info=stage_info,
5556
chunk_size=chunk_size,
5657
credentials=credentials,
58+
unsafe_file_write=unsafe_file_write,
5759
)
5860

5961
async def _has_expired_token(self, response: aiohttp.ClientResponse) -> bool:

src/snowflake/connector/aio/_connection.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,10 @@ def _init_connection_parameters(
492492
for name, (value, _) in DEFAULT_CONFIGURATION.items():
493493
setattr(self, f"_{name}", value)
494494

495+
# Initialize unsafe_file_write explicitly since it's not in DEFAULT_CONFIGURATION
496+
# TODO SNOW-2207863
497+
self._unsafe_file_write = False
498+
495499
self._heartbeat_task = None
496500
is_kwargs_empty = not connection_init_kwargs
497501

@@ -704,6 +708,14 @@ def errorhandler(self, value) -> None:
704708
def rest(self) -> SnowflakeRestful | None:
705709
return self._rest
706710

711+
@property
712+
def unsafe_file_write(self) -> bool:
713+
return self._unsafe_file_write
714+
715+
@unsafe_file_write.setter
716+
def unsafe_file_write(self, value: bool) -> None:
717+
self._unsafe_file_write = value
718+
707719
async def authenticate_with_retry(self, auth_instance) -> None:
708720
# make some changes if needed before real __authenticate
709721
try:

src/snowflake/connector/aio/_cursor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,7 @@ async def execute(
661661
source_from_stream=file_stream,
662662
multipart_threshold=data.get("threshold"),
663663
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
664+
unsafe_file_write=self._connection.unsafe_file_write,
664665
)
665666
await sf_file_transfer_agent.execute()
666667
data = sf_file_transfer_agent.result()

src/snowflake/connector/aio/_file_transfer_agent.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(
6262
multipart_threshold: int | None = None,
6363
source_from_stream: IO[bytes] | None = None,
6464
use_s3_regional_url: bool = False,
65+
unsafe_file_write: bool = False,
6566
) -> None:
6667
super().__init__(
6768
cursor,
@@ -80,6 +81,7 @@ def __init__(
8081
multipart_threshold,
8182
source_from_stream,
8283
use_s3_regional_url,
84+
unsafe_file_write,
8385
)
8486

8587
async def execute(self) -> None:
@@ -271,6 +273,7 @@ async def _create_file_transfer_client(
271273
meta,
272274
self._stage_info,
273275
4 * megabyte,
276+
unsafe_file_write=self._unsafe_file_write,
274277
)
275278
elif self._stage_location_type == AZURE_FS:
276279
return SnowflakeAzureRestClient(
@@ -279,6 +282,7 @@ async def _create_file_transfer_client(
279282
AZURE_CHUNK_SIZE,
280283
self._stage_info,
281284
use_s3_regional_url=self._use_s3_regional_url,
285+
unsafe_file_write=self._unsafe_file_write,
282286
)
283287
elif self._stage_location_type == S3_FS:
284288
client = SnowflakeS3RestClient(
@@ -288,6 +292,7 @@ async def _create_file_transfer_client(
288292
chunk_size=_chunk_size_calculator(meta.src_file_size),
289293
use_accelerate_endpoint=self._use_accelerate_endpoint,
290294
use_s3_regional_url=self._use_s3_regional_url,
295+
unsafe_file_write=self._unsafe_file_write,
291296
)
292297
await client.transfer_accelerate_config(self._use_accelerate_endpoint)
293298
return client
@@ -299,6 +304,7 @@ async def _create_file_transfer_client(
299304
self._cursor._connection,
300305
self._command,
301306
use_s3_regional_url=self._use_s3_regional_url,
307+
unsafe_file_write=self._unsafe_file_write,
302308
)
303309
if client.security_token:
304310
logger.debug(f"len(GCS_ACCESS_TOKEN): {len(client.security_token)}")

src/snowflake/connector/aio/_gcs_storage_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
cnx: SnowflakeConnection,
4040
command: str,
4141
use_s3_regional_url: bool = False,
42+
unsafe_file_write: bool = False,
4243
) -> None:
4344
"""Creates a client object with given stage credentials.
4445
@@ -55,6 +56,7 @@ def __init__(
5556
chunk_size=-1,
5657
credentials=credentials,
5758
chunked_transfer=False,
59+
unsafe_file_write=unsafe_file_write,
5860
)
5961
self.stage_info = stage_info
6062
self._command = command

src/snowflake/connector/aio/_s3_storage_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
chunk_size: int,
4949
use_accelerate_endpoint: bool | None = None,
5050
use_s3_regional_url: bool = False,
51+
unsafe_file_write: bool = False,
5152
) -> None:
5253
"""Rest client for S3 storage.
5354
@@ -60,6 +61,7 @@ def __init__(
6061
stage_info=stage_info,
6162
chunk_size=chunk_size,
6263
credentials=credentials,
64+
unsafe_file_write=unsafe_file_write,
6365
)
6466
# Signature version V4
6567
# Addressing style Virtual Host

src/snowflake/connector/aio/_storage_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
chunked_transfer: bool | None = True,
3838
credentials: StorageCredential | None = None,
3939
max_retry: int = 5,
40+
unsafe_file_write: bool = False,
4041
) -> None:
4142
SnowflakeStorageClientSync.__init__(
4243
self,
@@ -46,6 +47,7 @@ def __init__(
4647
chunked_transfer=chunked_transfer,
4748
credentials=credentials,
4849
max_retry=max_retry,
50+
unsafe_file_write=unsafe_file_write,
4951
)
5052

5153
@abstractmethod

test/integ/aio/test_put_get_async.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,13 @@ async def test_get_empty_file(tmp_path, aio_connection):
157157
async def test_get_file_permission(tmp_path, aio_connection, caplog):
158158
test_file = tmp_path / "data.csv"
159159
test_file.write_text("1,2,3\n")
160-
stage_name = random_string(5, "test_get_empty_file_")
160+
stage_name = random_string(5, "test_get_file_permission_")
161161
await aio_connection.connect()
162162
cur = aio_connection.cursor()
163163
await cur.execute(f"create temporary stage {stage_name}")
164164
filename_in_put = str(test_file).replace("\\", "/")
165165
await cur.execute(
166-
f"PUT 'file://{filename_in_put}' @{stage_name}",
166+
f"PUT 'file://{filename_in_put}' @{stage_name} AUTO_COMPRESS=FALSE",
167167
)
168168

169169
with caplog.at_level(logging.ERROR):
@@ -173,9 +173,8 @@ async def test_get_file_permission(tmp_path, aio_connection, caplog):
173173
# get the default mask, usually it is 0o022
174174
default_mask = os.umask(0)
175175
os.umask(default_mask)
176-
# files by default are given the permission 644 (Octal)
177-
# umask is for denial, we need to negate
178-
assert oct(os.stat(test_file).st_mode)[-3:] == oct(0o666 & ~default_mask)[-3:]
176+
# files should be created with 0o600 permissions (owner read/write only)
177+
assert oct(os.stat(test_file).st_mode)[-3:] == oct(0o600 & ~default_mask)[-3:]
179178

180179

181180
async def test_get_multiple_files_with_same_name(tmp_path, aio_connection, caplog):

0 commit comments

Comments
 (0)