Skip to content

Commit 8017d15

Browse files
[async] Applied #2513 to async code
1 parent be84331 commit 8017d15

File tree

2 files changed

+39
-26
lines changed

2 files changed

+39
-26
lines changed

src/snowflake/connector/aio/_cursor.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from snowflake.connector._sql_util import get_file_transfer_type
2626
from snowflake.connector.aio._bind_upload_agent import BindUploadAgent
27+
from snowflake.connector.aio._file_transfer_agent import SnowflakeFileTransferAgent
2728
from snowflake.connector.aio._result_batch import (
2829
ResultBatch,
2930
create_batches_from_response,
@@ -664,11 +665,8 @@ async def execute(
664665
)
665666
logger.debug("PUT OR GET: %s", self.is_file_transfer)
666667
if self.is_file_transfer:
667-
from ._file_transfer_agent import SnowflakeFileTransferAgent
668-
669668
# Decide whether to use the old, or new code path
670-
sf_file_transfer_agent = SnowflakeFileTransferAgent(
671-
self,
669+
sf_file_transfer_agent = self._create_file_transfer_agent(
672670
query,
673671
ret,
674672
put_callback=_put_callback,
@@ -684,9 +682,6 @@ async def execute(
684682
skip_upload_on_content_match=_skip_upload_on_content_match,
685683
source_from_stream=file_stream,
686684
multipart_threshold=data.get("threshold"),
687-
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
688-
unsafe_file_write=self._connection.unsafe_file_write,
689-
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
690685
)
691686
await sf_file_transfer_agent.execute()
692687
data = sf_file_transfer_agent.result()
@@ -1082,8 +1077,6 @@ async def _download(
10821077
_do_reset (bool, optional): Whether to reset the cursor before
10831078
downloading, by default we will reset the cursor.
10841079
"""
1085-
from ._file_transfer_agent import SnowflakeFileTransferAgent
1086-
10871080
if _do_reset:
10881081
self.reset()
10891082

@@ -1097,11 +1090,9 @@ async def _download(
10971090
)
10981091

10991092
# Execute the file operation based on the interpretation above.
1100-
file_transfer_agent = SnowflakeFileTransferAgent(
1101-
self,
1093+
file_transfer_agent = self._create_file_transfer_agent(
11021094
"", # empty command because it is triggered by directly calling this util not by a SQL query
11031095
ret,
1104-
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
11051096
)
11061097
await file_transfer_agent.execute()
11071098
await self._init_result_and_meta(file_transfer_agent.result())
@@ -1122,8 +1113,6 @@ async def _upload(
11221113
_do_reset (bool, optional): Whether to reset the cursor before
11231114
uploading, by default we will reset the cursor.
11241115
"""
1125-
from ._file_transfer_agent import SnowflakeFileTransferAgent
1126-
11271116
if _do_reset:
11281117
self.reset()
11291118

@@ -1137,12 +1126,10 @@ async def _upload(
11371126
)
11381127

11391128
# Execute the file operation based on the interpretation above.
1140-
file_transfer_agent = SnowflakeFileTransferAgent(
1141-
self,
1129+
file_transfer_agent = self._create_file_transfer_agent(
11421130
"", # empty command because it is triggered by directly calling this util not by a SQL query
11431131
ret,
11441132
force_put_overwrite=False, # _upload should respect user decision on overwriting
1145-
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
11461133
)
11471134
await file_transfer_agent.execute()
11481135
await self._init_result_and_meta(file_transfer_agent.result())
@@ -1191,8 +1178,6 @@ async def _upload_stream(
11911178
_do_reset (bool, optional): Whether to reset the cursor before
11921179
uploading, by default we will reset the cursor.
11931180
"""
1194-
from ._file_transfer_agent import SnowflakeFileTransferAgent
1195-
11961181
if _do_reset:
11971182
self.reset()
11981183

@@ -1207,13 +1192,11 @@ async def _upload_stream(
12071192
)
12081193

12091194
# Execute the file operation based on the interpretation above.
1210-
file_transfer_agent = SnowflakeFileTransferAgent(
1211-
self,
1195+
file_transfer_agent = self._create_file_transfer_agent(
12121196
"", # empty command because it is triggered by directly calling this util not by a SQL query
12131197
ret,
12141198
source_from_stream=input_stream,
12151199
force_put_overwrite=False, # _upload should respect user decision on overwriting
1216-
reraise_error_in_file_transfer_work_function=self.connection._reraise_error_in_file_transfer_work_function,
12171200
)
12181201
await file_transfer_agent.execute()
12191202
await self._init_result_and_meta(file_transfer_agent.result())
@@ -1321,6 +1304,24 @@ async def query_result(self, qid: str) -> SnowflakeCursor:
13211304
)
13221305
return self
13231306

1307+
def _create_file_transfer_agent(
1308+
self,
1309+
command: str,
1310+
ret: dict[str, Any],
1311+
/,
1312+
**kwargs,
1313+
) -> SnowflakeFileTransferAgent:
1314+
1315+
return SnowflakeFileTransferAgent(
1316+
self,
1317+
command,
1318+
ret,
1319+
use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
1320+
unsafe_file_write=self._connection.unsafe_file_write,
1321+
reraise_error_in_file_transfer_work_function=self._connection._reraise_error_in_file_transfer_work_function,
1322+
**kwargs,
1323+
)
1324+
13241325

13251326
class DictCursor(DictCursorSync, SnowflakeCursor):
13261327
pass

test/unit/aio/test_cursor_async_unit.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def __init__(self):
2929
self._log_max_query_length = 0
3030
self._reuse_results = None
3131
self._reraise_error_in_file_transfer_work_function = False
32+
self._enable_stage_s3_privatelink_for_us_east_1 = False
33+
self._unsafe_file_write = False
3234

3335

3436
@pytest.mark.parametrize(
@@ -109,7 +111,7 @@ async def mock_cmd_query(*args, **kwargs):
109111
class TestUploadDownloadMethods(IsolatedAsyncioTestCase):
110112
"""Test the _upload/_download/_upload_stream/_download_stream methods."""
111113

112-
@patch("snowflake.connector.aio._file_transfer_agent.SnowflakeFileTransferAgent")
114+
@patch("snowflake.connector.aio._cursor.SnowflakeFileTransferAgent")
113115
async def test_download(self, MockFileTransferAgent):
114116
cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks(
115117
MockFileTransferAgent
@@ -125,9 +127,11 @@ async def test_download(self, MockFileTransferAgent):
125127
# - download_as_stream of connection._stream_downloader
126128
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
127129
fake_conn._stream_downloader.download_as_stream.assert_not_called()
130+
MockFileTransferAgent.assert_called_once()
131+
assert MockFileTransferAgent.call_args.kwargs.get("use_s3_regional_url", False)
128132
mock_file_transfer_agent_instance.execute.assert_called_once()
129133

130-
@patch("snowflake.connector.aio._file_transfer_agent.SnowflakeFileTransferAgent")
134+
@patch("snowflake.connector.aio._cursor.SnowflakeFileTransferAgent")
131135
async def test_upload(self, MockFileTransferAgent):
132136
cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks(
133137
MockFileTransferAgent
@@ -143,9 +147,11 @@ async def test_upload(self, MockFileTransferAgent):
143147
# - download_as_stream of connection._stream_downloader
144148
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
145149
fake_conn._stream_downloader.download_as_stream.assert_not_called()
150+
MockFileTransferAgent.assert_called_once()
151+
assert MockFileTransferAgent.call_args.kwargs.get("use_s3_regional_url", False)
146152
mock_file_transfer_agent_instance.execute.assert_called_once()
147153

148-
@patch("snowflake.connector.aio._file_transfer_agent.SnowflakeFileTransferAgent")
154+
@patch("snowflake.connector.aio._cursor.SnowflakeFileTransferAgent")
149155
async def test_download_stream(self, MockFileTransferAgent):
150156
cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks(
151157
MockFileTransferAgent
@@ -161,9 +167,10 @@ async def test_download_stream(self, MockFileTransferAgent):
161167
# - execute in SnowflakeFileTransferAgent
162168
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
163169
fake_conn._stream_downloader.download_as_stream.assert_called_once()
170+
MockFileTransferAgent.assert_not_called()
164171
mock_file_transfer_agent_instance.execute.assert_not_called()
165172

166-
@patch("snowflake.connector.aio._file_transfer_agent.SnowflakeFileTransferAgent")
173+
@patch("snowflake.connector.aio._cursor.SnowflakeFileTransferAgent")
167174
async def test_upload_stream(self, MockFileTransferAgent):
168175
cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks(
169176
MockFileTransferAgent
@@ -180,6 +187,8 @@ async def test_upload_stream(self, MockFileTransferAgent):
180187
# - download_as_stream of connection._stream_downloader
181188
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
182189
fake_conn._stream_downloader.download_as_stream.assert_not_called()
190+
MockFileTransferAgent.assert_called_once()
191+
assert MockFileTransferAgent.call_args.kwargs.get("use_s3_regional_url", False)
183192
mock_file_transfer_agent_instance.execute.assert_called_once()
184193

185194
def _setup_mocks(self, MockFileTransferAgent):
@@ -191,6 +200,9 @@ def _setup_mocks(self, MockFileTransferAgent):
191200
fake_conn._file_operation_parser.parse_file_operation = AsyncMock()
192201
fake_conn._stream_downloader = MagicMock()
193202
fake_conn._stream_downloader.download_as_stream = AsyncMock()
203+
# this should be true on all new AWS deployments to use regional endpoints for staging operations
204+
fake_conn._enable_stage_s3_privatelink_for_us_east_1 = True
205+
fake_conn._unsafe_file_write = False
194206

195207
cursor = SnowflakeCursor(fake_conn)
196208
cursor.reset = MagicMock()

0 commit comments

Comments
 (0)