Skip to content

Commit 0208d11

Browse files
[Async] Apply #2303 to async code
# Conflicts: # src/snowflake/connector/aio/_direct_file_operation_utils.py
1 parent 9833870 commit 0208d11

File tree

5 files changed

+162
-11
lines changed

5 files changed

+162
-11
lines changed

src/snowflake/connector/aio/_build_upload_agent.py renamed to src/snowflake/connector/aio/_bind_upload_agent.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from __future__ import annotations
55

6+
import os
67
from io import BytesIO
78
from logging import getLogger
89
from typing import TYPE_CHECKING, cast
@@ -56,8 +57,11 @@ async def upload(self) -> None:
5657
if row_idx >= len(self.rows) or size >= self._stream_buffer_size:
5758
break
5859
try:
59-
await self.cursor.execute(
60-
f"PUT file://{row_idx}.csv {self.stage_path}", file_stream=f
60+
f.seek(0)
61+
await self.cursor._upload_stream(
62+
input_stream=f,
63+
stage_location=os.path.join(self.stage_path, f"{row_idx}.csv"),
64+
options={"source_compression": "auto_detect"},
6165
)
6266
except Error as err:
6367
logger.debug("Failed to upload the bindings file to stage.")

src/snowflake/connector/aio/_cursor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
ProgrammingError,
2424
)
2525
from snowflake.connector._sql_util import get_file_transfer_type
26-
from snowflake.connector.aio._build_upload_agent import BindUploadAgent
26+
from snowflake.connector.aio._bind_upload_agent import BindUploadAgent
2727
from snowflake.connector.aio._result_batch import (
2828
ResultBatch,
2929
create_batches_from_response,
@@ -803,7 +803,7 @@ async def executemany(
803803
bind_stage = None
804804
if (
805805
bind_size
806-
> self.connection._session_parameters[
806+
>= self.connection._session_parameters[
807807
"CLIENT_STAGE_ARRAY_BINDING_THRESHOLD"
808808
]
809809
> 0

src/snowflake/connector/aio/_direct_file_operation_utils.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from ._connection import SnowflakeConnection
7+
8+
import os
39
from abc import ABC, abstractmethod
410

11+
from ..constants import CMD_TYPE_UPLOAD
12+
513

614
class FileOperationParserBase(ABC):
715
"""The interface of internal utility functions for file operation parsing."""
@@ -37,8 +45,8 @@ async def download_as_stream(self, ret, decompress=False):
3745

3846

3947
class FileOperationParser(FileOperationParserBase):
40-
def __init__(self, connection):
41-
pass
48+
def __init__(self, connection: SnowflakeConnection):
49+
self._connection = connection
4250

4351
async def parse_file_operation(
4452
self,
@@ -49,7 +57,27 @@ async def parse_file_operation(
4957
options,
5058
has_source_from_stream=False,
5159
):
52-
raise NotImplementedError("parse_file_operation is not yet supported")
60+
"""Parses a file operation by constructing SQL and getting the SQL parsing result from server."""
61+
options = options or {}
62+
options_in_sql = " ".join(f"{k}={v}" for k, v in options.items())
63+
64+
if command_type == CMD_TYPE_UPLOAD:
65+
if has_source_from_stream:
66+
stage_location, unprefixed_local_file_name = os.path.split(
67+
stage_location
68+
)
69+
local_file_name = "file://" + unprefixed_local_file_name
70+
sql = f"PUT {local_file_name} ? {options_in_sql}"
71+
params = [stage_location]
72+
else:
73+
raise NotImplementedError(f"unsupported command type: {command_type}")
74+
75+
async with self._connection.cursor() as cursor:
76+
# Send constructed SQL to server and get back parsing result.
77+
processed_params = cursor._connection._process_params_qmarks(params, cursor)
78+
return await cursor._execute_helper(
79+
sql, binding_params=processed_params, is_internal=True
80+
)
5381

5482

5583
class StreamDownloader(StreamDownloaderBase):
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#!/usr/bin/env python
2+
from __future__ import annotations
3+
4+
import os
5+
from tempfile import TemporaryDirectory
6+
from typing import TYPE_CHECKING, AsyncGenerator, Callable, Coroutine
7+
8+
import pytest
9+
10+
try:
11+
from snowflake.connector.options import pandas
12+
from snowflake.connector.pandas_tools import (
13+
_iceberg_config_statement_helper,
14+
write_pandas,
15+
)
16+
except ImportError:
17+
pandas = None
18+
write_pandas = None
19+
_iceberg_config_statement_helper = None
20+
21+
if TYPE_CHECKING:
22+
from snowflake.connector.aio import SnowflakeConnection, SnowflakeCursor
23+
24+
from ..test_direct_file_operation_utils import _normalize_windows_local_path
25+
26+
27+
async def _validate_upload_content(
28+
expected_content, cursor, stage_name, local_dir, base_file_name, is_compressed
29+
):
30+
gz_suffix = ".gz"
31+
stage_path = f"@{stage_name}/{base_file_name}"
32+
local_path = os.path.join(local_dir, base_file_name)
33+
34+
await cursor.execute(
35+
f"GET {stage_path} 'file://{_normalize_windows_local_path(local_dir)}'",
36+
)
37+
if is_compressed:
38+
stage_path += gz_suffix
39+
local_path += gz_suffix
40+
import gzip
41+
42+
with gzip.open(local_path, "r") as f:
43+
read_content = f.read().decode("utf-8")
44+
assert read_content == expected_content, (read_content, expected_content)
45+
else:
46+
with open(local_path) as f:
47+
read_content = f.read()
48+
assert read_content == expected_content, (read_content, expected_content)
49+
50+
51+
async def _test_runner(
52+
conn_cnx: Callable[..., AsyncGenerator[SnowflakeConnection]],
53+
task: Callable[[SnowflakeCursor, str, str, str], Coroutine[None, None, None]],
54+
is_compressed: bool,
55+
special_stage_name: str = None,
56+
special_base_file_name: str = None,
57+
):
58+
from snowflake.connector._utils import TempObjectType, random_name_for_temp_object
59+
60+
async with conn_cnx() as conn:
61+
cursor = conn.cursor()
62+
stage_name = special_stage_name or random_name_for_temp_object(
63+
TempObjectType.STAGE
64+
)
65+
await cursor.execute(f"CREATE OR REPLACE SCOPED TEMP STAGE {stage_name}")
66+
expected_content = "hello, world"
67+
with TemporaryDirectory() as temp_dir:
68+
base_file_name = special_base_file_name or "test.txt"
69+
src_file_name = os.path.join(temp_dir, base_file_name)
70+
with open(src_file_name, "w") as f:
71+
f.write(expected_content)
72+
# Run the file operation
73+
await task(cursor, stage_name, temp_dir, base_file_name)
74+
# Clean up before validation.
75+
os.remove(src_file_name)
76+
# Validate result.
77+
await _validate_upload_content(
78+
expected_content,
79+
cursor,
80+
stage_name,
81+
temp_dir,
82+
base_file_name,
83+
is_compressed=is_compressed,
84+
)
85+
86+
87+
@pytest.mark.skipolddriver
88+
@pytest.mark.parametrize("is_compressed", [False, True])
89+
async def test_upload(
90+
conn_cnx: Callable[..., AsyncGenerator[SnowflakeConnection]],
91+
is_compressed: bool,
92+
):
93+
async def upload_task(cursor, stage_name, temp_dir, base_file_name):
94+
await cursor._upload(
95+
local_file_name=f"'file://{_normalize_windows_local_path(os.path.join(temp_dir, base_file_name))}'",
96+
stage_location=f"@{stage_name}",
97+
options={"auto_compress": is_compressed},
98+
)
99+
100+
await _test_runner(conn_cnx, upload_task, is_compressed=is_compressed)
101+
102+
103+
@pytest.mark.skipolddriver
104+
@pytest.mark.parametrize("is_compressed", [False, True])
105+
async def test_upload_stream(
106+
conn_cnx: Callable[..., AsyncGenerator[SnowflakeConnection]],
107+
is_compressed: bool,
108+
):
109+
async def upload_stream_task(cursor, stage_name, temp_dir, base_file_name):
110+
with open(f"{os.path.join(temp_dir, base_file_name)}", "rb") as input_stream:
111+
await cursor._upload_stream(
112+
input_stream=input_stream,
113+
stage_location=f"@{os.path.join(stage_name, base_file_name)}",
114+
options={"auto_compress": is_compressed},
115+
)
116+
117+
await _test_runner(conn_cnx, upload_stream_task, is_compressed=is_compressed)

test/unit/aio/test_bind_upload_agent_async.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,22 @@
99

1010

1111
async def test_bind_upload_agent_uploading_multiple_files():
12-
from snowflake.connector.aio._build_upload_agent import BindUploadAgent
12+
from snowflake.connector.aio._bind_upload_agent import BindUploadAgent
1313

1414
csr = AsyncMock(auto_spec=True)
1515
rows = [bytes(10)] * 10
1616
agent = BindUploadAgent(csr, rows, stream_buffer_size=10)
1717
await agent.upload()
18-
assert csr.execute.call_count == 11 # 1 for stage creation + 10 files
18+
assert csr.execute.call_count == 1 # 1 for stage creation
19+
assert csr._upload_stream.call_count == 10 # 10 for 10 files
1920

2021

2122
async def test_bind_upload_agent_row_size_exceed_buffer_size():
22-
from snowflake.connector.aio._build_upload_agent import BindUploadAgent
23+
from snowflake.connector.aio._bind_upload_agent import BindUploadAgent
2324

2425
csr = AsyncMock(auto_spec=True)
2526
rows = [bytes(15)] * 10
2627
agent = BindUploadAgent(csr, rows, stream_buffer_size=10)
2728
await agent.upload()
28-
assert csr.execute.call_count == 11 # 1 for stage creation + 10 files
29+
assert csr.execute.call_count == 1 # 1 for stage creation
30+
assert csr._upload_stream.call_count == 10 # 10 for 10 files

0 commit comments

Comments
 (0)