Skip to content

Commit b4f5940

Browse files
sfc-gh-mkellersfc-gh-pczajka
authored andcommitted
SNOW-1817982 iobound tpe limiting (#2115)
1 parent 5e22018 commit b4f5940

File tree

5 files changed

+130
-3
lines changed

5 files changed

+130
-3
lines changed

src/snowflake/connector/connection.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,10 @@ def _get_private_bytes_from_file(
299299
False,
300300
bool,
301301
), # disable saml url check in okta authentication
302+
"iobound_tpe_limit": (
303+
None,
304+
(type(None), int),
305+
), # SNOW-1817982: limit iobound TPE sizes when executing PUT/GET
302306
}
303307

304308
APPLICATION_RE = re.compile(r"[\w\d_]+")
@@ -753,6 +757,10 @@ def auth_class(self, value: AuthByPlugin) -> None:
753757
def is_query_context_cache_disabled(self) -> bool:
754758
return self._disable_query_context_cache
755759

760+
@property
761+
def iobound_tpe_limit(self) -> int | None:
762+
return self._iobound_tpe_limit
763+
756764
def connect(self, **kwargs) -> None:
757765
"""Establishes connection to Snowflake."""
758766
logger.debug("connect")

src/snowflake/connector/cursor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,7 @@ def execute(
10571057
source_from_stream=file_stream,
10581058
multipart_threshold=data.get("threshold"),
10591059
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
1060+
iobound_tpe_limit=self._connection.iobound_tpe_limit,
10601061
)
10611062
sf_file_transfer_agent.execute()
10621063
data = sf_file_transfer_agent.result()

