Skip to content

Commit 9d7e491

Browse files
authored
SNOW-892716 Improve localfs put/get test (#1786)
* SNOW-892716 Improve localfs put/get test Description This is a long overdue improvement from early slack message from Brandon Testing unit test * comment
1 parent d80f72e commit 9d7e491

File tree

1 file changed

+13
-30
lines changed

1 file changed

+13
-30
lines changed

test/unit/test_local_storage_client.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
33
#
44

5-
import os
65
import random
7-
import shutil
86
import string
97
import tempfile
8+
from pathlib import Path
109

1110
import pytest
1211

@@ -26,18 +25,14 @@ def test_multi_chunk_upload(multipart_threshold):
2625
[random.choice(string.ascii_letters) for _ in range(300)]
2726
).encode()
2827
file_name = "test_file"
29-
stage_dir = tempfile.mkdtemp()
30-
stage_file = os.path.join(stage_dir, file_name)
31-
local_dir = tempfile.mkdtemp()
32-
local_file = os.path.join(local_dir, file_name)
33-
34-
try:
35-
with open(local_file, "wb+") as fd:
36-
fd.write(file_content)
28+
with tempfile.TemporaryDirectory() as stage_dir, tempfile.TemporaryDirectory() as local_dir:
29+
stage_file = Path(stage_dir) / file_name
30+
local_file = Path(local_dir) / file_name
31+
Path(local_file).write_bytes(file_content)
3732

3833
meta = SnowflakeFileMeta(
3934
name=file_name,
40-
src_file_name=local_file,
35+
src_file_name=str(local_file),
4136
stage_location_type=LOCAL_FS,
4237
dst_file_name=file_name,
4338
multipart_threshold=multipart_threshold,
@@ -47,11 +42,7 @@ def test_multi_chunk_upload(multipart_threshold):
4742
for chunk_id in range(client.num_of_chunks):
4843
client.upload_chunk(chunk_id)
4944

50-
with open(stage_file, "rb") as fd:
51-
assert fd.read() == file_content
52-
finally:
53-
shutil.rmtree(stage_dir, ignore_errors=True)
54-
shutil.rmtree(local_dir, ignore_errors=True)
45+
assert Path(stage_file).read_bytes() == file_content
5546

5647

5748
@pytest.mark.parametrize("multipart_threshold", [0, 67108864])
@@ -60,18 +51,14 @@ def test_multi_chunk_download(multipart_threshold):
6051
[random.choice(string.ascii_letters) for _ in range(300)]
6152
).encode()
6253
file_name = "test_file"
63-
stage_dir = tempfile.mkdtemp()
64-
stage_file = os.path.join(stage_dir, file_name)
65-
local_dir = tempfile.mkdtemp()
66-
local_file = os.path.join(local_dir, file_name)
67-
68-
try:
69-
with open(stage_file, "wb+") as fd:
70-
fd.write(file_content)
54+
with tempfile.TemporaryDirectory() as stage_dir, tempfile.TemporaryDirectory() as local_dir:
55+
stage_file = Path(stage_dir) / file_name
56+
local_file = Path(local_dir) / file_name
57+
Path(stage_file).write_bytes(file_content)
7158

7259
meta = SnowflakeFileMeta(
7360
name=file_name,
74-
src_file_name=stage_file,
61+
src_file_name=str(stage_file),
7562
stage_location_type=LOCAL_FS,
7663
dst_file_name=file_name,
7764
local_location=local_dir,
@@ -83,8 +70,4 @@ def test_multi_chunk_download(multipart_threshold):
8370
client.download_chunk(chunk_id)
8471
client.finish_download()
8572

86-
with open(local_file, "rb") as fd:
87-
assert fd.read() == file_content
88-
finally:
89-
shutil.rmtree(stage_dir, ignore_errors=True)
90-
shutil.rmtree(local_dir, ignore_errors=True)
73+
assert Path(local_file).read_bytes() == file_content

0 commit comments

Comments
 (0)