Skip to content

Cherrypicks to aio connector part11 #2462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: cherrypicks-to-aio-connector-part10
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ install_requires =
boto3>=1.24
botocore>=1.24
cffi>=1.9,<2.0.0
cryptography>=3.1.0
cryptography>=3.1.0,<=44.0.3
pyOpenSSL>=22.0.0,<25.0.0
pyjwt<3.0.0
pytz
Expand Down
7 changes: 4 additions & 3 deletions src/snowflake/connector/aio/auth/_webbrowser.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,19 @@ async def prepare(
socket_connection.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)

try:
hostname = os.getenv("SF_AUTH_SOCKET_ADDR", "localhost")
try:
socket_connection.bind(
(
os.getenv("SF_AUTH_SOCKET_ADDR", "localhost"),
hostname,
int(os.getenv("SF_AUTH_SOCKET_PORT", 0)),
)
)
except socket.gaierror as ex:
if ex.args[0] == socket.EAI_NONAME:
raise OperationalError(
msg="localhost is not found. Ensure /etc/hosts has "
"localhost entry.",
msg=f"{hostname} is not found. Ensure /etc/hosts has "
f"{hostname} entry.",
errno=ER_NO_HOSTNAME_FOUND,
)
else:
Expand Down
7 changes: 4 additions & 3 deletions src/snowflake/connector/auth/webbrowser.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,19 @@ def prepare(
socket_connection.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)

