Skip to content

Commit ab5902f

Browse files
SNOW-1763096: Add async telemetry support (#2585)
1 parent 8404bbf commit ab5902f

File tree

7 files changed

+144
-37
lines changed

7 files changed

+144
-37
lines changed

src/snowflake/connector/errors.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python
22
from __future__ import annotations
33

4+
import inspect
45
import logging
56
import os
67
import re
@@ -14,6 +15,8 @@
1415
from .time_util import get_time_millis
1516

1617
if TYPE_CHECKING: # pragma: no cover
18+
from .aio._connection import SnowflakeConnection as AsyncSnowflakeConnection
19+
from .aio._cursor import SnowflakeCursor as AsyncSnowflakeCursor
1720
from .connection import SnowflakeConnection
1821
from .cursor import SnowflakeCursor
1922

@@ -35,8 +38,8 @@ def __init__(
3538
sfqid: str | None = None,
3639
query: str | None = None,
3740
done_format_msg: bool | None = None,
38-
connection: SnowflakeConnection | None = None,
39-
cursor: SnowflakeCursor | None = None,
41+
connection: SnowflakeConnection | AsyncSnowflakeConnection | None = None,
42+
cursor: SnowflakeCursor | AsyncSnowflakeCursor | None = None,
4043
errtype: TelemetryField = TelemetryField.SQL_EXCEPTION,
4144
send_telemetry: bool = True,
4245
) -> None:
@@ -145,11 +148,10 @@ def generate_telemetry_exception_data(
145148

146149
def send_exception_telemetry(
147150
self,
148-
connection: SnowflakeConnection | None,
151+
connection: SnowflakeConnection | AsyncSnowflakeConnection | None,
149152
telemetry_data: dict[str, Any],
150153
) -> None:
151154
"""Send telemetry data by in-band telemetry if it is enabled, otherwise send through out-of-band telemetry."""
152-
153155
if (
154156
connection is not None
155157
and connection.telemetry_enabled
@@ -159,21 +161,34 @@ def send_exception_telemetry(
159161
telemetry_data[TelemetryField.KEY_TYPE.value] = self.errtype.value
160162
telemetry_data[TelemetryField.KEY_SOURCE.value] = connection.application
161163
telemetry_data[TelemetryField.KEY_EXCEPTION.value] = self.__class__.__name__
164+
telemetry_data[TelemetryField.KEY_USES_AIO.value] = str(
165+
self._is_aio_connection(connection)
166+
).lower()
162167
ts = get_time_millis()
163168
try:
164-
connection._log_telemetry(
169+
result = connection._log_telemetry(
165170
TelemetryData.from_telemetry_data_dict(
166171
from_dict=telemetry_data, timestamp=ts, connection=connection
167172
)
168173
)
174+
if inspect.isawaitable(result):
175+
try:
176+
import asyncio
177+
178+
asyncio.get_running_loop().create_task(result)
179+
except Exception:
180+
logger.debug(
181+
"Failed to schedule async telemetry logging.",
182+
exc_info=True,
183+
)
169184
except AttributeError:
170185
logger.debug("Cursor failed to log to telemetry.", exc_info=True)
171186

172187
def exception_telemetry(
173188
self,
174189
msg: str,
175-
cursor: SnowflakeCursor | None,
176-
connection: SnowflakeConnection | None,
190+
cursor: SnowflakeCursor | AsyncSnowflakeCursor | None,
191+
connection: SnowflakeConnection | AsyncSnowflakeConnection | None,
177192
) -> None:
178193
"""Main method to generate and send telemetry data for exceptions."""
179194
try:
@@ -370,6 +385,18 @@ def errorhandler_make_exception(
370385
)
371386
return error_class(error_value)
372387

388+
@staticmethod
389+
def _is_aio_connection(
390+
connection: SnowflakeConnection | AsyncSnowflakeConnection,
391+
) -> bool:
392+
try:
393+
# Try import async connection. The import may fail if aio is not installed.
394+
from .aio._connection import SnowflakeConnection as AsyncSnowflakeConnection
395+
396+
return isinstance(connection, AsyncSnowflakeConnection)
397+
except ImportError:
398+
return False
399+
373400

374401
class _Warning(Exception):
375402
"""Exception for important warnings."""

src/snowflake/connector/telemetry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class TelemetryField(Enum):
5151
KEY_REASON = "reason"
5252
KEY_VALUE = "value"
5353
KEY_EXCEPTION = "exception"
54+
KEY_USES_AIO = "uses_aio"
5455
# Reserved UpperCamelName keys
5556
KEY_ERROR_NUMBER = "ErrorNumber"
5657
KEY_ERROR_MESSAGE = "ErrorMessage"

test/csp_helpers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,7 @@ def __enter__(self):
446446
def __exit__(self, *args, **kwargs):
447447
self.os_environment_patch.__exit__(*args)
448448
super().__exit__(*args, **kwargs)
449+
450+
451+
def is_running_against_gcp():
452+
return os.getenv("cloud_provider").lower() == "gcp"

test/integ/aio_it/test_cursor_binding_async.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from __future__ import annotations
77

8+
from test.csp_helpers import is_running_against_gcp
9+
810
import pytest
911

1012
from snowflake.connector.errors import ProgrammingError
@@ -46,21 +48,23 @@ async def test_binding_security(conn_cnx, db_parameters):
4648

4749
# SQL injection safe test
4850
# Good Example
49-
with pytest.raises(ProgrammingError):
50-
await cnx.cursor().execute(
51-
"SELECT * FROM {name} WHERE aa=%s".format(
52-
name=db_parameters["name"]
53-
),
54-
("1 or aa>0",),
55-
)
56-
57-
with pytest.raises(ProgrammingError):
58-
await cnx.cursor().execute(
59-
"SELECT * FROM {name} WHERE aa=%(aa)s".format(
60-
name=db_parameters["name"]
61-
),
62-
{"aa": "1 or aa>0"},
63-
)
51+
if not is_running_against_gcp():
52+
with pytest.raises(ProgrammingError):
53+
r = await cnx.cursor().execute(
54+
"SELECT * FROM {name} WHERE aa=%s".format(
55+
name=db_parameters["name"]
56+
),
57+
("1 or aa>0",),
58+
)
59+
await r.fetchall()
60+
61+
with pytest.raises(ProgrammingError):
62+
await cnx.cursor().execute(
63+
"SELECT * FROM {name} WHERE aa=%(aa)s".format(
64+
name=db_parameters["name"]
65+
),
66+
{"aa": "1 or aa>0"},
67+
)
6468

6569
# Bad Example in application. DON'T DO THIS
6670
c = cnx.cursor()

test/integ/test_cursor_binding.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#!/usr/bin/env python
22
from __future__ import annotations
33

4+
from test.csp_helpers import is_running_against_gcp
5+
46
import pytest
57

68
from snowflake.connector.errors import ProgrammingError
@@ -42,21 +44,22 @@ def test_binding_security(conn_cnx, db_parameters):
4244

4345
# SQL injection safe test
4446
# Good Example
45-
with pytest.raises(ProgrammingError):
46-
cnx.cursor().execute(
47-
"SELECT * FROM {name} WHERE aa=%s".format(
48-
name=db_parameters["name"]
49-
),
50-
("1 or aa>0",),
51-
)
52-
53-
with pytest.raises(ProgrammingError):
54-
cnx.cursor().execute(
55-
"SELECT * FROM {name} WHERE aa=%(aa)s".format(
56-
name=db_parameters["name"]
57-
),
58-
{"aa": "1 or aa>0"},
59-
)
47+
if not is_running_against_gcp():
48+
with pytest.raises(ProgrammingError):
49+
cnx.cursor().execute(
50+
"SELECT * FROM {name} WHERE aa=%s".format(
51+
name=db_parameters["name"]
52+
),
53+
("1 or aa>0",),
54+
)
55+
56+
with pytest.raises(ProgrammingError):
57+
cnx.cursor().execute(
58+
"SELECT * FROM {name} WHERE aa=%(aa)s".format(
59+
name=db_parameters["name"]
60+
),
61+
{"aa": "1 or aa>0"},
62+
)
6063

6164
# Bad Example in application. DON'T DO THIS
6265
c = cnx.cursor()
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
from unittest.mock import AsyncMock, Mock, patch
4+
5+
from snowflake.connector.aio import SnowflakeConnection
6+
from snowflake.connector.errors import Error
7+
from snowflake.connector.telemetry import TelemetryData, TelemetryField
8+
9+
10+
def _extract_message_from_log_call(mock_conn: Mock) -> dict:
11+
mock_conn._log_telemetry.assert_called_once()
12+
td = mock_conn._log_telemetry.call_args[0][0]
13+
assert isinstance(td, TelemetryData)
14+
return td.message
15+
16+
17+
async def test_error_telemetry_async_connection():
18+
conn = Mock(SnowflakeConnection)
19+
conn.telemetry_enabled = True
20+
conn._telemetry = Mock()
21+
conn._telemetry.is_closed = False
22+
conn.application = "pytest_app_async"
23+
conn._log_telemetry = AsyncMock()
24+
25+
with patch("asyncio.get_running_loop") as loop_mock:
26+
Error(msg="kaboom", errno=654321, sqlstate="00000", connection=conn)
27+
loop_mock.return_value.create_task.assert_called_once()
28+
29+
msg = _extract_message_from_log_call(conn)
30+
assert msg[TelemetryField.KEY_TYPE.value] == TelemetryField.SQL_EXCEPTION.value
31+
assert msg[TelemetryField.KEY_SOURCE.value] == conn.application
32+
assert msg[TelemetryField.KEY_EXCEPTION.value] == "Error"
33+
assert msg[TelemetryField.KEY_USES_AIO.value] == "true"
34+
assert TelemetryField.KEY_DRIVER_TYPE.value in msg
35+
assert TelemetryField.KEY_DRIVER_VERSION.value in msg

test/unit/test_errors_telemetry.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from __future__ import annotations
2+
3+
from unittest.mock import Mock
4+
5+
from snowflake.connector.errors import Error
6+
from snowflake.connector.telemetry import TelemetryData, TelemetryField
7+
8+
9+
def _extract_message_from_log_call(mock_conn: Mock) -> dict:
10+
mock_conn._log_telemetry.assert_called_once()
11+
td = mock_conn._log_telemetry.call_args[0][0]
12+
assert isinstance(td, TelemetryData)
13+
return td.message
14+
15+
16+
def test_error_telemetry_sync_connection():
17+
conn = Mock()
18+
conn.telemetry_enabled = True
19+
conn._telemetry = Mock()
20+
conn._telemetry.is_closed = False
21+
conn.application = "pytest_app"
22+
conn._log_telemetry = Mock()
23+
24+
err = Error(msg="boom", errno=123456, sqlstate="00000", connection=conn)
25+
assert str(err)
26+
27+
msg = _extract_message_from_log_call(conn)
28+
assert msg[TelemetryField.KEY_TYPE.value] == TelemetryField.SQL_EXCEPTION.value
29+
assert msg[TelemetryField.KEY_SOURCE.value] == conn.application
30+
assert msg[TelemetryField.KEY_EXCEPTION.value] == "Error"
31+
assert msg[TelemetryField.KEY_USES_AIO.value] == "false"
32+
assert TelemetryField.KEY_DRIVER_TYPE.value in msg
33+
assert TelemetryField.KEY_DRIVER_VERSION.value in msg

0 commit comments

Comments
 (0)