Skip to content

Commit abbf282

Browse files
committed
SNOW-2306184: config refactor - old & new unit tests pass
1 parent 6f0325a commit abbf282

File tree

4 files changed

+214
-47
lines changed

4 files changed

+214
-47
lines changed

src/snowflake/cli/api/config_ng/sources.py

Lines changed: 103 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,16 @@ class SnowSQLConfigFile(ValueSource):
5353
Later files override earlier files for the same keys.
5454
Returns configuration for ALL connections.
5555
56-
Config files searched (in order):
56+
Config files searched (in order, when not in test mode):
5757
1. Bundled default config (if in package)
5858
2. /etc/snowsql.cnf (system-wide)
5959
3. /etc/snowflake/snowsql.cnf (alternative system)
6060
4. /usr/local/etc/snowsql.cnf (local system)
6161
5. ~/.snowsql.cnf (legacy user config)
6262
6. ~/.snowsql/config (current user config)
63+
64+
In test mode (when config_file_override is set), SnowSQL config files are skipped
65+
to ensure test isolation.
6366
"""
6467

6568
# SnowSQL uses different key names - map them to CLI standard names
@@ -87,13 +90,31 @@ class SnowSQLConfigFile(ValueSource):
8790

8891
def __init__(self):
8992
"""Initialize SnowSQL config file source."""
90-
self._config_files = [
91-
Path("/etc/snowsql.cnf"),
92-
Path("/etc/snowflake/snowsql.cnf"),
93-
Path("/usr/local/etc/snowsql.cnf"),
94-
Path.home() / ".snowsql.cnf",
95-
Path.home() / ".snowsql" / "config",
96-
]
93+
# Use SNOWFLAKE_HOME if set and directory exists, otherwise use standard paths
94+
snowflake_home = os.environ.get("SNOWFLAKE_HOME")
95+
if snowflake_home:
96+
snowflake_home_path = Path(snowflake_home).expanduser()
97+
if snowflake_home_path.exists():
98+
# Use only the SnowSQL config file within SNOWFLAKE_HOME
99+
self._config_files = [snowflake_home_path / "config"]
100+
else:
101+
# SNOWFLAKE_HOME set but doesn't exist, use standard paths
102+
self._config_files = [
103+
Path("/etc/snowsql.cnf"),
104+
Path("/etc/snowflake/snowsql.cnf"),
105+
Path("/usr/local/etc/snowsql.cnf"),
106+
Path.home() / ".snowsql.cnf",
107+
Path.home() / ".snowsql" / "config",
108+
]
109+
else:
110+
# Standard paths when SNOWFLAKE_HOME not set
111+
self._config_files = [
112+
Path("/etc/snowsql.cnf"),
113+
Path("/etc/snowflake/snowsql.cnf"),
114+
Path("/usr/local/etc/snowsql.cnf"),
115+
Path.home() / ".snowsql.cnf",
116+
Path.home() / ".snowsql" / "config",
117+
]
97118

98119
@property
99120
def source_name(self) -> str:
@@ -162,17 +183,46 @@ class CliConfigFile(ValueSource):
162183
Does NOT merge multiple files - first found wins.
163184
Returns configuration for ALL connections.
164185
165-
Search order:
186+
Search order (when no override is set):
166187
1. ./config.toml (current directory)
167188
2. ~/.snowflake/config.toml (user config)
189+
190+
When config_file_override is set (e.g., in tests), only that file is used.
168191
"""
169192

