Skip to content

Cherrypicks to aio connector part10 #2461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: cherrypicks-to-aio-connector-part9
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ python_requires = >=3.9
packages = find_namespace:
install_requires =
asn1crypto>0.24.0,<2.0.0
boto3>=1.0
botocore>=1.0
boto3>=1.24
botocore>=1.24
cffi>=1.9,<2.0.0
cryptography>=3.1.0
pyOpenSSL>=22.0.0,<25.0.0
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def build_extension(self, ext):
"FixedSizeListConverter.cpp",
"FloatConverter.cpp",
"IntConverter.cpp",
"IntervalConverter.cpp",
"MapConverter.cpp",
"ObjectConverter.cpp",
"SnowflakeType.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import os
from io import BytesIO
from logging import getLogger
from typing import TYPE_CHECKING, cast
Expand Down Expand Up @@ -56,8 +57,11 @@ async def upload(self) -> None:
if row_idx >= len(self.rows) or size >= self._stream_buffer_size:
break
try:
await self.cursor.execute(
f"PUT file://{row_idx}.csv {self.stage_path}", file_stream=f
f.seek(0)
await self.cursor._upload_stream(
input_stream=f,
stage_location=os.path.join(self.stage_path, f"{row_idx}.csv"),
options={"source_compression": "auto_detect"},
)
except Error as err:
logger.debug("Failed to upload the bindings file to stage.")
Expand Down
9 changes: 6 additions & 3 deletions src/snowflake/connector/aio/_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
ProgrammingError,
)
from snowflake.connector._sql_util import get_file_transfer_type
from snowflake.connector.aio._build_upload_agent import BindUploadAgent
from snowflake.connector.aio._bind_upload_agent import BindUploadAgent
from snowflake.connector.aio._result_batch import (
ResultBatch,
create_batches_from_response,
Expand Down Expand Up @@ -235,7 +235,10 @@ async def _execute_helper(
else:
# or detect it.
self._is_file_transfer = get_file_transfer_type(query) is not None
logger.debug("is_file_transfer: %s", self._is_file_transfer is not None)
logger.debug(
"is_file_transfer: %s",
self._is_file_transfer if self._is_file_transfer is not None else "None",
)

real_timeout = (
timeout if timeout and timeout > 0 else self._connection.network_timeout
Expand Down Expand Up @@ -800,7 +803,7 @@ async def executemany(
bind_stage = None
if (
bind_size
> self.connection._session_parameters[
>= self.connection._session_parameters[
"CLIENT_STAGE_ARRAY_BINDING_THRESHOLD"
]
> 0
Expand Down
34 changes: 31 additions & 3 deletions src/snowflake/connector/aio/_direct_file_operation_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from ._connection import SnowflakeConnection

import os
from abc import ABC, abstractmethod

from ..constants import CMD_TYPE_UPLOAD


class FileOperationParserBase(ABC):
"""The interface of internal utility functions for file operation parsing."""
Expand Down Expand Up @@ -37,8 +45,8 @@ async def download_as_stream(self, ret, decompress=False):


class FileOperationParser(FileOperationParserBase):
def __init__(self, connection):
pass
def __init__(self, connection: SnowflakeConnection):
self._connection = connection

async def parse_file_operation(
self,
Expand All @@ -49,7 +57,27 @@ async def parse_file_operation(
options,
has_source_from_stream=False,
):
raise NotImplementedError("parse_file_operation is not yet supported")
"""Parses a file operation by constructing SQL and getting the SQL parsing result from server."""
options = options or {}
options_in_sql = " ".join(f"{k}={v}" for k, v in options.items())

if command_type == CMD_TYPE_UPLOAD:
if has_source_from_stream:
stage_location, unprefixed_local_file_name = os.path.split(
stage_location
)
local_file_name = "file://" + unprefixed_local_file_name
sql = f"PUT {local_file_name} ? {options_in_sql}"
params = [stage_location]
else:
raise NotImplementedError(f"unsupported command type: {command_type}")

async with self._connection.cursor() as cursor:
# Send constructed SQL to server and get back parsing result.
processed_params = cursor._connection._process_params_qmarks(params, cursor)
return await cursor._execute_helper(
sql, binding_params=processed_params, is_internal=True
)


class StreamDownloader(StreamDownloaderBase):
Expand Down
40 changes: 39 additions & 1 deletion src/snowflake/connector/arrow_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .converter import _generate_tzinfo_from_tzoffset

if TYPE_CHECKING:
from numpy import datetime64, float64, int64
from numpy import datetime64, float64, int64, timedelta64


try:
Expand Down Expand Up @@ -163,3 +163,41 @@ def DECFLOAT_to_decimal(self, exponent: int, significand: bytes) -> decimal.Deci

def DECFLOAT_to_numpy_float64(self, exponent: int, significand: bytes) -> float64:
return numpy.float64(self.DECFLOAT_to_decimal(exponent, significand))

def INTERVAL_YEAR_MONTH_to_numpy_timedelta(self, months: int) -> timedelta64:
return numpy.timedelta64(months, "M")

def INTERVAL_DAY_TIME_int_to_numpy_timedelta(self, nanos: int) -> timedelta64:
return numpy.timedelta64(nanos, "ns")

def INTERVAL_DAY_TIME_int_to_timedelta(self, nanos: int) -> timedelta:
# Python timedelta only supports microsecond precision. We receive value in
# nanoseconds.
return timedelta(microseconds=nanos // 1000)

def INTERVAL_DAY_TIME_decimal_to_numpy_timedelta(self, value: bytes) -> timedelta64:
# Snowflake supports up to 9 digits leading field precision for the day-time
# interval. That when represented in nanoseconds can not be stored in a 64-bit
# integer. So we send these as Decimal128 from server to client.
# Arrow uses little-endian by default.
# https://arrow.apache.org/docs/format/Columnar.html#byte-order-endianness
nanos = int.from_bytes(value, byteorder="little", signed=True)
# Numpy timedelta only supports up to 64-bit integers, so we need to change the
# unit to milliseconds to avoid overflow.
# Max value received from server
# = 10**9 * NANOS_PER_DAY - 1
# = 86399999999999999999999 nanoseconds
# = 86399999999999999 milliseconds
# math.log2(86399999999999999) = 56.3 < 64
return numpy.timedelta64(nanos // 1_000_000, "ms")

def INTERVAL_DAY_TIME_decimal_to_timedelta(self, value: bytes) -> timedelta:
# Snowflake supports up to 9 digits leading field precision for the day-time
# interval. That when represented in nanoseconds can not be stored in a 64-bit
# integer. So we send these as Decimal128 from server to client.
# Arrow uses little-endian by default.
# https://arrow.apache.org/docs/format/Columnar.html#byte-order-endianness
nanos = int.from_bytes(value, byteorder="little", signed=True)
# Python timedelta only supports microsecond precision. We receive value in
# nanoseconds.
return timedelta(microseconds=nanos // 1000)
8 changes: 6 additions & 2 deletions src/snowflake/connector/bind_upload_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
from __future__ import annotations

import os
import uuid
from io import BytesIO
from logging import getLogger
Expand Down Expand Up @@ -76,8 +77,11 @@ def upload(self) -> None:
if row_idx >= len(self.rows) or size >= self._stream_buffer_size:
break
try:
self.cursor.execute(
f"PUT file://{row_idx}.csv {self.stage_path}", file_stream=f
f.seek(0)
self.cursor._upload_stream(
input_stream=f,
stage_location=os.path.join(self.stage_path, f"{row_idx}.csv"),
options={"source_compression": "auto_detect"},
)
except Error as err:
logger.debug("Failed to upload the bindings file to stage.")
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/connector/connection_diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def __check_for_proxies(self) -> None:
cert_reqs=cert_reqs,
)
resp = http.request(
"GET", "https://nonexistentdomain.invalidtld", timeout=10.0
"GET", "https://nonexistentdomain.invalid", timeout=10.0
)

# squid does not throw exception. Check HTML
Expand Down
7 changes: 5 additions & 2 deletions src/snowflake/connector/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,10 @@ def _execute_helper(
else:
# or detect it.
self._is_file_transfer = get_file_transfer_type(query) is not None
logger.debug("is_file_transfer: %s", self._is_file_transfer is not None)
logger.debug(
"is_file_transfer: %s",
self._is_file_transfer if self._is_file_transfer is not None else "None",
)

real_timeout = (
timeout if timeout and timeout > 0 else self._connection.network_timeout
Expand Down Expand Up @@ -1460,7 +1463,7 @@ def executemany(
bind_stage = None
if (
bind_size
> self.connection._session_parameters[
>= self.connection._session_parameters[
"CLIENT_STAGE_ARRAY_BINDING_THRESHOLD"
]
> 0
Expand Down
34 changes: 31 additions & 3 deletions src/snowflake/connector/direct_file_operation_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .connection import SnowflakeConnection

import os
from abc import ABC, abstractmethod

from .constants import CMD_TYPE_UPLOAD


class FileOperationParserBase(ABC):
"""The interface of internal utility functions for file operation parsing."""
Expand Down Expand Up @@ -37,8 +45,8 @@ def download_as_stream(self, ret, decompress=False):


class FileOperationParser(FileOperationParserBase):
def __init__(self, connection):
pass
def __init__(self, connection: SnowflakeConnection):
self._connection = connection

def parse_file_operation(
self,
Expand All @@ -49,7 +57,27 @@ def parse_file_operation(
options,
has_source_from_stream=False,
):
raise NotImplementedError("parse_file_operation is not yet supported")
"""Parses a file operation by constructing SQL and getting the SQL parsing result from server."""
options = options or {}
options_in_sql = " ".join(f"{k}={v}" for k, v in options.items())

if command_type == CMD_TYPE_UPLOAD:
if has_source_from_stream:
stage_location, unprefixed_local_file_name = os.path.split(
stage_location
)
local_file_name = "file://" + unprefixed_local_file_name
sql = f"PUT {local_file_name} ? {options_in_sql}"
params = [stage_location]
else:
raise NotImplementedError(f"unsupported command type: {command_type}")

with self._connection.cursor() as cursor:
# Send constructed SQL to server and get back parsing result.
processed_params = cursor._connection._process_params_qmarks(params, cursor)
return cursor._execute_helper(
sql, binding_params=processed_params, is_internal=True
)


class StreamDownloader(StreamDownloaderBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "FixedSizeListConverter.hpp"
#include "FloatConverter.hpp"
#include "IntConverter.hpp"
#include "IntervalConverter.hpp"
#include "MapConverter.hpp"
#include "ObjectConverter.hpp"
#include "StringConverter.hpp"
Expand Down Expand Up @@ -479,6 +480,36 @@ std::shared_ptr<sf::IColumnConverter> getConverterFromSchema(
break;
}

case SnowflakeType::Type::INTERVAL_YEAR_MONTH: {
converter = std::make_shared<sf::IntervalYearMonthConverter>(
array, context, useNumpy);
break;
}

case SnowflakeType::Type::INTERVAL_DAY_TIME: {
switch (schemaView.type) {
case NANOARROW_TYPE_INT64:
converter = std::make_shared<sf::IntervalDayTimeConverterInt>(
array, context, useNumpy);
break;
case NANOARROW_TYPE_DECIMAL128:
converter = std::make_shared<sf::IntervalDayTimeConverterDecimal>(
array, context, useNumpy);
break;
default: {
std::string errorInfo = Logger::formatString(
"[Snowflake Exception] unknown arrow internal data type(%d) "
"for OBJECT data in %s",
NANOARROW_TYPE_ENUM_STRING[schemaView.type],
schemaView.schema->name);
logger->error(__FILE__, __func__, __LINE__, errorInfo.c_str());
PyErr_SetString(PyExc_Exception, errorInfo.c_str());
break;
}
}
break;
}

default: {
std::string errorInfo = Logger::formatString(
"[Snowflake Exception] unknown snowflake data type : %d", st);
Expand Down
Loading
Loading