Skip to content

Commit 0b31d9b

Browse files
SNOW-2283945 use AWS regional endpoints when required for storing pandas frames (#2513)
(cherry picked from commit a6450a5)
1 parent fba876e commit 0b31d9b

File tree

2 files changed

+44
-34
lines changed

2 files changed

+44
-34
lines changed

src/snowflake/connector/cursor.py

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@
7979
from pyarrow import Table
8080

8181
from .connection import SnowflakeConnection
82-
from .file_transfer_agent import SnowflakeProgressPercentage
82+
from .file_transfer_agent import (
83+
SnowflakeFileTransferAgent,
84+
SnowflakeProgressPercentage,
85+
)
8386
from .result_batch import ResultBatch
8487

8588
T = TypeVar("T", bound=collections.abc.Sequence)
@@ -1064,11 +1067,7 @@ def execute(
10641067
)
10651068
logger.debug("PUT OR GET: %s", self.is_file_transfer)
10661069
if self.is_file_transfer:
1067-
from .file_transfer_agent import SnowflakeFileTransferAgent
1068-
1069-
# Decide whether to use the old, or new code path
1070-
sf_file_transfer_agent = SnowflakeFileTransferAgent(
1071-
self,
1070+
sf_file_transfer_agent = self._create_file_transfer_agent(
10721071
query,
10731072
ret,
10741073
put_callback=_put_callback,
@@ -1084,13 +1083,6 @@ def execute(
10841083
skip_upload_on_content_match=_skip_upload_on_content_match,
10851084
source_from_stream=file_stream,
10861085
multipart_threshold=data.get("threshold"),
1087-
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
1088-
iobound_tpe_limit=self._connection.iobound_tpe_limit,
1089-
unsafe_file_write=self._connection.unsafe_file_write,
1090-
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
1091-
self._connection
1092-
),
1093-
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
10941086
)
10951087
sf_file_transfer_agent.execute()
10961088
data = sf_file_transfer_agent.result()
@@ -1785,8 +1777,6 @@ def _download(
17851777
_do_reset (bool, optional): Whether to reset the cursor before
17861778
downloading, by default we will reset the cursor.
17871779
"""
1788-
from .file_transfer_agent import SnowflakeFileTransferAgent
1789-
17901780
if _do_reset:
17911781
self.reset()
17921782

@@ -1800,14 +1790,9 @@ def _download(
18001790
)
18011791

18021792
# Execute the file operation based on the interpretation above.
1803-
file_transfer_agent = SnowflakeFileTransferAgent(
1804-
self,
1793+
file_transfer_agent = self._create_file_transfer_agent(
18051794
"", # empty command because it is triggered by directly calling this util not by a SQL query
18061795
ret,
1807-
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
1808-
self._connection
1809-
),
1810-
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
18111796
)
18121797
file_transfer_agent.execute()
18131798
self._init_result_and_meta(file_transfer_agent.result())
@@ -1828,7 +1813,6 @@ def _upload(
18281813
_do_reset (bool, optional): Whether to reset the cursor before
18291814
uploading, by default we will reset the cursor.
18301815
"""
1831-
from .file_transfer_agent import SnowflakeFileTransferAgent
18321816

18331817
if _do_reset:
18341818
self.reset()
@@ -1843,15 +1827,10 @@ def _upload(
18431827
)
18441828

18451829
# Execute the file operation based on the interpretation above.
1846-
file_transfer_agent = SnowflakeFileTransferAgent(
1847-
self,
1830+
file_transfer_agent = self._create_file_transfer_agent(
18481831
"", # empty command because it is triggered by directly calling this util not by a SQL query
18491832
ret,
18501833
force_put_overwrite=False, # _upload should respect user decision on overwriting
1851-
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
1852-
self._connection
1853-
),
1854-
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
18551834
)
18561835
file_transfer_agent.execute()
18571836
self._init_result_and_meta(file_transfer_agent.result())
@@ -1898,7 +1877,6 @@ def _upload_stream(
18981877
_do_reset (bool, optional): Whether to reset the cursor before
18991878
uploading, by default we will reset the cursor.
19001879
"""
1901-
from .file_transfer_agent import SnowflakeFileTransferAgent
19021880

19031881
if _do_reset:
19041882
self.reset()
@@ -1914,19 +1892,37 @@ def _upload_stream(
19141892
)
19151893

19161894
# Execute the file operation based on the interpretation above.
1917-
file_transfer_agent = SnowflakeFileTransferAgent(
1918-
self,
1895+
file_transfer_agent = self._create_file_transfer_agent(
19191896
"", # empty command because it is triggered by directly calling this util not by a SQL query
19201897
ret,
19211898
source_from_stream=input_stream,
19221899
force_put_overwrite=False, # _upload_stream should respect user decision on overwriting
1900+
)
1901+
file_transfer_agent.execute()
1902+
self._init_result_and_meta(file_transfer_agent.result())
1903+
1904+
def _create_file_transfer_agent(
1905+
self,
1906+
command: str,
1907+
ret: dict[str, Any],
1908+
/,
1909+
**kwargs,
1910+
) -> SnowflakeFileTransferAgent:
1911+
from .file_transfer_agent import SnowflakeFileTransferAgent
1912+
1913+
return SnowflakeFileTransferAgent(
1914+
self,
1915+
command,
1916+
ret,
1917+
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
1918+
iobound_tpe_limit=self._connection.iobound_tpe_limit,
1919+
unsafe_file_write=self._connection.unsafe_file_write,
19231920
snowflake_server_dop_cap_for_file_transfer=_snowflake_max_parallelism_for_file_transfer(
19241921
self._connection
19251922
),
1926-
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
1923+
reraise_error_in_file_transfer_work_function=self._connection._reraise_error_in_file_transfer_work_function,
1924+
**kwargs,
19271925
)
1928-
file_transfer_agent.execute()
1929-
self._init_result_and_meta(file_transfer_agent.result())
19301926

19311927

19321928
class DictCursor(SnowflakeCursor):

test/unit/test_cursor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ def __init__(self):
2525
self._log_max_query_length = 0
2626
self._reuse_results = None
2727
self._reraise_error_in_file_transfer_work_function = False
28+
self._enable_stage_s3_privatelink_for_us_east_1 = False
29+
self._iobound_tpe_limit = None
30+
self._unsafe_file_write = False
2831

2932

3033
@pytest.mark.parametrize(
@@ -121,6 +124,8 @@ def test_download(self, MockFileTransferAgent):
121124
# - download_as_stream of connection._stream_downloader
122125
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
123126
fake_conn._stream_downloader.download_as_stream.assert_not_called()
127+
MockFileTransferAgent.assert_called_once()
128+
assert MockFileTransferAgent.call_args.kwargs.get("use_s3_regional_url", False)
124129
mock_file_transfer_agent_instance.execute.assert_called_once()
125130

126131
@patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent")
@@ -139,6 +144,8 @@ def test_upload(self, MockFileTransferAgent):
139144
# - download_as_stream of connection._stream_downloader
140145
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
141146
fake_conn._stream_downloader.download_as_stream.assert_not_called()
147+
MockFileTransferAgent.assert_called_once()
148+
assert MockFileTransferAgent.call_args.kwargs.get("use_s3_regional_url", False)
142149
mock_file_transfer_agent_instance.execute.assert_called_once()
143150

144151
@patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent")
@@ -157,6 +164,7 @@ def test_download_stream(self, MockFileTransferAgent):
157164
# - execute in SnowflakeFileTransferAgent
158165
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
159166
fake_conn._stream_downloader.download_as_stream.assert_called_once()
167+
MockFileTransferAgent.assert_not_called()
160168
mock_file_transfer_agent_instance.execute.assert_not_called()
161169

162170
@patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent")
@@ -176,6 +184,8 @@ def test_upload_stream(self, MockFileTransferAgent):
176184
# - download_as_stream of connection._stream_downloader
177185
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
178186
fake_conn._stream_downloader.download_as_stream.assert_not_called()
187+
MockFileTransferAgent.assert_called_once()
188+
assert MockFileTransferAgent.call_args.kwargs.get("use_s3_regional_url", False)
179189
mock_file_transfer_agent_instance.execute.assert_called_once()
180190

181191
def _setup_mocks(self, MockFileTransferAgent):
@@ -185,6 +195,10 @@ def _setup_mocks(self, MockFileTransferAgent):
185195
fake_conn = FakeConnection()
186196
fake_conn._file_operation_parser = MagicMock()
187197
fake_conn._stream_downloader = MagicMock()
198+
# this should be true on all new AWS deployments to use regional endpoints for staging operations
199+
fake_conn._enable_stage_s3_privatelink_for_us_east_1 = True
200+
fake_conn._iobound_tpe_limit = 1
201+
fake_conn._unsafe_file_write = False
188202

189203
cursor = SnowflakeCursor(fake_conn)
190204
cursor.reset = MagicMock()

0 commit comments

Comments
 (0)