Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 41 additions & 13 deletions src/snowflake/sqlalchemy/snowdialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
from collections import defaultdict
from enum import Enum
from functools import reduce
from typing import Any, Collection, Optional
from logging import getLogger
from time import time as time_in_seconds
from typing import Any, Collection, Optional, cast
from urllib.parse import unquote_plus

import sqlalchemy.sql.sqltypes as sqltypes
from sqlalchemy import __version__ as SQLALCHEMY_VERSION
from sqlalchemy import event as sa_vnt
from sqlalchemy import exc as sa_exc
from sqlalchemy import util as sa_util
Expand All @@ -19,8 +22,10 @@
from sqlalchemy.types import FLOAT, Date, DateTime, Float, Time

from snowflake.connector import errors as sf_errors
from snowflake.connector.connection import DEFAULT_CONFIGURATION
from snowflake.connector.connection import DEFAULT_CONFIGURATION, SnowflakeConnection
from snowflake.connector.constants import UTF8
from snowflake.connector.network import SnowflakeRestful
from snowflake.connector.telemetry import TelemetryClient, TelemetryData, TelemetryField
from snowflake.sqlalchemy.compat import returns_unicode
from snowflake.sqlalchemy.name_utils import _NameUtils
from snowflake.sqlalchemy.structured_type_info_manager import _StructuredTypeInfoManager
Expand Down Expand Up @@ -59,6 +64,8 @@

_ENABLE_SQLALCHEMY_AS_APPLICATION_NAME = True

logger = getLogger(__name__)


class SnowflakeIsolationLevel(Enum):
READ_COMMITTED = "READ COMMITTED"
Expand Down Expand Up @@ -896,18 +903,39 @@ def get_indexes(self, connection, tablename, schema, **kw):
return self._value_or_default(data, table_name, schema)

def connect(self, *cargs, **cparams):
return (
super().connect(
*cargs,
**(
_update_connection_application_name(**cparams)
if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME
else cparams
),
if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME:
cparams = _update_connection_application_name(**cparams)

connection = super().connect(*cargs, **cparams)
self._log_sql_alchemy_version(connection)

return connection

def _log_sql_alchemy_version(self, connection):
try:
snowflake_connection = cast(SnowflakeConnection, cast(object, connection))
snowflake_rest_client = SnowflakeRestful(
host=snowflake_connection.host,
port=snowflake_connection.port,
protocol="https",
connection=snowflake_connection,
)
snowflake_telemetry_client = TelemetryClient(rest=snowflake_rest_client)
snowflake_telemetry_client.add_log_to_batch(
TelemetryData.from_telemetry_data_dict(
from_dict={
TelemetryField.KEY_TYPE.value: "sqlalchemy_version",
TelemetryField.KEY_VALUE.value: SQLALCHEMY_VERSION,
},
timestamp=int(time_in_seconds() * 1000),
connection=snowflake_connection,
)
)
snowflake_telemetry_client.send_batch()
except Exception as e:
logger.debug(
"Failed to send telemetry data: %s: %s", type(e).__name__, str(e)
)
if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME
else super().connect(*cargs, **cparams)
)


@sa_vnt.listens_for(Table, "before_create")
Expand Down
67 changes: 67 additions & 0 deletions tests/test_dialect_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from types import SimpleNamespace
from unittest import mock

import pytest
from sqlalchemy import __version__ as SQLALCHEMY_VERSION
from sqlalchemy.engine import default as sqla_default

from snowflake.sqlalchemy.snowdialect import SnowflakeDialect, TelemetryField


@pytest.fixture
def fake_connection():
return SimpleNamespace(
host="example.snowflakecomputing.com",
port=443,
application="test_app",
)


@mock.patch.object(sqla_default.DefaultDialect, "connect")
@mock.patch("snowflake.sqlalchemy.snowdialect.TelemetryClient")
@mock.patch("snowflake.sqlalchemy.snowdialect.SnowflakeRestful")
def test_connect_sends_telemetry(
mock_restful, mock_telemetry_client, mock_connect, fake_connection
):
"""Ensure telemetry is sent with the expected payload on connect."""
mock_connect.return_value = fake_connection

dialect = SnowflakeDialect()
result = dialect.connect()

assert result is fake_connection

# Verify add_log_to_batch was called with correct payload
telemetry_instance = mock_telemetry_client.return_value
payload = telemetry_instance.add_log_to_batch.call_args[0][0]
assert payload.message[TelemetryField.KEY_TYPE.value] == "sqlalchemy_version"
assert payload.message[TelemetryField.KEY_VALUE.value] == SQLALCHEMY_VERSION
assert payload.timestamp != 0

# Verify send_batch was called
telemetry_instance.send_batch.assert_called_once()


@mock.patch.object(sqla_default.DefaultDialect, "connect")
@mock.patch("snowflake.sqlalchemy.snowdialect.TelemetryClient")
@mock.patch("snowflake.sqlalchemy.snowdialect.SnowflakeRestful")
def test_connect_logs_when_telemetry_fails(
mock_restful, mock_telemetry_client, mock_connect, caplog, fake_connection
):
"""Ensure failures in telemetry do not break connect and are logged."""
mock_connect.return_value = fake_connection
mock_telemetry_client.side_effect = RuntimeError("boom")

caplog.set_level("DEBUG", logger="snowflake.sqlalchemy.snowdialect")

dialect = SnowflakeDialect()
result = dialect.connect()

assert result is fake_connection
assert any(
"Failed to send telemetry data" in message for message in caplog.messages
)
Loading