src/snowflake/connector/file_transfer_agent.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ def __init__(
354354
multipart_threshold: int | None = None,
355355
source_from_stream: IO[bytes] | None = None,
356356
use_s3_regional_url: bool = False,
357+
iobound_tpe_limit: int | None = None,
357358
) -> None:
358359
self._cursor = cursor
359360
self._command = command
@@ -384,6 +385,7 @@ def __init__(
384385
self._multipart_threshold = multipart_threshold or 67108864 # Historical value
385386
self._use_s3_regional_url = use_s3_regional_url
386387
self._credentials: StorageCredential | None = None
388+
self._iobound_tpe_limit = iobound_tpe_limit
387389

388390
def execute(self) -> None:
389391
self._parse_command()
@@ -440,10 +442,15 @@ def execute(self) -> None:
440442
result.result_status = result.result_status.value
441443

442444
def transfer(self, metas: list[SnowflakeFileMeta]) -> None:
445+
iobound_tpe_limit = min(len(metas), os.cpu_count())
446+
logger.debug("Decided IO-bound TPE size: %d", iobound_tpe_limit)
447+
if self._iobound_tpe_limit is not None:
448+
logger.debug("IO-bound TPE size is limited to: %d", self._iobound_tpe_limit)
449+
iobound_tpe_limit = min(iobound_tpe_limit, self._iobound_tpe_limit)
443450
max_concurrency = self._parallel
444451
network_tpe = ThreadPoolExecutor(max_concurrency)
445-
preprocess_tpe = ThreadPoolExecutor(min(len(metas), os.cpu_count()))
446-
postprocess_tpe = ThreadPoolExecutor(min(len(metas), os.cpu_count()))
452+
preprocess_tpe = ThreadPoolExecutor(iobound_tpe_limit)
453+
postprocess_tpe = ThreadPoolExecutor(iobound_tpe_limit)
447454
logger.debug(f"Chunk ThreadPoolExecutor size: {max_concurrency}")
448455
cv_main_thread = threading.Condition() # to signal the main thread
449456
cv_chunk_process = (
@@ -454,6 +461,9 @@ def transfer(self, metas: list[SnowflakeFileMeta]) -> None:
454461
transfer_metadata = TransferMetadata() # this is protected by cv_chunk_process
455462
is_upload = self._command_type == CMD_TYPE_UPLOAD
456463
exception_caught_in_callback: Exception | None = None
464+
logger.debug(
465+
"Going to %sload %d files", "up" if is_upload else "down", len(metas)
466+
)
457467

458468
def notify_file_completed() -> None:
459469
# Increment the number of completed files, then notify the main thread.

test/integ/test_put_get.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,3 +791,21 @@ def test_get_multiple_files_with_same_name(tmp_path, conn_cnx, caplog):
791791
# This is expected flakiness
792792
pass
793793
assert "Downloading multiple files with the same name" in caplog.text
794+
795+
796+
@pytest.mark.skipolddriver
797+
def test_iobound_limit(tmp_path, conn_cnx, caplog):
798+
tmp_stage_name = random_string(5, "test_iobound_limit")
799+
file0 = tmp_path / "file0"
800+
file1 = tmp_path / "file1"
801+
file0.touch()
802+
file1.touch()
803+
with conn_cnx(iobound_tpe_limit=1) as conn:
804+
with conn.cursor() as cur:
805+
cur.execute(f"create temp stage {tmp_stage_name}")
806+
with caplog.at_level(
807+
logging.DEBUG, "snowflake.connector.file_transfer_agent"
808+
):
809+
cur.execute(f"put file://{tmp_path}/* @{tmp_stage_name}")
810+
assert "Decided IO-bound TPE size: 2" in caplog.text
811+
assert "IO-bound TPE size is limited to: 1" in caplog.text

test/unit/test_put_get.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def test_percentage(tmp_path):
125125
func_callback(1)
126126

127127

128-
@pytest.mark.skipolddriver
129128
def test_upload_file_with_azure_upload_failed_error(tmp_path):
130129
"""Tests Upload file with expired Azure storage token."""
131130
file1 = tmp_path / "file1"
@@ -166,3 +165,94 @@ def test_upload_file_with_azure_upload_failed_error(tmp_path):
166165
rest_client.execute()
167166
assert mock_update.called
168167
assert rest_client._results[0].error_details is exc
168+
169+
170+
def test_iobound_limit(tmp_path):
171+
file1 = tmp_path / "file1"
172+
file2 = tmp_path / "file2"
173+
file3 = tmp_path / "file3"
174+
file1.touch()
175+
file2.touch()
176+
file3.touch()
177+
# Positive case
178+
rest_client = SnowflakeFileTransferAgent(
179+
mock.MagicMock(autospec=SnowflakeCursor),
180+
"PUT some_file.txt",
181+
{
182+
"data": {
183+
"command": "UPLOAD",
184+
"src_locations": [file1, file2, file3],
185+
"sourceCompression": "none",
186+
"stageInfo": {
187+
"creds": {
188+
"AZURE_SAS_TOKEN": "sas_token",
189+
},
190+
"location": "some_bucket",
191+
"region": "no_region",
192+
"locationType": "AZURE",
193+
"path": "remote_loc",
194+
"endPoint": "",
195+
"storageAccount": "storage_account",
196+
},
197+
},
198+
"success": True,
199+
},
200+
)
201+
with mock.patch(
202+
"snowflake.connector.file_transfer_agent.ThreadPoolExecutor"
203+
) as tpe:
204+
with mock.patch("snowflake.connector.file_transfer_agent.threading.Condition"):
205+
with mock.patch(
206+
"snowflake.connector.file_transfer_agent.TransferMetadata",
207+
return_value=mock.Mock(
208+
num_files_started=0,
209+
num_files_completed=3,
210+
),
211+
):
212+
try:
213+
rest_client.execute()
214+
except AttributeError:
215+
pass
216+
# 2 IObound TPEs should be created for 3 files unlimited
217+
rest_client = SnowflakeFileTransferAgent(
218+
mock.MagicMock(autospec=SnowflakeCursor),
219+
"PUT some_file.txt",
220+
{
221+
"data": {
222+
"command": "UPLOAD",
223+
"src_locations": [file1, file2, file3],
224+
"sourceCompression": "none",
225+
"stageInfo": {
226+
"creds": {
227+
"AZURE_SAS_TOKEN": "sas_token",
228+
},
229+
"location": "some_bucket",
230+
"region": "no_region",
231+
"locationType": "AZURE",
232+
"path": "remote_loc",
233+
"endPoint": "",
234+
"storageAccount": "storage_account",
235+
},
236+
},
237+
"success": True,
238+
},
239+
iobound_tpe_limit=2,
240+
)
241+
assert len(list(filter(lambda e: e.args == (3,), tpe.call_args_list))) == 2
242+
with mock.patch(
243+
"snowflake.connector.file_transfer_agent.ThreadPoolExecutor"
244+
) as tpe:
245+
with mock.patch("snowflake.connector.file_transfer_agent.threading.Condition"):
246+
with mock.patch(
247+
"snowflake.connector.file_transfer_agent.TransferMetadata",
248+
return_value=mock.Mock(
249+
num_files_started=0,
250+
num_files_completed=3,
251+
),
252+
):
253+
try:
254+
rest_client.execute()
255+
except AttributeError:
256+
pass
257+
# 2 IObound TPEs should be created for 3 files limited to 2
258+
assert len(list(filter(lambda e: e.args == (2,), tpe.call_args_list))) == 2

0 commit comments

Comments
 (0)