Skip to content

Commit 08dbe0e

Browse files
authored
SNOW-2216803 allow re-raising error in file transfer work function in main thread (#2443)
1 parent 2d8a795 commit 08dbe0e

File tree

6 files changed

+200
-0
lines changed

6 files changed

+200
-0
lines changed

src/snowflake/connector/connection.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,10 @@ def _get_private_bytes_from_file(
389389
_DEFAULT_VALUE_SERVER_DOP_CAP_FOR_FILE_TRANSFER, # default value
390390
int, # type
391391
), # snowflake internal
392+
"reraise_error_in_file_transfer_work_function": (
393+
False,
394+
bool,
395+
),
392396
}
393397

394398
APPLICATION_RE = re.compile(r"[\w\d_]+")

src/snowflake/connector/cursor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,7 @@ def execute(
10901090
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
10911091
self._connection
10921092
),
1093+
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
10931094
)
10941095
sf_file_transfer_agent.execute()
10951096
data = sf_file_transfer_agent.result()
@@ -1807,6 +1808,7 @@ def _download(
18071808
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
18081809
self._connection
18091810
),
1811+
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
18101812
)
18111813
file_transfer_agent.execute()
18121814
self._init_result_and_meta(file_transfer_agent.result())
@@ -1850,6 +1852,7 @@ def _upload(
18501852
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
18511853
self._connection
18521854
),
1855+
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
18531856
)
18541857
file_transfer_agent.execute()
18551858
self._init_result_and_meta(file_transfer_agent.result())
@@ -1921,6 +1924,7 @@ def _upload_stream(
19211924
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
19221925
self._connection
19231926
),
1927+
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
19241928
)
19251929
file_transfer_agent.execute()
19261930
self._init_result_and_meta(file_transfer_agent.result())

