Skip to content

Commit 9833870

Browse files
sfc-gh-zyaosfc-gh-fpawlowski
authored andcommitted
SNOW-2057867 refactor BindUploadAgent to make it work for Python sprocs (#2303)
(cherry picked from commit 0d79989)
1 parent e845d3b commit 9833870

File tree

5 files changed

+161
-8
lines changed

5 files changed

+161
-8
lines changed

src/snowflake/connector/bind_upload_agent.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python
22
from __future__ import annotations
33

4+
import os
45
import uuid
56
from io import BytesIO
67
from logging import getLogger
@@ -76,8 +77,11 @@ def upload(self) -> None:
7677
if row_idx >= len(self.rows) or size >= self._stream_buffer_size:
7778
break
7879
try:
79-
self.cursor.execute(
80-
f"PUT file://{row_idx}.csv {self.stage_path}", file_stream=f
80+
f.seek(0)
81+
self.cursor._upload_stream(
82+
input_stream=f,
83+
stage_location=os.path.join(self.stage_path, f"{row_idx}.csv"),
84+
options={"source_compression": "auto_detect"},
8185
)
8286
except Error as err:
8387
logger.debug("Failed to upload the bindings file to stage.")

src/snowflake/connector/cursor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1463,7 +1463,7 @@ def executemany(
14631463
bind_stage = None
14641464
if (
14651465
bind_size
1466-
> self.connection._session_parameters[
1466+
>= self.connection._session_parameters[
14671467
"CLIENT_STAGE_ARRAY_BINDING_THRESHOLD"
14681468
]
14691469
> 0

src/snowflake/connector/direct_file_operation_utils.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from .connection import SnowflakeConnection
7+
8+
import os
39
from abc import ABC, abstractmethod
410

11+
from .constants import CMD_TYPE_UPLOAD
12+
513

614
class FileOperationParserBase(ABC):
715
"""The interface of internal utility functions for file operation parsing."""
@@ -37,8 +45,8 @@ def download_as_stream(self, ret, decompress=False):
3745

3846

3947
class FileOperationParser(FileOperationParserBase):
40-
def __init__(self, connection):
41-
pass
48+
def __init__(self, connection: SnowflakeConnection):
49+
self._connection = connection
4250

4351
def parse_file_operation(
4452
self,
@@ -49,7 +57,27 @@ def parse_file_operation(
4957
options,
5058
has_source_from_stream=False,
5159
):
52-
raise NotImplementedError("parse_file_operation is not yet supported")
60+
"""Parses a file operation by constructing SQL and getting the SQL parsing result from server."""
61+
options = options or {}
62+
options_in_sql = " ".join(f"{k}={v}" for k, v in options.items())
63+
64+
if command_type == CMD_TYPE_UPLOAD:
65+
if has_source_from_stream:
66+
stage_location, unprefixed_local_file_name = os.path.split(
67+
stage_location
68+
)
69+
local_file_name = "file://" + unprefixed_local_file_name
70+
sql = f"PUT {local_file_name} ? {options_in_sql}"
71+
params = [stage_location]
72+
else:
73+
raise NotImplementedError(f"unsupported command type: {command_type}")
74+
75+
with self._connection.cursor() as cursor:
76+
# Send constructed SQL to server and get back parsing result.
77+
processed_params = cursor._connection._process_params_qmarks(params, cursor)
78+
return cursor._execute_helper(
79+
sql, binding_params=processed_params, is_internal=True
80+
)
5381

5482

5583
class StreamDownloader(StreamDownloaderBase):
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#!/usr/bin/env python
2+
from __future__ import annotations
3+
4+
import os
5+
from tempfile import TemporaryDirectory
6+
from typing import TYPE_CHECKING, Callable, Generator
7+
8+
import pytest
9+
10+
try:
11+
from snowflake.connector.options import pandas
12+
from snowflake.connector.pandas_tools import (
13+
_iceberg_config_statement_helper,
14+
write_pandas,
15+
)
16+
except ImportError:
17+
pandas = None
18+
write_pandas = None
19+
_iceberg_config_statement_helper = None
20+
21+
if TYPE_CHECKING:
22+
from snowflake.connector import SnowflakeConnection, SnowflakeCursor
23+
24+
25+
def _normalize_windows_local_path(path):
26+
return path.replace("\\", "\\\\").replace("'", "\\'")
27+
28+
29+
def _validate_upload_content(
30+
expected_content, cursor, stage_name, local_dir, base_file_name, is_compressed
31+
):
32+
gz_suffix = ".gz"
33+
stage_path = f"@{stage_name}/{base_file_name}"
34+
local_path = os.path.join(local_dir, base_file_name)
35+
36+
cursor.execute(
37+
f"GET {stage_path} 'file://{_normalize_windows_local_path(local_dir)}'",
38+
)
39+
if is_compressed:
40+
stage_path += gz_suffix
41+
local_path += gz_suffix
42+
import gzip
43+
44+
with gzip.open(local_path, "r") as f:
45+
read_content = f.read().decode("utf-8")
46+
assert read_content == expected_content, (read_content, expected_content)
47+
else:
48+
with open(local_path) as f:
49+
read_content = f.read()
50+
assert read_content == expected_content, (read_content, expected_content)
51+
52+
53+
def _test_runner(
54+
conn_cnx: Callable[..., Generator[SnowflakeConnection]],
55+
task: Callable[[SnowflakeCursor, str, str, str], None],
56+
is_compressed: bool,
57+
special_stage_name: str = None,
58+
special_base_file_name: str = None,
59+
):
60+
from snowflake.connector._utils import TempObjectType, random_name_for_temp_object
61+
62+
with conn_cnx() as conn:
63+
cursor = conn.cursor()
64+
stage_name = special_stage_name or random_name_for_temp_object(
65+
TempObjectType.STAGE
66+
)
67+
cursor.execute(f"CREATE OR REPLACE SCOPED TEMP STAGE {stage_name}")
68+
expected_content = "hello, world"
69+
with TemporaryDirectory() as temp_dir:
70+
base_file_name = special_base_file_name or "test.txt"
71+
src_file_name = os.path.join(temp_dir, base_file_name)
72+
with open(src_file_name, "w") as f:
73+
f.write(expected_content)
74+
# Run the file operation
75+
task(cursor, stage_name, temp_dir, base_file_name)
76+
# Clean up before validation.
77+
os.remove(src_file_name)
78+
# Validate result.
79+
_validate_upload_content(
80+
expected_content,
81+
cursor,
82+
stage_name,
83+
temp_dir,
84+
base_file_name,
85+
is_compressed=is_compressed,
86+
)
87+
88+
89+
@pytest.mark.skipolddriver
90+
@pytest.mark.parametrize("is_compressed", [False, True])
91+
def test_upload(
92+
conn_cnx: Callable[..., Generator[SnowflakeConnection]],
93+
is_compressed: bool,
94+
):
95+
def upload_task(cursor, stage_name, temp_dir, base_file_name):
96+
cursor._upload(
97+
local_file_name=f"'file://{_normalize_windows_local_path(os.path.join(temp_dir, base_file_name))}'",
98+
stage_location=f"@{stage_name}",
99+
options={"auto_compress": is_compressed},
100+
)
101+
102+
_test_runner(conn_cnx, upload_task, is_compressed=is_compressed)
103+
104+
105+
@pytest.mark.skipolddriver
106+
@pytest.mark.parametrize("is_compressed", [False, True])
107+
def test_upload_stream(
108+
conn_cnx: Callable[..., Generator[SnowflakeConnection]],
109+
is_compressed: bool,
110+
):
111+
def upload_stream_task(cursor, stage_name, temp_dir, base_file_name):
112+
with open(f"{os.path.join(temp_dir, base_file_name)}", "rb") as input_stream:
113+
cursor._upload_stream(
114+
input_stream=input_stream,
115+
stage_location=f"@{os.path.join(stage_name, base_file_name)}",
116+
options={"auto_compress": is_compressed},
117+
)
118+
119+
_test_runner(conn_cnx, upload_stream_task, is_compressed=is_compressed)

test/unit/test_bind_upload_agent.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ def test_bind_upload_agent_uploading_multiple_files():
1212
rows = [bytes(10)] * 10
1313
agent = BindUploadAgent(csr, rows, stream_buffer_size=10)
1414
agent.upload()
15-
assert csr.execute.call_count == 11 # 1 for stage creation + 10 files
15+
assert csr.execute.call_count == 1 # 1 for stage creation
16+
assert csr._upload_stream.call_count == 10 # 10 for 10 files
1617

1718

1819
def test_bind_upload_agent_row_size_exceed_buffer_size():
@@ -22,7 +23,8 @@ def test_bind_upload_agent_row_size_exceed_buffer_size():
2223
rows = [bytes(15)] * 10
2324
agent = BindUploadAgent(csr, rows, stream_buffer_size=10)
2425
agent.upload()
25-
assert csr.execute.call_count == 11 # 1 for stage creation + 10 files
26+
assert csr.execute.call_count == 1 # 1 for stage creation
27+
assert csr._upload_stream.call_count == 10 # 10 for 10 files
2628

2729

2830
def test_bind_upload_agent_scoped_temp_object():

0 commit comments

Comments
 (0)