Skip to content

Commit e050a89

Browse files
committed
SNOW-2306184: config refactor - variables merge
1 parent 9300890 commit e050a89

File tree

7 files changed

+491
-11
lines changed

7 files changed

+491
-11
lines changed

src/snowflake/cli/_plugins/dcm/manager.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import yaml
1919
from snowflake.cli._plugins.stage.manager import StageManager
2020
from snowflake.cli.api.artifacts.upload import sync_artifacts_with_stage
21-
from snowflake.cli.api.commands.utils import parse_key_value_variables
2221
from snowflake.cli.api.console.console import cli_console
2322
from snowflake.cli.api.constants import (
2423
DEFAULT_SIZE_LIMIT_MB,
@@ -102,8 +101,15 @@ def execute(
102101
if configuration:
103102
query += f" CONFIGURATION {configuration}"
104103
if variables:
104+
from snowflake.cli.api.commands.common import Variable
105+
from snowflake.cli.api.config_ng import get_merged_variables
106+
107+
# Get merged variables from SnowSQL config and CLI -D parameters
108+
merged_vars_dict = get_merged_variables(variables)
109+
# Convert dict to List[Variable] for compatibility with parse_execute_variables
110+
parsed_variables = [Variable(k, v) for k, v in merged_vars_dict.items()]
105111
query += StageManager.parse_execute_variables(
106-
parse_key_value_variables(variables)
112+
parsed_variables
107113
).removeprefix(" using")
108114
stage_path = StagePath.from_stage_str(from_stage)
109115
query += f" FROM {stage_path.absolute_path()}"

src/snowflake/cli/_plugins/sql/commands.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
)
3030
from snowflake.cli.api.commands.overrideable_parameter import OverrideableOption
3131
from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
32-
from snowflake.cli.api.commands.utils import parse_key_value_variables
3332
from snowflake.cli.api.exceptions import CliArgumentError
3433
from snowflake.cli.api.output.types import (
3534
CommandResult,
@@ -136,9 +135,9 @@ def execute_sql(
136135
The command supports variable substitution that happens on client-side.
137136
"""
138137

139-
data = {}
140-
if data_override:
141-
data = {v.key: v.value for v in parse_key_value_variables(data_override)}
138+
from snowflake.cli.api.config_ng import get_merged_variables
139+
140+
data = get_merged_variables(data_override)
142141

143142
template_syntax_config = _parse_template_syntax_config(enabled_templating)
144143

src/snowflake/cli/_plugins/stage/manager.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
OnErrorType,
3939
Variable,
4040
)
41-
from snowflake.cli.api.commands.utils import parse_key_value_variables
4241
from snowflake.cli.api.console import cli_console
4342
from snowflake.cli.api.constants import PYTHON_3_12
4443
from snowflake.cli.api.exceptions import CliError
@@ -608,7 +607,12 @@ def execute(
608607
filtered_file_list, key=lambda f: (path.dirname(f), path.basename(f))
609608
)
610609

611-
parsed_variables = parse_key_value_variables(variables)
610+
from snowflake.cli.api.config_ng import get_merged_variables
611+
612+
# Get merged variables from SnowSQL config and CLI -D parameters
613+
merged_vars_dict = get_merged_variables(variables)
614+
# Convert dict back to List[Variable] for compatibility with existing methods
615+
parsed_variables = [Variable(k, v) for k, v in merged_vars_dict.items()]
612616
sql_variables = self.parse_execute_variables(parsed_variables)
613617
python_variables = self._parse_python_variables(parsed_variables)
614618
results = []

src/snowflake/cli/api/config_ng/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
ConnectionSpecificEnvironment,
5656
SnowSQLConfigFile,
5757
SnowSQLEnvironment,
58+
SnowSQLSection,
59+
get_merged_variables,
5860
)
5961

6062
__all__ = [
@@ -69,6 +71,7 @@
6971
"explain_configuration",
7072
"export_resolution_history",
7173
"format_summary_for_display",
74+
"get_merged_variables",
7275
"get_resolution_summary",
7376
"get_resolver",
7477
"is_resolution_logging_available",
@@ -80,6 +83,7 @@
8083
"show_resolution_chain",
8184
"SnowSQLConfigFile",
8285
"SnowSQLEnvironment",
86+
"SnowSQLSection",
8387
"SourceType",
8488
"ValueSource",
8589
]

src/snowflake/cli/api/config_ng/sources.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,27 @@
3131
import configparser
3232
import logging
3333
import os
34+
from enum import Enum
3435
from pathlib import Path
35-
from typing import Any, Dict, Final, Optional
36+
from typing import Any, Dict, Final, List, Optional
3637

3738
from snowflake.cli.api.config_ng.core import ConfigValue, SourceType, ValueSource
3839

3940
log = logging.getLogger(__name__)
4041

42+
43+
class SnowSQLSection(Enum):
44+
"""
45+
SnowSQL configuration file section names.
46+
47+
These sections can be present in SnowSQL INI config files.
48+
"""
49+
50+
CONNECTIONS = "connections"
51+
VARIABLES = "variables"
52+
OPTIONS = "options"
53+
54+
4155
# Try to import tomllib (Python 3.11+) or fall back to tomli
4256
try:
4357
import tomllib
@@ -142,9 +156,9 @@ def discover(self, key: Optional[str] = None) -> Dict[str, ConfigValue]:
142156

143157
# Process all connection sections
144158
for section in config.sections():
145-
if section.startswith("connections"):
159+
if section.startswith(SnowSQLSection.CONNECTIONS.value):
146160
# Extract connection name
147-
if section == "connections":
161+
if section == SnowSQLSection.CONNECTIONS.value:
148162
# This is default connection
149163
connection_name = "default"
150164
else:
@@ -170,6 +184,19 @@ def discover(self, key: Optional[str] = None) -> Dict[str, ConfigValue]:
170184
raw_value=f"{param_key}={param_value}", # Show original key in raw_value
171185
)
172186

187+
elif section == SnowSQLSection.VARIABLES.value:
188+
# Process variables section (global, not connection-specific)
189+
section_data = dict(config[section])
190+
for var_key, var_value in section_data.items():
191+
full_key = f"variables.{var_key}"
192+
if key is None or full_key == key:
193+
merged_values[full_key] = ConfigValue(
194+
key=full_key,
195+
value=var_value,
196+
source_name=self.source_name,
197+
raw_value=f"{var_key}={var_value}",
198+
)
199+
173200
except Exception as e:
174201
log.debug("Failed to read SnowSQL config %s: %s", config_file, e)
175202

@@ -755,3 +782,36 @@ def discover(self, key: Optional[str] = None) -> Dict[str, ConfigValue]:
755782
def supports_key(self, key: str) -> bool:
756783
"""Check if key is present in CLI context with non-None value."""
757784
return key in self._cli_context and self._cli_context[key] is not None
785+
786+
787+
def get_merged_variables(cli_variables: Optional[List[str]] = None) -> Dict[str, str]:
788+
"""
789+
Merge SnowSQL [variables] with CLI -D parameters.
790+
791+
Precedence: SnowSQL variables (lower) < -D parameters (higher)
792+
793+
Args:
794+
cli_variables: List of "key=value" strings from -D parameters
795+
796+
Returns:
797+
Dictionary of merged variables (key -> value)
798+
"""
799+
from snowflake.cli.api.config_provider import get_config_provider_singleton
800+
801+
# Start with SnowSQL variables from config
802+
provider = get_config_provider_singleton()
803+
try:
804+
snowsql_vars = provider.get_section(SnowSQLSection.VARIABLES.value)
805+
except Exception:
806+
# If variables section doesn't exist or provider not initialized, start with empty dict
807+
snowsql_vars = {}
808+
809+
# Parse and overlay -D parameters (higher precedence)
810+
if cli_variables:
811+
from snowflake.cli.api.commands.utils import parse_key_value_variables
812+
813+
cli_vars_parsed = parse_key_value_variables(cli_variables)
814+
for var in cli_vars_parsed:
815+
snowsql_vars[var.key] = var.value
816+
817+
return snowsql_vars

src/snowflake/cli/api/config_provider.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,15 @@ def get_section(self, *path) -> dict:
335335
connection_name = path[1]
336336
return self._get_connection_dict_internal(connection_name)
337337

338+
# For variables section, return all variables as flat dict
339+
if len(path) == 1 and path[0] == "variables":
340+
result = {}
341+
for key, value in self._config_cache.items():
342+
if key.startswith("variables."):
343+
var_name = key[len("variables.") :]
344+
result[var_name] = value
345+
return result
346+
338347
# For other sections, try to resolve with path prefix
339348
section_prefix = ".".join(path)
340349
result = {}

0 commit comments

Comments
 (0)