src/snowflake/connector/file_transfer_agent.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ def __init__(
357357
iobound_tpe_limit: int | None = None,
358358
unsafe_file_write: bool = False,
359359
snowflake_server_dop_cap_for_file_transfer=_DEFAULT_VALUE_SERVER_DOP_CAP_FOR_FILE_TRANSFER,
360+
reraise_error_in_file_transfer_work_function: bool = False,
360361
) -> None:
361362
self._cursor = cursor
362363
self._command = command
@@ -392,6 +393,9 @@ def __init__(
392393
self._snowflake_server_dop_cap_for_file_transfer = (
393394
snowflake_server_dop_cap_for_file_transfer
394395
)
396+
self._reraise_error_in_file_transfer_work_function = (
397+
reraise_error_in_file_transfer_work_function
398+
)
395399

396400
def execute(self) -> None:
397401
self._parse_command()
@@ -471,6 +475,7 @@ def transfer(self, metas: list[SnowflakeFileMeta]) -> None:
471475
transfer_metadata = TransferMetadata() # this is protected by cv_chunk_process
472476
is_upload = self._command_type == CMD_TYPE_UPLOAD
473477
exception_caught_in_callback: Exception | None = None
478+
exception_caught_in_work: Exception | None = None
474479
logger.debug(
475480
"Going to %sload %d files", "up" if is_upload else "down", len(metas)
476481
)
@@ -626,6 +631,17 @@ def function_and_callback_wrapper(
626631
logger.error(f"An exception was raised in {repr(work)}", exc_info=True)
627632
file_meta.error_details = e
628633
result = (False, e)
634+
# If the reraise is enabled, notify the main thread of work
635+
# function error, with the concrete exception stored aside in
636+
# exception_caught_in_work, such that towards the end of
637+
# the transfer call, we reraise the error as is immediately
638+
# instead of continuing the execution after transfer.
639+
if self._reraise_error_in_file_transfer_work_function:
640+
with cv_main_thread:
641+
nonlocal exception_caught_in_work
642+
exception_caught_in_work = e
643+
cv_main_thread.notify()
644+
629645
try:
630646
_callback(*result, file_meta)
631647
except Exception as e:
@@ -670,6 +686,10 @@ def function_and_callback_wrapper(
670686
with cv_main_thread:
671687
while transfer_metadata.num_files_completed < num_total_files:
672688
cv_main_thread.wait()
689+
# If both exception_caught_in_work and exception_caught_in_callback
690+
# are present, the former will take precedence.
691+
if exception_caught_in_work is not None:
692+
raise exception_caught_in_work
673693
if exception_caught_in_callback is not None:
674694
raise exception_caught_in_callback
675695

test/unit/test_connection.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,3 +773,41 @@ def test_single_use_refresh_tokens_option_is_plumbed_into_authbyauthcode(
773773
oauth_enable_single_use_refresh_tokens=rtr_enabled,
774774
)
775775
assert conn.auth_class._enable_single_use_refresh_tokens == rtr_enabled
776+
777+
778+
# Skip for old drivers because the connection config of
779+
# reraise_error_in_file_transfer_work_function is newly introduced.
780+
@pytest.mark.skipolddriver
781+
@pytest.mark.parametrize("reraise_enabled", [True, False, None])
782+
def test_reraise_error_in_file_transfer_work_function_config(
783+
reraise_enabled: bool | None,
784+
):
785+
"""Test that reraise_error_in_file_transfer_work_function config is
786+
properly set on connection."""
787+
788+
with mock.patch(
789+
"snowflake.connector.network.SnowflakeRestful._post_request",
790+
return_value={
791+
"data": {
792+
"serverVersion": "a.b.c",
793+
},
794+
"code": None,
795+
"message": None,
796+
"success": True,
797+
},
798+
):
799+
if reraise_enabled is not None:
800+
# Create a connection with the config set to the value of reraise_enabled.
801+
conn = fake_connector(
802+
**{"reraise_error_in_file_transfer_work_function": reraise_enabled}
803+
)
804+
else:
805+
# Special test setup: when reraise_enabled is None, create a
806+
# connection without setting the config.
807+
conn = fake_connector()
808+
809+
# When reraise_enabled is None, we expect a default value of False,
810+
# so taking bool() on it also makes sense.
811+
expected_value = bool(reraise_enabled)
812+
actual_value = conn._reraise_error_in_file_transfer_work_function
813+
assert actual_value == expected_value

test/unit/test_cursor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class FakeConnection(SnowflakeConnection):
2424
def __init__(self):
2525
self._log_max_query_length = 0
2626
self._reuse_results = None
27+
self._reraise_error_in_file_transfer_work_function = False
2728

2829

2930
@pytest.mark.parametrize(

test/unit/test_put_get.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,3 +348,136 @@ def test_server_dop_cap(tmp_path):
348348
# and due to the server DoP cap, each of them will have a thread count
349349
# of 1.
350350
assert len(list(filter(lambda e: e.args == (1,), tpe.call_args_list))) == 3
351+
352+
353+
def _setup_test_for_reraise_file_transfer_work_fn_error(tmp_path, reraise_param_value):
354+
"""Helper function to set up common test infrastructure for tests related to re-raising file transfer work function error.
355+
356+
Returns:
357+
tuple: (agent, test_exception, mock_client, mock_create_client)
358+
"""
359+
360+
file1 = tmp_path / "file1"
361+
file1.write_text("test content")
362+
363+
# Mock cursor with connection attribute
364+
mock_cursor = mock.MagicMock(autospec=SnowflakeCursor)
365+
mock_cursor.connection._reraise_error_in_file_transfer_work_function = (
366+
reraise_param_value
367+
)
368+
369+
# Create file transfer agent
370+
agent = SnowflakeFileTransferAgent(
371+
mock_cursor,
372+
"PUT some_file.txt",
373+
{
374+
"data": {
375+
"command": "UPLOAD",
376+
"src_locations": [str(file1)],
377+
"sourceCompression": "none",
378+
"parallel": 1,
379+
"stageInfo": {
380+
"creds": {
381+
"AZURE_SAS_TOKEN": "sas_token",
382+
},
383+
"location": "some_bucket",
384+
"region": "no_region",
385+
"locationType": "AZURE",
386+
"path": "remote_loc",
387+
"endPoint": "",
388+
"storageAccount": "storage_account",
389+
},
390+
},
391+
"success": True,
392+
},
393+
reraise_error_in_file_transfer_work_function=reraise_param_value,
394+
)
395+
396+
# Quick check to make sure the field _reraise_error_in_file_transfer_work_function is correctly populated
397+
assert (
398+
agent._reraise_error_in_file_transfer_work_function == reraise_param_value
399+
), f"expected {reraise_param_value}, got {agent._reraise_error_in_file_transfer_work_function}"
400+
401+
# Parse command and initialize file metadata
402+
agent._parse_command()
403+
agent._init_file_metadata()
404+
agent._process_file_compression_type()
405+
406+
# Create a custom exception to be raised by the work function
407+
test_exception = Exception("Test work function failure")
408+
409+
def mock_upload_chunk_with_delay(*args, **kwargs):
410+
import time
411+
412+
time.sleep(0.2)
413+
raise test_exception
414+
415+
# Set up mock client patch, which we will activate in each unit test case.
416+
mock_create_client = mock.patch.object(agent, "_create_file_transfer_client")
417+
mock_client = mock.MagicMock()
418+
mock_client.upload_chunk.side_effect = mock_upload_chunk_with_delay
419+
420+
# Set up mock client attributes needed for the transfer flow
421+
mock_client.meta = agent._file_metadata[0]
422+
mock_client.num_of_chunks = 1
423+
mock_client.successful_transfers = 0
424+
mock_client.failed_transfers = 0
425+
mock_client.lock = mock.MagicMock()
426+
# Mock methods that would be called during cleanup
427+
mock_client.finish_upload = mock.MagicMock()
428+
mock_client.delete_client_data = mock.MagicMock()
429+
430+
return agent, test_exception, mock_client, mock_create_client
431+
432+
433+
# Skip for old drivers because the connection config of
434+
# reraise_error_in_file_transfer_work_function is newly introduced.
435+
@pytest.mark.skipolddriver
436+
def test_python_reraise_file_transfer_work_fn_error_as_is(tmp_path):
437+
"""Tests that when reraise_error_in_file_transfer_work_function config is True,
438+
exceptions are reraised immediately without continuing execution after transfer().
439+
"""
440+
agent, test_exception, mock_client, mock_create_client_patch = (
441+
_setup_test_for_reraise_file_transfer_work_fn_error(tmp_path, True)
442+
)
443+
444+
with mock_create_client_patch as mock_create_client:
445+
mock_create_client.return_value = mock_client
446+
447+
# Test that with the connection config
448+
# reraise_error_in_file_transfer_work_function is True, the
449+
# exception is reraised immediately in main thread of transfer.
450+
with pytest.raises(Exception) as exc_info:
451+
agent.transfer(agent._file_metadata)
452+
453+
# Verify it's the same exception we injected
454+
assert exc_info.value is test_exception
455+
456+
# Verify that prepare_upload was called (showing the work function was executed)
457+
mock_client.prepare_upload.assert_called_once()
458+
459+
460+
# Skip for old drivers because the connection config of
461+
# reraise_error_in_file_transfer_work_function is newly introduced.
462+
@pytest.mark.skipolddriver
463+
def test_python_not_reraise_file_transfer_work_fn_error_as_is(tmp_path):
464+
"""Tests that when reraise_error_in_file_transfer_work_function config is False (default),
465+
where exceptions are stored in file metadata but execution continues.
466+
"""
467+
agent, test_exception, mock_client, mock_create_client_patch = (
468+
_setup_test_for_reraise_file_transfer_work_fn_error(tmp_path, False)
469+
)
470+
471+
with mock_create_client_patch as mock_create_client:
472+
mock_create_client.return_value = mock_client
473+
474+
# Verify that with the connection config
475+
# reraise_error_in_file_transfer_work_function is False, the
476+
# exception is not reraised (but instead stored in file metadata).
477+
agent.transfer(agent._file_metadata)
478+
479+
# Verify that the error was stored in the file metadata
480+
assert agent._file_metadata[0].error_details is test_exception
481+
482+
# Verify that prepare_upload was called
483+
mock_client.prepare_upload.assert_called_once()

0 commit comments

Comments
 (0)