Skip to content

Commit be84331

Browse files
SNOW-2283945 use AWS regional endpoints when required for storing pandas frames (#2513)
(cherry picked from commit a6450a5)
1 parent 3f858bd commit be84331

File tree

2 files changed

+44
-34
lines changed

2 files changed

+44
-34
lines changed

src/snowflake/connector/cursor.py

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@
7979
from pyarrow import Table
8080

8181
from .connection import SnowflakeConnection
82-
from .file_transfer_agent import SnowflakeProgressPercentage
82+
from .file_transfer_agent import (
83+
SnowflakeFileTransferAgent,
84+
SnowflakeProgressPercentage,
85+
)
8386
from .result_batch import ResultBatch
8487

8588
T = TypeVar("T", bound=collections.abc.Sequence)
@@ -1064,11 +1067,7 @@ def execute(
10641067
)
10651068
logger.debug("PUT OR GET: %s", self.is_file_transfer)
10661069
if self.is_file_transfer:
1067-
from .file_transfer_agent import SnowflakeFileTransferAgent
1068-
1069-
# Decide whether to use the old, or new code path
1070-
sf_file_transfer_agent = SnowflakeFileTransferAgent(
1071-
self,
1070+
sf_file_transfer_agent = self._create_file_transfer_agent(
10721071
query,
10731072
ret,
10741073
put_callback=_put_callback,
@@ -1084,13 +1083,6 @@ def execute(
10841083
skip_upload_on_content_match=_skip_upload_on_content_match,
10851084
source_from_stream=file_stream,
10861085
multipart_threshold=data.get("threshold"),
1087-
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
1088-
iobound_tpe_limit=self._connection.iobound_tpe_limit,
1089-
unsafe_file_write=self._connection.unsafe_file_write,
1090-
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
1091-
self._connection
1092-
),
1093-
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
10941086
)
10951087
sf_file_transfer_agent.execute()
10961088
data = sf_file_transfer_agent.result()
@@ -1786,8 +1778,6 @@ def _download(
17861778
_do_reset (bool, optional): Whether to reset the cursor before
17871779
downloading, by default we will reset the cursor.
17881780
"""
1789-
from .file_transfer_agent import SnowflakeFileTransferAgent
1790-
17911781
if _do_reset:
17921782
self.reset()
17931783

@@ -1801,14 +1791,9 @@ def _download(
18011791
)
18021792

18031793
# Execute the file operation based on the interpretation above.
1804-
file_transfer_agent = SnowflakeFileTransferAgent(
1805-
self,
1794+
file_transfer_agent = self._create_file_transfer_agent(
18061795
"", # empty command because it is triggered by directly calling this util not by a SQL query
18071796
ret,
1808-
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
1809-
self._connection
1810-
),
1811-
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
18121797
)
18131798
file_transfer_agent.execute()
18141799
self._init_result_and_meta(file_transfer_agent.result())
@@ -1829,7 +1814,6 @@ def _upload(
18291814
_do_reset (bool, optional): Whether to reset the cursor before
18301815
uploading, by default we will reset the cursor.
18311816
"""
1832-
from .file_transfer_agent import SnowflakeFileTransferAgent
18331817

18341818
if _do_reset:
18351819
self.reset()
@@ -1844,15 +1828,10 @@ def _upload(
18441828
)
18451829

18461830
# Execute the file operation based on the interpretation above.
1847-
file_transfer_agent = SnowflakeFileTransferAgent(
1848-
self,
1831+
file_transfer_agent = self._create_file_transfer_agent(
18491832
"", # empty command because it is triggered by directly calling this util not by a SQL query
18501833
ret,
18511834
force_put_overwrite=False, # _upload should respect user decision on overwriting
1852-
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
1853-
self._connection
1854-
),
1855-
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
18561835
)
18571836
file_transfer_agent.execute()
18581837
self._init_result_and_meta(file_transfer_agent.result())
@@ -1899,7 +1878,6 @@ def _upload_stream(
18991878
_do_reset (bool, optional): Whether to reset the cursor before
19001879
uploading, by default we will reset the cursor.
19011880
"""
1902-
from .file_transfer_agent import SnowflakeFileTransferAgent
19031881

19041882
if _do_reset:
19051883
self.reset()
@@ -1915,19 +1893,37 @@ def _upload_stream(
19151893
)
19161894

19171895
# Execute the file operation based on the interpretation above.
1918-
file_transfer_agent = SnowflakeFileTransferAgent(
1919-
self,
1896+
file_transfer_agent = self._create_file_transfer_agent(
19201897
"", # empty command because it is triggered by directly calling this util not by a SQL query
19211898
ret,
19221899
source_from_stream=input_stream,
19231900
force_put_overwrite=False, # _upload_stream should respect user decision on overwriting
1901+
)
1902+
file_transfer_agent.execute()
1903+
self._init_result_and_meta(file_transfer_agent.result())
1904+
1905+
def _create_file_transfer_agent(
1906+
self,
1907+
command: str,
1908+
ret: dict[str, Any],
1909+
/,
1910+
**kwargs,
1911+
) -> SnowflakeFileTransferAgent:
1912+
from .file_transfer_agent import SnowflakeFileTransferAgent
1913+
1914+
return SnowflakeFileTransferAgent(
1915+
self,
1916+
command,
1917+
ret,
1918+
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
1919+
iobound_tpe_limit=self._connection.iobound_tpe_limit,
1920+
unsafe_file_write=self._connection.unsafe_file_write,
19241921
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
19251922
self._connection
19261923
),
1927-
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
1924+
reraise_error_in_file_transfer_work_function=self._connection._reraise_error_in_file_transfer_work_function,
1925+
**kwargs,
19281926
)
1929-
file_transfer_agent.execute()
1930-
self._init_result_and_meta(file_transfer_agent.result())
19311927

19321928

19331929
class DictCursor(SnowflakeCursor):

test/unit/test_cursor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ def __init__(self):
2525
self._log_max_query_length = 0
2626
self._reuse_results = None
2727
self._reraise_error_in_file_transfer_work_function = False
28+
self._enable_stage_s3_privatelink_for_us_east_1 = False
29+
self._iobound_tpe_limit = None
30+
self._unsafe_file_write = False
2831

2932

3033
@pytest.mark.parametrize(
@@ -121,6 +124,8 @@ def test_download(self, MockFileTransferAgent):
121124
# - download_as_stream of connection._stream_downloader
122125
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
123126
fake_conn._stream_downloader.download_as_stream.assert_not_called()
127+
MockFileTransferAgent.assert_called_once()
128+
assert MockFileTransferAgent.call_args.kwargs.get("use_s3_regional_url", False)
124129
mock_file_transfer_agent_instance.execute.assert_called_once()
125130

126131
@patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent")
@@ -139,6 +144,8 @@ def test_upload(self, MockFileTransferAgent):
139144
# - download_as_stream of connection._stream_downloader
140145
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
141146
fake_conn._stream_downloader.download_as_stream.assert_not_called()
147+
MockFileTransferAgent.assert_called_once()
148+
assert MockFileTransferAgent.call_args.kwargs.get("use_s3_regional_url", False)
142149
mock_file_transfer_agent_instance.execute.assert_called_once()
143150

144151
@patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent")
@@ -157,6 +164,7 @@ def test_download_stream(self, MockFileTransferAgent):
157164
# - execute in SnowflakeFileTransferAgent
158165
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
159166
fake_conn._stream_downloader.download_as_stream.assert_called_once()
167+
MockFileTransferAgent.assert_not_called()
160168
mock_file_transfer_agent_instance.execute.assert_not_called()
161169

162170
@patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent")
@@ -176,6 +184,8 @@ def test_upload_stream(self, MockFileTransferAgent):
176184
# - download_as_stream of connection._stream_downloader
177185
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
178186
fake_conn._stream_downloader.download_as_stream.assert_not_called()
187+
MockFileTransferAgent.assert_called_once()
188+
assert MockFileTransferAgent.call_args.kwargs.get("use_s3_regional_url", False)
179189
mock_file_transfer_agent_instance.execute.assert_called_once()
180190

181191
def _setup_mocks(self, MockFileTransferAgent):
@@ -185,6 +195,10 @@ def _setup_mocks(self, MockFileTransferAgent):
185195
fake_conn = FakeConnection()
186196
fake_conn._file_operation_parser = MagicMock()
187197
fake_conn._stream_downloader = MagicMock()
198+
# this should be true on all new AWS deployments to use regional endpoints for staging operations
199+
fake_conn._enable_stage_s3_privatelink_for_us_east_1 = True
200+
fake_conn._iobound_tpe_limit = 1
201+
fake_conn._unsafe_file_write = False
188202

189203
cursor = SnowflakeCursor(fake_conn)
190204
cursor.reset = MagicMock()

0 commit comments

Comments
 (0)