try:
hostname = os.getenv("SF_AUTH_SOCKET_ADDR", "localhost")
try:
socket_connection.bind(
(
os.getenv("SF_AUTH_SOCKET_ADDR", "localhost"),
hostname,
int(os.getenv("SF_AUTH_SOCKET_PORT", 0)),
)
)
except socket.gaierror as ex:
if ex.args[0] == socket.EAI_NONAME:
raise OperationalError(
msg="localhost is not found. Ensure /etc/hosts has "
"localhost entry.",
msg=f"{hostname} is not found. Ensure /etc/hosts has "
f"{hostname} entry.",
errno=ER_NO_HOSTNAME_FOUND,
)
else:
Expand Down
21 changes: 17 additions & 4 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,10 +459,9 @@ def __init__(
is_kwargs_empty = not kwargs

if "application" not in kwargs:
if ENV_VAR_PARTNER in os.environ.keys():
kwargs["application"] = os.environ[ENV_VAR_PARTNER]
elif "streamlit" in sys.modules:
kwargs["application"] = "streamlit"
app = self._detect_application()
if app:
kwargs["application"] = app

if "insecure_mode" in kwargs:
warn_message = "The 'insecure_mode' connection property is deprecated. Please use 'disable_ocsp_checks' instead"
Expand Down Expand Up @@ -2146,3 +2145,17 @@ def is_valid(self) -> bool:
except Exception as e:
logger.debug("session could not be validated due to exception: %s", e)
return False

@staticmethod
def _detect_application() -> None | str:
if ENV_VAR_PARTNER in os.environ.keys():
return os.environ[ENV_VAR_PARTNER]
if "streamlit" in sys.modules:
return "streamlit"
if all(
(jpmod in sys.modules)
for jpmod in ("ipykernel", "jupyter_core", "jupyter_client")
):
return "jupyter_notebook"
if "snowbooks" in sys.modules:
return "snowflake_notebook"
63 changes: 33 additions & 30 deletions src/snowflake/connector/pandas_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from snowflake.connector import ProgrammingError
from snowflake.connector.options import pandas
from snowflake.connector.telemetry import TelemetryData, TelemetryField
from snowflake.connector.util_text import random_string

from ._utils import (
_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING,
Expand Down Expand Up @@ -108,11 +107,7 @@ def _create_temp_stage(
overwrite: bool,
use_scoped_temp_object: bool = False,
) -> str:
stage_name = (
random_name_for_temp_object(TempObjectType.STAGE)
if use_scoped_temp_object
else random_string()
)
stage_name = random_name_for_temp_object(TempObjectType.STAGE)
stage_location = build_location_helper(
database=database,
schema=schema,
Expand Down Expand Up @@ -179,11 +174,7 @@ def _create_temp_file_format(
sql_use_logical_type: str,
use_scoped_temp_object: bool = False,
) -> str:
file_format_name = (
random_name_for_temp_object(TempObjectType.FILE_FORMAT)
if use_scoped_temp_object
else random_string()
)
file_format_name = random_name_for_temp_object(TempObjectType.FILE_FORMAT)
file_format_location = build_location_helper(
database=database,
schema=schema,
Expand Down Expand Up @@ -269,6 +260,7 @@ def write_pandas(
table_type: Literal["", "temp", "temporary", "transient"] = "",
use_logical_type: bool | None = None,
iceberg_config: dict[str, str] | None = None,
bulk_upload_chunks: bool = False,
**kwargs: Any,
) -> tuple[
bool,
Expand Down Expand Up @@ -340,6 +332,8 @@ def write_pandas(
* base_location: the base directory that snowflake can write iceberg metadata and files to
* catalog_sync: optionally sets the catalog integration configured for Polaris Catalog
* storage_serialization_policy: specifies the storage serialization policy for the table
bulk_upload_chunks: If set to True, the upload will use the wildcard upload method.
This is a faster method of uploading but instead of uploading and cleaning up each chunk separately it will upload all chunks at once and then clean up locally stored chunks.



Expand Down Expand Up @@ -388,6 +382,10 @@ def write_pandas(
"Unsupported table type. Expected table types: temp/temporary, transient"
)

if table_type.lower() in ["temp", "temporary"]:
# Add scoped keyword when applicable.
table_type = get_temp_type_for_object(_use_scoped_temp_object).lower()

if chunk_size is None:
chunk_size = len(df)

Expand Down Expand Up @@ -442,25 +440,26 @@ def write_pandas(
chunk_path = os.path.join(tmp_folder, f"file{i}.txt")
# Dump chunk into parquet file
chunk.to_parquet(chunk_path, compression=compression, **kwargs)
# Upload parquet file
upload_sql = (
"PUT /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
"'file://{path}' ? PARALLEL={parallel}"
).format(
path=chunk_path.replace("\\", "\\\\").replace("'", "\\'"),
parallel=parallel,
)
params = ("@" + stage_location,)
logger.debug(f"uploading files with '{upload_sql}', params: %s", params)
cursor.execute(
upload_sql,
_is_internal=True,
_force_qmark_paramstyle=True,
params=params,
num_statements=1,
if not bulk_upload_chunks:
# Upload parquet file chunk right away
path = chunk_path.replace("\\", "\\\\").replace("'", "\\'")
cursor._upload(
local_file_name=f"'file://{path}'",
stage_location="@" + stage_location,
options={"parallel": parallel, "source_compression": "auto_detect"},
)

# Remove chunk file
os.remove(chunk_path)

if bulk_upload_chunks:
# Upload tmp directory with parquet chunks
path = tmp_folder.replace("\\", "\\\\").replace("'", "\\'")
cursor._upload(
local_file_name=f"'file://{path}/*'",
stage_location="@" + stage_location,
options={"parallel": parallel, "source_compression": "auto_detect"},
)
# Remove chunk file
os.remove(chunk_path)

# in Snowflake, all parquet data is stored in a single column, $1, so we must select columns explicitly
# see (https://docs.snowflake.com/en/user-guide/script-data-load-transform-parquet.html)
Expand Down Expand Up @@ -522,7 +521,11 @@ def drop_object(name: str, object_type: str) -> None:
target_table_location = build_location_helper(
database,
schema,
random_string() if (overwrite and auto_create_table) else table_name,
(
random_name_for_temp_object(TempObjectType.TABLE)
if (overwrite and auto_create_table)
else table_name
),
quote_identifiers,
)

Expand Down
16 changes: 13 additions & 3 deletions test/integ/aio/test_connection_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,8 +928,13 @@ async def test_invalid_connection_parameter(db_parameters, name, value, exc_warn
conn = snowflake.connector.aio.SnowflakeConnection(**conn_params)
await conn.connect()
assert getattr(conn, "_" + name) == value
assert len(w) == 1
assert str(w[0].message) == str(exc_warn)
# TODO: SNOW-2114216 remove filtering once the root cause for deprecation warning is fixed
# Filter out the deprecation warning
filtered_w = [
warning for warning in w if warning.category != DeprecationWarning
]
assert len(filtered_w) == 1
assert str(filtered_w[0].message) == str(exc_warn)
finally:
await conn.close()

Expand All @@ -955,7 +960,12 @@ async def test_invalid_connection_parameters_turned_off(db_parameters):
await conn.connect()
assert conn._autocommit == conn_params["autocommit"]
assert conn._applucation == conn_params["applucation"]
assert len(w) == 0
# TODO: SNOW-2114216 remove filtering once the root cause for deprecation warning is fixed
# Filter out the deprecation warning
filtered_w = [
warning for warning in w if warning.category != DeprecationWarning
]
assert len(filtered_w) == 0
finally:
await conn.close()

Expand Down
67 changes: 63 additions & 4 deletions test/integ/pandas/test_pandas_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Callable, Generator
from unittest import mock
from unittest.mock import MagicMock

import numpy.random
import pytest
Expand Down Expand Up @@ -543,7 +544,10 @@ def mocked_execute(*args, **kwargs):
with mock.patch(
"snowflake.connector.cursor.SnowflakeCursor.execute",
side_effect=mocked_execute,
) as m_execute:
) as m_execute, mock.patch(
"snowflake.connector.cursor.SnowflakeCursor._upload",
side_effect=MagicMock(),
) as _:
success, nchunks, nrows, _ = write_pandas(
cnx,
sf_connector_version_df.get(),
Expand Down Expand Up @@ -593,7 +597,10 @@ def mocked_execute(*args, **kwargs):
with mock.patch(
"snowflake.connector.cursor.SnowflakeCursor.execute",
side_effect=mocked_execute,
) as m_execute:
) as m_execute, mock.patch(
"snowflake.connector.cursor.SnowflakeCursor._upload",
side_effect=MagicMock(),
) as _:
success, nchunks, nrows, _ = write_pandas(
cnx,
sf_connector_version_df.get(),
Expand Down Expand Up @@ -645,7 +652,10 @@ def mocked_execute(*args, **kwargs):
with mock.patch(
"snowflake.connector.cursor.SnowflakeCursor.execute",
side_effect=mocked_execute,
) as m_execute:
) as m_execute, mock.patch(
"snowflake.connector.cursor.SnowflakeCursor._upload",
side_effect=MagicMock(),
) as _:
cnx._update_parameters({"PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS": True})
success, nchunks, nrows, _ = write_pandas(
cnx,
Expand Down Expand Up @@ -703,7 +713,10 @@ def mocked_execute(*args, **kwargs):
with mock.patch(
"snowflake.connector.cursor.SnowflakeCursor.execute",
side_effect=mocked_execute,
) as m_execute:
) as m_execute, mock.patch(
"snowflake.connector.cursor.SnowflakeCursor._upload",
side_effect=MagicMock(),
) as _:
success, nchunks, nrows, _ = write_pandas(
cnx,
sf_connector_version_df.get(),
Expand Down Expand Up @@ -1126,3 +1139,49 @@ def test_pandas_with_single_quote(
)
finally:
cnx.execute_string(f"drop table if exists {table_name}")


@pytest.mark.parametrize("bulk_upload_chunks", [True, False])
def test_write_pandas_bulk_chunks_upload(conn_cnx, bulk_upload_chunks):
"""Tests whether overwriting table using a Pandas DataFrame works as expected."""
random_table_name = random_string(5, "userspoints_")
df_data = [("Dash", 50), ("Luke", 20), ("Mark", 10), ("John", 30)]
df = pandas.DataFrame(df_data, columns=["name", "points"])

table_name = random_table_name
col_id = "id"
col_name = "name"
col_points = "points"

create_sql = (
f"CREATE OR REPLACE TABLE {table_name}"
f"({col_name} STRING, {col_points} INT, {col_id} INT AUTOINCREMENT)"
)

select_count_sql = f"SELECT count(*) FROM {table_name}"
drop_sql = f"DROP TABLE IF EXISTS {table_name}"
with conn_cnx() as cnx: # type: SnowflakeConnection
cnx.execute_string(create_sql)
try:
# Write dataframe with 1 row
success, nchunks, nrows, _ = write_pandas(
cnx,
df,
random_table_name,
quote_identifiers=False,
auto_create_table=False,
overwrite=True,
index=True,
on_error="continue",
chunk_size=1,
bulk_upload_chunks=bulk_upload_chunks,
)
# Check write_pandas output
assert success
assert nchunks == 4
assert nrows == 4
result = cnx.cursor(DictCursor).execute(select_count_sql).fetchone()
# Check number of rows
assert result["COUNT(*)"] == 4
finally:
cnx.execute_string(drop_sql)
16 changes: 13 additions & 3 deletions test/integ/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,8 +877,13 @@ def test_invalid_connection_parameter(db_parameters, name, value, exc_warn):
try:
conn = snowflake.connector.connect(**conn_params)
assert getattr(conn, "_" + name) == value
assert len(w) == 1
assert str(w[0].message) == str(exc_warn)
# TODO: SNOW-2114216 remove filtering once the root cause for deprecation warning is fixed
# Filter out the deprecation warning
filtered_w = [
warning for warning in w if warning.category != DeprecationWarning
]
assert len(filtered_w) == 1
assert str(filtered_w[0].message) == str(exc_warn)
finally:
conn.close()

Expand All @@ -903,7 +908,12 @@ def test_invalid_connection_parameters_turned_off(db_parameters):
conn = snowflake.connector.connect(**conn_params)
assert conn._autocommit == conn_params["autocommit"]
assert conn._applucation == conn_params["applucation"]
assert len(w) == 0
# TODO: SNOW-2114216 remove filtering once the root cause for deprecation warning is fixed
# Filter out the deprecation warning
filtered_w = [
warning for warning in w if warning.category != DeprecationWarning
]
assert len(filtered_w) == 0
finally:
conn.close()

Expand Down
Loading
Loading