170193
def __init__(self):
171194
"""Initialize CLI config file source."""
172-
self._search_paths = [
173-
Path.cwd() / "config.toml",
174-
Path.home() / ".snowflake" / "config.toml",
175-
]
195+
# Check for config file override from CLI context first
196+
try:
197+
from snowflake.cli.api.cli_global_context import get_cli_context
198+
199+
cli_context = get_cli_context()
200+
config_override = cli_context.config_file_override
201+
if config_override:
202+
self._search_paths = [Path(config_override)]
203+
return
204+
except Exception:
205+
pass
206+
207+
# Use SNOWFLAKE_HOME if set and directory exists, otherwise use standard paths
208+
snowflake_home = os.environ.get("SNOWFLAKE_HOME")
209+
if snowflake_home:
210+
snowflake_home_path = Path(snowflake_home).expanduser()
211+
if snowflake_home_path.exists():
212+
# Use only config.toml within SNOWFLAKE_HOME
213+
self._search_paths = [snowflake_home_path / "config.toml"]
214+
else:
215+
# SNOWFLAKE_HOME set but doesn't exist, use standard paths
216+
self._search_paths = [
217+
Path.cwd() / "config.toml",
218+
Path.home() / ".snowflake" / "config.toml",
219+
]
220+
else:
221+
# Standard paths when SNOWFLAKE_HOME not set
222+
self._search_paths = [
223+
Path.cwd() / "config.toml",
224+
Path.home() / ".snowflake" / "config.toml",
225+
]
176226

177227
@property
178228
def source_name(self) -> str:
@@ -234,7 +284,16 @@ class ConnectionsConfigFile(ValueSource):
234284

235285
def __init__(self):
236286
"""Initialize connections.toml source."""
237-
self._file_path = Path.home() / ".snowflake" / "connections.toml"
287+
# Use SNOWFLAKE_HOME if set and directory exists, otherwise use standard path
288+
snowflake_home = os.environ.get("SNOWFLAKE_HOME")
289+
if snowflake_home:
290+
snowflake_home_path = Path(snowflake_home).expanduser()
291+
if snowflake_home_path.exists():
292+
self._file_path = snowflake_home_path / "connections.toml"
293+
else:
294+
self._file_path = Path.home() / ".snowflake" / "connections.toml"
295+
else:
296+
self._file_path = Path.home() / ".snowflake" / "connections.toml"
238297

239298
@property
240299
def source_name(self) -> str:
@@ -244,6 +303,15 @@ def discover(self, key: Optional[str] = None) -> Dict[str, ConfigValue]:
244303
"""
245304
Read connections.toml if it exists.
246305
Returns keys in format: connections.{name}.{param} for ALL connections.
306+
307+
Supports both legacy formats:
308+
1. Direct connection sections (legacy):
309+
[default]
310+
database = "value"
311+
312+
2. Nested under [connections] section:
313+
[connections.default]
314+
database = "value"
247315
"""
248316
if not self._file_path.exists():
249317
return {}
@@ -253,12 +321,13 @@ def discover(self, key: Optional[str] = None) -> Dict[str, ConfigValue]:
253321
data = tomllib.load(f)
254322

255323
result = {}
256-
connections = data.get("connections", {})
257324

258-
for conn_name, conn_data in connections.items():
259-
if isinstance(conn_data, dict):
260-
for param_key, param_value in conn_data.items():
261-
full_key = f"connections.{conn_name}.{param_key}"
325+
# Check for direct connection sections (legacy format)
326+
for section_name, section_data in data.items():
327+
if isinstance(section_data, dict) and section_name != "connections":
328+
# This is a direct connection section like [default]
329+
for param_key, param_value in section_data.items():
330+
full_key = f"connections.{section_name}.{param_key}"
262331
if key is None or full_key == key:
263332
result[full_key] = ConfigValue(
264333
key=full_key,
@@ -267,6 +336,21 @@ def discover(self, key: Optional[str] = None) -> Dict[str, ConfigValue]:
267336
raw_value=param_value,
268337
)
269338

339+
# Check for nested [connections] section format
340+
connections_section = data.get("connections", {})
341+
if isinstance(connections_section, dict):
342+
for conn_name, conn_data in connections_section.items():
343+
if isinstance(conn_data, dict):
344+
for param_key, param_value in conn_data.items():
345+
full_key = f"connections.{conn_name}.{param_key}"
346+
if key is None or full_key == key:
347+
result[full_key] = ConfigValue(
348+
key=full_key,
349+
value=param_value,
350+
source_name=self.source_name,
351+
raw_value=param_value,
352+
)
353+
270354
return result
271355

272356
except Exception as e:

src/snowflake/cli/api/config_provider.py

Lines changed: 65 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,17 @@ def read_config(self) -> None:
158158
config_manager.read_config()
159159

