Skip to content

Commit 8ab954c

Browse files
sfc-gh-mkellersfc-gh-pczajka
authored andcommitted
SNOW-1817982 iobound tpe limiting (#2115)
1 parent b2c73f8 commit 8ab954c

File tree

6 files changed

+133
-3
lines changed

6 files changed

+133
-3
lines changed

DESCRIPTION.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
4646
- Added a feature to verify if the connection is still good enough to send queries over.
4747
- Added support for base64-encoded DER private key strings in the `private_key` authentication type.
4848

49+
- v3.12.5(TBD)
50+
- Added a feature to limit the sizes of IO-bound ThreadPoolExecutors during PUT and GET commands.
51+
4952
- v3.12.4(December 3,2024)
5053
- Fixed a bug where multipart uploads to Azure would be missing their MD5 hashes.
5154
- Fixed a bug where OpenTelemetry header injection would sometimes cause Exceptions to be thrown.

src/snowflake/connector/connection.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,10 @@ def _get_private_bytes_from_file(
301301
False,
302302
bool,
303303
), # disable saml url check in okta authentication
304+
"iobound_tpe_limit": (
305+
None,
306+
(type(None), int),
307+
), # SNOW-1817982: limit iobound TPE sizes when executing PUT/GET
304308
}
305309

306310
APPLICATION_RE = re.compile(r"[\w\d_]+")
@@ -755,6 +759,10 @@ def auth_class(self, value: AuthByPlugin) -> None:
755759
def is_query_context_cache_disabled(self) -> bool:
756760
return self._disable_query_context_cache
757761

762+
@property
763+
def iobound_tpe_limit(self) -> int | None:
764+
return self._iobound_tpe_limit
765+
758766
def connect(self, **kwargs) -> None:
759767
"""Establishes connection to Snowflake."""
760768
logger.debug("connect")

src/snowflake/connector/cursor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,7 @@ def execute(
10591059
source_from_stream=file_stream,
10601060
multipart_threshold=data.get("threshold"),
10611061
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
1062+
iobound_tpe_limit=self._connection.iobound_tpe_limit,
10621063
)
10631064
sf_file_transfer_agent.execute()
10641065
data = sf_file_transfer_agent.result()

src/snowflake/connector/file_transfer_agent.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ def __init__(
354354
multipart_threshold: int | None = None,
355355
source_from_stream: IO[bytes] | None = None,
356356
use_s3_regional_url: bool = False,
357+
iobound_tpe_limit: int | None = None,
357358
) -> None:
358359
self._cursor = cursor
359360
self._command = command
@@ -384,6 +385,7 @@ def __init__(
384385
self._multipart_threshold = multipart_threshold or 67108864 # Historical value
385386
self._use_s3_regional_url = use_s3_regional_url
386387
self._credentials: StorageCredential | None = None
388+
self._iobound_tpe_limit = iobound_tpe_limit
387389

388390
def execute(self) -> None:
389391
self._parse_command()
@@ -440,10 +442,15 @@ def execute(self) -> None:
440442
result.result_status = result.result_status.value
441443

442444
def transfer(self, metas: list[SnowflakeFileMeta]) -> None:
445+
iobound_tpe_limit = min(len(metas), os.cpu_count())
446+
logger.debug("Decided IO-bound TPE size: %d", iobound_tpe_limit)
447+
if self._iobound_tpe_limit is not None:
448+
logger.debug("IO-bound TPE size is limited to: %d", self._iobound_tpe_limit)
449+
iobound_tpe_limit = min(iobound_tpe_limit, self._iobound_tpe_limit)
443450
max_concurrency = self._parallel
444451
network_tpe = ThreadPoolExecutor(max_concurrency)
445-
preprocess_tpe = ThreadPoolExecutor(min(len(metas), os.cpu_count()))
446-
postprocess_tpe = ThreadPoolExecutor(min(len(metas), os.cpu_count()))
452+
preprocess_tpe = ThreadPoolExecutor(iobound_tpe_limit)
453+
postprocess_tpe = ThreadPoolExecutor(iobound_tpe_limit)
447454
logger.debug(f"Chunk ThreadPoolExecutor size: {max_concurrency}")
448455
cv_main_thread = threading.Condition() # to signal the main thread
449456
cv_chunk_process = (
@@ -454,6 +461,9 @@ def transfer(self, metas: list[SnowflakeFileMeta]) -> None:
454461
transfer_metadata = TransferMetadata() # this is protected by cv_chunk_process
455462
is_upload = self._command_type == CMD_TYPE_UPLOAD
456463
exception_caught_in_callback: Exception | None = None
464+
logger.debug(
465+
"Going to %sload %d files", "up" if is_upload else "down", len(metas)
466+
)
457467

458468
def notify_file_completed() -> None:
459469
# Increment the number of completed files, then notify the main thread.

test/integ/test_put_get.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,3 +827,21 @@ def test_put_md5(tmp_path, conn_cnx):
827827
cur.execute(f"LS @{stage_name}").fetchall(),
828828
)
829829
)
830+
831+
832+
@pytest.mark.skipolddriver
833+
def test_iobound_limit(tmp_path, conn_cnx, caplog):
834+
tmp_stage_name = random_string(5, "test_iobound_limit")
835+
file0 = tmp_path / "file0"
836+
file1 = tmp_path / "file1"
837+
file0.touch()
838+
file1.touch()
839+
with conn_cnx(iobound_tpe_limit=1) as conn:
840+
with conn.cursor() as cur:
841+
cur.execute(f"create temp stage {tmp_stage_name}")
842+
with caplog.at_level(
843+
logging.DEBUG, "snowflake.connector.file_transfer_agent"
844+
):
845+
cur.execute(f"put file://{tmp_path}/* @{tmp_stage_name}")
846+
assert "Decided IO-bound TPE size: 2" in caplog.text
847+
assert "IO-bound TPE size is limited to: 1" in caplog.text

