Skip to content

Commit 259074f

Browse files
authored
SNOW-2173966 introduce server DoP cap (#2375)
1 parent 411c973 commit 259074f

File tree

6 files changed

+162
-2
lines changed

6 files changed

+162
-2
lines changed

src/snowflake/connector/_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@ class TempObjectType(Enum):
3232

3333
REQUEST_ID_STATEMENT_PARAM_NAME = "requestId"
3434

35+
# Default server side cap on Degree of Parallelism for file transfer
36+
# This default value is set to 2^30 (~ 10^9), such that it will not
37+
# throttle regular sessions.
38+
_DEFAULT_VALUE_SERVER_DOP_CAP_FOR_FILE_TRANSFER = 1 << 30
39+
# Variable name of server DoP cap for file transfer
40+
_VARIABLE_NAME_SERVER_DOP_CAP_FOR_FILE_TRANSFER = (
41+
"snowflake_server_dop_cap_for_file_transfer"
42+
)
43+
3544

3645
def generate_random_alphanumeric(length: int = 10) -> str:
3746
return "".join(choice(ALPHANUMERIC) for _ in range(length))
@@ -60,6 +69,15 @@ def is_uuid4(str_or_uuid: str | UUID) -> bool:
6069
return uuid_str == str_or_uuid
6170

6271

72+
def _snowflake_max_parallelism_for_file_transfer(connection):
73+
"""Returns the server side cap on max parallelism for file transfer for the given connection."""
74+
return getattr(
75+
connection,
76+
f"_{_VARIABLE_NAME_SERVER_DOP_CAP_FOR_FILE_TRANSFER}",
77+
_DEFAULT_VALUE_SERVER_DOP_CAP_FOR_FILE_TRANSFER,
78+
)
79+
80+
6381
class _TrackedQueryCancellationTimer(Timer):
6482
def __init__(self, interval, function, args=None, kwargs=None):
6583
super().__init__(interval, function, args, kwargs)

src/snowflake/connector/connection.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929

3030
from . import errors, proxy
3131
from ._query_context_cache import QueryContextCache
32+
from ._utils import (
33+
_DEFAULT_VALUE_SERVER_DOP_CAP_FOR_FILE_TRANSFER,
34+
_VARIABLE_NAME_SERVER_DOP_CAP_FOR_FILE_TRANSFER,
35+
)
3236
from .auth import (
3337
FIRST_PARTY_AUTHENTICATORS,
3438
Auth,
@@ -369,6 +373,10 @@ def _get_private_bytes_from_file(
369373
str,
370374
# SNOW-2096721: External (Spark) session ID
371375
),
376+
_VARIABLE_NAME_SERVER_DOP_CAP_FOR_FILE_TRANSFER: (
377+
_DEFAULT_VALUE_SERVER_DOP_CAP_FOR_FILE_TRANSFER, # default value
378+
int, # type
379+
), # snowflake internal
372380
}
373381

374382
APPLICATION_RE = re.compile(r"[\w\d_]+")

src/snowflake/connector/cursor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from ._sql_util import get_file_transfer_type
3838
from ._utils import (
3939
REQUEST_ID_STATEMENT_PARAM_NAME,
40+
_snowflake_max_parallelism_for_file_transfer,
4041
_TrackedQueryCancellationTimer,
4142
is_uuid4,
4243
)
@@ -1086,6 +1087,9 @@ def execute(
10861087
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
10871088
iobound_tpe_limit=self._connection.iobound_tpe_limit,
10881089
unsafe_file_write=self._connection.unsafe_file_write,
1090+
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
1091+
self._connection
1092+
),
10891093
)
10901094
sf_file_transfer_agent.execute()
10911095
data = sf_file_transfer_agent.result()
@@ -1800,6 +1804,9 @@ def _download(
18001804
self,
18011805
"", # empty command because it is triggered by directly calling this util not by a SQL query
18021806
ret,
1807+
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
1808+
self._connection
1809+
),
18031810
)
18041811
file_transfer_agent.execute()
18051812
self._init_result_and_meta(file_transfer_agent.result())
@@ -1840,6 +1847,9 @@ def _upload(
18401847
"", # empty command because it is triggered by directly calling this util not by a SQL query
18411848
ret,
18421849
force_put_overwrite=False, # _upload should respect user decision on overwriting
1850+
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
1851+
self._connection
1852+
),
18431853
)
18441854
file_transfer_agent.execute()
18451855
self._init_result_and_meta(file_transfer_agent.result())
@@ -1908,6 +1918,9 @@ def _upload_stream(
19081918
ret,
19091919
source_from_stream=input_stream,
19101920
force_put_overwrite=False, # _upload_stream should respect user decision on overwriting
1921+
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
1922+
self._connection
1923+
),
19111924
)
19121925
file_transfer_agent.execute()
19131926
self._init_result_and_meta(file_transfer_agent.result())

