Skip to content

Commit 77c3188

Browse files
committed
SNOW-2306184: config refactor - simplified implementation
1 parent 01ec265 commit 77c3188

File tree

7 files changed

+282
-166
lines changed

7 files changed

+282
-166
lines changed

src/snowflake/cli/api/config.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -313,19 +313,35 @@ def config_section_exists(*path) -> bool:
313313

314314

315315
def get_all_connections() -> dict[str, ConnectionConfig]:
316-
return {
317-
k: ConnectionConfig.from_dict(connection_dict)
318-
for k, connection_dict in get_config_section("connections").items()
319-
}
316+
# Use config provider if available
317+
try:
318+
from snowflake.cli.api.config_provider import get_config_provider_singleton
319+
320+
provider = get_config_provider_singleton()
321+
return provider.get_all_connections()
322+
except Exception:
323+
# Fall back to legacy implementation
324+
return {
325+
k: ConnectionConfig.from_dict(connection_dict)
326+
for k, connection_dict in get_config_section("connections").items()
327+
}
320328

321329

322330
def get_connection_dict(connection_name: str) -> dict:
331+
# Use config provider if available
323332
try:
324-
return get_config_section(CONNECTIONS_SECTION, connection_name)
325-
except KeyError:
326-
raise MissingConfigurationError(
327-
f"Connection {connection_name} is not configured"
328-
)
333+
from snowflake.cli.api.config_provider import get_config_provider_singleton
334+
335+
provider = get_config_provider_singleton()
336+
return provider.get_connection_dict(connection_name)
337+
except Exception:
338+
# Fall back to legacy implementation
339+
try:
340+
return get_config_section(CONNECTIONS_SECTION, connection_name)
341+
except KeyError:
342+
raise MissingConfigurationError(
343+
f"Connection {connection_name} is not configured"
344+
)
329345

330346

331347
def get_default_connection_name() -> str:

src/snowflake/cli/api/config_ng/resolver.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,11 @@ def resolve(self, key: Optional[str] = None, default: Any = None) -> Dict[str, A
348348
Resolution Process:
349349
1. Iterate sources in order (lowest to highest priority)
350350
2. Record all discovered values in history
351-
3. Later sources overwrite earlier sources (simple dict update)
352-
4. Mark which value was selected
353-
5. Return final resolved values
351+
3. For connection keys (connections.{name}.{param}):
352+
- Merge connection-by-connection: later sources extend/overwrite individual params
353+
4. For flat keys: later sources overwrite earlier sources
354+
5. Mark which value was selected
355+
6. Return final resolved values
354356
355357
Args:
356358
key: Specific key to resolve (None = all keys)
@@ -360,9 +362,10 @@ def resolve(self, key: Optional[str] = None, default: Any = None) -> Dict[str, A
360362
Dictionary of resolved values (key -> value)
361363
"""
362364
all_values: Dict[str, ConfigValue] = {}
365+
# Track connection values separately for intelligent merging
366+
connections: Dict[str, Dict[str, ConfigValue]] = defaultdict(dict)
363367

364368
# Process sources in order (first = lowest priority, last = highest)
365-
# Later sources overwrite earlier ones via dict.update()
366369
for source in self._sources:
367370
try:
368371
source_values = source.discover(key)
@@ -371,12 +374,29 @@ def resolve(self, key: Optional[str] = None, default: Any = None) -> Dict[str, A
371374
for k, config_value in source_values.items():
372375
self._history_tracker.record_discovery(k, config_value)
373376

374-
# Update current values (later source overwrites earlier)
375-
all_values.update(source_values)
377+
# Separate connection keys from flat keys
378+
for k, config_value in source_values.items():
379+
if k.startswith("connections."):
380+
# Parse: connections.{name}.{param}
381+
parts = k.split(".", 2)
382+
if len(parts) == 3:
383+
conn_name = parts[1]
384+
param = parts[2]
385+
param_key = f"connections.{conn_name}.{param}"
386+
387+
# Merge at parameter level: later source overwrites/extends
388+
connections[conn_name][param_key] = config_value
389+
else:
390+
# Flat key: later source overwrites
391+
all_values[k] = config_value
376392

377393
except Exception as e:
378394
log.warning("Error from source %s: %s", source.source_name, e)
379395

396+
# Flatten connection data back into all_values
397+
for conn_name, conn_params in connections.items():
398+
all_values.update(conn_params)
399+
380400
# Mark which values were selected in history
381401
for k, config_value in all_values.items():
382402
self._history_tracker.mark_selected(k, config_value.source_name)

0 commit comments

Comments
 (0)