test/unit/test_put_get.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def test_percentage(tmp_path):
125125
func_callback(1)
126126

127127

128-
@pytest.mark.skipolddriver
129128
def test_upload_file_with_azure_upload_failed_error(tmp_path):
130129
"""Tests Upload file with expired Azure storage token."""
131130
file1 = tmp_path / "file1"
@@ -166,3 +165,94 @@ def test_upload_file_with_azure_upload_failed_error(tmp_path):
166165
rest_client.execute()
167166
assert mock_update.called
168167
assert rest_client._results[0].error_details is exc
168+
169+
170+
def test_iobound_limit(tmp_path):
171+
file1 = tmp_path / "file1"
172+
file2 = tmp_path / "file2"
173+
file3 = tmp_path / "file3"
174+
file1.touch()
175+
file2.touch()
176+
file3.touch()
177+
# Positive case
178+
rest_client = SnowflakeFileTransferAgent(
179+
mock.MagicMock(autospec=SnowflakeCursor),
180+
"PUT some_file.txt",
181+
{
182+
"data": {
183+
"command": "UPLOAD",
184+
"src_locations": [file1, file2, file3],
185+
"sourceCompression": "none",
186+
"stageInfo": {
187+
"creds": {
188+
"AZURE_SAS_TOKEN": "sas_token",
189+
},
190+
"location": "some_bucket",
191+
"region": "no_region",
192+
"locationType": "AZURE",
193+
"path": "remote_loc",
194+
"endPoint": "",
195+
"storageAccount": "storage_account",
196+
},
197+
},
198+
"success": True,
199+
},
200+
)
201+
with mock.patch(
202+
"snowflake.connector.file_transfer_agent.ThreadPoolExecutor"
203+
) as tpe:
204+
with mock.patch("snowflake.connector.file_transfer_agent.threading.Condition"):
205+
with mock.patch(
206+
"snowflake.connector.file_transfer_agent.TransferMetadata",
207+
return_value=mock.Mock(
208+
num_files_started=0,
209+
num_files_completed=3,
210+
),
211+
):
212+
try:
213+
rest_client.execute()
214+
except AttributeError:
215+
pass
216+
# 2 IObound TPEs should be created for 3 files unlimited
217+
rest_client = SnowflakeFileTransferAgent(
218+
mock.MagicMock(autospec=SnowflakeCursor),
219+
"PUT some_file.txt",
220+
{
221+
"data": {
222+
"command": "UPLOAD",
223+
"src_locations": [file1, file2, file3],
224+
"sourceCompression": "none",
225+
"stageInfo": {
226+
"creds": {
227+
"AZURE_SAS_TOKEN": "sas_token",
228+
},
229+
"location": "some_bucket",
230+
"region": "no_region",
231+
"locationType": "AZURE",
232+
"path": "remote_loc",
233+
"endPoint": "",
234+
"storageAccount": "storage_account",
235+
},
236+
},
237+
"success": True,
238+
},
239+
iobound_tpe_limit=2,
240+
)
241+
assert len(list(filter(lambda e: e.args == (3,), tpe.call_args_list))) == 2
242+
with mock.patch(
243+
"snowflake.connector.file_transfer_agent.ThreadPoolExecutor"
244+
) as tpe:
245+
with mock.patch("snowflake.connector.file_transfer_agent.threading.Condition"):
246+
with mock.patch(
247+
"snowflake.connector.file_transfer_agent.TransferMetadata",
248+
return_value=mock.Mock(
249+
num_files_started=0,
250+
num_files_completed=3,
251+
),
252+
):
253+
try:
254+
rest_client.execute()
255+
except AttributeError:
256+
pass
257+
# 2 IObound TPEs should be created for 3 files limited to 2
258+
assert len(list(filter(lambda e: e.args == (2,), tpe.call_args_list))) == 2

0 commit comments

Comments
 (0)