Skip to content

Commit f50427b

Browse files
Implement thread safety during instantiation of TelemetryService (#1729)
* Implement thread safety during instantiation on TelemetryService * Code and test cleanup * Edited DESCRIPTION.md * Update DESCRIPTION.md Take Adam's suggestion Co-authored-by: Adam Ling <[email protected]> --------- Co-authored-by: Adam Ling <[email protected]>
1 parent d0e00ad commit f50427b

File tree

3 files changed

+39
-5
lines changed

3 files changed

+39
-5
lines changed

DESCRIPTION.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
88

99
# Release Notes
1010

11+
- v3.2.1(TBD)
12+
13+
- Added thread safety in telemetry when instantiating multiple connections concurrently.
14+
1115
- v3.2.0(September 06,2023)
1216

1317
- Made the ``parser`` -> ``manager`` renaming more consistent in ``snowflake.connector.config_manager`` module.

src/snowflake/connector/telemetry_oob.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import uuid
1212
from collections import namedtuple
1313
from queue import Queue
14+
from threading import Lock
1415
from typing import Any
1516

1617
from .compat import OK
@@ -152,13 +153,16 @@ def get_type(self) -> str:
152153

153154
class TelemetryService:
154155
__instance = None
156+
# prevents race condition from multiple threads creating Snowflake connections
157+
__lock_init = Lock()
155158

156-
@staticmethod
157-
def get_instance() -> TelemetryService:
159+
@classmethod
160+
def get_instance(cls) -> TelemetryService:
158161
"""Static access method."""
159-
if TelemetryService.__instance is None:
160-
TelemetryService()
161-
return TelemetryService.__instance
162+
with cls.__lock_init:
163+
if cls.__instance is None:
164+
cls()
165+
return cls.__instance
162166

163167
def __init__(self) -> None:
164168
"""Virtually private constructor."""

test/unit/test_telemetry_oob.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from __future__ import annotations
77

88
import logging
9+
import time
10+
from concurrent.futures import ThreadPoolExecutor
911

1012
import pytest
1113

@@ -26,6 +28,8 @@
2628
"password": "ShouldNotShowUp",
2729
"protocol": "http",
2830
}
31+
TEST_RACE_CONDITION_THREAD_COUNT = 2
32+
TEST_RACE_CONDITION_DELAY_SECONDS = 1
2933
telemetry_data = {}
3034
exception = RevocationCheckError("Test OCSP Revocation error")
3135
event_type = "Test OCSP Exception"
@@ -198,3 +202,25 @@ def test_generate_telemetry_with_driver_info():
198202
snowflake.connector.telemetry.TelemetryField.KEY_OOB_VERSION.value: "1.2.3",
199203
"key": "value",
200204
}
205+
206+
207+
class MockTelemetryService(TelemetryService):
208+
"""Mocks a delay in the __init__ of TelemetryService to simulate a race condition"""
209+
210+
def __init__(self, *args, **kwargs):
211+
# this delay all but guarantees enough time to catch multiple threads entering __init__
212+
time.sleep(TEST_RACE_CONDITION_DELAY_SECONDS)
213+
super().__init__(*args, **kwargs)
214+
215+
216+
def test_get_instance_multithreaded():
217+
"""Tests thread safety of multithreaded calls to TelemetryService.get_instance()"""
218+
TelemetryService._TelemetryService__instance = None
219+
with ThreadPoolExecutor() as executor:
220+
futures = [
221+
executor.submit(MockTelemetryService.get_instance)
222+
for _ in range(TEST_RACE_CONDITION_THREAD_COUNT)
223+
]
224+
for future in futures:
225+
# will error if singleton constraint violated
226+
future.result()

0 commit comments

Comments
 (0)