160160
def get_connection_dict(self, connection_name: str) -> dict:
161-
from snowflake.cli.api.config import get_connection_dict
161+
from snowflake.cli.api.config import get_config_section
162162

163-
result = get_connection_dict(connection_name)
164-
return self._transform_private_key_raw(result)
163+
try:
164+
result = get_config_section("connections", connection_name)
165+
return self._transform_private_key_raw(result)
166+
except KeyError:
167+
from snowflake.cli.api.exceptions import MissingConfigurationError
168+
169+
raise MissingConfigurationError(
170+
f"Connection {connection_name} is not configured"
171+
)
165172

166173
def get_all_connections(self) -> dict:
167174
from snowflake.cli.api.config import get_all_connections
@@ -375,15 +382,16 @@ def _get_connection_dict_internal(self, connection_name: str) -> Dict[str, Any]:
375382
"""
376383
Get connection configuration by name.
377384
378-
Merges two types of keys:
379-
1. Connection-specific: connections.{name}.{param} (from files)
380-
2. Flat keys: {param} (from env/CLI, applies to active connection)
385+
Behavior is controlled by SNOWFLAKE_CLI_CONNECTIONS_TOML_REPLACE environment variable:
386+
- If set to "true" (default): connections.toml completely replaces connections
387+
from config.toml (legacy behavior)
388+
- If set to "false": connections.toml values are merged with config.toml values
381389
382390
Args:
383391
connection_name: Name of the connection
384392
385393
Returns:
386-
Dictionary of connection parameters
394+
Dictionary of connection parameters from file sources only
387395
"""
388396
self._ensure_initialized()
389397

@@ -392,19 +400,58 @@ def _get_connection_dict_internal(self, connection_name: str) -> Dict[str, Any]:
392400
self._config_cache = self._resolver.resolve()
393401

394402
connection_dict: Dict[str, Any] = {}
395-
396-
# First, get connection-specific keys (from file sources)
397403
connection_prefix = f"connections.{connection_name}."
398-
for key, value in self._config_cache.items():
399-
if key.startswith(connection_prefix):
400-
# Extract parameter name
401-
param_name = key[len(connection_prefix) :]
402-
connection_dict[param_name] = value
403404

404-
# Then, overlay flat keys (from env/CLI sources) - these have higher priority
405-
for key, value in self._config_cache.items():
406-
if "." not in key: # Flat key like "account", "user"
407-
connection_dict[key] = value
405+
# Check if replacement behavior is enabled (default: true for backward compatibility)
406+
import os
407+
408+
replace_behavior = os.environ.get(
409+
"SNOWFLAKE_CLI_CONNECTIONS_TOML_REPLACE", "true"
410+
).lower() in ("true", "1", "yes", "on")
411+
412+
if replace_behavior:
413+
# Legacy replacement behavior: if connections.toml has the connection,
414+
# use ONLY values from connections.toml
415+
has_connections_toml = False
416+
if self._resolver is not None:
417+
for key in self._config_cache.keys():
418+
if key.startswith(connection_prefix):
419+
# Check resolution history to see if this came from connections.toml
420+
history = self._resolver.get_resolution_history(key)
421+
if history and history.selected_entry:
422+
if (
423+
history.selected_entry.config_value.source_name
424+
== "connections_toml"
425+
):
426+
has_connections_toml = True
427+
break
428+
429+
if has_connections_toml:
430+
# Use ONLY connections.toml values (replacement behavior)
431+
for key, value in self._config_cache.items():
432+
if key.startswith(connection_prefix):
433+
# Check if this specific value comes from connections.toml
434+
if self._resolver is not None:
435+
history = self._resolver.get_resolution_history(key)
436+
if history and history.selected_entry:
437+
if (
438+
history.selected_entry.config_value.source_name
439+
== "connections_toml"
440+
):
441+
param_name = key[len(connection_prefix) :]
442+
connection_dict[param_name] = value
443+
else:
444+
# No connections.toml, use merged values from other sources
445+
for key, value in self._config_cache.items():
446+
if key.startswith(connection_prefix):
447+
param_name = key[len(connection_prefix) :]
448+
connection_dict[param_name] = value
449+
else:
450+
# New merging behavior: merge all sources normally
451+
for key, value in self._config_cache.items():
452+
if key.startswith(connection_prefix):
453+
param_name = key[len(connection_prefix) :]
454+
connection_dict[param_name] = value
408455

