diff --git a/setup.cfg b/setup.cfg index 280890d066..6ce15afc56 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,7 @@ install_requires = boto3>=1.24 botocore>=1.24 cffi>=1.9,<2.0.0 - cryptography>=3.1.0 + cryptography>=3.1.0,<=44.0.3 pyOpenSSL>=22.0.0,<25.0.0 pyjwt<3.0.0 pytz diff --git a/src/snowflake/connector/aio/auth/_webbrowser.py b/src/snowflake/connector/aio/auth/_webbrowser.py index c00e9a3293..0e9bdce9aa 100644 --- a/src/snowflake/connector/aio/auth/_webbrowser.py +++ b/src/snowflake/connector/aio/auth/_webbrowser.py @@ -94,18 +94,19 @@ async def prepare( socket_connection.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) try: + hostname = os.getenv("SF_AUTH_SOCKET_ADDR", "localhost") try: socket_connection.bind( ( - os.getenv("SF_AUTH_SOCKET_ADDR", "localhost"), + hostname, int(os.getenv("SF_AUTH_SOCKET_PORT", 0)), ) ) except socket.gaierror as ex: if ex.args[0] == socket.EAI_NONAME: raise OperationalError( - msg="localhost is not found. Ensure /etc/hosts has " - "localhost entry.", + msg=f"{hostname} is not found. Ensure /etc/hosts has " + f"{hostname} entry.", errno=ER_NO_HOSTNAME_FOUND, ) else: diff --git a/src/snowflake/connector/auth/webbrowser.py b/src/snowflake/connector/auth/webbrowser.py index 2f77badf8c..20b92efb52 100644 --- a/src/snowflake/connector/auth/webbrowser.py +++ b/src/snowflake/connector/auth/webbrowser.py @@ -123,18 +123,19 @@ def prepare( socket_connection.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) try: + hostname = os.getenv("SF_AUTH_SOCKET_ADDR", "localhost") try: socket_connection.bind( ( - os.getenv("SF_AUTH_SOCKET_ADDR", "localhost"), + hostname, int(os.getenv("SF_AUTH_SOCKET_PORT", 0)), ) ) except socket.gaierror as ex: if ex.args[0] == socket.EAI_NONAME: raise OperationalError( - msg="localhost is not found. Ensure /etc/hosts has " - "localhost entry.", + msg=f"{hostname} is not found. Ensure /etc/hosts has " + f"{hostname} entry.", errno=ER_NO_HOSTNAME_FOUND, ) else: diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index 2a85965e6c..84e0052a62 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -459,10 +459,9 @@ def __init__( is_kwargs_empty = not kwargs if "application" not in kwargs: - if ENV_VAR_PARTNER in os.environ.keys(): - kwargs["application"] = os.environ[ENV_VAR_PARTNER] - elif "streamlit" in sys.modules: - kwargs["application"] = "streamlit" + app = self._detect_application() + if app: + kwargs["application"] = app if "insecure_mode" in kwargs: warn_message = "The 'insecure_mode' connection property is deprecated. Please use 'disable_ocsp_checks' instead" @@ -2146,3 +2145,17 @@ def is_valid(self) -> bool: except Exception as e: logger.debug("session could not be validated due to exception: %s", e) return False + + @staticmethod + def _detect_application() -> None | str: + if ENV_VAR_PARTNER in os.environ.keys(): + return os.environ[ENV_VAR_PARTNER] + if "streamlit" in sys.modules: + return "streamlit" + if all( + (jpmod in sys.modules) + for jpmod in ("ipykernel", "jupyter_core", "jupyter_client") + ): + return "jupyter_notebook" + if "snowbooks" in sys.modules: + return "snowflake_notebook" diff --git a/src/snowflake/connector/pandas_tools.py b/src/snowflake/connector/pandas_tools.py index a9555dd553..6f7d30d0a2 100644 --- a/src/snowflake/connector/pandas_tools.py +++ b/src/snowflake/connector/pandas_tools.py @@ -20,7 +20,6 @@ from snowflake.connector import ProgrammingError from snowflake.connector.options import pandas from snowflake.connector.telemetry import TelemetryData, TelemetryField -from snowflake.connector.util_text import random_string from ._utils import ( _PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING, @@ -108,11 +107,7 @@ def _create_temp_stage( overwrite: bool, use_scoped_temp_object: bool = False, ) -> str: - stage_name = ( - random_name_for_temp_object(TempObjectType.STAGE) - if use_scoped_temp_object - else random_string() - ) + stage_name = random_name_for_temp_object(TempObjectType.STAGE) stage_location = build_location_helper( database=database, schema=schema, @@ -179,11 +174,7 @@ def _create_temp_file_format( sql_use_logical_type: str, use_scoped_temp_object: bool = False, ) -> str: - file_format_name = ( - random_name_for_temp_object(TempObjectType.FILE_FORMAT) - if use_scoped_temp_object - else random_string() - ) + file_format_name = random_name_for_temp_object(TempObjectType.FILE_FORMAT) file_format_location = build_location_helper( database=database, schema=schema, @@ -269,6 +260,7 @@ def write_pandas( table_type: Literal["", "temp", "temporary", "transient"] = "", use_logical_type: bool | None = None, iceberg_config: dict[str, str] | None = None, + bulk_upload_chunks: bool = False, **kwargs: Any, ) -> tuple[ bool, @@ -340,6 +332,8 @@ def write_pandas( * base_location: the base directory that snowflake can write iceberg metadata and files to * catalog_sync: optionally sets the catalog integration configured for Polaris Catalog * storage_serialization_policy: specifies the storage serialization policy for the table + bulk_upload_chunks: If set to True, the upload will use the wildcard upload method. + This is a faster method of uploading but instead of uploading and cleaning up each chunk separately it will upload all chunks at once and then clean up locally stored chunks. @@ -388,6 +382,10 @@ def write_pandas( "Unsupported table type. Expected table types: temp/temporary, transient" ) + if table_type.lower() in ["temp", "temporary"]: + # Add scoped keyword when applicable. + table_type = get_temp_type_for_object(_use_scoped_temp_object).lower() + if chunk_size is None: chunk_size = len(df) @@ -442,25 +440,26 @@ def write_pandas( chunk_path = os.path.join(tmp_folder, f"file{i}.txt") # Dump chunk into parquet file chunk.to_parquet(chunk_path, compression=compression, **kwargs) - # Upload parquet file - upload_sql = ( - "PUT /* Python:snowflake.connector.pandas_tools.write_pandas() */ " - "'file://{path}' ? PARALLEL={parallel}" - ).format( - path=chunk_path.replace("\\", "\\\\").replace("'", "\\'"), - parallel=parallel, - ) - params = ("@" + stage_location,) - logger.debug(f"uploading files with '{upload_sql}', params: %s", params) - cursor.execute( - upload_sql, - _is_internal=True, - _force_qmark_paramstyle=True, - params=params, - num_statements=1, + if not bulk_upload_chunks: + # Upload parquet file chunk right away + path = chunk_path.replace("\\", "\\\\").replace("'", "\\'") + cursor._upload( + local_file_name=f"'file://{path}'", + stage_location="@" + stage_location, + options={"parallel": parallel, "source_compression": "auto_detect"}, + ) + + # Remove chunk file + os.remove(chunk_path) + + if bulk_upload_chunks: + # Upload tmp directory with parquet chunks + path = tmp_folder.replace("\\", "\\\\").replace("'", "\\'") + cursor._upload( + local_file_name=f"'file://{path}/*'", + stage_location="@" + stage_location, + options={"parallel": parallel, "source_compression": "auto_detect"}, ) - # Remove chunk file - os.remove(chunk_path) # in Snowflake, all parquet data is stored in a single column, $1, so we must select columns explicitly # see (https://docs.snowflake.com/en/user-guide/script-data-load-transform-parquet.html) @@ -522,7 +521,11 @@ def drop_object(name: str, object_type: str) -> None: target_table_location = build_location_helper( database, schema, - random_string() if (overwrite and auto_create_table) else table_name, + ( + random_name_for_temp_object(TempObjectType.TABLE) + if (overwrite and auto_create_table) + else table_name + ), quote_identifiers, ) diff --git a/test/integ/aio/test_connection_async.py b/test/integ/aio/test_connection_async.py index c8d7ea6a4d..e0c771664a 100644 --- a/test/integ/aio/test_connection_async.py +++ b/test/integ/aio/test_connection_async.py @@ -928,8 +928,13 @@ async def test_invalid_connection_parameter(db_parameters, name, value, exc_warn conn = snowflake.connector.aio.SnowflakeConnection(**conn_params) await conn.connect() assert getattr(conn, "_" + name) == value - assert len(w) == 1 - assert str(w[0].message) == str(exc_warn) + # TODO: SNOW-2114216 remove filtering once the root cause for deprecation warning is fixed + # Filter out the deprecation warning + filtered_w = [ + warning for warning in w if warning.category != DeprecationWarning + ] + assert len(filtered_w) == 1 + assert str(filtered_w[0].message) == str(exc_warn) finally: await conn.close() @@ -955,7 +960,12 @@ async def test_invalid_connection_parameters_turned_off(db_parameters): await conn.connect() assert conn._autocommit == conn_params["autocommit"] assert conn._applucation == conn_params["applucation"] - assert len(w) == 0 + # TODO: SNOW-2114216 remove filtering once the root cause for deprecation warning is fixed + # Filter out the deprecation warning + filtered_w = [ + warning for warning in w if warning.category != DeprecationWarning + ] + assert len(filtered_w) == 0 finally: await conn.close() diff --git a/test/integ/pandas/test_pandas_tools.py b/test/integ/pandas/test_pandas_tools.py index 43f788ebbb..5f9bc34e93 100644 --- a/test/integ/pandas/test_pandas_tools.py +++ b/test/integ/pandas/test_pandas_tools.py @@ -6,6 +6,7 @@ from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Callable, Generator from unittest import mock +from unittest.mock import MagicMock import numpy.random import pytest @@ -543,7 +544,10 @@ def mocked_execute(*args, **kwargs): with mock.patch( "snowflake.connector.cursor.SnowflakeCursor.execute", side_effect=mocked_execute, - ) as m_execute: + ) as m_execute, mock.patch( + "snowflake.connector.cursor.SnowflakeCursor._upload", + side_effect=MagicMock(), + ) as _: success, nchunks, nrows, _ = write_pandas( cnx, sf_connector_version_df.get(), @@ -593,7 +597,10 @@ def mocked_execute(*args, **kwargs): with mock.patch( "snowflake.connector.cursor.SnowflakeCursor.execute", side_effect=mocked_execute, - ) as m_execute: + ) as m_execute, mock.patch( + "snowflake.connector.cursor.SnowflakeCursor._upload", + side_effect=MagicMock(), + ) as _: success, nchunks, nrows, _ = write_pandas( cnx, sf_connector_version_df.get(), @@ -645,7 +652,10 @@ def mocked_execute(*args, **kwargs): with mock.patch( "snowflake.connector.cursor.SnowflakeCursor.execute", side_effect=mocked_execute, - ) as m_execute: + ) as m_execute, mock.patch( + "snowflake.connector.cursor.SnowflakeCursor._upload", + side_effect=MagicMock(), + ) as _: cnx._update_parameters({"PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS": True}) success, nchunks, nrows, _ = write_pandas( cnx, @@ -703,7 +713,10 @@ def mocked_execute(*args, **kwargs): with mock.patch( "snowflake.connector.cursor.SnowflakeCursor.execute", side_effect=mocked_execute, - ) as m_execute: + ) as m_execute, mock.patch( + "snowflake.connector.cursor.SnowflakeCursor._upload", + side_effect=MagicMock(), + ) as _: success, nchunks, nrows, _ = write_pandas( cnx, sf_connector_version_df.get(), @@ -1126,3 +1139,49 @@ def test_pandas_with_single_quote( ) finally: cnx.execute_string(f"drop table if exists {table_name}") + + +@pytest.mark.parametrize("bulk_upload_chunks", [True, False]) +def test_write_pandas_bulk_chunks_upload(conn_cnx, bulk_upload_chunks): + """Tests whether overwriting table using a Pandas DataFrame works as expected.""" + random_table_name = random_string(5, "userspoints_") + df_data = [("Dash", 50), ("Luke", 20), ("Mark", 10), ("John", 30)] + df = pandas.DataFrame(df_data, columns=["name", "points"]) + + table_name = random_table_name + col_id = "id" + col_name = "name" + col_points = "points" + + create_sql = ( + f"CREATE OR REPLACE TABLE {table_name}" + f"({col_name} STRING, {col_points} INT, {col_id} INT AUTOINCREMENT)" + ) + + select_count_sql = f"SELECT count(*) FROM {table_name}" + drop_sql = f"DROP TABLE IF EXISTS {table_name}" + with conn_cnx() as cnx: # type: SnowflakeConnection + cnx.execute_string(create_sql) + try: + # Write dataframe with 1 row + success, nchunks, nrows, _ = write_pandas( + cnx, + df, + random_table_name, + quote_identifiers=False, + auto_create_table=False, + overwrite=True, + index=True, + on_error="continue", + chunk_size=1, + bulk_upload_chunks=bulk_upload_chunks, + ) + # Check write_pandas output + assert success + assert nchunks == 4 + assert nrows == 4 + result = cnx.cursor(DictCursor).execute(select_count_sql).fetchone() + # Check number of rows + assert result["COUNT(*)"] == 4 + finally: + cnx.execute_string(drop_sql) diff --git a/test/integ/test_connection.py b/test/integ/test_connection.py index 0df386afca..8436290e52 100644 --- a/test/integ/test_connection.py +++ b/test/integ/test_connection.py @@ -877,8 +877,13 @@ def test_invalid_connection_parameter(db_parameters, name, value, exc_warn): try: conn = snowflake.connector.connect(**conn_params) assert getattr(conn, "_" + name) == value - assert len(w) == 1 - assert str(w[0].message) == str(exc_warn) + # TODO: SNOW-2114216 remove filtering once the root cause for deprecation warning is fixed + # Filter out the deprecation warning + filtered_w = [ + warning for warning in w if warning.category != DeprecationWarning + ] + assert len(filtered_w) == 1 + assert str(filtered_w[0].message) == str(exc_warn) finally: conn.close() @@ -903,7 +908,12 @@ def test_invalid_connection_parameters_turned_off(db_parameters): conn = snowflake.connector.connect(**conn_params) assert conn._autocommit == conn_params["autocommit"] assert conn._applucation == conn_params["applucation"] - assert len(w) == 0 + # TODO: SNOW-2114216 remove filtering once the root cause for deprecation warning is fixed + # Filter out the deprecation warning + filtered_w = [ + warning for warning in w if warning.category != DeprecationWarning + ] + assert len(filtered_w) == 0 finally: conn.close() diff --git a/test/unit/aio/test_connection_async_unit.py b/test/unit/aio/test_connection_async_unit.py index 43a6c63324..0778f58e6a 100644 --- a/test/unit/aio/test_connection_async_unit.py +++ b/test/unit/aio/test_connection_async_unit.py @@ -196,13 +196,25 @@ async def test_partner_env_var(mock_post_requests): ) -async def test_imported_module(mock_post_requests): - with patch.dict(sys.modules, {"streamlit": "foo"}): +@pytest.mark.skipolddriver +@pytest.mark.parametrize( + "sys_modules,application", + [ + ({"streamlit": None}, "streamlit"), + ( + {"ipykernel": None, "jupyter_core": None, "jupyter_client": None}, + "jupyter_notebook", + ), + ({"snowbooks": None}, "snowflake_notebook"), + ], +) +async def test_imported_module(mock_post_requests, sys_modules, application): + with patch.dict(sys.modules, sys_modules): async with fake_db_conn() as conn: - assert conn.application == "streamlit" + assert conn.application == application assert ( - mock_post_requests["data"]["CLIENT_ENVIRONMENT"]["APPLICATION"] == "streamlit" + mock_post_requests["data"]["CLIENT_ENVIRONMENT"]["APPLICATION"] == application ) diff --git a/test/unit/test_connection.py b/test/unit/test_connection.py index 5fa43a4224..8e229b751f 100644 --- a/test/unit/test_connection.py +++ b/test/unit/test_connection.py @@ -194,12 +194,23 @@ def test_partner_env_var(mock_post_requests): @pytest.mark.skipolddriver -def test_imported_module(mock_post_requests): - with patch.dict(sys.modules, {"streamlit": "foo"}): - assert fake_connector().application == "streamlit" +@pytest.mark.parametrize( + "sys_modules,application", + [ + ({"streamlit": None}, "streamlit"), + ( + {"ipykernel": None, "jupyter_core": None, "jupyter_client": None}, + "jupyter_notebook", + ), + ({"snowbooks": None}, "snowflake_notebook"), + ], +) +def test_imported_module(mock_post_requests, sys_modules, application): + with patch.dict(sys.modules, sys_modules): + assert fake_connector().application == application assert ( - mock_post_requests["data"]["CLIENT_ENVIRONMENT"]["APPLICATION"] == "streamlit" + mock_post_requests["data"]["CLIENT_ENVIRONMENT"]["APPLICATION"] == application ) diff --git a/tox.ini b/tox.ini index 25bef2ffe7..ee1e1005f3 100644 --- a/tox.ini +++ b/tox.ini @@ -84,7 +84,7 @@ deps = numpy==1.26.4 pendulum!=2.1.1 pytest<6.1.0 - pytest-cov + pytest-cov<6.2.0 pytest-rerunfailures pytest-timeout pytest-xdist