src/snowflake/connector/file_transfer_agent.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from time import time
1616
from typing import IO, TYPE_CHECKING, Any, Callable, TypeVar
1717

18+
from ._utils import _DEFAULT_VALUE_SERVER_DOP_CAP_FOR_FILE_TRANSFER
1819
from .azure_storage_client import SnowflakeAzureRestClient
1920
from .compat import IS_WINDOWS
2021
from .constants import (
@@ -355,6 +356,7 @@ def __init__(
355356
use_s3_regional_url: bool = False,
356357
iobound_tpe_limit: int | None = None,
357358
unsafe_file_write: bool = False,
359+
snowflake_server_dop_cap_for_file_transfer=_DEFAULT_VALUE_SERVER_DOP_CAP_FOR_FILE_TRANSFER,
358360
) -> None:
359361
self._cursor = cursor
360362
self._command = command
@@ -387,6 +389,9 @@ def __init__(
387389
self._credentials: StorageCredential | None = None
388390
self._iobound_tpe_limit = iobound_tpe_limit
389391
self._unsafe_file_write = unsafe_file_write
392+
self._snowflake_server_dop_cap_for_file_transfer = (
393+
snowflake_server_dop_cap_for_file_transfer
394+
)
390395

391396
def execute(self) -> None:
392397
self._parse_command()
@@ -443,12 +448,16 @@ def execute(self) -> None:
443448
result.result_status = result.result_status.value
444449

445450
def transfer(self, metas: list[SnowflakeFileMeta]) -> None:
446-
iobound_tpe_limit = min(len(metas), os.cpu_count())
451+
iobound_tpe_limit = min(
452+
len(metas), os.cpu_count(), self._snowflake_server_dop_cap_for_file_transfer
453+
)
447454
logger.debug("Decided IO-bound TPE size: %d", iobound_tpe_limit)
448455
if self._iobound_tpe_limit is not None:
449456
logger.debug("IO-bound TPE size is limited to: %d", self._iobound_tpe_limit)
450457
iobound_tpe_limit = min(iobound_tpe_limit, self._iobound_tpe_limit)
451-
max_concurrency = self._parallel
458+
max_concurrency = min(
459+
self._parallel, self._snowflake_server_dop_cap_for_file_transfer
460+
)
452461
network_tpe = ThreadPoolExecutor(max_concurrency)
453462
preprocess_tpe = ThreadPoolExecutor(iobound_tpe_limit)
454463
postprocess_tpe = ThreadPoolExecutor(iobound_tpe_limit)

test/unit/test_cursor.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,60 @@ def _setup_mocks(self, MockFileTransferAgent):
189189
cursor.reset = MagicMock()
190190
cursor._init_result_and_meta = MagicMock()
191191
return cursor, fake_conn, mock_file_transfer_agent_instance
192+
193+
def _run_dop_cap_test(self, task, dop_cap):
194+
"""A helper to run dop cap test.
195+
196+
It mainly verifies that when performing the specified task, we are using a FileTransferAgent with DoP cap as specified.
197+
"""
198+
from snowflake.connector._utils import (
199+
_VARIABLE_NAME_SERVER_DOP_CAP_FOR_FILE_TRANSFER,
200+
)
201+
202+
mock_conn = FakeConnection()
203+
setattr(
204+
mock_conn, f"_{_VARIABLE_NAME_SERVER_DOP_CAP_FOR_FILE_TRANSFER}", dop_cap
205+
)
206+
207+
class FakeFileOperationParser:
208+
def parse_file_operation(
209+
self,
210+
stage_location,
211+
local_file_name,
212+
target_directory,
213+
command_type,
214+
options,
215+
has_source_from_stream=False,
216+
):
217+
return {}
218+
219+
mock_cursor = SnowflakeCursor(mock_conn)
220+
mock_conn._file_operation_parser = FakeFileOperationParser()
221+
with patch.object(
222+
mock_cursor, "_init_result_and_meta", return_value=None
223+
), patch(
224+
"snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent"
225+
) as MockFileTransferAgent:
226+
task(mock_cursor)
227+
# Verify that when running the file operation, we are using FileTransferAgent with server DoP cap as 1.
228+
_, kwargs = MockFileTransferAgent.call_args
229+
assert dop_cap == kwargs["snowflake_server_dop_cap_for_file_transfer"]
230+
231+
def test_dop_cap_for_upload(self):
232+
def task(cursor):
233+
cursor._upload("/tmp/test.txt", "@st", {})
234+
235+
self._run_dop_cap_test(task, dop_cap=1)
236+
237+
def test_dop_cap_for_upload_stream(self):
238+
def task(cursor):
239+
mock_input_stream = MagicMock()
240+
cursor._upload_stream(mock_input_stream, "@st", {})
241+
242+
self._run_dop_cap_test(task, dop_cap=1)
243+
244+
def test_dop_cap_for_download(self):
245+
def task(cursor):
246+
cursor._download("@st", "/tmp", {})
247+
248+
self._run_dop_cap_test(task, dop_cap=1)

test/unit/test_put_get.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,58 @@ def test_strip_stage_prefix_from_dst_file_name_for_download():
293293
agent._strip_stage_prefix_from_dst_file_name_for_download.assert_called_with(
294294
file
295295
)
296+
297+
298+
# The server DoP cap is newly introduced and therefore should not be tested in
299+
# old drivers.
300+
@pytest.mark.skipolddriver
301+
def test_server_dop_cap(tmp_path):
302+
file1 = tmp_path / "file1"
303+
file2 = tmp_path / "file2"
304+
file1.touch()
305+
file2.touch()
306+
# Positive case
307+
rest_client = SnowflakeFileTransferAgent(
308+
mock.MagicMock(autospec=SnowflakeCursor),
309+
"PUT some_file.txt",
310+
{
311+
"data": {
312+
"command": "UPLOAD",
313+
"src_locations": [file1, file2],
314+
"sourceCompression": "none",
315+
"parallel": 8,
316+
"stageInfo": {
317+
"creds": {},
318+
"location": "some_bucket",
319+
"region": "no_region",
320+
"locationType": "AZURE",
321+
"path": "remote_loc",
322+
"endPoint": "",
323+
"storageAccount": "storage_account",
324+
},
325+
},
326+
"success": True,
327+
},
328+
snowflake_server_dop_cap_for_file_transfer=1,
329+
)
330+
with mock.patch(
331+
"snowflake.connector.file_transfer_agent.ThreadPoolExecutor"
332+
) as tpe:
333+
with mock.patch("snowflake.connector.file_transfer_agent.threading.Condition"):
334+
with mock.patch(
335+
"snowflake.connector.file_transfer_agent.TransferMetadata",
336+
return_value=mock.Mock(
337+
num_files_started=0,
338+
num_files_completed=3,
339+
),
340+
):
341+
try:
342+
rest_client.execute()
343+
except AttributeError:
344+
pass
345+
346+
# We expect 3 thread pool executors to be created with thread count as 1,
347+
# because we will create executors for network, preprocess and postprocess,
348+
# and due to the server DoP cap, each of them will have a thread count
349+
# of 1.
350+
assert len(list(filter(lambda e: e.args == (1,), tpe.call_args_list))) == 3

0 commit comments

Comments
 (0)