409456
if not connection_dict:
410457
from snowflake.cli.api.exceptions import MissingConfigurationError

src/snowflake/cli/api/connections.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,18 @@ class ConnectionContext:
6969
oauth_enable_single_use_refresh_tokens: Optional[bool] = None
7070
client_store_temporary_credential: Optional[bool] = None
7171

72+
# Internal flag to track if config has been loaded
73+
_config_loaded: bool = field(default=False, repr=False, init=False)
74+
7275
VALIDATED_FIELD_NAMES = ["schema"]
7376

7477
def present_values_as_dict(self) -> dict:
7578
"""Dictionary representation of this ConnectionContext for values that are not None"""
76-
return {k: v for (k, v) in asdict(self).items() if v is not None}
79+
return {
80+
k: v
81+
for (k, v) in asdict(self).items()
82+
if v is not None and not k.startswith("_")
83+
}
7784

7885
def clone(self) -> ConnectionContext:
7986
return replace(self)
@@ -112,6 +119,7 @@ def update_from_config(self) -> ConnectionContext:
112119
del connection_config["private_key_path"]
113120

114121
self.merge_with_config(**connection_config)
122+
self._config_loaded = True
115123
return self
116124

117125
def __repr__(self) -> str:
@@ -138,16 +146,12 @@ def validate_schema(self, value: Optional[str]):
138146
def validate_and_complete(self):
139147
"""
140148
Ensure we can create a connection from this context.
141-
Loads connection parameters from config if not already set.
149+
Sets default connection name if needed, but does not load configuration.
150+
Configuration is loaded lazily in build_connection().
142151
"""
143152
if not self.temporary_connection and not self.connection_name:
144153
self.connection_name = get_default_connection_name()
145154

146-
# Load connection parameters from config if we have a connection_name
147-
# and haven't loaded them yet (e.g., user is still None)
148-
if self.connection_name and not self.user:
149-
self.update_from_config()
150-
151155
def build_connection(self):
152156
from snowflake.cli._app.snow_connector import connect_to_snowflake
153157

@@ -160,9 +164,36 @@ def build_connection(self):
160164
module="snowflake.connector.config_manager",
161165
)
162166

163-
# Get connection parameters and pass them directly to connect_to_snowflake
164-
# This restores the original behavior before the change that enforced temporary_connection
165-
conn_params = self.present_values_as_dict()
167+
if self.temporary_connection:
168+
# For temporary connections, pass all parameters
169+
# connect_to_snowflake will use these directly without loading config
170+
conn_params = self.present_values_as_dict()
171+
else:
172+
# For named connections, pass connection_name and all override parameters
173+
# connect_to_snowflake will load the connection config internally and apply overrides
174+
all_params = self.present_values_as_dict()
175+
control_params = {
176+
"connection_name",
177+
"enable_diag",
178+
"diag_log_path",
179+
"diag_allowlist_path",
180+
"temporary_connection",
181+
"mfa_passcode",
182+
}
183+
184+
# Separate control parameters from connection overrides
185+
conn_params = {}
186+
overrides = {}
187+
188+
for k, v in all_params.items():
189+
if k in control_params:
190+
conn_params[k] = v
191+
else:
192+
# These are connection parameters that should override config values
193+
overrides[k] = v
194+
195+
# Merge overrides into conn_params
196+
conn_params.update(overrides)
166197

167198
return connect_to_snowflake(**conn_params)
168199

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ def os_agnostic_snapshot(snapshot):
106106
def reset_global_context_and_setup_config_and_logging_levels(
107107
request, test_snowcli_config
108108
):
109+
# Reset config provider singleton to prevent test interference
110+
from snowflake.cli.api.config_provider import reset_config_provider
111+
112+
reset_config_provider()
113+
109114
with fork_cli_context():
110115
connection_cache = OpenConnectionCache()
111116
cli_context_manager = get_cli_context_manager()

0 commit comments

Comments
 (0)