Skip to content

Commit 86c0a11

Browse files
SNOW-891470 adding implementation for default_connection_name (#1721)
1 parent da1ae4e commit 86c0a11

File tree

5 files changed

+170
-8
lines changed

5 files changed

+170
-8
lines changed

src/snowflake/connector/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#
88
from __future__ import annotations
99

10+
from functools import wraps
11+
1012
apilevel = "2.0"
1113
threadsafety = 2
1214
paramstyle = "pyformat"
@@ -47,6 +49,7 @@
4749
logging.getLogger(__name__).addHandler(NullHandler())
4850

4951

52+
@wraps(SnowflakeConnection.__init__)
5053
def Connect(**kwargs) -> SnowflakeConnection:
5154
return SnowflakeConnection(**kwargs)
5255

src/snowflake/connector/config_manager.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def value(self) -> Any:
126126
value = self._get_config()
127127
source = "configuration file"
128128
except MissingConfigOptionError:
129-
if self.default:
129+
if self.default is not None:
130130
source = "default_value"
131131
value = self.default
132132
else:
@@ -166,8 +166,8 @@ def _get_env(self) -> tuple[bool, str | _T | None]:
166166
env_var = os.environ.get(env_name)
167167
if env_var is None:
168168
return False, None
169-
loaded_var: str | _T | None = env_var
170-
if env_var and self.parse_str is not None:
169+
loaded_var: str | _T = env_var
170+
if self.parse_str is not None:
171171
loaded_var = self.parse_str(env_var)
172172
if isinstance(loaded_var, (Table, tomlkit.TOMLDocument)):
173173
# If we got a TOML table we probably want it in dictionary form
@@ -449,6 +449,7 @@ def __getitem__(self, name: str) -> ConfigOption | ConfigManager:
449449
CONFIG_MANAGER.add_option(
450450
name="connections",
451451
parse_str=tomlkit.parse,
452+
default=dict(),
452453
)
453454
CONFIG_MANAGER.add_option(
454455
name="default_connection_name",

src/snowflake/connector/connection.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,21 @@ def __init__(
302302
connections_file_path: pathlib.Path | None = None,
303303
**kwargs,
304304
) -> None:
305+
"""Create a new SnowflakeConnection.
306+
307+
Connections can be loaded from the TOML file located at
308+
snowflake.connector.constants.CONNECTIONS_FILE.
309+
310+
When connection_name is supplied we will first load that connection
311+
and then override any other values supplied.
312+
313+
When no arguments are given (other than connection_file_path) the
314+
default connection will be loaded first. Note that no overwriting is
315+
supported in this case.
316+
317+
If overwriting values from the default connection is desirable, supply
318+
the name explicitly.
319+
"""
305320
self._lock_sequence_counter = Lock()
306321
self.sequence_counter = 0
307322
self._errorhandler = Error.default_errorhandler
@@ -324,6 +339,7 @@ def __init__(
324339
setattr(self, f"_{name}", value)
325340

326341
self.heartbeat_thread = None
342+
is_kwargs_empty = not kwargs
327343

328344
if "application" not in kwargs:
329345
if ENV_VAR_PARTNER in os.environ.keys():
@@ -349,6 +365,17 @@ def __init__(
349365
f" known ones are {list(connections.keys())}"
350366
)
351367
kwargs = {**connections[connection_name], **kwargs}
368+
elif is_kwargs_empty:
369+
# connection_name is None and kwargs was empty when called
370+
def_connection_name = CONFIG_MANAGER["default_connection_name"]
371+
connections = CONFIG_MANAGER["connections"]
372+
if def_connection_name not in connections:
373+
raise Error(
374+
f"Default connection with name '{def_connection_name}' "
375+
"cannot be found, known ones are "
376+
f"{list(connections.keys())}"
377+
)
378+
kwargs = {**connections[def_connection_name]}
352379
self.__set_error_attributes()
353380
self.connect(**kwargs)
354381
self._telemetry = TelemetryClient(self._rest)

test/integ/test_connection.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,7 +1226,7 @@ def test_connection_name_loading(monkeypatch, db_parameters, tmp_path, mode):
12261226

12271227
doc = tomlkit.document()
12281228
default_con = tomlkit.table()
1229-
tmp_config_file: None | pathlib.Path = None
1229+
tmp_connections_file: None | pathlib.Path = None
12301230
try:
12311231
# If anything unexpected fails here, don't want to expose password
12321232
for k, v in db_parameters.items():
@@ -1236,11 +1236,11 @@ def test_connection_name_loading(monkeypatch, db_parameters, tmp_path, mode):
12361236
if mode == "env":
12371237
m.setenv("SF_CONNECTIONS", tomlkit.dumps(doc))
12381238
else:
1239-
tmp_config_file = tmp_path / "connections.toml"
1240-
tmp_config_file.write_text(tomlkit.dumps(doc))
1239+
tmp_connections_file = tmp_path / "connections.toml"
1240+
tmp_connections_file.write_text(tomlkit.dumps(doc))
12411241
with snowflake.connector.connect(
12421242
connection_name="default",
1243-
connections_file_path=tmp_config_file,
1243+
connections_file_path=tmp_connections_file,
12441244
) as conn:
12451245
with conn.cursor() as cur:
12461246
assert cur.execute("select 1;").fetchall() == [
@@ -1253,6 +1253,32 @@ def test_connection_name_loading(monkeypatch, db_parameters, tmp_path, mode):
12531253
pytest.fail("something failed", pytrace=False)
12541254

12551255

1256+
@pytest.mark.skipolddriver
1257+
def test_default_connection_name_loading(monkeypatch, db_parameters):
1258+
import tomlkit
1259+
1260+
doc = tomlkit.document()
1261+
default_con = tomlkit.table()
1262+
try:
1263+
# If anything unexpected fails here, don't want to expose password
1264+
for k, v in db_parameters.items():
1265+
default_con[k] = v
1266+
doc["default"] = default_con
1267+
with monkeypatch.context() as m:
1268+
m.setenv("SNOWFLAKE_CONNECTIONS", tomlkit.dumps(doc))
1269+
m.setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "default")
1270+
with snowflake.connector.connect() as conn:
1271+
with conn.cursor() as cur:
1272+
assert cur.execute("select 1;").fetchall() == [
1273+
(1,),
1274+
]
1275+
except Exception:
1276+
# This is my way of guaranteeing that we'll not expose the
1277+
# sensitive information that this test needs to handle.
1278+
# db_parameter contains passwords.
1279+
pytest.fail("something failed", pytrace=False)
1280+
1281+
12561282
@pytest.mark.skipolddriver
12571283
def test_not_found_connection_name():
12581284
connection_name = random_string(5)

test/unit/test_connection.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88
import json
99
import os
1010
import sys
11+
from textwrap import dedent
1112
from unittest.mock import patch
1213

1314
import pytest
1415

1516
import snowflake.connector
17+
from snowflake.connector.errors import Error
18+
19+
from ..randomize import random_string
1620

1721
try:
1822
from snowflake.connector.auth import (
@@ -28,10 +32,11 @@
2832

2933
try: # pragma: no cover
3034
from snowflake.connector.auth import AuthByUsrPwdMfa
35+
from snowflake.connector.config_manager import CONFIG_MANAGER
3136
from snowflake.connector.constants import ENV_VAR_PARTNER, QueryStatus
3237
except ImportError:
3338
ENV_VAR_PARTNER = "SF_PARTNER"
34-
QueryStatus = None
39+
QueryStatus = CONFIG_MANAGER = None
3540

3641
class AuthByUsrPwdMfa(AuthByDefault):
3742
def __init__(self, password: str, mfa_token: str) -> None:
@@ -211,3 +216,103 @@ def test_negative_custom_auth(auth_class):
211216
user="user",
212217
auth_class=auth_class,
213218
)
219+
220+
221+
def test_missing_default_connection(monkeypatch, tmp_path):
222+
connections_file = tmp_path / "connections.toml"
223+
config_file = tmp_path / "config.toml"
224+
with monkeypatch.context() as m:
225+
m.delenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", raising=False)
226+
m.delenv("SNOWFLAKE_CONNECTIONS", raising=False)
227+
m.setattr(CONFIG_MANAGER, "conf_file_cache", None)
228+
m.setattr(CONFIG_MANAGER, "file_path", config_file)
229+
230+
with pytest.raises(
231+
Error,
232+
match="Default connection with name 'default' cannot be found, known ones are \\[\\]",
233+
):
234+
snowflake.connector.connect(connections_file_path=connections_file)
235+
236+
237+
def test_missing_default_connection_conf_file(monkeypatch, tmp_path):
238+
connection_name = random_string(5)
239+
connections_file = tmp_path / "connections.toml"
240+
config_file = tmp_path / "config.toml"
241+
config_file.write_text(
242+
dedent(
243+
f"""\
244+
default_connection_name = "{connection_name}"
245+
"""
246+
)
247+
)
248+
with monkeypatch.context() as m:
249+
m.delenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", raising=False)
250+
m.delenv("SNOWFLAKE_CONNECTIONS", raising=False)
251+
m.setattr(CONFIG_MANAGER, "conf_file_cache", None)
252+
m.setattr(CONFIG_MANAGER, "file_path", config_file)
253+
254+
with pytest.raises(
255+
Error,
256+
match=f"Default connection with name '{connection_name}' cannot be found, known ones are \\[\\]",
257+
):
258+
snowflake.connector.connect(connections_file_path=connections_file)
259+
260+
261+
def test_missing_default_connection_conn_file(monkeypatch, tmp_path):
262+
connections_file = tmp_path / "connections.toml"
263+
config_file = tmp_path / "config.toml"
264+
connections_file.write_text(
265+
dedent(
266+
"""\
267+
[con_a]
268+
user = "test user"
269+
account = "test account"
270+
password = "test password"
271+
"""
272+
)
273+
)
274+
with monkeypatch.context() as m:
275+
m.delenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", raising=False)
276+
m.delenv("SNOWFLAKE_CONNECTIONS", raising=False)
277+
m.setattr(CONFIG_MANAGER, "conf_file_cache", None)
278+
m.setattr(CONFIG_MANAGER, "file_path", config_file)
279+
280+
with pytest.raises(
281+
Error,
282+
match="Default connection with name 'default' cannot be found, known ones are \\['con_a'\\]",
283+
):
284+
snowflake.connector.connect(connections_file_path=connections_file)
285+
286+
287+
def test_missing_default_connection_conf_conn_file(monkeypatch, tmp_path):
288+
connection_name = random_string(5)
289+
connections_file = tmp_path / "connections.toml"
290+
config_file = tmp_path / "config.toml"
291+
config_file.write_text(
292+
dedent(
293+
f"""\
294+
default_connection_name = "{connection_name}"
295+
"""
296+
)
297+
)
298+
connections_file.write_text(
299+
dedent(
300+
"""\
301+
[con_a]
302+
user = "test user"
303+
account = "test account"
304+
password = "test password"
305+
"""
306+
)
307+
)
308+
with monkeypatch.context() as m:
309+
m.delenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", raising=False)
310+
m.delenv("SNOWFLAKE_CONNECTIONS", raising=False)
311+
m.setattr(CONFIG_MANAGER, "conf_file_cache", None)
312+
m.setattr(CONFIG_MANAGER, "file_path", config_file)
313+
314+
with pytest.raises(
315+
Error,
316+
match=f"Default connection with name '{connection_name}' cannot be found, known ones are \\['con_a'\\]",
317+
):
318+
snowflake.connector.connect(connections_file_path=connections_file)

0 commit comments

Comments
 (0)