Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/snowflake/connector/bind_upload_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
from __future__ import annotations

import os
import uuid
from io import BytesIO
from logging import getLogger
Expand Down Expand Up @@ -76,8 +77,11 @@ def upload(self) -> None:
if row_idx >= len(self.rows) or size >= self._stream_buffer_size:
break
try:
self.cursor.execute(
f"PUT file://{row_idx}.csv {self.stage_path}", file_stream=f
f.seek(0)
self.cursor._upload_stream(
input_stream=f,
stage_location=os.path.join(self.stage_path, f"{row_idx}.csv"),
options={},
)
except Error as err:
logger.debug("Failed to upload the bindings file to stage.")
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/connector/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,7 +1459,7 @@ def executemany(
bind_stage = None
if (
bind_size
> self.connection._session_parameters[
>= self.connection._session_parameters[
"CLIENT_STAGE_ARRAY_BINDING_THRESHOLD"
]
> 0
Expand Down
34 changes: 31 additions & 3 deletions src/snowflake/connector/direct_file_operation_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .connection import SnowflakeConnection

import os
from abc import ABC, abstractmethod

from .constants import CMD_TYPE_UPLOAD


class FileOperationParserBase(ABC):
"""The interface of internal utility functions for file operation parsing."""
Expand Down Expand Up @@ -37,8 +45,8 @@ def download_as_stream(self, ret, decompress=False):


class FileOperationParser(FileOperationParserBase):
def __init__(self, connection):
pass
def __init__(self, connection: SnowflakeConnection):
self._connection = connection

def parse_file_operation(
self,
Expand All @@ -49,7 +57,27 @@ def parse_file_operation(
options,
has_source_from_stream=False,
):
raise NotImplementedError("parse_file_operation is not yet supported")
"""Parses a file operation by constructing SQL and getting the SQL parsing result from server."""
options = options or {}
options_in_sql = " ".join(f"{k}={v}" for k, v in options.items())

if command_type == CMD_TYPE_UPLOAD:
if has_source_from_stream:
stage_location, unprefixed_local_file_name = os.path.split(
stage_location
)
local_file_name = "file://" + unprefixed_local_file_name
sql = f"PUT {local_file_name} ? {options_in_sql}"
params = [stage_location]
else:
raise NotImplementedError(f"unsupported command type: {command_type}")

with self._connection.cursor() as cursor:
# Send constructed SQL to server and get back parsing result.
processed_params = cursor._connection._process_params_qmarks(params, cursor)
return cursor._execute_helper(
sql, binding_params=processed_params, is_internal=True
)


class StreamDownloader(StreamDownloaderBase):
Expand Down
113 changes: 113 additions & 0 deletions test/integ/test_direct_file_operation_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#!/usr/bin/env python
from __future__ import annotations

import os
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Callable, Generator

import pytest

from snowflake.connector._utils import TempObjectType, random_name_for_temp_object

try:
from snowflake.connector.options import pandas
from snowflake.connector.pandas_tools import (
_iceberg_config_statement_helper,
write_pandas,
)
except ImportError:
pandas = None
write_pandas = None
_iceberg_config_statement_helper = None

if TYPE_CHECKING:
from snowflake.connector import SnowflakeConnection, SnowflakeCursor


def _validate_upload_content(
expected_content, cursor, stage_name, local_dir, base_file_name, is_compressed
):
gz_suffix = ".gz"
stage_path = f"@{stage_name}/{base_file_name}"
local_path = f"{local_dir}/{base_file_name}"

cursor.execute(
f"GET ? 'file://{local_dir}'", params=[stage_path], _force_qmark_paramstyle=True
)
if is_compressed:
stage_path += gz_suffix
local_path += gz_suffix
import gzip

with gzip.open(local_path, "r") as f:
read_content = f.read().decode("utf-8")
assert read_content == expected_content, (read_content, expected_content)
else:
with open(local_path) as f:
read_content = f.read()
assert read_content == expected_content, (read_content, expected_content)


def _test_runner(
conn_cnx: Callable[..., Generator[SnowflakeConnection]],
task: Callable[[SnowflakeCursor, str, str, str], None],
is_compressed: bool,
special_stage_name: str = None,
special_base_file_name: str = None,
):
with conn_cnx() as conn:
cursor = conn.cursor()
stage_name = special_stage_name or random_name_for_temp_object(
TempObjectType.STAGE
)
cursor.execute(f"CREATE OR REPLACE SCOPED TEMP STAGE {stage_name}")
expected_content = "hello, world"
with TemporaryDirectory() as temp_dir:
base_file_name = special_base_file_name or "test.txt"
src_file_name = os.path.join(temp_dir, base_file_name)
with open(src_file_name, "w") as f:
f.write(expected_content)
# Run the file operation
task(cursor, stage_name, temp_dir, base_file_name)
# Clean up before validation.
os.remove(src_file_name)
# Validate result.
_validate_upload_content(
expected_content,
cursor,
stage_name,
temp_dir,
base_file_name,
is_compressed=is_compressed,
)


@pytest.mark.parametrize("is_compressed", [False, True])
def test_upload(
conn_cnx: Callable[..., Generator[SnowflakeConnection]],
is_compressed: bool,
):
def upload_task(cursor, stage_name, temp_dir, base_file_name):
cursor._upload(
local_file_name=f"file://{temp_dir}/{base_file_name}",
stage_location=f"@{stage_name}",
options={"auto_compress": is_compressed},
)

_test_runner(conn_cnx, upload_task, is_compressed=is_compressed)


@pytest.mark.parametrize("is_compressed", [False, True])
def test_upload_stream(
conn_cnx: Callable[..., Generator[SnowflakeConnection]],
is_compressed: bool,
):
def upload_stream_task(cursor, stage_name, temp_dir, base_file_name):
with open(f"{temp_dir}/{base_file_name}", "rb") as input_stream:
cursor._upload_stream(
input_stream=input_stream,
stage_location=f"@{stage_name}/{base_file_name}",
options={"auto_compress": is_compressed},
)

_test_runner(conn_cnx, upload_stream_task, is_compressed=is_compressed)
6 changes: 4 additions & 2 deletions test/unit/test_bind_upload_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def test_bind_upload_agent_uploading_multiple_files():
rows = [bytes(10)] * 10
agent = BindUploadAgent(csr, rows, stream_buffer_size=10)
agent.upload()
assert csr.execute.call_count == 11 # 1 for stage creation + 10 files
assert csr.execute.call_count == 1 # 1 for stage creation
assert csr._upload_stream.call_count == 10 # 10 for 10 files


def test_bind_upload_agent_row_size_exceed_buffer_size():
Expand All @@ -22,7 +23,8 @@ def test_bind_upload_agent_row_size_exceed_buffer_size():
rows = [bytes(15)] * 10
agent = BindUploadAgent(csr, rows, stream_buffer_size=10)
agent.upload()
assert csr.execute.call_count == 11 # 1 for stage creation + 10 files
assert csr.execute.call_count == 1 # 1 for stage creation
assert csr._upload_stream.call_count == 10 # 10 for 10 files


def test_bind_upload_agent_scoped_temp_object():
Expand Down
Loading