Skip to content

Commit 35c1e66

Browse files
[Async] Apply #2198 to async code
1 parent a7e0f8e commit 35c1e66

File tree

4 files changed

+280
-1
lines changed

4 files changed

+280
-1
lines changed

src/snowflake/connector/aio/_connection.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@
5252
QueryStatus,
5353
)
5454
from ..description import PLATFORM, PYTHON_VERSION, SNOWFLAKE_CONNECTOR_VERSION
55+
from ..direct_file_operation_utils import (
56+
FileOperationParser as FileOperationParserSynch,
57+
)
5558
from ..errorcode import (
5659
ER_CONNECTION_IS_CLOSED,
5760
ER_FAILED_TO_CONNECT_TO_DB,
@@ -76,6 +79,7 @@
7679
from ..wif_util import AttestationProvider
7780
from ._cursor import SnowflakeCursor
7881
from ._description import CLIENT_NAME
82+
from ._direct_file_operation_utils import StreamDownloader
7983
from ._network import SnowflakeRestful
8084
from ._telemetry import TelemetryClient
8185
from ._time_util import HeartBeatTimer
@@ -121,6 +125,10 @@ def __init__(
121125
# check SNOW-1218851 for long term improvement plan to refactor ocsp code
122126
atexit.register(self._close_at_exit)
123127

128+
# Set up the file operation parser and stream downloader.
129+
self._file_operation_parser = FileOperationParserSynch(self)
130+
self._stream_downloader = StreamDownloader(self)
131+
124132
def __enter__(self):
125133
# async connection does not support sync context manager
126134
raise TypeError(

src/snowflake/connector/aio/_cursor.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
)
3535
from snowflake.connector.aio._result_set import ResultSet, ResultSetIterator
3636
from snowflake.connector.constants import (
37+
CMD_TYPE_DOWNLOAD,
38+
CMD_TYPE_UPLOAD,
3739
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT,
3840
QueryStatus,
3941
)
@@ -1043,6 +1045,153 @@ async def get_result_batches(self) -> list[ResultBatch] | None:
10431045
)
10441046
return self._result_set.batches
10451047

1048+
async def _download(
1049+
self,
1050+
stage_location: str,
1051+
target_directory: str,
1052+
options: dict[str, Any],
1053+
_do_reset: bool = True,
1054+
) -> None:
1055+
"""Downloads from the stage location to the target directory.
1056+
1057+
Args:
1058+
stage_location (str): The location of the stage to download from.
1059+
target_directory (str): The destination directory to download into.
1060+
options (dict[str, Any]): The download options.
1061+
_do_reset (bool, optional): Whether to reset the cursor before
1062+
downloading, by default we will reset the cursor.
1063+
"""
1064+
from ._file_transfer_agent import SnowflakeFileTransferAgent
1065+
1066+
if _do_reset:
1067+
self.reset()
1068+
1069+
# Interpret the file operation.
1070+
ret = self.connection._file_operation_parser.parse_file_operation(
1071+
stage_location=stage_location,
1072+
local_file_name=None,
1073+
target_directory=target_directory,
1074+
command_type=CMD_TYPE_DOWNLOAD,
1075+
options=options,
1076+
)
1077+
1078+
# Execute the file operation based on the interpretation above.
1079+
file_transfer_agent = SnowflakeFileTransferAgent(
1080+
self,
1081+
"", # empty command because it is triggered by directly calling this util not by a SQL query
1082+
ret,
1083+
)
1084+
await file_transfer_agent.execute()
1085+
await self._init_result_and_meta(file_transfer_agent.result())
1086+
1087+
async def _upload(
1088+
self,
1089+
local_file_name: str,
1090+
stage_location: str,
1091+
options: dict[str, Any],
1092+
_do_reset: bool = True,
1093+
) -> None:
1094+
"""Uploads the local file to the stage location.
1095+
1096+
Args:
1097+
local_file_name (str): The local file to be uploaded.
1098+
stage_location (str): The stage location to upload the local file to.
1099+
options (dict[str, Any]): The upload options.
1100+
_do_reset (bool, optional): Whether to reset the cursor before
1101+
uploading, by default we will reset the cursor.
1102+
"""
1103+
from ._file_transfer_agent import SnowflakeFileTransferAgent
1104+
1105+
if _do_reset:
1106+
self.reset()
1107+
1108+
# Interpret the file operation.
1109+
ret = self.connection._file_operation_parser.parse_file_operation(
1110+
stage_location=stage_location,
1111+
local_file_name=local_file_name,
1112+
target_directory=None,
1113+
command_type=CMD_TYPE_UPLOAD,
1114+
options=options,
1115+
)
1116+
1117+
# Execute the file operation based on the interpretation above.
1118+
file_transfer_agent = SnowflakeFileTransferAgent(
1119+
self,
1120+
"", # empty command because it is triggered by directly calling this util not by a SQL query
1121+
ret,
1122+
)
1123+
await file_transfer_agent.execute()
1124+
await self._init_result_and_meta(file_transfer_agent.result())
1125+
1126+
async def _download_stream(
1127+
self, stage_location: str, decompress: bool = False
1128+
) -> IO[bytes]:
1129+
"""Downloads from the stage location as a stream.
1130+
1131+
Args:
1132+
stage_location (str): The location of the stage to download from.
1133+
decompress (bool, optional): Whether to decompress the file, by
1134+
default we do not decompress.
1135+
1136+
Returns:
1137+
IO[bytes]: A stream to read from.
1138+
"""
1139+
# Interpret the file operation.
1140+
ret = self.connection._file_operation_parser.parse_file_operation(
1141+
stage_location=stage_location,
1142+
local_file_name=None,
1143+
target_directory=None,
1144+
command_type=CMD_TYPE_DOWNLOAD,
1145+
options=None,
1146+
has_source_from_stream=True,
1147+
)
1148+
1149+
# Set up stream downloading based on the interpretation and return the stream for reading.
1150+
return await self.connection._stream_downloader.download_as_stream(
1151+
ret, decompress
1152+
)
1153+
1154+
async def _upload_stream(
1155+
self,
1156+
input_stream: IO[bytes],
1157+
stage_location: str,
1158+
options: dict[str, Any],
1159+
_do_reset: bool = True,
1160+
) -> None:
1161+
"""Uploads content in the input stream to the stage location.
1162+
1163+
Args:
1164+
input_stream (IO[bytes]): A stream to read from.
1165+
stage_location (str): The location of the stage to upload to.
1166+
options (dict[str, Any]): The upload options.
1167+
_do_reset (bool, optional): Whether to reset the cursor before
1168+
uploading, by default we will reset the cursor.
1169+
"""
1170+
from ._file_transfer_agent import SnowflakeFileTransferAgent
1171+
1172+
if _do_reset:
1173+
self.reset()
1174+
1175+
# Interpret the file operation.
1176+
ret = self.connection._file_operation_parser.parse_file_operation(
1177+
stage_location=stage_location,
1178+
local_file_name=None,
1179+
target_directory=None,
1180+
command_type=CMD_TYPE_UPLOAD,
1181+
options=options,
1182+
has_source_from_stream=input_stream,
1183+
)
1184+
1185+
# Execute the file operation based on the interpretation above.
1186+
file_transfer_agent = SnowflakeFileTransferAgent(
1187+
self,
1188+
"", # empty command because it is triggered by directly calling this util not by a SQL query
1189+
ret,
1190+
source_from_stream=input_stream,
1191+
)
1192+
await file_transfer_agent.execute()
1193+
await self._init_result_and_meta(file_transfer_agent.result())
1194+
10461195
async def get_results_from_sfqid(self, sfqid: str) -> None:
10471196
"""Gets the results from previously ran query. This methods differs from ``SnowflakeCursor.query_result``
10481197
in that it monitors the ``sfqid`` until it is no longer running, and then retrieves the results.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#
2+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
from __future__ import annotations
6+
7+
from abc import ABC, abstractmethod
8+
9+
10+
class StreamDownloaderBase(ABC):
11+
"""The interface of internal utility functions for stream downloading of file."""
12+
13+
@abstractmethod
14+
def __init__(self, connection):
15+
pass
16+
17+
@abstractmethod
18+
async def download_as_stream(self, ret, decompress=False):
19+
pass
20+
21+
22+
class StreamDownloader(StreamDownloaderBase):
23+
def __init__(self, connection):
24+
pass
25+
26+
async def download_as_stream(self, ret, decompress=False):
27+
raise NotImplementedError("download_as_stream is not yet supported")

test/unit/aio/test_cursor_async_unit.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
import asyncio
88
import unittest.mock
9-
from unittest.mock import MagicMock, patch
9+
from unittest import IsolatedAsyncioTestCase
10+
from unittest.mock import AsyncMock, MagicMock, patch
1011

1112
import pytest
1213

@@ -99,3 +100,97 @@ async def mock_cmd_query(*args, **kwargs):
99100

100101
# query cancel request should be sent upon timeout
101102
assert mockCancelQuery.called
103+
104+
105+
# The _upload/_download/_upload_stream/_download_stream are newly introduced
106+
# and therefore should not be tested in old drivers.
107+
@pytest.mark.skipolddriver
108+
class TestUploadDownloadMethods(IsolatedAsyncioTestCase):
109+
"""Test the _upload/_download/_upload_stream/_download_stream methods."""
110+
111+
@patch("snowflake.connector.aio._file_transfer_agent.SnowflakeFileTransferAgent")
112+
async def test_download(self, MockFileTransferAgent):
113+
cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks(
114+
MockFileTransferAgent
115+
)
116+
117+
# Call _download method
118+
await cursor._download("@st", "/tmp/test.txt", {})
119+
120+
# In the process of _download execution, we expect these methods to be called
121+
# - parse_file_operation in connection._file_operation_parser
122+
# - execute in SnowflakeFileTransferAgent
123+
# And we do not expect this method to be involved
124+
# - download_as_stream of connection._stream_downloader
125+
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
126+
fake_conn._stream_downloader.download_as_stream.assert_not_called()
127+
mock_file_transfer_agent_instance.execute.assert_called_once()
128+
129+
@patch("snowflake.connector.aio._file_transfer_agent.SnowflakeFileTransferAgent")
130+
async def test_upload(self, MockFileTransferAgent):
131+
cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks(
132+
MockFileTransferAgent
133+
)
134+
135+
# Call _upload method
136+
await cursor._upload("/tmp/test.txt", "@st", {})
137+
138+
# In the process of _upload execution, we expect these methods to be called
139+
# - parse_file_operation in connection._file_operation_parser
140+
# - execute in SnowflakeFileTransferAgent
141+
# And we do not expect this method to be involved
142+
# - download_as_stream of connection._stream_downloader
143+
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
144+
fake_conn._stream_downloader.download_as_stream.assert_not_called()
145+
mock_file_transfer_agent_instance.execute.assert_called_once()
146+
147+
@patch("snowflake.connector.aio._file_transfer_agent.SnowflakeFileTransferAgent")
148+
async def test_download_stream(self, MockFileTransferAgent):
149+
cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks(
150+
MockFileTransferAgent
151+
)
152+
153+
# Call _download_stream method
154+
await cursor._download_stream("@st/test.txt", decompress=True)
155+
156+
# In the process of _download_stream execution, we expect these methods to be called
157+
# - parse_file_operation in connection._file_operation_parser
158+
# - download_as_stream of connection._stream_downloader
159+
# And we do not expect this method to be involved
160+
# - execute in SnowflakeFileTransferAgent
161+
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
162+
fake_conn._stream_downloader.download_as_stream.assert_called_once()
163+
mock_file_transfer_agent_instance.execute.assert_not_called()
164+
165+
@patch("snowflake.connector.aio._file_transfer_agent.SnowflakeFileTransferAgent")
166+
async def test_upload_stream(self, MockFileTransferAgent):
167+
cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks(
168+
MockFileTransferAgent
169+
)
170+
171+
# Call _upload_stream method
172+
fd = MagicMock()
173+
await cursor._upload_stream(fd, "@st/test.txt", {})
174+
175+
# In the process of _upload_stream execution, we expect these methods to be called
176+
# - parse_file_operation in connection._file_operation_parser
177+
# - execute in SnowflakeFileTransferAgent
178+
# And we do not expect this method to be involved
179+
# - download_as_stream of connection._stream_downloader
180+
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
181+
fake_conn._stream_downloader.download_as_stream.assert_not_called()
182+
mock_file_transfer_agent_instance.execute.assert_called_once()
183+
184+
def _setup_mocks(self, MockFileTransferAgent):
185+
mock_file_transfer_agent_instance = MockFileTransferAgent.return_value
186+
mock_file_transfer_agent_instance.execute = AsyncMock(return_value=None)
187+
188+
fake_conn = FakeConnection()
189+
fake_conn._file_operation_parser = MagicMock()
190+
fake_conn._stream_downloader = MagicMock()
191+
fake_conn._stream_downloader.download_as_stream = AsyncMock()
192+
193+
cursor = SnowflakeCursor(fake_conn)
194+
cursor.reset = MagicMock()
195+
cursor._init_result_and_meta = AsyncMock()
196+
return cursor, fake_conn, mock_file_transfer_agent_instance

0 commit comments

Comments
 (0)