Skip to content

Commit a7e0f8e

Browse files
sfc-gh-zyaosfc-gh-fpawlowski
authored andcommitted
SNOW-1963078 Port _upload / _download / _upload_stream / _download_st… (#2198)
(cherry picked from commit a3229c3)
1 parent 609427f commit a7e0f8e

File tree

5 files changed

+319
-0
lines changed

5 files changed

+319
-0
lines changed

src/snowflake/connector/connection.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
PYTHON_VERSION,
8383
SNOWFLAKE_CONNECTOR_VERSION,
8484
)
85+
from .direct_file_operation_utils import FileOperationParser, StreamDownloader
8586
from .errorcode import (
8687
ER_CONNECTION_IS_CLOSED,
8788
ER_FAILED_PROCESSING_PYFORMAT,
@@ -512,6 +513,10 @@ def __init__(
512513
# check SNOW-1218851 for long term improvement plan to refactor ocsp code
513514
atexit.register(self._close_at_exit)
514515

516+
# Set up the file operation parser and stream downloader.
517+
self._file_operation_parser = FileOperationParser(self)
518+
self._stream_downloader = StreamDownloader(self)
519+
515520
# Deprecated
516521
@property
517522
def insecure_mode(self) -> bool:

src/snowflake/connector/cursor.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
from ._utils import _TrackedQueryCancellationTimer
4343
from .bind_upload_agent import BindUploadAgent, BindUploadError
4444
from .constants import (
45+
CMD_TYPE_DOWNLOAD,
46+
CMD_TYPE_UPLOAD,
4547
FIELD_NAME_TO_ID,
4648
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT,
4749
FileTransferType,
@@ -1740,6 +1742,151 @@ def get_result_batches(self) -> list[ResultBatch] | None:
17401742
)
17411743
return self._result_set.batches
17421744

1745+
def _download(
1746+
self,
1747+
stage_location: str,
1748+
target_directory: str,
1749+
options: dict[str, Any],
1750+
_do_reset: bool = True,
1751+
) -> None:
1752+
"""Downloads from the stage location to the target directory.
1753+
1754+
Args:
1755+
stage_location (str): The location of the stage to download from.
1756+
target_directory (str): The destination directory to download into.
1757+
options (dict[str, Any]): The download options.
1758+
_do_reset (bool, optional): Whether to reset the cursor before
1759+
downloading, by default we will reset the cursor.
1760+
"""
1761+
from .file_transfer_agent import SnowflakeFileTransferAgent
1762+
1763+
if _do_reset:
1764+
self.reset()
1765+
1766+
# Interpret the file operation.
1767+
ret = self.connection._file_operation_parser.parse_file_operation(
1768+
stage_location=stage_location,
1769+
local_file_name=None,
1770+
target_directory=target_directory,
1771+
command_type=CMD_TYPE_DOWNLOAD,
1772+
options=options,
1773+
)
1774+
1775+
# Execute the file operation based on the interpretation above.
1776+
file_transfer_agent = SnowflakeFileTransferAgent(
1777+
self,
1778+
"", # empty command because it is triggered by directly calling this util not by a SQL query
1779+
ret,
1780+
)
1781+
file_transfer_agent.execute()
1782+
self._init_result_and_meta(file_transfer_agent.result())
1783+
1784+
def _upload(
1785+
self,
1786+
local_file_name: str,
1787+
stage_location: str,
1788+
options: dict[str, Any],
1789+
_do_reset: bool = True,
1790+
) -> None:
1791+
"""Uploads the local file to the stage location.
1792+
1793+
Args:
1794+
local_file_name (str): The local file to be uploaded.
1795+
stage_location (str): The stage location to upload the local file to.
1796+
options (dict[str, Any]): The upload options.
1797+
_do_reset (bool, optional): Whether to reset the cursor before
1798+
uploading, by default we will reset the cursor.
1799+
"""
1800+
from .file_transfer_agent import SnowflakeFileTransferAgent
1801+
1802+
if _do_reset:
1803+
self.reset()
1804+
1805+
# Interpret the file operation.
1806+
ret = self.connection._file_operation_parser.parse_file_operation(
1807+
stage_location=stage_location,
1808+
local_file_name=local_file_name,
1809+
target_directory=None,
1810+
command_type=CMD_TYPE_UPLOAD,
1811+
options=options,
1812+
)
1813+
1814+
# Execute the file operation based on the interpretation above.
1815+
file_transfer_agent = SnowflakeFileTransferAgent(
1816+
self,
1817+
"", # empty command because it is triggered by directly calling this util not by a SQL query
1818+
ret,
1819+
)
1820+
file_transfer_agent.execute()
1821+
self._init_result_and_meta(file_transfer_agent.result())
1822+
1823+
def _download_stream(
1824+
self, stage_location: str, decompress: bool = False
1825+
) -> IO[bytes]:
1826+
"""Downloads from the stage location as a stream.
1827+
1828+
Args:
1829+
stage_location (str): The location of the stage to download from.
1830+
decompress (bool, optional): Whether to decompress the file, by
1831+
default we do not decompress.
1832+
1833+
Returns:
1834+
IO[bytes]: A stream to read from.
1835+
"""
1836+
# Interpret the file operation.
1837+
ret = self.connection._file_operation_parser.parse_file_operation(
1838+
stage_location=stage_location,
1839+
local_file_name=None,
1840+
target_directory=None,
1841+
command_type=CMD_TYPE_DOWNLOAD,
1842+
options=None,
1843+
has_source_from_stream=True,
1844+
)
1845+
1846+
# Set up stream downloading based on the interpretation and return the stream for reading.
1847+
return self.connection._stream_downloader.download_as_stream(ret, decompress)
1848+
1849+
def _upload_stream(
1850+
self,
1851+
input_stream: IO[bytes],
1852+
stage_location: str,
1853+
options: dict[str, Any],
1854+
_do_reset: bool = True,
1855+
) -> None:
1856+
"""Uploads content in the input stream to the stage location.
1857+
1858+
Args:
1859+
input_stream (IO[bytes]): A stream to read from.
1860+
stage_location (str): The location of the stage to upload to.
1861+
options (dict[str, Any]): The upload options.
1862+
_do_reset (bool, optional): Whether to reset the cursor before
1863+
uploading, by default we will reset the cursor.
1864+
"""
1865+
from .file_transfer_agent import SnowflakeFileTransferAgent
1866+
1867+
if _do_reset:
1868+
self.reset()
1869+
1870+
# Interpret the file operation.
1871+
ret = self.connection._file_operation_parser.parse_file_operation(
1872+
stage_location=stage_location,
1873+
local_file_name=None,
1874+
target_directory=None,
1875+
command_type=CMD_TYPE_UPLOAD,
1876+
options=options,
1877+
has_source_from_stream=input_stream,
1878+
)
1879+
1880+
# Execute the file operation based on the interpretation above.
1881+
file_transfer_agent = SnowflakeFileTransferAgent(
1882+
self,
1883+
"", # empty command because it is triggered by directly calling this util not by a SQL query
1884+
ret,
1885+
source_from_stream=input_stream,
1886+
)
1887+
file_transfer_agent.execute()
1888+
self._init_result_and_meta(file_transfer_agent.result())
1889+
17431890

17441891
class DictCursor(SnowflakeCursor):
17451892
"""Cursor returning results in a dictionary."""
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#
2+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
from __future__ import annotations
6+
7+
from abc import ABC, abstractmethod
8+
9+
10+
class FileOperationParserBase(ABC):
11+
"""The interface of internal utility functions for file operation parsing."""
12+
13+
@abstractmethod
14+
def __init__(self, connection):
15+
pass
16+
17+
@abstractmethod
18+
def parse_file_operation(
19+
self,
20+
stage_location,
21+
local_file_name,
22+
target_directory,
23+
command_type,
24+
options,
25+
has_source_from_stream=False,
26+
):
27+
"""Converts the file operation details into a SQL and returns the SQL parsing result."""
28+
pass
29+
30+
31+
class StreamDownloaderBase(ABC):
32+
"""The interface of internal utility functions for stream downloading of file."""
33+
34+
@abstractmethod
35+
def __init__(self, connection):
36+
pass
37+
38+
@abstractmethod
39+
def download_as_stream(self, ret, decompress=False):
40+
pass
41+
42+
43+
class FileOperationParser(FileOperationParserBase):
44+
def __init__(self, connection):
45+
pass
46+
47+
def parse_file_operation(
48+
self,
49+
stage_location,
50+
local_file_name,
51+
target_directory,
52+
command_type,
53+
options,
54+
has_source_from_stream=False,
55+
):
56+
raise NotImplementedError("parse_file_operation is not yet supported")
57+
58+
59+
class StreamDownloader(StreamDownloaderBase):
60+
def __init__(self, connection):
61+
pass
62+
63+
def download_as_stream(self, ret, decompress=False):
64+
raise NotImplementedError("download_as_stream is not yet supported")

test/integ/test_connection.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,3 +1628,12 @@ def test_no_auth_connection_negative_case():
16281628
# connection is not able to run any query
16291629
with pytest.raises(DatabaseError, match="Connection is closed"):
16301630
conn.execute_string("select 1")
1631+
1632+
1633+
# _file_operation_parser and _stream_downloader are newly introduced and
1634+
# therefore should not be tested on old drivers.
1635+
@pytest.mark.skipolddriver
1636+
def test_file_utils_sanity_check():
1637+
conn = create_connection("default")
1638+
assert hasattr(conn._file_operation_parser, "parse_file_operation")
1639+
assert hasattr(conn._stream_downloader, "download_as_stream")

test/unit/test_cursor.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import time
8+
from unittest import TestCase
89
from unittest.mock import MagicMock, patch
910

1011
import pytest
@@ -99,3 +100,96 @@ def mock_cmd_query(*args, **kwargs):
99100

100101
# query cancel request should be sent upon timeout
101102
assert mockCancelQuery.called
103+
104+
105+
# The _upload/_download/_upload_stream/_download_stream are newly introduced
106+
# and therefore should not be tested in old drivers.
107+
@pytest.mark.skipolddriver
108+
class TestUploadDownloadMethods(TestCase):
109+
"""Test the _upload/_download/_upload_stream/_download_stream methods."""
110+
111+
@patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent")
112+
def test_download(self, MockFileTransferAgent):
113+
cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks(
114+
MockFileTransferAgent
115+
)
116+
117+
# Call _download method
118+
cursor._download("@st", "/tmp/test.txt", {})
119+
120+
# In the process of _download execution, we expect these methods to be called
121+
# - parse_file_operation in connection._file_operation_parser
122+
# - execute in SnowflakeFileTransferAgent
123+
# And we do not expect this method to be involved
124+
# - download_as_stream of connection._stream_downloader
125+
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
126+
fake_conn._stream_downloader.download_as_stream.assert_not_called()
127+
mock_file_transfer_agent_instance.execute.assert_called_once()
128+
129+
@patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent")
130+
def test_upload(self, MockFileTransferAgent):
131+
cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks(
132+
MockFileTransferAgent
133+
)
134+
135+
# Call _upload method
136+
cursor._upload("/tmp/test.txt", "@st", {})
137+
138+
# In the process of _upload execution, we expect these methods to be called
139+
# - parse_file_operation in connection._file_operation_parser
140+
# - execute in SnowflakeFileTransferAgent
141+
# And we do not expect this method to be involved
142+
# - download_as_stream of connection._stream_downloader
143+
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
144+
fake_conn._stream_downloader.download_as_stream.assert_not_called()
145+
mock_file_transfer_agent_instance.execute.assert_called_once()
146+
147+
@patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent")
148+
def test_download_stream(self, MockFileTransferAgent):
149+
cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks(
150+
MockFileTransferAgent
151+
)
152+
153+
# Call _download_stream method
154+
cursor._download_stream("@st/test.txt", decompress=True)
155+
156+
# In the process of _download_stream execution, we expect these methods to be called
157+
# - parse_file_operation in connection._file_operation_parser
158+
# - download_as_stream of connection._stream_downloader
159+
# And we do not expect this method to be involved
160+
# - execute in SnowflakeFileTransferAgent
161+
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
162+
fake_conn._stream_downloader.download_as_stream.assert_called_once()
163+
mock_file_transfer_agent_instance.execute.assert_not_called()
164+
165+
@patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent")
166+
def test_upload_stream(self, MockFileTransferAgent):
167+
cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks(
168+
MockFileTransferAgent
169+
)
170+
171+
# Call _upload_stream method
172+
fd = MagicMock()
173+
cursor._upload_stream(fd, "@st/test.txt", {})
174+
175+
# In the process of _upload_stream execution, we expect these methods to be called
176+
# - parse_file_operation in connection._file_operation_parser
177+
# - execute in SnowflakeFileTransferAgent
178+
# And we do not expect this method to be involved
179+
# - download_as_stream of connection._stream_downloader
180+
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
181+
fake_conn._stream_downloader.download_as_stream.assert_not_called()
182+
mock_file_transfer_agent_instance.execute.assert_called_once()
183+
184+
def _setup_mocks(self, MockFileTransferAgent):
185+
mock_file_transfer_agent_instance = MockFileTransferAgent.return_value
186+
mock_file_transfer_agent_instance.execute.return_value = None
187+
188+
fake_conn = FakeConnection()
189+
fake_conn._file_operation_parser = MagicMock()
190+
fake_conn._stream_downloader = MagicMock()
191+
192+
cursor = SnowflakeCursor(fake_conn)
193+
cursor.reset = MagicMock()
194+
cursor._init_result_and_meta = MagicMock()
195+
return cursor, fake_conn, mock_file_transfer_agent_instance

0 commit comments

Comments
 (0)