Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Source code is also available at:
<https://github.com/snowflakedb/snowflake-sqlalchemy>
# Unreleased Notes
- v1.7.8
- Add logging of SQLAlchemy version
- Add logging of SQLAlchemy version and pandas (if used)

# Release Notes
- v1.7.7(September 3, 2025)
Expand Down
28 changes: 23 additions & 5 deletions src/snowflake/sqlalchemy/snowdialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@
logger = getLogger(__name__)


class TelemetryEvents(Enum):
NEW_CONNECTION = "sqlalchemy_new_connection"


class SnowflakeIsolationLevel(Enum):
READ_COMMITTED = "READ COMMITTED"
AUTOCOMMIT = "AUTOCOMMIT"
Expand Down Expand Up @@ -907,11 +911,11 @@ def connect(self, *cargs, **cparams):
cparams = _update_connection_application_name(**cparams)

connection = super().connect(*cargs, **cparams)
self._log_sql_alchemy_version(connection)
self._log_new_connection_event(connection)

return connection

def _log_sql_alchemy_version(self, connection):
def _log_new_connection_event(self, connection):
try:
snowflake_connection = cast(SnowflakeConnection, cast(object, connection))
snowflake_rest_client = SnowflakeRestful(
Expand All @@ -921,11 +925,22 @@ def _log_sql_alchemy_version(self, connection):
connection=snowflake_connection,
)
snowflake_telemetry_client = TelemetryClient(rest=snowflake_rest_client)

telemetry_value = {
"SQLAlchemy": SQLALCHEMY_VERSION,
}
try:
from pandas import __version__ as PANDAS_VERSION

telemetry_value["pandas"] = PANDAS_VERSION
except ImportError:
pass

snowflake_telemetry_client.add_log_to_batch(
TelemetryData.from_telemetry_data_dict(
from_dict={
TelemetryField.KEY_TYPE.value: "sqlalchemy_version",
TelemetryField.KEY_VALUE.value: SQLALCHEMY_VERSION,
TelemetryField.KEY_TYPE.value: TelemetryEvents.NEW_CONNECTION.value,
TelemetryField.KEY_VALUE.value: str(telemetry_value),
},
timestamp=int(time_in_seconds() * 1000),
connection=snowflake_connection,
Expand All @@ -934,7 +949,10 @@ def _log_sql_alchemy_version(self, connection):
snowflake_telemetry_client.send_batch()
except Exception as e:
logger.debug(
"Failed to send telemetry data: %s: %s", type(e).__name__, str(e)
"Failed to send telemetry data for %s event: %s: %s",
TelemetryEvents.NEW_CONNECTION.value,
type(e).__name__,
str(e),
)


Expand Down
67 changes: 62 additions & 5 deletions tests/test_dialect_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from sys import modules
from types import SimpleNamespace
from unittest import mock

import pytest
from sqlalchemy import __version__ as SQLALCHEMY_VERSION
from sqlalchemy.engine import default as sqla_default

from snowflake.sqlalchemy.snowdialect import SnowflakeDialect, TelemetryField
from snowflake.sqlalchemy.snowdialect import (
SnowflakeDialect,
TelemetryEvents,
TelemetryField,
)


@pytest.fixture
Expand All @@ -30,22 +35,74 @@ def test_connect_sends_telemetry(
"""Ensure telemetry is sent with the expected payload on connect."""
mock_connect.return_value = fake_connection

dialect = SnowflakeDialect()
result = dialect.connect()
# Mock out pandas to ensure deterministic behavior
with mock.patch.dict(modules, {"pandas": None}):
dialect = SnowflakeDialect()
result = dialect.connect()

assert result is fake_connection

# Verify add_log_to_batch was called with correct payload
telemetry_instance = mock_telemetry_client.return_value
payload = telemetry_instance.add_log_to_batch.call_args[0][0]
assert payload.message[TelemetryField.KEY_TYPE.value] == "sqlalchemy_version"
assert payload.message[TelemetryField.KEY_VALUE.value] == SQLALCHEMY_VERSION
assert (
payload.message[TelemetryField.KEY_TYPE.value]
== TelemetryEvents.NEW_CONNECTION.value
)
assert payload.message[TelemetryField.KEY_VALUE.value] == str(
{"SQLAlchemy": SQLALCHEMY_VERSION}
)
assert payload.timestamp != 0

# Verify send_batch was called
telemetry_instance.send_batch.assert_called_once()


@mock.patch.object(sqla_default.DefaultDialect, "connect")
@mock.patch("snowflake.sqlalchemy.snowdialect.TelemetryClient")
@mock.patch("snowflake.sqlalchemy.snowdialect.SnowflakeRestful")
def test_connect_telemetry_includes_pandas_when_available(
mock_restful, mock_telemetry_client, mock_connect, fake_connection
):
"""Ensure telemetry includes pandas version when pandas is installed."""
mock_connect.return_value = fake_connection

# Create a mock pandas module with a version
mock_pandas = mock.MagicMock()
mock_pandas.__version__ = "2.1.0"

with mock.patch.dict(modules, {"pandas": mock_pandas}):
dialect = SnowflakeDialect()
dialect.connect()

telemetry_instance = mock_telemetry_client.return_value
payload = telemetry_instance.add_log_to_batch.call_args[0][0]
telemetry_value = payload.message[TelemetryField.KEY_VALUE.value]

assert telemetry_value == str({"SQLAlchemy": SQLALCHEMY_VERSION, "pandas": "2.1.0"})


@mock.patch.object(sqla_default.DefaultDialect, "connect")
@mock.patch("snowflake.sqlalchemy.snowdialect.TelemetryClient")
@mock.patch("snowflake.sqlalchemy.snowdialect.SnowflakeRestful")
def test_connect_telemetry_excludes_pandas_when_not_available(
mock_restful, mock_telemetry_client, mock_connect, fake_connection
):
"""Ensure telemetry does not include pandas when it is not installed."""
mock_connect.return_value = fake_connection

# Simulate pandas not being installed
with mock.patch.dict(modules, {"pandas": None}):
dialect = SnowflakeDialect()
dialect.connect()

telemetry_instance = mock_telemetry_client.return_value
payload = telemetry_instance.add_log_to_batch.call_args[0][0]
telemetry_value = payload.message[TelemetryField.KEY_VALUE.value]

assert telemetry_value == str({"SQLAlchemy": SQLALCHEMY_VERSION})


@mock.patch.object(sqla_default.DefaultDialect, "connect")
@mock.patch("snowflake.sqlalchemy.snowdialect.TelemetryClient")
@mock.patch("snowflake.sqlalchemy.snowdialect.SnowflakeRestful")
Expand Down
Loading