diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index b2495b7336..181285a064 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -12,7 +12,7 @@ on: - features/* env: - TERM: unknown # Disables colors in rich + TERM: unknown # Disables colors in rich permissions: contents: read @@ -43,3 +43,29 @@ jobs: - name: Test with hatch run: hatch run test-cov - uses: codecov/codecov-action@18283e04ce6e62d37312384ff67231eb8fd56d24 + + tests-config-ng: + needs: define-matrix + strategy: + fail-fast: true + matrix: + os: ${{ fromJSON(needs.define-matrix.outputs.os) }} + python-version: ${{ fromJSON(needs.define-matrix.outputs.python) }} + runs-on: ${{ matrix.os }} + env: + SNOWFLAKE_CLI_CONFIG_V2_ENABLED: 1 + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install hatch + run: | + pip install -U click==8.2.1 hatch + hatch env create default + - name: Test with hatch + run: hatch run test-cov + - uses: codecov/codecov-action@18283e04ce6e62d37312384ff67231eb8fd56d24 diff --git a/.github/workflows/test_trusted.yaml b/.github/workflows/test_trusted.yaml index 20e12e0f32..a101d03631 100644 --- a/.github/workflows/test_trusted.yaml +++ b/.github/workflows/test_trusted.yaml @@ -46,3 +46,31 @@ jobs: SNOWFLAKE_CONNECTIONS_INTEGRATION_DATABASE: ${{ secrets.SNOWFLAKE_DATABASE }} SNOWFLAKE_CONNECTIONS_INTEGRATION_PRIVATE_KEY_RAW: ${{ secrets.SNOWFLAKE_PRIVATE_KEY_RAW }} run: python -m hatch run ${{ inputs.hatch-run }} + + tests-trusted-ng: + runs-on: ${{ inputs.runs-on }} + env: + SNOWFLAKE_CLI_CONFIG_V2_ENABLED: 1 + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip click==8.2.1 hatch + python -m hatch env create ${{ inputs.python-env }} + - name: Run integration tests + env: + GH_TOKEN: ${{ secrets.SNOWFLAKE_GITHUB_TOKEN }} + TERM: unknown + SNOWFLAKE_CONNECTIONS_INTEGRATION_AUTHENTICATOR: SNOWFLAKE_JWT + SNOWFLAKE_CONNECTIONS_INTEGRATION_HOST: ${{ secrets.SNOWFLAKE_HOST }} + SNOWFLAKE_CONNECTIONS_INTEGRATION_USER: ${{ secrets.SNOWFLAKE_USER }} + SNOWFLAKE_CONNECTIONS_INTEGRATION_ACCOUNT: ${{ secrets.SNOWFLAKE_ACCOUNT }} + SNOWFLAKE_CONNECTIONS_INTEGRATION_DATABASE: ${{ secrets.SNOWFLAKE_DATABASE }} + SNOWFLAKE_CONNECTIONS_INTEGRATION_PRIVATE_KEY_RAW: ${{ secrets.SNOWFLAKE_PRIVATE_KEY_RAW }} + run: python -m hatch run ${{ inputs.hatch-run }} diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index a1b4af1ca8..5ab9f66845 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -15,6 +15,7 @@ --> # Unreleased version ## Backward incompatibility +* **Configuration System (NG)**: File-based configuration sources (`snowsql_config`, `cli_config_toml`, `connections_toml`) now use **connection-level replacement** instead of field-level merging. When a later file source defines a connection, it completely replaces the entire connection from earlier file sources - fields are NOT inherited. Environment variables and CLI arguments continue to overlay per-field on top of the file-derived connection. This provides more predictable configuration behavior where file-defined connections are atomic units. ## Deprecations diff --git a/src/snowflake/cli/_app/telemetry.py b/src/snowflake/cli/_app/telemetry.py index 77ba8def26..10198c6a6f 100644 --- a/src/snowflake/cli/_app/telemetry.py +++ b/src/snowflake/cli/_app/telemetry.py @@ -61,6 +61,11 @@ class CLITelemetryField(Enum): COMMAND_CI_ENVIRONMENT = "command_ci_environment" # Configuration CONFIG_FEATURE_FLAGS = "config_feature_flags" + CONFIG_PROVIDER_TYPE = "config_provider_type" + CONFIG_SOURCES_USED = "config_sources_used" + CONFIG_SOURCE_WINS = "config_source_wins" + CONFIG_TOTAL_KEYS_RESOLVED = "config_total_keys_resolved" + CONFIG_KEYS_WITH_OVERRIDES = "config_keys_with_overrides" # Metrics COUNTERS = "counters" SPANS = "spans" @@ -219,6 +224,55 @@ def python_version() -> str: return f"{py_ver.major}.{py_ver.minor}.{py_ver.micro}" +def _get_config_telemetry() -> TelemetryDict: + """Get configuration resolution telemetry data.""" + try: + from snowflake.cli.api.config_ng.telemetry_integration import ( + get_config_telemetry_payload, + ) + from snowflake.cli.api.config_provider import ( + AlternativeConfigProvider, + get_config_provider_singleton, + ) + + provider = get_config_provider_singleton() + + # Identify which config provider is being used + provider_type = ( + "ng" if isinstance(provider, AlternativeConfigProvider) else "legacy" + ) + + result: TelemetryDict = {CLITelemetryField.CONFIG_PROVIDER_TYPE: provider_type} + + # Get detailed telemetry if using ng config + if isinstance(provider, AlternativeConfigProvider): + provider._ensure_initialized() # noqa: SLF001 + payload = get_config_telemetry_payload(provider._resolver) # noqa: SLF001 + + # Map payload keys to telemetry fields + if payload: + if "config_sources_used" in payload: + result[CLITelemetryField.CONFIG_SOURCES_USED] = payload[ + "config_sources_used" + ] + if "config_source_wins" in payload: + result[CLITelemetryField.CONFIG_SOURCE_WINS] = payload[ + "config_source_wins" + ] + if "config_total_keys_resolved" in payload: + result[CLITelemetryField.CONFIG_TOTAL_KEYS_RESOLVED] = payload[ + "config_total_keys_resolved" + ] + if "config_keys_with_overrides" in payload: + result[CLITelemetryField.CONFIG_KEYS_WITH_OVERRIDES] = payload[ + "config_keys_with_overrides" + ] + + return result + except Exception: + return {} + + class CLITelemetryClient: @property def _ctx(self) -> _CliGlobalContextAccess: @@ -239,6 +293,7 @@ def generate_telemetry_data_dict( k: str(v) for k, v in get_feature_flags_section().items() }, **_find_command_info(), + **_get_config_telemetry(), **telemetry_payload, } # To map Enum to string, so we don't have to use .value every time diff --git a/src/snowflake/cli/_plugins/connection/commands.py b/src/snowflake/cli/_plugins/connection/commands.py index f177a963be..7d39c12e8f 100644 --- a/src/snowflake/cli/_plugins/connection/commands.py +++ b/src/snowflake/cli/_plugins/connection/commands.py @@ -94,11 +94,32 @@ def _mask_sensitive_parameters(connection_params: dict): @app.command(name="list") -def list_connections(**options) -> CommandResult: +def list_connections( + all_sources: bool = typer.Option( + False, + "--all", + "-a", + help="Include connections from all sources (environment variables, SnowSQL config). " + "By default, only shows connections from configuration files.", + ), + **options, +) -> CommandResult: """ Lists configured connections. """ - connections = get_all_connections() + from snowflake.cli.api.config_provider import ( + get_config_provider_singleton, + is_alternative_config_enabled, + ) + + # Use provider directly for config_ng to pass the flag + if is_alternative_config_enabled(): + provider = get_config_provider_singleton() + connections = provider.get_all_connections(include_env_connections=all_sources) + else: + # Legacy provider ignores the flag + connections = get_all_connections() + default_connection = get_default_connection_name() result = ( { diff --git a/src/snowflake/cli/_plugins/dcm/manager.py b/src/snowflake/cli/_plugins/dcm/manager.py index 717239b4cf..8357f40150 100644 --- a/src/snowflake/cli/_plugins/dcm/manager.py +++ b/src/snowflake/cli/_plugins/dcm/manager.py @@ -18,7 +18,6 @@ import yaml from snowflake.cli._plugins.stage.manager import StageManager from snowflake.cli.api.artifacts.upload import sync_artifacts_with_stage -from snowflake.cli.api.commands.utils import parse_key_value_variables from snowflake.cli.api.console.console import cli_console from snowflake.cli.api.constants import ( DEFAULT_SIZE_LIMIT_MB, @@ -102,8 +101,15 @@ def execute( if configuration: query += f" CONFIGURATION {configuration}" if variables: + from snowflake.cli.api.commands.common import Variable + from snowflake.cli.api.config_ng import get_merged_variables + + # Get merged variables from SnowSQL config and CLI -D parameters + merged_vars_dict = get_merged_variables(variables) + # Convert dict to List[Variable] for compatibility with parse_execute_variables + parsed_variables = [Variable(k, v) for k, v in merged_vars_dict.items()] query += StageManager.parse_execute_variables( - parse_key_value_variables(variables) + parsed_variables ).removeprefix(" using") stage_path = StagePath.from_stage_str(from_stage) query += f" FROM {stage_path.absolute_path()}" diff --git a/src/snowflake/cli/_plugins/helpers/commands.py b/src/snowflake/cli/_plugins/helpers/commands.py index 45a02faf79..e1363fc55d 100644 --- a/src/snowflake/cli/_plugins/helpers/commands.py +++ b/src/snowflake/cli/_plugins/helpers/commands.py @@ -30,6 +30,7 @@ get_all_connections, set_config_value, ) +from snowflake.cli.api.config_provider import ALTERNATIVE_CONFIG_ENV_VAR from snowflake.cli.api.console import cli_console from snowflake.cli.api.output.types import ( CollectionResult, @@ -317,3 +318,89 @@ def check_snowsql_env_vars(**options): results.append(MessageResult(summary)) return MultipleResults(results) + + +@app.command( + name="show-config-sources", + requires_connection=False, + hidden=os.environ.get(ALTERNATIVE_CONFIG_ENV_VAR, "").lower() + not in ("1", "true", "yes", "on"), +) +def show_config_sources( + key: Optional[str] = typer.Argument( + None, + help="Specific configuration key to show resolution for (e.g., 'account', 'user'). If not provided, shows summary for all keys.", + ), + show_details: bool = typer.Option( + False, + "--show-details", + "-d", + help="Show detailed resolution chains for all sources consulted.", + ), + export_file: Optional[Path] = typer.Option( + None, + "--export", + "-e", + help="Export complete resolution history to JSON file for support or debugging.", + file_okay=True, + dir_okay=False, + ), + **options, +) -> CommandResult: + """ + Show where configuration values come from. + + This command displays the configuration resolution process, showing which + source (CLI arguments, environment variables, or config files) provided + each configuration value. Useful for debugging configuration issues. + + Examples: + + # Show summary of all configuration resolution + snow helpers show-config-sources + + # Show detailed resolution for all keys + snow helpers show-config-sources --show-details + + # Show resolution for a specific key + snow helpers show-config-sources account + + # Show detailed resolution for a specific key + snow helpers show-config-sources account --show-details + + # Export complete resolution history to file + snow helpers show-config-sources --export config_debug.json + + Note: This command requires the enhanced configuration system to be enabled. + Set SNOWFLAKE_CLI_CONFIG_V2_ENABLED=true to enable it. + """ + from snowflake.cli.api.config_ng import ( + export_resolution_history, + is_resolution_logging_available, + ) + from snowflake.cli.api.config_ng.resolution_logger import ( + get_configuration_explanation_results, + ) + + if not is_resolution_logging_available(): + return MessageResult( + f"⚠️ Configuration resolution logging is not available.\n\n" + f"To enable it, set the environment variable:\n" + f" export {ALTERNATIVE_CONFIG_ENV_VAR}=true\n\n" + f"Then run this command again to see where configuration values come from." + ) + + # Export if requested + if export_file: + success = export_resolution_history(export_file) + if not success: + return MessageResult( + f"❌ Failed to export resolution history to {export_file}" + ) + return MessageResult( + f"✅ Resolution history exported to: {export_file}\n\n" + f"This file contains complete details about configuration resolution " + f"and can be attached to support tickets." + ) + + return get_configuration_explanation_results(key=key, verbose=show_details) diff --git a/src/snowflake/cli/_plugins/sql/commands.py b/src/snowflake/cli/_plugins/sql/commands.py index 7a3b2ad817..9b73929f65 100644 --- a/src/snowflake/cli/_plugins/sql/commands.py +++ b/src/snowflake/cli/_plugins/sql/commands.py @@ -29,7 +29,6 @@ ) from snowflake.cli.api.commands.overrideable_parameter import OverrideableOption from snowflake.cli.api.commands.snow_typer import SnowTyperFactory -from snowflake.cli.api.commands.utils import parse_key_value_variables from snowflake.cli.api.exceptions import CliArgumentError from snowflake.cli.api.output.types import ( CommandResult, @@ -136,9 +135,9 @@ def execute_sql( The command supports variable substitution that happens on client-side. """ - data = {} - if data_override: - data = {v.key: v.value for v in parse_key_value_variables(data_override)} + from snowflake.cli.api.config_ng import get_merged_variables + + data = get_merged_variables(data_override) template_syntax_config = _parse_template_syntax_config(enabled_templating) diff --git a/src/snowflake/cli/_plugins/stage/manager.py b/src/snowflake/cli/_plugins/stage/manager.py index 3e8a72c9f7..4865697f3d 100644 --- a/src/snowflake/cli/_plugins/stage/manager.py +++ b/src/snowflake/cli/_plugins/stage/manager.py @@ -38,7 +38,6 @@ OnErrorType, Variable, ) -from snowflake.cli.api.commands.utils import parse_key_value_variables from snowflake.cli.api.console import cli_console from snowflake.cli.api.constants import PYTHON_3_12 from snowflake.cli.api.exceptions import CliError @@ -608,7 +607,12 @@ def execute( filtered_file_list, key=lambda f: (path.dirname(f), path.basename(f)) ) - parsed_variables = parse_key_value_variables(variables) + from snowflake.cli.api.config_ng import get_merged_variables + + # Get merged variables from SnowSQL config and CLI -D parameters + merged_vars_dict = get_merged_variables(variables) + # Convert dict back to List[Variable] for compatibility with existing methods + parsed_variables = [Variable(k, v) for k, v in merged_vars_dict.items()] sql_variables = self.parse_execute_variables(parsed_variables) python_variables = self._parse_python_variables(parsed_variables) results = [] diff --git a/src/snowflake/cli/api/config.py b/src/snowflake/cli/api/config.py index 5be189c2fb..77c9aac461 100644 --- a/src/snowflake/cli/api/config.py +++ b/src/snowflake/cli/api/config.py @@ -99,6 +99,10 @@ class ConnectionConfig: authenticator: Optional[str] = None workload_identity_provider: Optional[str] = None private_key_file: Optional[str] = None + private_key_passphrase: Optional[str] = field(default=None, repr=False) + token: Optional[str] = field(default=None, repr=False) + session_token: Optional[str] = field(default=None, repr=False) + master_token: Optional[str] = field(default=None, repr=False) token_file_path: Optional[str] = None oauth_client_id: Optional[str] = None oauth_client_secret: Optional[str] = None @@ -106,7 +110,7 @@ class ConnectionConfig: oauth_token_request_url: Optional[str] = None oauth_redirect_uri: Optional[str] = None oauth_scope: Optional[str] = None - oatuh_enable_pkce: Optional[bool] = None + oauth_enable_pkce: Optional[bool] = None oauth_enable_refresh_tokens: Optional[bool] = None oauth_enable_single_use_refresh_tokens: Optional[bool] = None client_store_temporary_credential: Optional[bool] = None @@ -217,6 +221,16 @@ def _config_file(): yield conf_file_cache _dump_config(conf_file_cache) + # Reset config provider cache after writing to ensure it re-reads on next access + try: + from snowflake.cli.api.config_provider import get_config_provider_singleton + + provider = get_config_provider_singleton() + if hasattr(provider, "invalidate_cache"): + provider.invalidate_cache() + except Exception: + pass + def _read_config_file(): config_manager = get_config_manager() @@ -308,19 +322,17 @@ def config_section_exists(*path) -> bool: def get_all_connections() -> dict[str, ConnectionConfig]: - return { - k: ConnectionConfig.from_dict(connection_dict) - for k, connection_dict in get_config_section("connections").items() - } + from snowflake.cli.api.config_provider import get_config_provider_singleton + + provider = get_config_provider_singleton() + return provider.get_all_connections() def get_connection_dict(connection_name: str) -> dict: - try: - return get_config_section(CONNECTIONS_SECTION, connection_name) - except KeyError: - raise MissingConfigurationError( - f"Connection {connection_name} is not configured" - ) + from snowflake.cli.api.config_provider import get_config_provider_singleton + + provider = get_config_provider_singleton() + return provider.get_connection_dict(connection_name) def get_default_connection_name() -> str: diff --git a/src/snowflake/cli/api/config_ng/__init__.py b/src/snowflake/cli/api/config_ng/__init__.py new file mode 100644 index 0000000000..ab7e7ddded --- /dev/null +++ b/src/snowflake/cli/api/config_ng/__init__.py @@ -0,0 +1,120 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Enhanced Configuration System - Next Generation (NG) + +This package implements a simple, extensible configuration system with: +- Two-phase resolution: file sources use connection-level replacement, + overlay sources (env/CLI) use field-level merging +- List-order precedence (explicit ordering in source list) +- Migration support (SnowCLI and SnowSQL compatibility) +- Complete resolution history tracking +- Read-only, immutable configuration sources +""" + +from snowflake.cli.api.config_ng.constants import ( + FILE_SOURCE_NAMES, + INTERNAL_CLI_PARAMETERS, + ConfigSection, +) +from snowflake.cli.api.config_ng.core import ( + ConfigValue, + ResolutionEntry, + ResolutionHistory, + SourceType, + ValueSource, +) +from snowflake.cli.api.config_ng.dict_utils import deep_merge +from snowflake.cli.api.config_ng.merge_operations import ( + create_default_connection_from_params, + extract_root_level_connection_params, + merge_params_into_connections, +) +from snowflake.cli.api.config_ng.parsers import SnowSQLParser, TOMLParser +from snowflake.cli.api.config_ng.presentation import ResolutionPresenter +from snowflake.cli.api.config_ng.resolution_logger import ( + check_value_source, + explain_configuration, + export_resolution_history, + format_summary_for_display, + get_resolution_summary, + get_resolver, + is_resolution_logging_available, + show_all_resolution_chains, + show_resolution_chain, +) +from snowflake.cli.api.config_ng.resolver import ( + ConfigurationResolver, + ResolutionHistoryTracker, +) +from snowflake.cli.api.config_ng.source_factory import create_default_sources +from snowflake.cli.api.config_ng.source_manager import SourceManager +from snowflake.cli.api.config_ng.sources import ( + CliConfigFile, + CliEnvironment, + CliParameters, + ConnectionsConfigFile, + ConnectionSpecificEnvironment, + SnowSQLConfigFile, + SnowSQLEnvironment, + SnowSQLSection, + get_merged_variables, +) +from snowflake.cli.api.config_ng.telemetry_integration import ( + get_config_telemetry_payload, + record_config_source_usage, +) + +__all__ = [ + "check_value_source", + "CliConfigFile", + "CliEnvironment", + "CliParameters", + "ConfigSection", + "ConfigurationResolver", + "ConfigValue", + "ConnectionsConfigFile", + "ConnectionSpecificEnvironment", + "create_default_connection_from_params", + "create_default_sources", + "deep_merge", + "explain_configuration", + "export_resolution_history", + "extract_root_level_connection_params", + "FILE_SOURCE_NAMES", + "format_summary_for_display", + "get_config_telemetry_payload", + "get_merged_variables", + "get_resolution_summary", + "get_resolver", + "INTERNAL_CLI_PARAMETERS", + "is_resolution_logging_available", + "merge_params_into_connections", + "record_config_source_usage", + "ResolutionEntry", + "ResolutionHistory", + "ResolutionHistoryTracker", + "ResolutionPresenter", + "show_all_resolution_chains", + "show_resolution_chain", + "SnowSQLConfigFile", + "SnowSQLEnvironment", + "SnowSQLParser", + "SnowSQLSection", + "SourceManager", + "SourceType", + "TOMLParser", + "ValueSource", +] diff --git a/src/snowflake/cli/api/config_ng/constants.py b/src/snowflake/cli/api/config_ng/constants.py new file mode 100644 index 0000000000..40b938507a --- /dev/null +++ b/src/snowflake/cli/api/config_ng/constants.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants for configuration system.""" + +from enum import Enum +from typing import Final, Literal + + +class ConfigSection(str, Enum): + """Configuration section names.""" + + CONNECTIONS = "connections" + VARIABLES = "variables" + CLI = "cli" + CLI_LOGS = "cli.logs" + CLI_FEATURES = "cli.features" + + def __str__(self) -> str: + """Return the string value for backward compatibility.""" + return self.value + + +# Environment variable names +SNOWFLAKE_HOME_ENV: Final[str] = "SNOWFLAKE_HOME" + +# Internal CLI parameters that should not be treated as connection parameters +INTERNAL_CLI_PARAMETERS: Final[set[str]] = { + "enable_diag", + "temporary_connection", + "default_connection_name", + "connection_name", + "diag_log_path", + "diag_allowlist_path", + "mfa_passcode", +} + +# Define Literal type for file source names +FileSourceName = Literal[ + "snowsql_config", + "cli_config_toml", + "connections_toml", +] + +# Source names that represent file-based configuration sources +FILE_SOURCE_NAMES: Final[set[str]] = { + "snowsql_config", + "cli_config_toml", + "connections_toml", +} diff --git a/src/snowflake/cli/api/config_ng/core.py b/src/snowflake/cli/api/config_ng/core.py new file mode 100644 index 0000000000..7d8bff20f0 --- /dev/null +++ b/src/snowflake/cli/api/config_ng/core.py @@ -0,0 +1,271 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Core abstractions for the enhanced configuration system. + +This module implements the foundational data structures and interfaces: +- ConfigValue: Immutable value container with provenance +- ValueSource: Common protocol for all configuration sources +- ResolutionHistory: Tracks the complete resolution process +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any, Callable, Dict, List, Literal, Optional + + +class SourceType(Enum): + """ + Classification of configuration sources by merging behavior. + + FILE sources use connection-level replacement (later file replaces entire connection). + OVERLAY sources use field-level overlay (add/override individual fields). + """ + + FILE = "file" + OVERLAY = "overlay" + + +@dataclass(frozen=True) +class ConfigValue: + """ + Immutable configuration value with full provenance tracking. + Stores both parsed value and original raw value. + """ + + key: str + value: Any + source_name: str + raw_value: Optional[Any] = None + + def __repr__(self) -> str: + """Readable representation showing conversion if applicable.""" + value_display = f"{self.value}" + if self.raw_value is not None and self.raw_value != self.value: + value_display = f"{self.raw_value} → {self.value}" + return f"ConfigValue({self.key}={value_display}, from {self.source_name})" + + @classmethod + def from_source( + cls, + key: str, + raw_value: str, + source_name: str, + value_parser: Optional[Callable[[str], Any]] = None, + ) -> ConfigValue: + """ + Factory method to create ConfigValue from a source. + + Args: + key: Configuration key + raw_value: Raw string value from the source + source_name: Name of the configuration source + value_parser: Optional parser function; if None, raw_value is used as-is + + Returns: + ConfigValue instance with parsed value + """ + parsed_value = value_parser(raw_value) if value_parser else raw_value + return cls( + key=key, + value=parsed_value, + source_name=source_name, + raw_value=raw_value, + ) + + +class ValueSource(ABC): + """ + Common interface for all configuration sources. + All implementations are READ-ONLY discovery mechanisms. + Precedence is determined by the order sources are provided to the resolver. + """ + + # Allowed source names for config resolution + SourceName = Literal[ + "snowsql_config", + "cli_config_toml", + "connections_toml", + "snowsql_env", + "connection_specific_env", + "cli_env", + "cli_arguments", + ] + + @property + @abstractmethod + def source_name(self) -> SourceName: + """ + Unique identifier for this source. + Examples: "cli_arguments", "snowsql_config", "cli_env" + """ + ... + + @property + @abstractmethod + def source_type(self) -> SourceType: + """ + Classification of this source for merging behavior. + FILE sources replace entire connections, OVERLAY sources merge per-field. + """ + ... + + @abstractmethod + def discover(self, key: Optional[str] = None) -> Dict[str, Any]: + """ + Discover configuration values as nested dict structure. + + Sources return configuration as nested dictionaries that reflect + the natural structure of the configuration. For example: + {"connections": {"prod": {"account": "val"}}} + + Empty connections are represented as empty dicts: + {"connections": {"prod": {}}} + + General parameters (not connection-specific) are at the root level: + {"database": "mydb", "role": "myrole"} + + Args: + key: Specific key path to discover (dot-separated), or None for all + + Returns: + Nested dictionary of configuration values. Returns empty dict + if no values found. + """ + ... + + @abstractmethod + def supports_key(self, key: str) -> bool: + """ + Check if this source can provide the given configuration key. + + Args: + key: Configuration key to check + + Returns: + True if this source supports the key, False otherwise + """ + ... + + +@dataclass(frozen=True) +class ResolutionEntry: + """ + Represents a single value discovery during resolution. + Immutable record of what was found where and when. + """ + + config_value: ConfigValue + timestamp: datetime + was_used: bool + overridden_by: Optional[str] = None + + +@dataclass +class ResolutionHistory: + """ + Complete resolution history for a single configuration key. + Shows the full precedence chain from lowest to highest priority. + """ + + key: str + entries: List[ResolutionEntry] = field(default_factory=list) + final_value: Optional[Any] = None + default_used: bool = False + + @property + def sources_consulted(self) -> List[str]: + """List of all source names that were consulted.""" + return [entry.config_value.source_name for entry in self.entries] + + @property + def values_considered(self) -> List[Any]: + """List of all values that were considered.""" + return [entry.config_value.value for entry in self.entries] + + @property + def selected_entry(self) -> Optional[ResolutionEntry]: + """The entry that was ultimately selected.""" + for entry in self.entries: + if entry.was_used: + return entry + return None + + @property + def overridden_entries(self) -> List[ResolutionEntry]: + """All entries that were overridden by higher priority sources.""" + return [entry for entry in self.entries if not entry.was_used] + + def format_chain(self) -> str: + """ + Format the resolution chain as a readable string. + + Example output: + account resolution chain (4 sources): + 1. ❌ snowsql_config: "old_account" (overridden by cli_arguments) + 2. ❌ toml:connections: "new_account" (overridden by cli_arguments) + 3. ❌ snowflake_cli_env: "env_account" (overridden by cli_arguments) + 4. ✅ cli_arguments: "final_account" (SELECTED) + """ + lines = [f"{self.key} resolution chain ({len(self.entries)} sources):"] + + for i, entry in enumerate(self.entries, 1): + cv = entry.config_value + status_icon = "✅" if entry.was_used else "❌" + + if entry.was_used: + status_text = "(SELECTED)" + elif entry.overridden_by: + status_text = f"(overridden by {entry.overridden_by})" + else: + status_text = "(not used)" + + # Show raw value if different from parsed value + value_display = f'"{cv.value}"' + if cv.raw_value is not None and cv.raw_value != cv.value: + value_display = f'"{cv.raw_value}" → {cv.value}' + + lines.append( + f" {i}. {status_icon} {cv.source_name}: {value_display} {status_text}" + ) + + if self.default_used: + lines.append(f" Default value used: {self.final_value}") + + return "\n".join(lines) + + def to_dict(self) -> dict: + """Convert to dictionary for JSON serialization/export.""" + return { + "key": self.key, + "final_value": self.final_value, + "default_used": self.default_used, + "sources_consulted": self.sources_consulted, + "entries": [ + { + "source": entry.config_value.source_name, + "value": entry.config_value.value, + "raw_value": entry.config_value.raw_value, + "was_used": entry.was_used, + "overridden_by": entry.overridden_by, + "timestamp": entry.timestamp.isoformat(), + } + for entry in self.entries + ], + } diff --git a/src/snowflake/cli/api/config_ng/dict_utils.py b/src/snowflake/cli/api/config_ng/dict_utils.py new file mode 100644 index 0000000000..740e1b31f2 --- /dev/null +++ b/src/snowflake/cli/api/config_ng/dict_utils.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for nested dictionary operations.""" + +from typing import Any, Dict + + +def deep_merge(base: Dict[str, Any], overlay: Dict[str, Any]) -> Dict[str, Any]: + """ + Deep merge two dictionaries. Overlay values win on conflict. + + Recursively merges nested dictionaries. Non-dict values from overlay + replace values in base. + + Example: + base = {"a": {"b": 1, "c": 2}} + overlay = {"a": {"c": 3, "d": 4}} + result = {"a": {"b": 1, "c": 3, "d": 4}} + + Args: + base: Base dictionary + overlay: Overlay dictionary (wins on conflicts) + + Returns: + Merged dictionary + """ + result = base.copy() + + for key, value in overlay.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = deep_merge(result[key], value) + else: + result[key] = value + + return result diff --git a/src/snowflake/cli/api/config_ng/merge_operations.py b/src/snowflake/cli/api/config_ng/merge_operations.py new file mode 100644 index 0000000000..f789a66632 --- /dev/null +++ b/src/snowflake/cli/api/config_ng/merge_operations.py @@ -0,0 +1,118 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pure functions for configuration merging operations.""" + +from typing import Any, Dict + +from snowflake.cli.api.config_ng.constants import ( + INTERNAL_CLI_PARAMETERS, + ConfigSection, +) + + +def extract_root_level_connection_params( + config: Dict[str, Any], +) -> tuple[Dict[str, Any], Dict[str, Any]]: + """ + Extract root-level connection parameters from config. + + Connection parameters at root level (not under any section) should + be treated as general connection parameters that apply to all connections. + + Args: + config: Configuration dictionary with mixed sections and parameters + + Returns: + Tuple of (connection_params, remaining_config) + + Example: + Input: {"account": "acc", "cli": {...}, "connections": {...}} + Output: ({"account": "acc"}, {"cli": {...}, "connections": {...}}) + """ + known_sections = {s.value for s in ConfigSection} + + connection_params = {} + remaining = {} + + for key, value in config.items(): + # Check if this key is a known section or internal parameter + is_section = key in known_sections or any( + key.startswith(s + ".") for s in known_sections + ) + is_internal = key in INTERNAL_CLI_PARAMETERS + + if not is_section and not is_internal: + # Root-level parameter that's not a section = connection parameter + connection_params[key] = value + else: + remaining[key] = value + + return connection_params, remaining + + +def merge_params_into_connections( + connections: Dict[str, Dict[str, Any]], params: Dict[str, Any] +) -> Dict[str, Dict[str, Any]]: + """ + Merge parameters into all existing connections. + + Used for overlay sources where root-level connection params apply to all connections. + The params overlay (override) values in each connection. + + Args: + connections: Dictionary of connection configurations + params: Parameters to merge into each connection + + Returns: + Dictionary of connections with params merged in + + Example: + Input: + connections = {"dev": {"account": "dev_acc", "user": "dev_user"}} + params = {"user": "override_user", "password": "new_pass"} + Output: + {"dev": {"account": "dev_acc", "user": "override_user", "password": "new_pass"}} + """ + from snowflake.cli.api.config_ng.dict_utils import deep_merge + + result = {} + for conn_name, conn_config in connections.items(): + if isinstance(conn_config, dict): + result[conn_name] = deep_merge(conn_config, params) + else: + result[conn_name] = conn_config + + return result + + +def create_default_connection_from_params( + params: Dict[str, Any], +) -> Dict[str, Dict[str, Any]]: + """ + Create a default connection from connection parameters. + + Args: + params: Connection parameters + + Returns: + Dictionary with "default" connection containing the params + + Example: + Input: {"account": "acc", "user": "usr"} + Output: {"default": {"account": "acc", "user": "usr"}} + """ + if not params: + return {} + return {"default": params.copy()} diff --git a/src/snowflake/cli/api/config_ng/parsers.py b/src/snowflake/cli/api/config_ng/parsers.py new file mode 100644 index 0000000000..5ca6cf6f53 --- /dev/null +++ b/src/snowflake/cli/api/config_ng/parsers.py @@ -0,0 +1,139 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration parsers - decouple parsing from file I/O.""" + +import configparser +from typing import Any, Dict + +# Try to import tomllib (Python 3.11+) or fall back to tomli +try: + import tomllib +except ImportError: + import tomli as tomllib # type: ignore + + +class SnowSQLParser: + """Parse SnowSQL INI format to nested dict.""" + + # Mapping of SnowSQL key names to CLI standard names + SNOWSQL_KEY_MAP: Dict[str, str] = { + "accountname": "account", + "username": "user", + "rolename": "role", + "warehousename": "warehouse", + "schemaname": "schema", + "dbname": "database", + "pwd": "password", + # Keys that don't need mapping (already correct) + "password": "password", + "database": "database", + "schema": "schema", + "role": "role", + "warehouse": "warehouse", + "host": "host", + "port": "port", + "protocol": "protocol", + "authenticator": "authenticator", + "private_key_path": "private_key_path", + "private_key_passphrase": "private_key_passphrase", + "account": "account", + "user": "user", + } + + @classmethod + def parse(cls, content: str) -> Dict[str, Any]: + """ + Parse SnowSQL INI format from string. + + Args: + content: INI format configuration as string + + Returns: + Nested dict: {"connections": {...}, "variables": {...}} + + Example: + Input: + [connections.dev] + accountname = myaccount + username = myuser + + [variables] + stage = mystage + + Output: + { + "connections": { + "dev": {"account": "myaccount", "user": "myuser"} + }, + "variables": {"stage": "mystage"} + } + """ + config = configparser.ConfigParser() + config.read_string(content) + + result: Dict[str, Any] = {} + + for section in config.sections(): + if section.startswith("connections"): + # Extract connection name from section + if section == "connections": + conn_name = "default" + else: + conn_name = ( + section.split(".", 1)[1] if "." in section else "default" + ) + + if "connections" not in result: + result["connections"] = {} + if conn_name not in result["connections"]: + result["connections"][conn_name] = {} + + for key, value in config[section].items(): + mapped_key = cls.SNOWSQL_KEY_MAP.get(key, key) + result["connections"][conn_name][mapped_key] = value + + elif section == "variables": + result["variables"] = dict(config[section]) + + return result + + +class TOMLParser: + """Parse TOML format to nested dict.""" + + @staticmethod + def parse(content: str) -> Dict[str, Any]: + """ + Parse TOML format from string. + + TOML is already nested, so this just wraps tomllib.loads(). + All TOML sources (CLI config, connections.toml) use this parser. + + Args: + content: TOML format configuration as string + + Returns: + Nested dict with TOML structure preserved + + Example: + Input: + [connections.prod] + account = "myaccount" + user = "myuser" + + Output: + {"connections": {"prod": {"account": "myaccount", "user": "myuser"}}} + """ + return tomllib.loads(content) diff --git a/src/snowflake/cli/api/config_ng/presentation.py b/src/snowflake/cli/api/config_ng/presentation.py new file mode 100644 index 0000000000..549c5c659f --- /dev/null +++ b/src/snowflake/cli/api/config_ng/presentation.py @@ -0,0 +1,361 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Resolution presentation utilities. + +This module handles all formatting, display, and export of configuration +resolution data. It separates presentation concerns from resolution logic. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple + +from snowflake.cli.api.console import cli_console +from snowflake.cli.api.output.types import CollectionResult, MessageResult + +if TYPE_CHECKING: + from snowflake.cli.api.config_ng.resolver import ConfigurationResolver + +# Sensitive configuration keys that should be masked when displayed +SENSITIVE_KEYS = { + "password", + "pwd", + "oauth_client_secret", + "token", + "session_token", + "master_token", + "mfa_passcode", + "private_key", # Private key content (not path) + "passphrase", + "secret", +} + +# Keys that contain file paths (paths are OK to display, but not file contents) +PATH_KEYS = { + "private_key_file", + "private_key_path", + "token_file_path", +} + +# Fixed table columns ordered from most important (left) to least (right) +SourceColumn = Literal[ + "params", + "global_envs", + "connections_env", + "snowsql_env", + "connections.toml", + "config.toml", + "snowsql", +] + +TABLE_COLUMNS: Tuple[str, ...] = ( + "key", + "value", + "params", + "global_envs", + "connections_env", + "snowsql_env", + "connections.toml", + "config.toml", + "snowsql", +) + +# Mapping of internal source names to fixed table columns +SOURCE_TO_COLUMN: Dict[str, SourceColumn] = { + "cli_arguments": "params", + "cli_env": "global_envs", + "connection_specific_env": "connections_env", + "snowsql_env": "snowsql_env", + "connections_toml": "connections.toml", + "cli_config_toml": "config.toml", + "snowsql_config": "snowsql", +} + + +def _should_mask_value(key: str) -> bool: + """ + Determine if a configuration value should be masked for security. + + Args: + key: Configuration key name + + Returns: + True if the value should be masked, False if it can be displayed + """ + key_lower = key.lower() + + if any(path_key in key_lower for path_key in PATH_KEYS): + return False + + return any(sensitive_key in key_lower for sensitive_key in SENSITIVE_KEYS) + + +def _mask_sensitive_value(key: str, value: Any) -> str: + """ + Mask sensitive configuration values for display. + + Args: + key: Configuration key name + value: Value to potentially mask + + Returns: + Masked string if sensitive, otherwise string representation of value + """ + if _should_mask_value(key): + return "****" + return str(value) + + +class ResolutionPresenter: + """ + Handles all presentation, formatting, and export of resolution data. + + This class is responsible for: + - Console output with colors and formatting + - Building CommandResult objects for the output system + - Exporting resolution data to files + - Masking sensitive values in all outputs + """ + + def __init__(self, resolver: ConfigurationResolver): + """ + Initialize presenter with a resolver. + + Args: + resolver: ConfigurationResolver instance to present data from + """ + self._resolver = resolver + + def get_summary(self) -> dict: + """ + Get summary statistics about configuration resolution. + + Returns: + Dictionary with statistics: + - total_keys_resolved + - keys_with_overrides + - keys_using_defaults + - source_usage (how many values each source provided) + - source_wins (how many final values came from each source) + """ + return self._resolver.get_tracker().get_summary() + + def build_sources_table(self, key: Optional[str] = None) -> CollectionResult: + """ + Build a tabular view of configuration sources per key. + + Columns (left to right): key, value, params, env, connections.toml, cli_config.toml, snowsql. + - value: masked final selected value for the key + - presence columns: "+" if a given source provided a value for the key, empty otherwise + + Args: + key: Optional specific key to build table for, or None for all keys + """ + tracker = self._resolver.get_tracker() + if key is None and not tracker.get_all_histories(): + self._resolver.resolve() + elif key is not None and tracker.get_history(key) is None: + self._resolver.resolve(key=key) + + histories = ( + {key: tracker.get_history(key)} + if key is not None + else tracker.get_all_histories() + ) + + def _row_items(): + for k, history in histories.items(): + if history is None: + continue + row: Dict[str, Any] = {c: "" for c in TABLE_COLUMNS} + row["key"] = k + + masked_final = _mask_sensitive_value(k, history.final_value) + row["value"] = masked_final + + for entry in history.entries: + source_column = SOURCE_TO_COLUMN.get(entry.config_value.source_name) + if source_column is not None: + row[source_column] = "+" + + ordered_row = {column: row[column] for column in TABLE_COLUMNS} + yield ordered_row + + return CollectionResult(_row_items()) + + def format_history_message(self, key: Optional[str] = None) -> MessageResult: + """ + Build a masked, human-readable history of merging as a single message. + If key is None, returns concatenated histories for all keys. + + Args: + key: Optional specific key to format, or None for all keys + """ + histories = ( + {key: self._resolver.get_resolution_history(key)} + if key is not None + else self._resolver.get_all_histories() + ) + + if not histories: + return MessageResult("No resolution history available") + + lines = [] + lines.append("Configuration Resolution History") + lines.append("=" * 80) + lines.append("") + + for k in sorted(histories.keys()): + history = histories[k] + if history is None: + continue + + lines.append(f"Key: {k}") + lines.append( + f"Final Value: {_mask_sensitive_value(k, history.final_value)}" + ) + + if history.entries: + lines.append("Resolution Chain:") + for i, entry in enumerate(history.entries, 1): + cv = entry.config_value + status = "SELECTED" if entry.was_used else "overridden" + masked_value = _mask_sensitive_value(cv.key, cv.value) + lines.append(f" {i}. [{status}] {cv.source_name}: {masked_value}") + + if history.default_used: + lines.append(" (default value used)") + + lines.append("") + + return MessageResult("\n".join(lines)) + + def print_resolution_chain(self, key: str) -> None: + """ + Print the resolution chain for a key using cli_console formatting. + Sensitive values (passwords, tokens, etc.) are automatically masked. + + Args: + key: Configuration key + """ + history = self._resolver.get_resolution_history(key) + if not history: + cli_console.warning(f"No resolution history found for key: {key}") + return + + with cli_console.phase( + f"{key} resolution chain ({len(history.entries)} sources):" + ): + for i, entry in enumerate(history.entries, 1): + cv = entry.config_value + status_icon = "✅" if entry.was_used else "❌" + + if entry.was_used: + status_text = "(SELECTED)" + elif entry.overridden_by: + status_text = f"(overridden by {entry.overridden_by})" + else: + status_text = "(not used)" + + # Mask sensitive values + masked_value = _mask_sensitive_value(cv.key, cv.value) + masked_raw = ( + _mask_sensitive_value(cv.key, cv.raw_value) + if cv.raw_value is not None + else None + ) + + # Show raw value if different from parsed value + value_display = f'"{masked_value}"' + if masked_raw is not None and cv.raw_value != cv.value: + value_display = f'"{masked_raw}" → {masked_value}' + + cli_console.step( + f"{i}. {status_icon} {cv.source_name}: {value_display} {status_text}" + ) + + if history.default_used: + masked_default = _mask_sensitive_value(key, history.final_value) + cli_console.step(f"Default value used: {masked_default}") + + def print_all_chains(self) -> None: + """ + Print resolution chains for all keys using cli_console formatting. + Sensitive values (passwords, tokens, etc.) are automatically masked. + """ + histories = self._resolver.get_all_histories() + if not histories: + cli_console.warning("No resolution history available") + return + + with cli_console.phase( + f"Configuration Resolution History ({len(histories)} keys)" + ): + for key in sorted(histories.keys()): + history = histories[key] + cli_console.message( + f"\n{key} resolution chain ({len(history.entries)} sources):" + ) + with cli_console.indented(): + for i, entry in enumerate(history.entries, 1): + cv = entry.config_value + status_icon = "✅" if entry.was_used else "❌" + + if entry.was_used: + status_text = "(SELECTED)" + elif entry.overridden_by: + status_text = f"(overridden by {entry.overridden_by})" + else: + status_text = "(not used)" + + # Mask sensitive values + masked_value = _mask_sensitive_value(cv.key, cv.value) + masked_raw = ( + _mask_sensitive_value(cv.key, cv.raw_value) + if cv.raw_value is not None + else None + ) + + # Show raw value if different from parsed value + value_display = f'"{masked_value}"' + if masked_raw is not None and cv.raw_value != cv.value: + value_display = f'"{masked_raw}" → {masked_value}' + + cli_console.step( + f"{i}. {status_icon} {cv.source_name}: {value_display} {status_text}" + ) + + if history.default_used: + masked_default = _mask_sensitive_value(key, history.final_value) + cli_console.step(f"Default value used: {masked_default}") + + def export_history(self, filepath: Path) -> None: + """ + Export resolution history to JSON file. + + Args: + filepath: Path to output file + """ + histories = self._resolver.get_all_histories() + data = { + "summary": self.get_summary(), + "histories": {key: history.to_dict() for key, history in histories.items()}, + } + + with open(filepath, "w") as f: + json.dump(data, f, indent=2) diff --git a/src/snowflake/cli/api/config_ng/resolution_logger.py b/src/snowflake/cli/api/config_ng/resolution_logger.py new file mode 100644 index 0000000000..e64cf97a37 --- /dev/null +++ b/src/snowflake/cli/api/config_ng/resolution_logger.py @@ -0,0 +1,345 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Configuration resolution logging utilities. + +This module provides internal utilities for logging and displaying configuration +resolution information. It's designed to be used independently of CLI commands, +allowing it to be used in any context where configuration debugging is needed. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Dict, Optional + +from snowflake.cli.api.config_ng.presentation import ResolutionPresenter +from snowflake.cli.api.config_provider import ( + ALTERNATIVE_CONFIG_ENV_VAR, + AlternativeConfigProvider, + get_config_provider_singleton, +) +from snowflake.cli.api.console import cli_console +from snowflake.cli.api.output.types import ( + CollectionResult, + CommandResult, + MessageResult, + MultipleResults, +) + +if TYPE_CHECKING: + from snowflake.cli.api.config_ng.resolver import ConfigurationResolver + + +def is_resolution_logging_available() -> bool: + """ + Check if configuration resolution logging is available. + + Returns: + True if the alternative config provider is enabled and has resolution history + """ + provider = get_config_provider_singleton() + return isinstance(provider, AlternativeConfigProvider) + + +def get_resolver() -> Optional[ConfigurationResolver]: + """ + Get the ConfigurationResolver from the current provider. + + Returns: + ConfigurationResolver instance if available, None otherwise + """ + provider = get_config_provider_singleton() + if not isinstance(provider, AlternativeConfigProvider): + return None + + provider._ensure_initialized() # noqa: SLF001 + return provider._resolver # noqa: SLF001 + + +def show_resolution_chain(key: str) -> None: + """ + Display the resolution chain for a specific configuration key. + + This shows: + - All sources that provided values for the key + - The order in which values were considered + - Which value overrode which + - The final selected value + + Args: + key: Configuration key to show resolution for + """ + from snowflake.cli.api.config_provider import get_config_provider_singleton + + provider = get_config_provider_singleton() + + provider.read_config() + + resolver = get_resolver() + + if resolver is None: + cli_console.warning( + "Configuration resolution logging is not available. " + f"Set {ALTERNATIVE_CONFIG_ENV_VAR}=true to enable it." + ) + return + + presenter = ResolutionPresenter(resolver) + presenter.print_resolution_chain(key) + + +def show_all_resolution_chains() -> None: + """ + Display resolution chains for all configured keys. + + This provides a complete overview of the configuration resolution process, + showing how every configuration value was determined. + """ + from snowflake.cli.api.config_provider import get_config_provider_singleton + + provider = get_config_provider_singleton() + + provider.read_config() + + resolver = get_resolver() + + if resolver is None: + cli_console.warning( + "Configuration resolution logging is not available. " + f"Set {ALTERNATIVE_CONFIG_ENV_VAR}=true to enable it." + ) + return + + presenter = ResolutionPresenter(resolver) + presenter.print_all_chains() + + +def get_resolution_summary() -> Optional[Dict]: + """ + Get summary statistics about configuration resolution. + + Returns: + Dictionary with statistics including: + - total_keys_resolved: Number of keys resolved + - keys_with_overrides: Number of keys where values were overridden + - keys_using_defaults: Number of keys using default values + - source_usage: Dict of source_name -> count of values provided + - source_wins: Dict of source_name -> count of values selected + + None if resolution logging is not available + """ + from snowflake.cli.api.config_provider import get_config_provider_singleton + + provider = get_config_provider_singleton() + + provider.read_config() + + resolver = get_resolver() + + if resolver is None: + return None + + presenter = ResolutionPresenter(resolver) + return presenter.get_summary() + + +def export_resolution_history(output_path: Path) -> bool: + """ + Export complete resolution history to a JSON file. + + This creates a detailed JSON report that can be: + - Attached to support tickets + - Used for configuration debugging + - Analyzed programmatically + + Args: + output_path: Path where the JSON file should be saved + + Returns: + True if export succeeded, False otherwise + """ + from snowflake.cli.api.config_provider import get_config_provider_singleton + + provider = get_config_provider_singleton() + + provider.read_config() + + resolver = get_resolver() + + if resolver is None: + cli_console.warning( + "Configuration resolution logging is not available. " + f"Set {ALTERNATIVE_CONFIG_ENV_VAR}=true to enable it." + ) + return False + + try: + presenter = ResolutionPresenter(resolver) + presenter.export_history(output_path) + cli_console.message(f"✅ Resolution history exported to: {output_path}") + return True + except Exception as e: + cli_console.warning(f"❌ Failed to export resolution history: {e}") + return False + + +def format_summary_for_display() -> Optional[str]: + """ + Format resolution summary as a human-readable string. + + Returns: + Formatted summary string, or None if resolution logging not available + """ + summary = get_resolution_summary() + + if summary is None: + return None + + lines = [ + "\n" + "=" * 80, + "Configuration Resolution Summary", + "=" * 80, + f"Total keys resolved: {summary['total_keys_resolved']}", + f"Keys with overrides: {summary['keys_with_overrides']}", + f"Keys using defaults: {summary['keys_using_defaults']}", + "", + "Source Usage:", + ] + + # Sort sources by number of values provided (descending) + source_usage = summary["source_usage"] + source_wins = summary["source_wins"] + + for source_name in sorted(source_usage, key=source_usage.get, reverse=True): + provided = source_usage[source_name] + wins = source_wins.get(source_name, 0) + lines.append( + f" {source_name:30s} provided: {provided:3d} selected: {wins:3d}" + ) + + lines.append("=" * 80 + "\n") + return "\n".join(lines) + + +def check_value_source(key: str) -> Optional[str]: + """ + Check which source provided the value for a specific configuration key. + + Args: + key: Configuration key to check + + Returns: + Name of the source that provided the final value, or None if not found + """ + from snowflake.cli.api.config_provider import get_config_provider_singleton + + provider = get_config_provider_singleton() + + provider.read_config() + + resolver = get_resolver() + + if resolver is None: + return None + + history = resolver.get_resolution_history(key) + if history and history.selected_entry: + return history.selected_entry.config_value.source_name + + return None + + +def explain_configuration(key: Optional[str] = None, verbose: bool = False) -> None: + """ + Explain configuration resolution for a key or all keys. + + This is a high-level function that combines multiple resolution + logging capabilities to provide comprehensive configuration explanation. + + Args: + key: Specific key to explain, or None to explain all + verbose: If True, show detailed resolution chains + """ + from snowflake.cli.api.config_provider import get_config_provider_singleton + + provider = get_config_provider_singleton() + + provider.read_config() + + resolver = get_resolver() + + if resolver is None: + cli_console.warning( + "Configuration resolution logging is not available. " + f"Set {ALTERNATIVE_CONFIG_ENV_VAR}=true to enable the new config system." + ) + return + + presenter = ResolutionPresenter(resolver) + + if key: + # Explain specific key + with cli_console.phase(f"Configuration Resolution: {key}"): + source = check_value_source(key) + if source: + cli_console.message(f"Current value from: {source}") + else: + cli_console.message("No value found for this key") + + if verbose: + presenter.print_resolution_chain(key) + else: + # Explain all configuration + with cli_console.phase("Complete Configuration Resolution"): + summary_text = format_summary_for_display() + if summary_text: + cli_console.message(summary_text) + + if verbose: + presenter.print_all_chains() + + +def get_configuration_explanation_results( + key: Optional[str] = None, verbose: bool = False +) -> CommandResult: + """ + Build CommandResult(s) representing a fixed-column sources table and optional + masked history message, suitable for Snow's output formats. + + Returns: + - CollectionResult for the table (always) + - If verbose is True, MultipleResults with the table and a MessageResult + containing the masked resolution history (for the key or all keys) + """ + from snowflake.cli.api.config_provider import get_config_provider_singleton + + provider = get_config_provider_singleton() + provider.read_config() + + resolver = get_resolver() + if resolver is None: + return MessageResult( + "Configuration resolution logging is not available. " + f"Set {ALTERNATIVE_CONFIG_ENV_VAR}=true to enable it." + ) + + presenter = ResolutionPresenter(resolver) + table_result: CollectionResult = presenter.build_sources_table(key) + if not verbose: + return table_result + + history_message: MessageResult = presenter.format_history_message(key) + return MultipleResults([table_result, history_message]) diff --git a/src/snowflake/cli/api/config_ng/resolver.py b/src/snowflake/cli/api/config_ng/resolver.py new file mode 100644 index 0000000000..6a14b8d3af --- /dev/null +++ b/src/snowflake/cli/api/config_ng/resolver.py @@ -0,0 +1,860 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Configuration resolver with resolution history tracking. + +This module implements: +- ResolutionHistoryTracker: Tracks configuration value discoveries and precedence +- ConfigurationResolver: Orchestrates sources and resolves configuration values +""" + +from __future__ import annotations + +import logging +from collections import defaultdict +from datetime import datetime +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +from snowflake.cli.api.config_ng.core import ( + ConfigValue, + ResolutionEntry, + ResolutionHistory, + SourceType, +) + +if TYPE_CHECKING: + from snowflake.cli.api.config_ng.core import ValueSource + +log = logging.getLogger(__name__) + + +class ResolutionHistoryTracker: + """ + Tracks the complete resolution process for all configuration keys. + + This class records: + - Every value discovered from every source + - The order in which values were considered + - Which value was ultimately selected + - Which values were overridden and by what + + Provides debugging utilities and export functionality. + """ + + def __init__(self): + """Initialize empty history tracker.""" + self._histories: Dict[str, ResolutionHistory] = {} + self._discoveries: Dict[str, List[tuple[ConfigValue, datetime]]] = defaultdict( + list + ) + self._enabled = True + + def enable(self) -> None: + """Enable history tracking.""" + self._enabled = True + + def disable(self) -> None: + """Disable history tracking for performance.""" + self._enabled = False + + def is_enabled(self) -> bool: + """Check if history tracking is enabled.""" + return self._enabled + + def clear(self) -> None: + """Clear all recorded history.""" + self._histories.clear() + self._discoveries.clear() + + def _flatten_nested_dict( + self, nested: Dict[str, Any], prefix: str = "" + ) -> Dict[str, Any]: + """ + Flatten nested dict to dot-separated keys for internal storage. + + Args: + nested: Nested dictionary structure + prefix: Current key prefix + + Returns: + Flat dictionary with dot-separated keys + + Example: + {"connections": {"test": {"account": "val"}}} + -> {"connections.test.account": "val"} + """ + result = {} + for key, value in nested.items(): + flat_key = f"{prefix}.{key}" if prefix else key + + if isinstance(value, dict) and value: + # Recursively flatten nested dicts + result.update(self._flatten_nested_dict(value, flat_key)) + else: + # Leaf value - store it + result[flat_key] = value + + return result + + def record_nested_discovery( + self, nested_data: Dict[str, Any], source_name: str + ) -> None: + """ + Record discoveries from a source that returns nested dict. + + Args: + nested_data: Nested dictionary from source + source_name: Name of the source providing this data + """ + if not self._enabled: + return + + flat_data = self._flatten_nested_dict(nested_data) + + timestamp = datetime.now() + for flat_key, value in flat_data.items(): + config_value = ConfigValue( + key=flat_key, value=value, source_name=source_name + ) + self._discoveries[flat_key].append((config_value, timestamp)) + + def record_discovery(self, key: str, config_value: ConfigValue) -> None: + """ + Record a value discovery from a source. + + Args: + key: Configuration key + config_value: The discovered ConfigValue with metadata + """ + if not self._enabled: + return + + timestamp = datetime.now() + self._discoveries[key].append((config_value, timestamp)) + + def mark_selected(self, key: str, source_name: str) -> None: + """ + Mark which source's value was selected for a key. + + Args: + key: Configuration key + source_name: Name of the source whose value was selected + """ + if not self._enabled or key not in self._discoveries: + return + + entries: List[ResolutionEntry] = [] + selected_value = None + + for config_value, timestamp in self._discoveries[key]: + was_selected = config_value.source_name == source_name + overridden_by = source_name if not was_selected else None + + entry = ResolutionEntry( + config_value=config_value, + timestamp=timestamp, + was_used=was_selected, + overridden_by=overridden_by, + ) + entries.append(entry) + + if was_selected: + selected_value = config_value.value + + self._histories[key] = ResolutionHistory( + key=key, entries=entries, final_value=selected_value, default_used=False + ) + + def mark_default_used(self, key: str, default_value: Any) -> None: + """ + Mark that a default value was used for a key. + + Args: + key: Configuration key + default_value: The default value used + """ + if not self._enabled: + return + + if key in self._histories: + self._histories[key].default_used = True + self._histories[key].final_value = default_value + else: + self._histories[key] = ResolutionHistory( + key=key, entries=[], final_value=default_value, default_used=True + ) + + def get_history(self, key: str) -> Optional[ResolutionHistory]: + """ + Get resolution history for a specific key. + + Args: + key: Configuration key + + Returns: + ResolutionHistory object or None if key not tracked + """ + return self._histories.get(key) + + def get_all_histories(self) -> Dict[str, ResolutionHistory]: + """ + Get all resolution histories. + + Returns: + Dictionary mapping keys to their ResolutionHistory objects + """ + return self._histories.copy() + + def finalize_with_result(self, final_config: Dict[str, Any]) -> None: + """ + Mark which values were selected in the final configuration. + + This method flattens the final nested config and marks the selected + source for each value. + + Args: + final_config: The final resolved configuration (nested dict) + """ + if not self._enabled: + return + + flat_final = self._flatten_nested_dict(final_config) + + for flat_key, final_value in flat_final.items(): + if flat_key not in self._discoveries: + continue + + discoveries = self._discoveries[flat_key] + for config_value, timestamp in reversed(discoveries): + if config_value.value == final_value: + self.mark_selected(flat_key, config_value.source_name) + break + + def record_general_params_merged_to_connections( + self, + general_params: Dict[str, Any], + connection_names: List[str], + source_name: str, + ) -> None: + """ + Record when general parameters are merged into connections. + + When overlay sources provide general params (like SNOWFLAKE_ACCOUNT), + these get merged into each existing connection. This method records + that merge operation for history tracking. + + Args: + general_params: Dictionary of general parameters + connection_names: List of connection names to merge into + source_name: Name of the source providing these params + """ + if not self._enabled: + return + + timestamp = datetime.now() + for param_key, param_value in general_params.items(): + for conn_name in connection_names: + flat_key = f"connections.{conn_name}.{param_key}" + config_value = ConfigValue( + key=flat_key, value=param_value, source_name=source_name + ) + self._discoveries[flat_key].append((config_value, timestamp)) + + def replicate_root_level_discoveries_to_connection( + self, param_keys: List[str], connection_name: str + ) -> None: + """ + Replicate discoveries from root-level keys to connection-specific keys. + + This is used when creating a default connection from general parameters + (e.g., SNOWFLAKE_ACCOUNT -> connections.default.account). + + Args: + param_keys: List of parameter keys that exist at root level + connection_name: Name of the connection to replicate discoveries to + """ + if not self._enabled: + return + + for param_key in param_keys: + if param_key in self._discoveries: + conn_key = f"connections.{connection_name}.{param_key}" + for config_value, timestamp in self._discoveries[param_key]: + self._discoveries[conn_key].append((config_value, timestamp)) + + def get_summary(self) -> dict: + """ + Get summary statistics about configuration resolution. + + Returns: + Dictionary with statistics: + - total_keys_resolved: Number of keys resolved + - keys_with_overrides: Number of keys where values were overridden + - keys_using_defaults: Number of keys using default values + - source_usage: Dict of source_name -> count of values provided + - source_wins: Dict of source_name -> count of values selected + """ + total_keys = len(self._histories) + keys_with_overrides = sum( + 1 for h in self._histories.values() if len(h.overridden_entries) > 0 + ) + keys_using_defaults = sum(1 for h in self._histories.values() if h.default_used) + + source_usage: Dict[str, int] = defaultdict(int) + source_wins: Dict[str, int] = defaultdict(int) + + for history in self._histories.values(): + for entry in history.entries: + source_usage[entry.config_value.source_name] += 1 + if entry.was_used: + source_wins[entry.config_value.source_name] += 1 + + return { + "total_keys_resolved": total_keys, + "keys_with_overrides": keys_with_overrides, + "keys_using_defaults": keys_using_defaults, + "source_usage": dict(source_usage), + "source_wins": dict(source_wins), + } + + +class ConfigurationResolver: + """ + Orchestrates configuration sources with resolution history tracking. + + This is the main entry point for configuration resolution. It: + - Manages multiple configuration sources in precedence order + - Applies precedence rules based on source list order + - Tracks complete resolution history + + Sources should be provided in precedence order (lowest to highest priority). + Later sources in the list override earlier sources. + + For presentation/formatting of resolution data, use ResolutionPresenter + from the presentation module. + + Example: + from snowflake.cli.api.config_ng import ConfigurationResolver + from snowflake.cli.api.config_ng.presentation import ResolutionPresenter + + resolver = ConfigurationResolver( + sources=[ + snowsql_config, # Lowest priority + cli_config, + env_source, + cli_arguments, # Highest priority + ] + ) + + # Resolve all configuration + config = resolver.resolve() + + # For debugging/presentation, use the presenter + presenter = ResolutionPresenter(resolver) + presenter.print_resolution_chain("account") + presenter.export_history(Path("debug_config.json")) + """ + + def __init__( + self, + sources: Optional[List["ValueSource"]] = None, + ): + """ + Initialize resolver with sources and history tracking. + + Args: + sources: List of configuration sources in precedence order + (first = lowest priority, last = highest priority) + """ + self._sources = sources or [] + self._history_tracker = ResolutionHistoryTracker() + + def add_source(self, source: "ValueSource") -> None: + """ + Add a configuration source to the end of the list (highest priority). + + Args: + source: ValueSource to add + """ + self._sources.append(source) + + def get_sources(self) -> List["ValueSource"]: + """Get list of all sources in precedence order (for inspection).""" + return self._sources.copy() + + def _parse_connection_key(self, key: str) -> Optional[Tuple[str, str]]: + """ + Parse a connection key into (connection_name, parameter). + + Args: + key: Configuration key (e.g., "connections.prod.account") + + Returns: + Tuple of (connection_name, parameter) or None if not a connection key + """ + if not key.startswith("connections."): + return None + + parts = key.split(".", 2) + if len(parts) != 3: + return None + + return parts[1], parts[2] # (conn_name, param) + + def _get_sources_by_type(self, source_type: SourceType) -> List["ValueSource"]: + """ + Get all sources matching the specified type. + + Args: + source_type: Type of source to filter by + + Returns: + List of sources matching the type + """ + return [s for s in self._sources if s.source_type is source_type] + + def _record_discoveries(self, source_values: Dict[str, ConfigValue]) -> None: + """ + Record all discovered values in history tracker. + + Args: + source_values: Dictionary of discovered configuration values + """ + for k, config_value in source_values.items(): + self._history_tracker.record_discovery(k, config_value) + + def _finalize_history(self, all_values: Dict[str, ConfigValue]) -> None: + """ + Mark which values were selected in resolution history. + + Args: + all_values: Final dictionary of selected configuration values + """ + for k, config_value in all_values.items(): + self._history_tracker.mark_selected(k, config_value.source_name) + + def _apply_default( + self, resolved: Dict[str, Any], key: str, default: Any + ) -> Dict[str, Any]: + """ + Apply default value for a specific key if provided. + + Args: + resolved: Current resolved configuration dictionary + key: Configuration key + default: Default value to apply + + Returns: + Updated resolved dictionary + """ + if default is not None: + resolved[key] = default + self._history_tracker.mark_default_used(key, default) + return resolved + + def _group_by_connection( + self, source_values: Dict[str, ConfigValue] + ) -> Tuple[Dict[str, Dict[str, ConfigValue]], set[str]]: + """ + Group connection parameters by connection name. + + Args: + source_values: All values discovered from a source + + Returns: + Tuple of (per_conn, empty_connections): + - per_conn: Dict mapping connection name to its ConfigValue parameters + - empty_connections: Set of connection names that are empty + """ + per_conn: Dict[str, Dict[str, ConfigValue]] = defaultdict(dict) + empty_connections: set[str] = set() + + for k, config_value in source_values.items(): + parsed = self._parse_connection_key(k) + if parsed is None: + continue + + conn_name, param = parsed + + # Track empty connection markers + if param == "_empty_connection": + empty_connections.add(conn_name) + else: + per_conn[conn_name][k] = config_value + + return per_conn, empty_connections + + def _extract_flat_values( + self, source_values: Dict[str, ConfigValue] + ) -> Dict[str, ConfigValue]: + """ + Extract non-connection (flat) configuration values. + + Args: + source_values: All values discovered from a source + + Returns: + Dictionary of flat configuration values (non-connection keys) + """ + return { + k: v for k, v in source_values.items() if not k.startswith("connections.") + } + + def _replace_connections( + self, + file_connections: Dict[str, Dict[str, ConfigValue]], + per_conn: Dict[str, Dict[str, ConfigValue]], + empty_connections: set[str], + source: "ValueSource", + ) -> None: + """ + Replace entire connections with new definitions from source. + + This implements connection-level replacement: when a FILE source defines + a connection, it completely replaces any previous definition. + + Args: + file_connections: Accumulator for all file-based connections + per_conn: New connection definitions from current source + empty_connections: Set of empty connection names from current source + source: The source providing these connections + """ + all_conn_names = set(per_conn.keys()) | empty_connections + + for conn_name in all_conn_names: + conn_params = per_conn.get(conn_name, {}) + log.debug( + "Connection %s replaced by file source %s (%d params)", + conn_name, + source.source_name, + len(conn_params), + ) + file_connections[conn_name] = conn_params + + def _resolve_file_sources(self, key: Optional[str] = None) -> Dict[str, Any]: + """ + Process FILE sources with connection-level replacement semantics. + + FILE sources replace entire connections rather than merging fields. + Later FILE sources override earlier ones completely. + + Args: + key: Specific key to resolve (None = all keys) + + Returns: + Nested dict with merged file source data + """ + result: Dict[str, Any] = {} + + for source in self._get_sources_by_type(SourceType.FILE): + try: + source_data = source.discover(key) + + self._history_tracker.record_nested_discovery( + source_data, source.source_name + ) + + if "connections" in source_data: + if "connections" not in result: + result["connections"] = {} + + for conn_name, conn_data in source_data["connections"].items(): + result["connections"][conn_name] = conn_data + + for k, v in source_data.items(): + if k != "connections": + result[k] = v + + except Exception as e: + log.warning("Error from source %s: %s", source.source_name, e) + + return result + + def _merge_file_results( + self, + file_connections: Dict[str, Dict[str, ConfigValue]], + file_flat_values: Dict[str, ConfigValue], + ) -> Dict[str, ConfigValue]: + """ + Merge file connections and flat values into single dictionary. + + Args: + file_connections: Connection parameters from file sources + file_flat_values: Flat configuration values from file sources + + Returns: + Merged dictionary of all file-based configuration values + """ + all_values: Dict[str, ConfigValue] = {} + + for conn_params in file_connections.values(): + all_values.update(conn_params) + + all_values.update(file_flat_values) + + return all_values + + def _apply_overlay_sources( + self, base: Dict[str, Any], key: Optional[str] = None + ) -> Dict[str, Any]: + """ + Apply OVERLAY sources with field-level merging. + + OVERLAY sources (env vars, CLI args) add or override individual fields + without replacing entire connections. General params are merged into + each existing connection. + + Args: + base: Base configuration (typically from file sources) + key: Specific key to resolve (None = all keys) + + Returns: + Updated dictionary with overlay values applied + """ + from snowflake.cli.api.config_ng.dict_utils import deep_merge + from snowflake.cli.api.config_ng.merge_operations import ( + extract_root_level_connection_params, + merge_params_into_connections, + ) + + result = base.copy() + + for source in self._get_sources_by_type(SourceType.OVERLAY): + try: + source_data = source.discover(key) + + self._history_tracker.record_nested_discovery( + source_data, source.source_name + ) + + general_params, other_data = extract_root_level_connection_params( + source_data + ) + + result = deep_merge(result, other_data) + + if general_params and "connections" in result and result["connections"]: + connection_names = [ + name + for name in result["connections"] + if isinstance(result["connections"][name], dict) + ] + + self._history_tracker.record_general_params_merged_to_connections( + general_params, connection_names, source.source_name + ) + + result["connections"] = merge_params_into_connections( + result["connections"], general_params + ) + elif general_params: + result = deep_merge(result, general_params) + + except Exception as e: + log.warning("Error from source %s: %s", source.source_name, e) + + if "connections" in result and result["connections"]: + remaining_general_params, _ = extract_root_level_connection_params(result) + + if remaining_general_params: + for conn_name in result["connections"]: + if isinstance(result["connections"][conn_name], dict): + result["connections"][conn_name] = deep_merge( + remaining_general_params, result["connections"][conn_name] + ) + + for key in remaining_general_params: + if key in result: + result.pop(key) + + return result + + def _ensure_default_connection(self, config: Dict[str, Any]) -> Dict[str, Any]: + """ + Ensure a default connection exists when general connection params are present. + + Border conditions for creating default connection: + 1. No connections exist in config (empty or missing "connections" key) + 2. At least one general connection parameter exists at root level + 3. General params are NOT internal CLI parameters or variables + + This allows users to set SNOWFLAKE_ACCOUNT, SNOWFLAKE_USER etc. without + needing --temporary-connection flag or defining connections in config files. + + Args: + config: Resolved configuration dictionary + + Returns: + Configuration with default connection created if conditions are met + """ + from snowflake.cli.api.config_ng.constants import INTERNAL_CLI_PARAMETERS + + connections = config.get("connections", {}) + if connections: + return config + + general_params = {} + for key, value in config.items(): + if ( + key not in ("connections", "variables") + and key not in INTERNAL_CLI_PARAMETERS + ): + general_params[key] = value + + if not general_params: + return config + + result = config.copy() + result["connections"] = {"default": general_params.copy()} + + self._history_tracker.replicate_root_level_discoveries_to_connection( + list(general_params.keys()), "default" + ) + + for key in general_params: + result.pop(key, None) + + return result + + def resolve(self, key: Optional[str] = None, default: Any = None) -> Dict[str, Any]: + """ + Resolve configuration to nested dict. + + Resolution Process (Four-Phase): + + Phase A - File Sources (Connection-Level Replacement): + - Process FILE sources in precedence order (lowest to highest priority) + - For each connection, later FILE sources completely REPLACE earlier ones + - Fields from earlier file sources are NOT inherited + + Phase B - Overlay Sources (Field-Level Overlay): + - Start with the file-derived configuration + - Process OVERLAY sources (env vars, CLI args) in precedence order + - These add/override individual fields without replacing entire connections + - Uses deep merge for nested structures + + Phase C - Default Connection Creation: + - If no connections exist but general params present, create "default" connection + - Allows env-only configuration without --temporary-connection flag + + Phase D - Resolution History Finalization: + - Mark which values were selected in the final configuration + - Enables debugging and diagnostics + + Args: + key: Specific key to resolve (None = all keys) + default: Default value if key not found + + Returns: + Nested dictionary of resolved configuration + """ + result = self._resolve_file_sources(key) + + result = self._apply_overlay_sources(result, key) + + result = self._ensure_default_connection(result) + + self._finalize_resolution_history(result) + + return result + + def resolve_value(self, key: str, default: Any = None) -> Any: + """ + Resolve a single configuration value. + + Args: + key: Configuration key + default: Default value if not found + + Returns: + Resolved value or default + """ + resolved = self.resolve(key=key, default=default) + return resolved.get(key, default) + + def get_value_metadata(self, key: str) -> Optional[ConfigValue]: + """ + Get metadata for the selected value. + + Args: + key: Configuration key + + Returns: + ConfigValue for the selected value, or None if not found + """ + history = self._history_tracker.get_history(key) + if history and history.selected_entry: + return history.selected_entry.config_value + + # Fallback to live query if history not available + for source in self._sources: + values = source.discover(key) + if key in values: + return values[key] + + return None + + def get_tracker(self) -> ResolutionHistoryTracker: + """ + Get the history tracker for direct access to resolution data. + + Returns: + ResolutionHistoryTracker instance + """ + return self._history_tracker + + def _finalize_resolution_history(self, final_config: Dict[str, Any]) -> None: + """ + Mark which values were selected in final configuration. + + Delegates to the history tracker which handles all history-related logic. + + Args: + final_config: The final resolved configuration (nested dict) + """ + self._history_tracker.finalize_with_result(final_config) + + def get_resolution_history(self, key: str) -> Optional[ResolutionHistory]: + """ + Get complete resolution history for a key. + + Supports both formats: + - Flat: "connections.test.account" + - Root-level: "account" (checks connections for this key) + + Args: + key: Configuration key (flat or simple) + + Returns: + ResolutionHistory showing the full precedence chain + """ + history = self._history_tracker.get_history(key) + if history: + return history + + # If not found and it's a simple key (no dots), search in connections + if "." not in key: + # Look for any connection that has this key + all_histories = self._history_tracker.get_all_histories() + for hist_key, hist in all_histories.items(): + # Match pattern: "connections.*.{key}" or root level "{key}" + if hist_key.endswith(f".{key}"): + return hist + + return None + + def get_all_histories(self) -> Dict[str, ResolutionHistory]: + """Get resolution histories for all keys.""" + return self._history_tracker.get_all_histories() diff --git a/src/snowflake/cli/api/config_ng/source_factory.py b/src/snowflake/cli/api/config_ng/source_factory.py new file mode 100644 index 0000000000..a76a36b353 --- /dev/null +++ b/src/snowflake/cli/api/config_ng/source_factory.py @@ -0,0 +1,62 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Factory for creating configuration sources.""" + +from typing import Any, Dict, List, Optional + +from snowflake.cli.api.config_ng.core import ValueSource + + +def create_default_sources( + cli_context: Optional[Dict[str, Any]] = None, +) -> List[ValueSource]: + """ + Create default source list in precedence order. + + Creates the standard 7-source configuration stack from lowest + to highest priority: + 1. SnowSQL config files (merged) + 2. CLI config.toml (first-found) + 3. Dedicated connections.toml + 4. SnowSQL environment variables (SNOWSQL_*) + 5. Connection-specific environment variables (SNOWFLAKE_CONNECTIONS_*) + 6. General CLI environment variables (SNOWFLAKE_*) + 7. CLI command-line arguments (highest priority) + + Args: + cli_context: Optional CLI context dictionary for CliParameters source + + Returns: + List of ValueSource instances in precedence order + """ + from snowflake.cli.api.config_ng import ( + CliConfigFile, + CliEnvironment, + CliParameters, + ConnectionsConfigFile, + ConnectionSpecificEnvironment, + SnowSQLConfigFile, + SnowSQLEnvironment, + ) + + return [ + SnowSQLConfigFile(), + CliConfigFile(), + ConnectionsConfigFile(), + SnowSQLEnvironment(), + ConnectionSpecificEnvironment(), + CliEnvironment(), + CliParameters(cli_context=cli_context or {}), + ] diff --git a/src/snowflake/cli/api/config_ng/source_manager.py b/src/snowflake/cli/api/config_ng/source_manager.py new file mode 100644 index 0000000000..4a13c9918a --- /dev/null +++ b/src/snowflake/cli/api/config_ng/source_manager.py @@ -0,0 +1,87 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Manager for configuration sources.""" + +from typing import Any, Dict, List, Optional + +from snowflake.cli.api.config_ng.constants import FILE_SOURCE_NAMES +from snowflake.cli.api.config_ng.core import ValueSource + + +class SourceManager: + """ + Manages configuration sources and derives their priorities. + + Provides a clean interface for working with configuration sources + without exposing implementation details. + """ + + def __init__(self, sources: List[ValueSource]): + """ + Initialize with a list of sources. + + Args: + sources: List of sources in precedence order (lowest to highest) + """ + self._sources = sources + + @classmethod + def with_default_sources( + cls, cli_context: Optional[Dict[str, Any]] = None + ) -> "SourceManager": + """ + Class method constructor with default sources. + + Args: + cli_context: Optional CLI context for CliParameters source + + Returns: + SourceManager configured with default 7-source stack + """ + from snowflake.cli.api.config_ng.source_factory import create_default_sources + + sources = create_default_sources(cli_context) + return cls(sources) + + def get_source_priorities(self) -> Dict[str, int]: + """ + Derive priorities from source list order. + + Priority numbers are 1-indexed (1 = lowest, higher = higher priority). + This is dynamically derived from the source list order to eliminate + duplication and ensure consistency. + + Returns: + Dictionary mapping source names to priority levels + """ + return {s.source_name: idx + 1 for idx, s in enumerate(self._sources)} + + def get_file_sources(self) -> List[ValueSource]: + """ + Get only file-based sources. + + Returns: + List of sources that are file-based + """ + return [s for s in self._sources if s.source_name in FILE_SOURCE_NAMES] + + def get_sources(self) -> List[ValueSource]: + """ + Get all sources. + + Returns: + Copy of the sources list + """ + return self._sources.copy() diff --git a/src/snowflake/cli/api/config_ng/sources.py b/src/snowflake/cli/api/config_ng/sources.py new file mode 100644 index 0000000000..42814c8fdd --- /dev/null +++ b/src/snowflake/cli/api/config_ng/sources.py @@ -0,0 +1,743 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Configuration sources for the Snowflake CLI. + +This module implements concrete configuration sources that discover values from: +- SnowSQL configuration files (INI format, merged from multiple locations) +- CLI configuration files (TOML format, first-found) +- Connections configuration files (dedicated connections.toml) +- SnowSQL environment variables (SNOWSQL_* prefix) +- CLI environment variables (SNOWFLAKE_* patterns) +- CLI command-line parameters + +Precedence is determined by the order sources are provided to the resolver. +""" + +from __future__ import annotations + +import configparser +import logging +import os +from enum import Enum +from pathlib import Path +from typing import Any, Dict, Final, List, Optional + +from snowflake.cli.api.config_ng.constants import SNOWFLAKE_HOME_ENV +from snowflake.cli.api.config_ng.core import SourceType, ValueSource + +log = logging.getLogger(__name__) + + +class SnowSQLSection(Enum): + """ + SnowSQL configuration file section names. + + These sections can be present in SnowSQL INI config files. + """ + + CONNECTIONS = "connections" + VARIABLES = "variables" + OPTIONS = "options" + + +class SnowSQLConfigFile(ValueSource): + """ + SnowSQL configuration file source with two-phase design. + + Phase 1: Acquire content (read and merge multiple config files) + Phase 2: Parse content (using SnowSQLParser) + + Reads multiple config files in order and MERGES them (SnowSQL behavior). + Later files override earlier files for the same keys. + Returns configuration for ALL connections. + + Config files searched (in order): + 1. /etc/snowsql.cnf (system-wide) + 2. /etc/snowflake/snowsql.cnf (alternative system) + 3. /usr/local/etc/snowsql.cnf (local system) + 4. ~/.snowsql.cnf (legacy user config) + 5. ~/.snowsql/config (current user config) + """ + + def __init__( + self, content: Optional[str] = None, config_paths: Optional[List[Path]] = None + ): + """ + Initialize SnowSQL config file source. + + Args: + content: Optional string content for testing (bypasses file I/O) + config_paths: Optional custom config file paths + """ + self._content = content + self._config_paths = config_paths or self._get_default_paths() + + @staticmethod + def _get_default_paths() -> List[Path]: + """Get standard SnowSQL config file paths.""" + snowflake_home = os.environ.get(SNOWFLAKE_HOME_ENV) + if snowflake_home: + snowflake_home_path = Path(snowflake_home).expanduser() + if snowflake_home_path.exists(): + return [snowflake_home_path / "config"] + + return [ + Path("/etc/snowsql.cnf"), + Path("/etc/snowflake/snowsql.cnf"), + Path("/usr/local/etc/snowsql.cnf"), + Path.home() / ".snowsql.cnf", + Path.home() / ".snowsql" / "config", + ] + + @classmethod + def from_string(cls, content: str) -> "SnowSQLConfigFile": + """ + Create source from string content (for testing). + + Args: + content: INI format configuration as string + + Returns: + SnowSQLConfigFile instance using string content + """ + return cls(content=content) + + @property + def source_name(self) -> "ValueSource.SourceName": + return "snowsql_config" + + @property + def source_type(self) -> SourceType: + return SourceType.FILE + + def discover(self, key: Optional[str] = None) -> Dict[str, Any]: + """ + Two-phase discovery: acquire content → parse. + + Phase 1: Get content (from string or by reading and merging files) + Phase 2: Parse content using SnowSQLParser + + Returns: + Nested dict structure: {"connections": {...}, "variables": {...}} + """ + from snowflake.cli.api.config_ng.parsers import SnowSQLParser + + if self._content is not None: + content = self._content + else: + content = self._read_and_merge_files() + + return SnowSQLParser.parse(content) + + def _read_and_merge_files(self) -> str: + """ + Read all config files and merge into single INI string. + + Returns: + Merged INI content as string + """ + merged_config = configparser.ConfigParser() + + for config_file in self._config_paths: + if config_file.exists(): + try: + merged_config.read(config_file) + except Exception as e: + log.debug("Failed to read SnowSQL config %s: %s", config_file, e) + + from io import StringIO + + output = StringIO() + merged_config.write(output) + return output.getvalue() + + def supports_key(self, key: str) -> bool: + return key in self.discover() + + +class CliConfigFile(ValueSource): + """ + CLI config.toml file source with two-phase design. + + Phase 1: Acquire content (find and read first config file) + Phase 2: Parse content (using TOMLParser) + + Scans for config.toml files in order and uses FIRST file found (CLI behavior). + Does NOT merge multiple files - first found wins. + Returns configuration for ALL connections. + + Search order (when no override is set): + 1. ./config.toml (current directory) + 2. ~/.snowflake/config.toml (user config) + """ + + def __init__( + self, content: Optional[str] = None, search_paths: Optional[List[Path]] = None + ): + """ + Initialize CLI config file source. + + Args: + content: Optional string content for testing (bypasses file I/O) + search_paths: Optional custom search paths + """ + self._content = content + self._search_paths = search_paths or self._get_default_paths() + + @staticmethod + def _get_default_paths() -> List[Path]: + """Get standard CLI config search paths.""" + try: + from snowflake.cli.api.cli_global_context import get_cli_context + + cli_context = get_cli_context() + config_override = cli_context.config_file_override + if config_override: + return [Path(config_override)] + except Exception: + log.debug("CLI context not available, using standard config paths") + + snowflake_home = os.environ.get(SNOWFLAKE_HOME_ENV) + if snowflake_home: + snowflake_home_path = Path(snowflake_home).expanduser() + if snowflake_home_path.exists(): + return [snowflake_home_path / "config.toml"] + + return [ + Path.cwd() / "config.toml", + Path.home() / ".snowflake" / "config.toml", + ] + + @classmethod + def from_string(cls, content: str) -> "CliConfigFile": + """ + Create source from TOML string (for testing). + + Args: + content: TOML format configuration as string + + Returns: + CliConfigFile instance using string content + """ + return cls(content=content) + + @property + def source_name(self) -> "ValueSource.SourceName": + return "cli_config_toml" + + @property + def source_type(self) -> SourceType: + return SourceType.FILE + + def discover(self, key: Optional[str] = None) -> Dict[str, Any]: + """ + Two-phase discovery: acquire content → parse. + + Phase 1: Get content (from string or by reading first existing file) + Phase 2: Parse content using TOMLParser + + Returns: + Nested dict structure with all TOML sections preserved + """ + from snowflake.cli.api.config_ng.parsers import TOMLParser + + if self._content is not None: + content = self._content + else: + content = self._read_first_file() + + if not content: + return {} + + return TOMLParser.parse(content) + + def _read_first_file(self) -> str: + """ + Read first existing config file. + + Returns: + File content as string, or empty string if no file found + """ + for config_file in self._search_paths: + if config_file.exists(): + try: + return config_file.read_text() + except Exception as e: + log.debug("Failed to read CLI config %s: %s", config_file, e) + + return "" + + def supports_key(self, key: str) -> bool: + return key in self.discover() + + +class ConnectionsConfigFile(ValueSource): + """ + Dedicated connections.toml file source with three-phase design. + + Phase 1: Acquire content (read file) + Phase 2: Parse content (using TOMLParser) + Phase 3: Normalize legacy format (connections.toml specific) + + Reads ~/.snowflake/connections.toml specifically. + Returns configuration for ALL connections. + + Supports both legacy formats: + 1. Direct connection sections (legacy): + [default] + database = "value" + + 2. Nested under [connections] section: + [connections.default] + database = "value" + + Both are normalized to nested format: {"connections": {"default": {...}}} + """ + + def __init__(self, content: Optional[str] = None, file_path: Optional[Path] = None): + """ + Initialize connections.toml source. + + Args: + content: Optional string content for testing (bypasses file I/O) + file_path: Optional custom file path + """ + self._content = content + self._file_path = file_path or self._get_default_path() + + @staticmethod + def _get_default_path() -> Path: + """Get standard connections.toml path.""" + snowflake_home = os.environ.get(SNOWFLAKE_HOME_ENV) + if snowflake_home: + snowflake_home_path = Path(snowflake_home).expanduser() + if snowflake_home_path.exists(): + return snowflake_home_path / "connections.toml" + return Path.home() / ".snowflake" / "connections.toml" + + @classmethod + def from_string(cls, content: str) -> "ConnectionsConfigFile": + """ + Create source from TOML string (for testing). + + Args: + content: TOML format configuration as string + + Returns: + ConnectionsConfigFile instance using string content + """ + return cls(content=content) + + @property + def source_name(self) -> "ValueSource.SourceName": + return "connections_toml" + + @property + def source_type(self) -> SourceType: + return SourceType.FILE + + @property + def is_connections_file(self) -> bool: + """Mark this as the dedicated connections file source.""" + return True + + def get_defined_connections(self) -> set[str]: + """ + Return set of connection names that are defined in connections.toml. + This is used by the resolver to implement replacement behavior. + """ + try: + data = self.discover() + connections_section = data.get("connections", {}) + if isinstance(connections_section, dict): + return set(connections_section.keys()) + return set() + except Exception as e: + log.debug("Failed to get defined connections: %s", e) + return set() + + def discover(self, key: Optional[str] = None) -> Dict[str, Any]: + """ + Three-phase discovery: acquire content → parse → normalize. + + Phase 1: Get content (from string or file) + Phase 2: Parse TOML (generic parser) + Phase 3: Normalize legacy format (connections.toml specific) + + Returns: + Nested dict structure: {"connections": {"conn_name": {...}}} + """ + from snowflake.cli.api.config_ng.parsers import TOMLParser + + if self._content is not None: + content = self._content + else: + if not self._file_path.exists(): + return {} + try: + content = self._file_path.read_text() + except Exception as e: + log.debug("Failed to read connections.toml: %s", e) + return {} + + try: + data = TOMLParser.parse(content) + except Exception as e: + log.debug("Failed to parse connections.toml: %s", e) + return {} + + return self._normalize_connections_format(data) + + @staticmethod + def _normalize_connections_format(data: Dict[str, Any]) -> Dict[str, Any]: + """ + Normalize connections.toml format to standard structure. + + Supports: + - Legacy: [connection_name] → {"connections": {"connection_name": {...}}} + - New: [connections.connection_name] → {"connections": {"connection_name": {...}}} + + Args: + data: Parsed TOML data + + Returns: + Normalized structure with connections under "connections" key + """ + result: Dict[str, Any] = {} + + for section_name, section_data in data.items(): + if isinstance(section_data, dict) and section_name != "connections": + if "connections" not in result: + result["connections"] = {} + result["connections"][section_name] = section_data + + connections_section = data.get("connections", {}) + if isinstance(connections_section, dict) and connections_section: + if "connections" not in result: + result["connections"] = {} + result["connections"].update(connections_section) + + return result + + def supports_key(self, key: str) -> bool: + return key in self.discover() + + +class SnowSQLEnvironment(ValueSource): + """ + SnowSQL environment variables source. + + Discovers SNOWSQL_* environment variables only. + Simple prefix mapping without connection-specific variants. + + Examples: + SNOWSQL_ACCOUNT -> account + SNOWSQL_USER -> user + SNOWSQL_PWD -> password + """ + + # Mapping of SNOWSQL_* env vars to configuration keys + ENV_VAR_MAPPING = { + "SNOWSQL_ACCOUNT": "account", + "SNOWSQL_ACCOUNTNAME": "account", # Alternative + "SNOWSQL_USER": "user", + "SNOWSQL_USERNAME": "user", # Alternative + "SNOWSQL_PWD": "password", + "SNOWSQL_PASSWORD": "password", # Alternative + "SNOWSQL_DATABASE": "database", + "SNOWSQL_DBNAME": "database", # Alternative + "SNOWSQL_SCHEMA": "schema", + "SNOWSQL_SCHEMANAME": "schema", # Alternative + "SNOWSQL_ROLE": "role", + "SNOWSQL_ROLENAME": "role", # Alternative + "SNOWSQL_WAREHOUSE": "warehouse", + "SNOWSQL_WAREHOUSENAME": "warehouse", # Alternative + "SNOWSQL_PROTOCOL": "protocol", + "SNOWSQL_HOST": "host", + "SNOWSQL_PORT": "port", + "SNOWSQL_REGION": "region", + "SNOWSQL_AUTHENTICATOR": "authenticator", + "SNOWSQL_PRIVATE_KEY_PASSPHRASE": "private_key_passphrase", + } + + @property + def source_name(self) -> "ValueSource.SourceName": + return "snowsql_env" + + @property + def source_type(self) -> SourceType: + return SourceType.OVERLAY + + def discover(self, key: Optional[str] = None) -> Dict[str, Any]: + """ + Discover SNOWSQL_* environment variables. + Returns flat values at root level (no connection prefix). + """ + result: Dict[str, Any] = {} + + for env_var, config_key in self.ENV_VAR_MAPPING.items(): + env_value = os.getenv(env_var) + if env_value is not None: + if config_key not in result: + result[config_key] = env_value + + return result + + def supports_key(self, key: str) -> bool: + for env_var, config_key in self.ENV_VAR_MAPPING.items(): + if config_key == key and os.getenv(env_var) is not None: + return True + return False + + +# Base configuration keys that can be set via environment +_ENV_CONFIG_KEYS: Final[list[str]] = [ + "account", + "user", + "password", + "database", + "schema", + "role", + "warehouse", + "protocol", + "host", + "port", + "region", + "authenticator", + "workload_identity_provider", + "private_key_file", + "private_key_path", # Used by integration tests + "private_key_raw", # Used by integration tests + "private_key_passphrase", # Private key passphrase for encrypted keys + "token", # OAuth token + "session_token", # Session token for session-based authentication + "master_token", # Master token for advanced authentication + "token_file_path", + "oauth_client_id", + "oauth_client_secret", + "oauth_authorization_url", + "oauth_token_request_url", + "oauth_redirect_uri", + "oauth_scope", + "oauth_enable_pkce", # Fixed typo: was "oatuh_enable_pkce" + "oauth_enable_refresh_tokens", + "oauth_enable_single_use_refresh_tokens", + "client_store_temporary_credential", +] + + +class ConnectionSpecificEnvironment(ValueSource): + """ + Connection-specific environment variables source. + + Discovers SNOWFLAKE_CONNECTIONS__ environment variables. + Returns prefixed keys: connections.{name}.{key} + + Examples: + SNOWFLAKE_CONNECTIONS_INTEGRATION_ACCOUNT=x -> connections.integration.account=x + SNOWFLAKE_CONNECTIONS_DEV_USER=y -> connections.dev.user=y + """ + + @property + def source_name(self) -> "ValueSource.SourceName": + return "connection_specific_env" + + @property + def source_type(self) -> SourceType: + return SourceType.OVERLAY + + def discover(self, key: Optional[str] = None) -> Dict[str, Any]: + """ + Discover SNOWFLAKE_CONNECTIONS_* environment variables. + Returns nested dict structure. + + Pattern: SNOWFLAKE_CONNECTIONS__=value + -> {"connections": {"{name}": {"{key}": value}}} + """ + result: Dict[str, Any] = {} + + for env_name, env_value in os.environ.items(): + if env_name.startswith("SNOWFLAKE_CONNECTIONS_"): + remainder = env_name[len("SNOWFLAKE_CONNECTIONS_") :] + + match: tuple[str, str] | None = None + for candidate in sorted(_ENV_CONFIG_KEYS, key=len, reverse=True): + key_suffix = "_" + candidate.upper() + if remainder.endswith(key_suffix): + conn_name_upper = remainder[: -len(key_suffix)] + if conn_name_upper: # ensure non-empty connection name + match = (conn_name_upper, candidate) + break + + if not match: + continue + + conn_name_upper, config_key = match + conn_name = conn_name_upper.lower() + + if "connections" not in result: + result["connections"] = {} + if conn_name not in result["connections"]: + result["connections"][conn_name] = {} + + result["connections"][conn_name][config_key] = env_value + + return result + + def supports_key(self, key: str) -> bool: + if key.startswith("connections."): + parts = key.split(".", 2) + if len(parts) == 3: + _, conn_name, config_key = parts + env_var = ( + f"SNOWFLAKE_CONNECTIONS_{conn_name.upper()}_{config_key.upper()}" + ) + return os.getenv(env_var) is not None + return False + + +class CliEnvironment(ValueSource): + """ + CLI general environment variables source. + + Discovers general SNOWFLAKE_* environment variables (not connection-specific). + Returns flat keys that apply to all connections. + + Examples: + SNOWFLAKE_ACCOUNT -> account (general, applies to all connections) + SNOWFLAKE_USER -> user + SNOWFLAKE_PASSWORD -> password + """ + + @property + def source_name(self) -> "ValueSource.SourceName": + return "cli_env" + + @property + def source_type(self) -> SourceType: + return SourceType.OVERLAY + + def discover(self, key: Optional[str] = None) -> Dict[str, Any]: + """ + Discover general SNOWFLAKE_* environment variables. + Returns flat values at root level. + + Pattern: SNOWFLAKE_=value -> {key: value} + """ + result: Dict[str, Any] = {} + + for env_name, env_value in os.environ.items(): + if not env_name.startswith("SNOWFLAKE_"): + continue + + if env_name.startswith("SNOWFLAKE_CONNECTIONS_"): + continue + + config_key_upper = env_name[len("SNOWFLAKE_") :] + config_key = config_key_upper.lower() + + if config_key in _ENV_CONFIG_KEYS: + result[config_key] = env_value + + return result + + def supports_key(self, key: str) -> bool: + if "." in key: + return False + + env_var = f"SNOWFLAKE_{key.upper()}" + return os.getenv(env_var) is not None + + +class CliParameters(ValueSource): + """ + CLI command-line parameters source. + + Highest priority source that extracts values from parsed CLI arguments. + Values are already parsed by Typer/Click framework. + + Examples: + --account my_account -> account: "my_account" + --user alice -> user: "alice" + -a my_account -> account: "my_account" + """ + + def __init__(self, cli_context: Optional[Dict[str, Any]] = None): + """ + Initialize CLI parameters source. + + Args: + cli_context: Dictionary of CLI arguments (key -> value) + """ + self._cli_context = cli_context or {} + + @property + def source_name(self) -> "ValueSource.SourceName": + return "cli_arguments" + + @property + def source_type(self) -> SourceType: + return SourceType.OVERLAY + + def discover(self, key: Optional[str] = None) -> Dict[str, Any]: + """ + Extract non-None values from CLI context. + CLI arguments are already parsed by the framework. + Returns flat values at root level. + """ + result: Dict[str, Any] = {} + + for k, v in self._cli_context.items(): + if v is None: + continue + + result[k] = v + + return result + + def supports_key(self, key: str) -> bool: + """Check if key is present in CLI context with non-None value.""" + return key in self._cli_context and self._cli_context[key] is not None + + +def get_merged_variables(cli_variables: Optional[List[str]] = None) -> Dict[str, str]: + """ + Merge SnowSQL [variables] with CLI -D parameters. + + Precedence: SnowSQL variables (lower) < -D parameters (higher) + + Args: + cli_variables: List of "key=value" strings from -D parameters + + Returns: + Dictionary of merged variables (key -> value) + """ + from snowflake.cli.api.config_provider import get_config_provider_singleton + + provider = get_config_provider_singleton() + try: + snowsql_vars = provider.get_section(SnowSQLSection.VARIABLES.value) + except Exception: + snowsql_vars = {} + + if cli_variables: + from snowflake.cli.api.commands.utils import parse_key_value_variables + + cli_vars_parsed = parse_key_value_variables(cli_variables) + for var in cli_vars_parsed: + snowsql_vars[var.key] = var.value + + return snowsql_vars diff --git a/src/snowflake/cli/api/config_ng/telemetry_integration.py b/src/snowflake/cli/api/config_ng/telemetry_integration.py new file mode 100644 index 0000000000..56495c69cd --- /dev/null +++ b/src/snowflake/cli/api/config_ng/telemetry_integration.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Telemetry integration for config_ng system. + +This module provides functions to track configuration source usage +and integrate with the CLI's telemetry system. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Optional + +if TYPE_CHECKING: + from snowflake.cli.api.config_ng.resolver import ConfigurationResolver + + +# Map source names to counter field names +SOURCE_TO_COUNTER = { + "snowsql_config": "config_source_snowsql", + "cli_config_toml": "config_source_cli_toml", + "connections_toml": "config_source_connections_toml", + "snowsql_env": "config_source_snowsql_env", + "connection_specific_env": "config_source_connection_env", + "cli_env": "config_source_cli_env", + "cli_arguments": "config_source_cli_args", +} + + +def record_config_source_usage(resolver: ConfigurationResolver) -> None: + """ + Record configuration source usage to CLI metrics. + + This should be called after configuration resolution completes. + Sets counters to 1 for sources that provided winning values, 0 otherwise. + + Args: + resolver: The ConfigurationResolver instance + """ + try: + from snowflake.cli.api.cli_global_context import get_cli_context + from snowflake.cli.api.metrics import CLICounterField + + cli_context = get_cli_context() + summary = resolver.get_tracker().get_summary() + + source_wins = summary.get("source_wins", {}) + + for source_name, counter_name in SOURCE_TO_COUNTER.items(): + value = 1 if source_wins.get(source_name, 0) > 0 else 0 + counter_field = getattr(CLICounterField, counter_name.upper(), None) + if counter_field: + cli_context.metrics.set_counter(counter_field, value) + + except Exception: + pass + + +def get_config_telemetry_payload( + resolver: Optional[ConfigurationResolver], +) -> Dict[str, Any]: + """ + Get configuration telemetry payload for inclusion in command telemetry. + + Args: + resolver: Optional ConfigurationResolver instance + + Returns: + Dictionary with config telemetry data + """ + if resolver is None: + return {} + + try: + summary = resolver.get_tracker().get_summary() + + return { + "config_sources_used": list(summary.get("source_usage", {}).keys()), + "config_source_wins": summary.get("source_wins", {}), + "config_total_keys_resolved": summary.get("total_keys_resolved", 0), + "config_keys_with_overrides": summary.get("keys_with_overrides", 0), + } + except Exception: + return {} diff --git a/src/snowflake/cli/api/config_provider.py b/src/snowflake/cli/api/config_provider.py new file mode 100644 index 0000000000..04f6c59d4d --- /dev/null +++ b/src/snowflake/cli/api/config_provider.py @@ -0,0 +1,618 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import atexit +import os +from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Final, Optional + +if TYPE_CHECKING: + from snowflake.cli.api.config_ng.resolver import ConfigurationResolver + from snowflake.cli.api.config_ng.source_manager import SourceManager + +ALTERNATIVE_CONFIG_ENV_VAR: Final[str] = "SNOWFLAKE_CLI_CONFIG_V2_ENABLED" + + +class ConfigProvider(ABC): + """ + Abstract base class for configuration providers. + All methods must return data in the same format as current implementation. + """ + + @abstractmethod + def get_section(self, *path) -> dict: + """Get configuration section at specified path.""" + ... + + @abstractmethod + def get_value(self, *path, key: str, default: Optional[Any] = None) -> Any: + """Get single configuration value.""" + ... + + @abstractmethod + def set_value(self, path: list[str], value: Any) -> None: + """Set configuration value at path.""" + ... + + @abstractmethod + def unset_value(self, path: list[str]) -> None: + """Remove configuration value at path.""" + ... + + @abstractmethod + def section_exists(self, *path) -> bool: + """Check if configuration section exists.""" + ... + + @abstractmethod + def read_config(self) -> None: + """Load configuration from source.""" + ... + + @abstractmethod + def get_connection_dict(self, connection_name: str) -> dict: + """Get connection configuration by name.""" + ... + + @abstractmethod + def get_all_connections(self, include_env_connections: bool = False) -> dict: + """Get all connection configurations. + + Args: + include_env_connections: If True, include connections created from + environment variables. Default False. + """ + ... + + def _transform_private_key_raw(self, connection_dict: dict) -> dict: + """ + Transform private_key_raw to private_key_file for ConnectionContext compatibility. + + The ConnectionContext dataclass doesn't have a private_key_raw field, so it gets + filtered out by merge_with_config. To work around this, we write private_key_raw + content to a temporary file and return it as private_key_file. + + Args: + connection_dict: Connection configuration dictionary + + Returns: + Modified connection dictionary with private_key_raw transformed to private_key_file + """ + if "private_key_raw" not in connection_dict: + return connection_dict + + # Don't transform if private_key_file is already set + if "private_key_file" in connection_dict: + return connection_dict + + import tempfile + + try: + # Create a temporary file with the private key content + with tempfile.NamedTemporaryFile( + mode="w", suffix=".pem", delete=False + ) as f: + f.write(connection_dict["private_key_raw"]) + temp_file_path = f.name + + # Set restrictive permissions on the temporary file + os.chmod(temp_file_path, 0o600) + + # Create a copy of the connection dict with the transformation + result = connection_dict.copy() + result["private_key_file"] = temp_file_path + del result["private_key_raw"] + + # Track created temp file on the provider instance for cleanup + temp_files_attr = "_temp_private_key_files" + existing = getattr(self, temp_files_attr, None) + if existing is None: + setattr(self, temp_files_attr, {temp_file_path}) + else: + existing.add(temp_file_path) + + return result + + except Exception: + # If transformation fails, return original dict + # The error will be handled downstream + return connection_dict + + def cleanup_temp_files(self) -> None: + """Delete any temporary files created from private_key_raw transformation.""" + temp_files = getattr(self, "_temp_private_key_files", None) + if not temp_files: + return + to_remove = list(temp_files) + for path in to_remove: + try: + Path(path).unlink(missing_ok=True) + except Exception: + # Best-effort cleanup; ignore failures + pass + temp_files.clear() + + +class LegacyConfigProvider(ConfigProvider): + """ + Current TOML-based configuration provider. + Wraps existing implementation for compatibility. + """ + + def get_section(self, *path) -> dict: + from snowflake.cli.api.config import get_config_section + + return get_config_section(*path) + + def get_value(self, *path, key: str, default: Optional[Any] = None) -> Any: + from snowflake.cli.api.config import Empty, get_config_value + + return get_config_value( + *path, key=key, default=default if default is not None else Empty + ) + + def set_value(self, path: list[str], value: Any) -> None: + from snowflake.cli.api.config import set_config_value + + set_config_value(path, value) + + def unset_value(self, path: list[str]) -> None: + from snowflake.cli.api.config import unset_config_value + + unset_config_value(path) + + def section_exists(self, *path) -> bool: + from snowflake.cli.api.config import config_section_exists + + return config_section_exists(*path) + + def read_config(self) -> None: + from snowflake.cli.api.config import get_config_manager + + config_manager = get_config_manager() + config_manager.read_config() + + def get_connection_dict(self, connection_name: str) -> dict: + from snowflake.cli.api.config import get_config_section + + try: + result = get_config_section("connections", connection_name) + return self._transform_private_key_raw(result) + except KeyError: + from snowflake.cli.api.exceptions import MissingConfigurationError + + raise MissingConfigurationError( + f"Connection {connection_name} is not configured" + ) + + def get_all_connections(self, include_env_connections: bool = False) -> dict: + from snowflake.cli.api.config import ConnectionConfig, get_config_section + + # Legacy provider ignores the flag since it never had env connections + connections = get_config_section("connections") + return { + name: ConnectionConfig.from_dict(self._transform_private_key_raw(config)) + for name, config in connections.items() + } + + +class AlternativeConfigProvider(ConfigProvider): + """ + New configuration provider using config_ng resolution system. + + This provider uses ConfigurationResolver to discover values from: + - CLI arguments (highest priority) + - Environment variables (SNOWFLAKE_* and SNOWSQL_*) + - Configuration files (SnowCLI TOML and SnowSQL config) + + Maintains backward compatibility with LegacyConfigProvider output format. + """ + + def __init__( + self, + source_manager: Optional["SourceManager"] = None, + cli_context_getter: Optional[Any] = None, + ) -> None: + """ + Initialize provider with optional dependencies for testing. + + Args: + source_manager: Optional source manager (for testing) + cli_context_getter: Optional CLI context getter function (for testing) + """ + self._source_manager = source_manager + self._cli_context_getter = ( + cli_context_getter or self._default_cli_context_getter + ) + self._resolver: Optional["ConfigurationResolver"] = None + self._config_cache: Dict[str, Any] = {} + self._initialized: bool = False + self._last_config_override: Optional[Path] = None + + @staticmethod + def _default_cli_context_getter(): + """Default implementation that accesses global CLI context.""" + from snowflake.cli.api.cli_global_context import get_cli_context + + return get_cli_context() + + def _ensure_initialized(self) -> None: + """Lazily initialize the resolver on first use.""" + # Check if config_file_override has changed + try: + cli_context = self._cli_context_getter() + current_override = cli_context.config_file_override + + # If override changed, force re-initialization + if current_override != self._last_config_override: + self._initialized = False + self._config_cache.clear() + self._last_config_override = current_override + except Exception: + pass + + if self._initialized: + return + + from snowflake.cli.api.config_ng import ConfigurationResolver + from snowflake.cli.api.config_ng.source_factory import create_default_sources + from snowflake.cli.api.config_ng.source_manager import SourceManager + + # Get CLI context + try: + cli_context = self._cli_context_getter() + cli_context_dict = cli_context.connection_context.present_values_as_dict() + except Exception: + cli_context_dict = {} + + # Create or use provided source manager + if self._source_manager is None: + sources = create_default_sources(cli_context_dict) + self._source_manager = SourceManager(sources) + + # Create resolver + self._resolver = ConfigurationResolver( + sources=self._source_manager.get_sources() + ) + + # Initialize cache (resolver returns nested dict) + if not self._config_cache: + self._config_cache = self._resolver.resolve() + + # Record telemetry about config sources used + self._record_config_telemetry() + + self._initialized = True + + def _record_config_telemetry(self) -> None: + """Record configuration source usage to telemetry system.""" + if self._resolver is None: + return + + try: + from snowflake.cli.api.config_ng.telemetry_integration import ( + record_config_source_usage, + ) + + record_config_source_usage(self._resolver) + except Exception: + # Don't break initialization if telemetry fails + pass + + def read_config(self) -> None: + """ + Load configuration from all sources. + Resolver returns nested dict structure. + """ + self._initialized = False + self._config_cache.clear() + self._last_config_override = None + self._ensure_initialized() + + # Resolver returns nested dict + assert self._resolver is not None + self._config_cache = self._resolver.resolve() + + def get_section(self, *path) -> dict: + """ + Navigate nested dict to get configuration section. + + Args: + *path: Section path components (e.g., "connections", "prod") + + Returns: + Dictionary of section contents + + Example: + Cache: {"connections": {"prod": {"account": "val"}}} + get_section("connections", "prod") -> {"account": "val"} + """ + self._ensure_initialized() + + if not path: + return self._config_cache + + # Navigate nested structure + result = self._config_cache + for part in path: + if not isinstance(result, dict) or part not in result: + return {} + result = result[part] + + return result if isinstance(result, dict) else {} + + def get_value(self, *path, key: str, default: Optional[Any] = None) -> Any: + """ + Get single configuration value by navigating nested dict. + + Args: + *path: Path to section + key: Configuration key + default: Default value if not found + + Returns: + Configuration value or default + """ + self._ensure_initialized() + + # Navigate to section, then get key + section = self.get_section(*path) + return section.get(key, default) + + def set_value(self, path: list[str], value: Any) -> None: + """ + Set configuration value at path. + + Note: config_ng is read-only for resolution. This delegates to + legacy config system for writing. + """ + from snowflake.cli.api.config import set_config_value as legacy_set_value + + legacy_set_value(path, value) + # Clear cache to force re-read on next access + self._config_cache.clear() + self._initialized = False + + def unset_value(self, path: list[str]) -> None: + """ + Remove configuration value at path. + + Note: config_ng is read-only for resolution. This delegates to + legacy config system for writing. + """ + from snowflake.cli.api.config import unset_config_value as legacy_unset_value + + legacy_unset_value(path) + # Clear cache to force re-read on next access + self._config_cache.clear() + self._initialized = False + + def section_exists(self, *path) -> bool: + """ + Check if configuration section exists by navigating nested dict. + + Args: + *path: Section path + + Returns: + True if section exists + """ + self._ensure_initialized() + + if not path: + return True + + # Navigate nested structure + result = self._config_cache + for part in path: + if not isinstance(result, dict) or part not in result: + return False + result = result[part] + + return True + + def _get_connection_dict_internal(self, connection_name: str) -> Dict[str, Any]: + """ + Get connection configuration by navigating nested dict. + + Note: The resolver already merged general params into each connection + during the OVERLAY phase, so we just return the connection dict directly. + + Args: + connection_name: Name of the connection + + Returns: + Dictionary of connection parameters + """ + from snowflake.cli.api.exceptions import MissingConfigurationError + + self._ensure_initialized() + + # Get connection from nested dict + connections = self._config_cache.get("connections", {}) + if connection_name in connections and isinstance( + connections[connection_name], dict + ): + result = connections[connection_name] + # Allow empty connections - they're valid (just have no parameters set) + return result + + raise MissingConfigurationError( + f"Connection {connection_name} is not configured" + ) + + def get_connection_dict(self, connection_name: str) -> dict: + """ + Get connection configuration by name. + + Args: + connection_name: Name of the connection + + Returns: + Dictionary of connection parameters + """ + result = self._get_connection_dict_internal(connection_name) + return self._transform_private_key_raw(result) + + def _get_all_connections_dict(self) -> Dict[str, Dict[str, Any]]: + """ + Get all connections from nested dict. + + Returns: + Dictionary mapping connection names to their configurations + """ + self._ensure_initialized() + + connections = self._config_cache.get("connections", {}) + return connections if isinstance(connections, dict) else {} + + def get_all_connections(self, include_env_connections: bool = False) -> dict: + """ + Get all connection configurations. + + Args: + include_env_connections: If True, include connections created from + environment variables. Default False for + backward compatibility with legacy behavior. + + Returns: + Dictionary mapping connection names to ConnectionConfig objects + """ + from snowflake.cli.api.config import ConnectionConfig + + if not include_env_connections: + # Only return connections from file sources (matching legacy behavior) + return self._get_file_based_connections() + + # Return all connections including environment-based ones + connections_dict = self._get_all_connections_dict() + return { + name: ConnectionConfig.from_dict(config) + for name, config in connections_dict.items() + } + + def _get_file_based_connections(self) -> dict: + """ + Get connections only from file sources. + + Excludes connections that exist solely due to environment variables + or CLI parameters. Matches legacy behavior. + + Returns: + Dictionary mapping connection names to ConnectionConfig objects + """ + from snowflake.cli.api.config import ConnectionConfig + from snowflake.cli.api.config_ng.constants import FILE_SOURCE_NAMES + + self._ensure_initialized() + + connections: Dict[str, Dict[str, Any]] = {} + + assert self._resolver is not None + for source in self._resolver.get_sources(): + if source.source_name not in FILE_SOURCE_NAMES: + continue + + try: + source_data = source.discover() # Returns nested dict + if "connections" in source_data: + for conn_name, conn_config in source_data["connections"].items(): + if isinstance(conn_config, dict): + connections[conn_name] = conn_config + except Exception: + # Silently skip sources that fail to discover + pass + + return { + name: ConnectionConfig.from_dict(config) + for name, config in connections.items() + } + + def invalidate_cache(self) -> None: + """ + Invalidate the provider's cache, forcing it to re-read configuration on next access. + + This is useful when configuration files are modified externally. + """ + self._initialized = False + self._config_cache.clear() + if hasattr(self, "_last_config_override"): + self._last_config_override = None + + +def is_alternative_config_enabled() -> bool: + """ + Check if alternative configuration handling is enabled via environment variable. + Does not use the built-in feature flags mechanism. + """ + return os.environ.get(ALTERNATIVE_CONFIG_ENV_VAR, "").lower() in ( + "1", + "true", + "yes", + "on", + ) + + +def get_config_provider() -> ConfigProvider: + """ + Factory function to get the appropriate configuration provider + based on environment variable. + """ + if is_alternative_config_enabled(): + return AlternativeConfigProvider() + return LegacyConfigProvider() + + +_config_provider_instance: Optional[ConfigProvider] = None + + +def get_config_provider_singleton() -> ConfigProvider: + """ + Get or create singleton instance of configuration provider. + """ + global _config_provider_instance + if _config_provider_instance is None: + _config_provider_instance = get_config_provider() + return _config_provider_instance + + +def reset_config_provider(): + """ + Reset the config provider singleton. + Useful for testing and when config source changes. + """ + global _config_provider_instance + # Cleanup any temp files created by the current provider instance + if _config_provider_instance is not None: + try: + _config_provider_instance.cleanup_temp_files() + except Exception: + pass + _config_provider_instance = None + + +def _cleanup_provider_at_exit() -> None: + """Process-exit cleanup for provider-managed temporary files.""" + global _config_provider_instance + if _config_provider_instance is not None: + try: + _config_provider_instance.cleanup_temp_files() + except Exception: + pass + + +atexit.register(_cleanup_provider_at_exit) diff --git a/src/snowflake/cli/api/connections.py b/src/snowflake/cli/api/connections.py index 671c22e55d..671d9db3d8 100644 --- a/src/snowflake/cli/api/connections.py +++ b/src/snowflake/cli/api/connections.py @@ -47,6 +47,7 @@ class ConnectionContext: authenticator: Optional[str] = None workload_identity_provider: Optional[str] = None private_key_file: Optional[str] = None + private_key_passphrase: Optional[str] = field(default=None, repr=False) warehouse: Optional[str] = None mfa_passcode: Optional[str] = None token: Optional[str] = None @@ -68,11 +69,18 @@ class ConnectionContext: oauth_enable_single_use_refresh_tokens: Optional[bool] = None client_store_temporary_credential: Optional[bool] = None + # Internal flag to track if config has been loaded + _config_loaded: bool = field(default=False, repr=False, init=False) + VALIDATED_FIELD_NAMES = ["schema"] def present_values_as_dict(self) -> dict: """Dictionary representation of this ConnectionContext for values that are not None""" - return {k: v for (k, v) in asdict(self).items() if v is not None} + return { + k: v + for (k, v) in asdict(self).items() + if v is not None and not k.startswith("_") + } def clone(self) -> ConnectionContext: return replace(self) @@ -111,6 +119,7 @@ def update_from_config(self) -> ConnectionContext: del connection_config["private_key_path"] self.merge_with_config(**connection_config) + self._config_loaded = True return self def __repr__(self) -> str: @@ -137,6 +146,8 @@ def validate_schema(self, value: Optional[str]): def validate_and_complete(self): """ Ensure we can create a connection from this context. + Sets default connection name if needed, but does not load configuration. + Configuration is loaded lazily in build_connection(). """ if not self.temporary_connection and not self.connection_name: self.connection_name = get_default_connection_name() @@ -153,7 +164,38 @@ def build_connection(self): module="snowflake.connector.config_manager", ) - return connect_to_snowflake(**self.present_values_as_dict()) + if self.temporary_connection: + # For temporary connections, pass all parameters + # connect_to_snowflake will use these directly without loading config + conn_params = self.present_values_as_dict() + else: + # For named connections, pass connection_name and all override parameters + # connect_to_snowflake will load the connection config internally and apply overrides + all_params = self.present_values_as_dict() + control_params = { + "connection_name", + "enable_diag", + "diag_log_path", + "diag_allowlist_path", + "temporary_connection", + "mfa_passcode", + } + + # Separate control parameters from connection overrides + conn_params = {} + overrides = {} + + for k, v in all_params.items(): + if k in control_params: + conn_params[k] = v + else: + # These are connection parameters that should override config values + overrides[k] = v + + # Merge overrides into conn_params + conn_params.update(overrides) + + return connect_to_snowflake(**conn_params) class OpenConnectionCache: diff --git a/src/snowflake/cli/api/metrics.py b/src/snowflake/cli/api/metrics.py index 69778872e1..94c898db2f 100644 --- a/src/snowflake/cli/api/metrics.py +++ b/src/snowflake/cli/api/metrics.py @@ -76,6 +76,28 @@ class CLICounterField: EVENT_SHARING_ERROR = ( f"{_TypePrefix.FEATURES}.{_DomainPrefix.APP}.event_sharing_error" ) + # Config source usage tracking + CONFIG_SOURCE_SNOWSQL = ( + f"{_TypePrefix.FEATURES}.{_DomainPrefix.GLOBAL}.config_source_snowsql" + ) + CONFIG_SOURCE_CLI_TOML = ( + f"{_TypePrefix.FEATURES}.{_DomainPrefix.GLOBAL}.config_source_cli_toml" + ) + CONFIG_SOURCE_CONNECTIONS_TOML = ( + f"{_TypePrefix.FEATURES}.{_DomainPrefix.GLOBAL}.config_source_connections_toml" + ) + CONFIG_SOURCE_SNOWSQL_ENV = ( + f"{_TypePrefix.FEATURES}.{_DomainPrefix.GLOBAL}.config_source_snowsql_env" + ) + CONFIG_SOURCE_CONNECTION_ENV = ( + f"{_TypePrefix.FEATURES}.{_DomainPrefix.GLOBAL}.config_source_connection_env" + ) + CONFIG_SOURCE_CLI_ENV = ( + f"{_TypePrefix.FEATURES}.{_DomainPrefix.GLOBAL}.config_source_cli_env" + ) + CONFIG_SOURCE_CLI_ARGS = ( + f"{_TypePrefix.FEATURES}.{_DomainPrefix.GLOBAL}.config_source_cli_args" + ) @dataclass diff --git a/tests/__snapshots__/test_help_messages.ambr b/tests/__snapshots__/test_help_messages.ambr index 25e07a8453..2b8175001a 100644 --- a/tests/__snapshots__/test_help_messages.ambr +++ b/tests/__snapshots__/test_help_messages.ambr @@ -4264,6 +4264,9 @@ Lists configured connections. +- Options --------------------------------------------------------------------+ + | --all -a Include connections from all sources (environment | + | variables, SnowSQL config). By default, only shows | + | connections from configuration files. | | --help -h Show this message and exit. | +------------------------------------------------------------------------------+ +- Global configuration -------------------------------------------------------+ diff --git a/tests/api/test_config_provider.py b/tests/api/test_config_provider.py new file mode 100644 index 0000000000..310a2f7d72 --- /dev/null +++ b/tests/api/test_config_provider.py @@ -0,0 +1,90 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +from snowflake.cli.api.config_provider import ( + ALTERNATIVE_CONFIG_ENV_VAR, + AlternativeConfigProvider, + LegacyConfigProvider, + get_config_provider, + reset_config_provider, +) + + +def test_legacy_provider_by_default(): + """Should use legacy provider when env var not set.""" + if ALTERNATIVE_CONFIG_ENV_VAR in os.environ: + del os.environ[ALTERNATIVE_CONFIG_ENV_VAR] + + reset_config_provider() + provider = get_config_provider() + assert isinstance(provider, LegacyConfigProvider) + + +def test_alternative_provider_when_enabled(): + """Should use alternative provider when env var is set.""" + os.environ[ALTERNATIVE_CONFIG_ENV_VAR] = "1" + + reset_config_provider() + provider = get_config_provider() + assert isinstance(provider, AlternativeConfigProvider) + + del os.environ[ALTERNATIVE_CONFIG_ENV_VAR] + + +@pytest.mark.parametrize("value", ["true", "True", "TRUE", "yes", "Yes", "on", "1"]) +def test_alternative_provider_various_values(value): + """Should enable alternative provider for various truthy values.""" + os.environ[ALTERNATIVE_CONFIG_ENV_VAR] = value + + reset_config_provider() + provider = get_config_provider() + assert isinstance(provider, AlternativeConfigProvider) + + del os.environ[ALTERNATIVE_CONFIG_ENV_VAR] + + +@pytest.mark.parametrize("value", ["0", "false", "False", "no", "off", ""]) +def test_legacy_provider_for_falsy_values(value): + """Should use legacy provider for falsy env var values.""" + os.environ[ALTERNATIVE_CONFIG_ENV_VAR] = value + + reset_config_provider() + provider = get_config_provider() + assert isinstance(provider, LegacyConfigProvider) + + del os.environ[ALTERNATIVE_CONFIG_ENV_VAR] + + +def test_provider_singleton(): + """Should return same instance on multiple calls.""" + reset_config_provider() + from snowflake.cli.api.config_provider import get_config_provider_singleton + + provider1 = get_config_provider_singleton() + provider2 = get_config_provider_singleton() + assert provider1 is provider2 + + +def test_reset_provider(): + """Should create new instance after reset.""" + reset_config_provider() + from snowflake.cli.api.config_provider import get_config_provider_singleton + + provider1 = get_config_provider_singleton() + reset_config_provider() + provider2 = get_config_provider_singleton() + assert provider1 is not provider2 diff --git a/tests/app/test_telemetry.py b/tests/app/test_telemetry.py index 8792e1f5e1..a4967cf0c1 100644 --- a/tests/app/test_telemetry.py +++ b/tests/app/test_telemetry.py @@ -19,6 +19,10 @@ import pytest import typer from click import ClickException +from snowflake.cli.api.config_provider import ( + ALTERNATIVE_CONFIG_ENV_VAR, + reset_config_provider, +) from snowflake.cli.api.constants import ObjectType from snowflake.cli.api.exceptions import CouldNotUseObjectError from snowflake.cli.api.feature_flags import BooleanFlag, FeatureFlagMixin @@ -41,52 +45,122 @@ class _TestFlags(FeatureFlagMixin): @mock.patch("snowflake.connector.connect") @mock.patch("snowflake.cli._plugins.connection.commands.ObjectManager") @with_feature_flags({_TestFlags.FOO: False}) -def test_executing_command_sends_telemetry_usage_data( +def test_executing_command_sends_telemetry_usage_data_legacy_config( _, mock_conn, mock_time, mock_uuid4, mock_platform, mock_version, runner ): - mock_time.return_value = "123" - mock_platform.return_value = "FancyOS" - mock_version.return_value = "2.3.4" - mock_uuid4.return_value = uuid.UUID("8a2225b3800c4017a4a9eab941db58fa") - result = runner.invoke(["connection", "test"], catch_exceptions=False) - assert result.exit_code == 0, result.output - # The method is called with a TelemetryData type, so we cast it to dict for simpler comparison - usage_command_event = ( - mock_conn.return_value._telemetry.try_add_log_to_batch.call_args_list[ # noqa: SLF001 - 0 - ] - .args[0] - .to_dict() - ) + """Test telemetry with legacy config provider.""" + # Ensure legacy config is used + with mock.patch.dict(os.environ, {}, clear=False): + if ALTERNATIVE_CONFIG_ENV_VAR in os.environ: + del os.environ[ALTERNATIVE_CONFIG_ENV_VAR] + reset_config_provider() - del usage_command_event["message"][ - "command_ci_environment" - ] # to avoid side effect from CI - assert usage_command_event == { - "message": { - "driver_type": "PythonConnector", - "driver_version": ".".join(str(s) for s in DRIVER_VERSION[:3]), - "source": "snowcli", - "version_cli": "0.0.0-test_patched", - "version_os": "FancyOS", - "version_python": "2.3.4", - "installation_source": "pypi", - "command": ["connection", "test"], - "command_group": "connection", - "command_execution_id": "8a2225b3800c4017a4a9eab941db58fa", - "command_flags": {"diag_log_path": "DEFAULT", "format": "DEFAULT"}, - "command_output_type": "TABLE", - "type": "executing_command", - "project_definition_version": "None", - "config_feature_flags": { - "dummy_flag": "True", - "foo": "False", - "wrong_type_flag": "UNKNOWN", + mock_time.return_value = "123" + mock_platform.return_value = "FancyOS" + mock_version.return_value = "2.3.4" + mock_uuid4.return_value = uuid.UUID("8a2225b3800c4017a4a9eab941db58fa") + result = runner.invoke(["connection", "test"], catch_exceptions=False) + assert result.exit_code == 0, result.output + # The method is called with a TelemetryData type, so we cast it to dict for simpler comparison + usage_command_event = ( + mock_conn.return_value._telemetry.try_add_log_to_batch.call_args_list[ # noqa: SLF001 + 0 + ] + .args[0] + .to_dict() + ) + + del usage_command_event["message"][ + "command_ci_environment" + ] # to avoid side effect from CI + assert usage_command_event == { + "message": { + "driver_type": "PythonConnector", + "driver_version": ".".join(str(s) for s in DRIVER_VERSION[:3]), + "source": "snowcli", + "version_cli": "0.0.0-test_patched", + "version_os": "FancyOS", + "version_python": "2.3.4", + "installation_source": "pypi", + "command": ["connection", "test"], + "command_group": "connection", + "command_execution_id": "8a2225b3800c4017a4a9eab941db58fa", + "command_flags": {"diag_log_path": "DEFAULT", "format": "DEFAULT"}, + "command_output_type": "TABLE", + "type": "executing_command", + "project_definition_version": "None", + "config_feature_flags": { + "dummy_flag": "True", + "foo": "False", + "wrong_type_flag": "UNKNOWN", + }, + "config_provider_type": "legacy", + "mode": "cmd", }, - "mode": "cmd", - }, - "timestamp": "123", - } + "timestamp": "123", + } + + +@mock.patch( + "snowflake.cli._app.telemetry.python_version", +) +@mock.patch("snowflake.cli._app.telemetry.platform.platform") +@mock.patch("uuid.uuid4") +@mock.patch("snowflake.cli._app.telemetry.get_time_millis") +@mock.patch("snowflake.connector.connect") +@mock.patch("snowflake.cli._plugins.connection.commands.ObjectManager") +@with_feature_flags({_TestFlags.FOO: False}) +def test_executing_command_sends_telemetry_usage_data_ng_config( + _, mock_conn, mock_time, mock_uuid4, mock_platform, mock_version, runner +): + """Test telemetry with NG config provider.""" + # Enable NG config + with mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: "true"}): + reset_config_provider() + + mock_time.return_value = "123" + mock_platform.return_value = "FancyOS" + mock_version.return_value = "2.3.4" + mock_uuid4.return_value = uuid.UUID("8a2225b3800c4017a4a9eab941db58fa") + result = runner.invoke(["connection", "test"], catch_exceptions=False) + assert result.exit_code == 0, result.output + + # The method is called with a TelemetryData type, so we cast it to dict for simpler comparison + usage_command_event = ( + mock_conn.return_value._telemetry.try_add_log_to_batch.call_args_list[ # noqa: SLF001 + 0 + ] + .args[0] + .to_dict() + ) + + del usage_command_event["message"][ + "command_ci_environment" + ] # to avoid side effect from CI + + # Verify common fields + message = usage_command_event["message"] + assert message["driver_type"] == "PythonConnector" + assert message["source"] == "snowcli" + assert message["version_cli"] == "0.0.0-test_patched" + assert message["version_os"] == "FancyOS" + assert message["version_python"] == "2.3.4" + assert message["command"] == ["connection", "test"] + assert message["command_group"] == "connection" + assert message["type"] == "executing_command" + assert message["config_provider_type"] == "ng" + + # Verify NG-specific config fields are present + assert "config_sources_used" in message + assert "config_source_wins" in message + assert "config_total_keys_resolved" in message + assert "config_keys_with_overrides" in message + + # These fields should be present (values will vary based on test config) + assert isinstance(message["config_sources_used"], list) + assert isinstance(message["config_source_wins"], dict) + assert isinstance(message["config_total_keys_resolved"], int) + assert isinstance(message["config_keys_with_overrides"], int) @pytest.mark.parametrize( diff --git a/tests/config_ng/__init__.py b/tests/config_ng/__init__.py new file mode 100644 index 0000000000..c64cc0301a --- /dev/null +++ b/tests/config_ng/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the Enhanced Configuration System (config_ng). +""" diff --git a/tests/config_ng/conftest.py b/tests/config_ng/conftest.py new file mode 100644 index 0000000000..adb09d3f96 --- /dev/null +++ b/tests/config_ng/conftest.py @@ -0,0 +1,143 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Configuration testing utilities for config_ng tests. + +Provides fixtures for setting up temporary configuration environments. +""" + +import copy +import os +import tempfile +from contextlib import contextmanager +from pathlib import Path +from textwrap import dedent +from typing import Dict, Optional + +import pytest + + +@contextmanager +def _temp_environment(env_vars: Dict[str, str]): + """ + Context manager for temporarily setting environment variables. + + Saves the entire environment, applies new variables, then restores + the original environment completely on exit. + + Args: + env_vars: Dictionary of environment variables to set + + Yields: + None + """ + original_env = copy.deepcopy(dict(os.environ)) + try: + os.environ.update(env_vars) + yield + finally: + os.environ.clear() + os.environ.update(original_env) + + +@pytest.fixture +def config_ng_setup(): + """ + Fixture that provides a context manager for setting up config_ng test environments. + + Returns a context manager function that: + 1. Creates temp SNOWFLAKE_HOME + 2. Writes config files + 3. Sets env vars + 4. Enables config_ng + 5. Resets provider + 6. Yields (test can now call get_connection_dict()) + 7. Cleans up + + Usage: + def test_something(config_ng_setup): + with config_ng_setup( + cli_config="[connections.test]\\naccount = 'test'", + env_vars={"SNOWFLAKE_USER": "alice"} + ): + from snowflake.cli.api.config import get_connection_dict + conn = get_connection_dict("test") + assert conn["account"] == "test" + + Args (to returned context manager): + snowsql_config: SnowSQL INI config content (will be dedented) + cli_config: CLI TOML config content (will be dedented) + connections_toml: Connections TOML content (will be dedented) + env_vars: Environment variables to set + """ + + @contextmanager + def _setup( + snowsql_config: Optional[str] = None, + cli_config: Optional[str] = None, + connections_toml: Optional[str] = None, + env_vars: Optional[Dict[str, str]] = None, + ): + with tempfile.TemporaryDirectory() as tmpdir: + snowflake_home = Path(tmpdir) / ".snowflake" + snowflake_home.mkdir() + + # Write config files if provided + if snowsql_config: + (snowflake_home / "config").write_text(dedent(snowsql_config)) + if cli_config: + (snowflake_home / "config.toml").write_text(dedent(cli_config)) + if connections_toml: + (snowflake_home / "connections.toml").write_text( + dedent(connections_toml) + ) + + # Prepare environment variables + env_to_set = { + "SNOWFLAKE_HOME": str(snowflake_home), + "SNOWFLAKE_CLI_CONFIG_V2_ENABLED": "true", + } + if env_vars: + env_to_set.update(env_vars) + + # Set up environment and run test + with _temp_environment(env_to_set): + # Clear config_file_override to use SNOWFLAKE_HOME instead + from snowflake.cli.api.cli_global_context import ( + get_cli_context_manager, + ) + + cli_ctx_mgr = get_cli_context_manager() + original_config_override = cli_ctx_mgr.config_file_override + cli_ctx_mgr.config_file_override = None + + try: + # Reset config provider to use new config + from snowflake.cli.api.config_provider import reset_config_provider + + reset_config_provider() + + yield + + finally: + # Restore config_file_override + if original_config_override is not None: + cli_ctx_mgr = get_cli_context_manager() + cli_ctx_mgr.config_file_override = original_config_override + + # Reset config provider + reset_config_provider() + + return _setup diff --git a/tests/config_ng/test_config_value.py b/tests/config_ng/test_config_value.py new file mode 100644 index 0000000000..cfd9748065 --- /dev/null +++ b/tests/config_ng/test_config_value.py @@ -0,0 +1,190 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for ConfigValue dataclass. + +Tests verify: +- Field values and types +- Raw value preservation +- Type conversions +- Representation formatting +""" + +from snowflake.cli.api.config_ng.core import ConfigValue + + +class TestConfigValue: + """Test suite for ConfigValue dataclass.""" + + def test_create_basic_config_value(self): + """Should create a basic ConfigValue with required fields.""" + cv = ConfigValue( + key="account", + value="my_account", + source_name="cli_arguments", + ) + + assert cv.key == "account" + assert cv.value == "my_account" + assert cv.source_name == "cli_arguments" + assert cv.raw_value is None + + def test_create_config_value_with_raw_value(self): + """Should create ConfigValue with raw value preservation.""" + cv = ConfigValue( + key="port", + value=443, + source_name="cli_env", + raw_value="443", + ) + + assert cv.key == "port" + assert cv.value == 443 + assert cv.raw_value == "443" + assert isinstance(cv.value, int) + assert isinstance(cv.raw_value, str) + + def test_repr_without_conversion(self): + """__repr__ should show value only when no conversion occurred.""" + cv = ConfigValue( + key="account", + value="my_account", + source_name="cli_arguments", + ) + + repr_str = repr(cv) + assert "account=my_account" in repr_str + assert "cli_arguments" in repr_str + assert "→" not in repr_str + + def test_repr_with_conversion(self): + """__repr__ should show conversion when raw_value differs from value.""" + cv = ConfigValue( + key="port", + value=443, + source_name="cli_env", + raw_value="443", + ) + + repr_str = repr(cv) + assert "port" in repr_str + assert "443" in repr_str + assert "→" in repr_str + assert "cli_env" in repr_str + + def test_repr_with_same_raw_and_parsed_value(self): + """__repr__ should not show conversion when values are the same.""" + cv = ConfigValue( + key="account", + value="my_account", + source_name="cli_arguments", + raw_value="my_account", + ) + + repr_str = repr(cv) + assert "→" not in repr_str + + def test_boolean_conversion_example(self): + """Should handle boolean conversion from string.""" + cv = ConfigValue( + key="enable_diag", + value=True, + source_name="cli_env", + raw_value="true", + ) + + assert cv.value is True + assert cv.raw_value == "true" + assert isinstance(cv.value, bool) + assert isinstance(cv.raw_value, str) + + def test_integer_conversion_example(self): + """Should handle integer conversion from string.""" + cv = ConfigValue( + key="timeout", + value=30, + source_name="cli_env", + raw_value="30", + ) + + assert cv.value == 30 + assert cv.raw_value == "30" + assert isinstance(cv.value, int) + assert isinstance(cv.raw_value, str) + + def test_snowsql_key_mapping_example(self): + """Should preserve original SnowSQL key in raw_value.""" + cv = ConfigValue( + key="account", + value="my_account", + source_name="snowsql_config", + raw_value="accountname=my_account", + ) + + assert cv.key == "account" + assert cv.value == "my_account" + assert cv.raw_value == "accountname=my_account" + + def test_none_value(self): + """Should handle None as a value.""" + cv = ConfigValue( + key="optional_field", + value=None, + source_name="cli_arguments", + ) + + assert cv.value is None + assert cv.key == "optional_field" + + def test_complex_value_types(self): + """Should handle complex value types like lists and dicts.""" + cv_list = ConfigValue( + key="tags", + value=["tag1", "tag2"], + source_name="connections_toml", + ) + + cv_dict = ConfigValue( + key="metadata", + value={"key1": "value1", "key2": "value2"}, + source_name="connections_toml", + ) + + assert cv_list.value == ["tag1", "tag2"] + assert cv_dict.value == {"key1": "value1", "key2": "value2"} + + def test_different_source_names(self): + """Should work with different source names.""" + cv_cli = ConfigValue( + key="account", + value="cli_account", + source_name="cli_arguments", + ) + + cv_env = ConfigValue( + key="account", + value="env_account", + source_name="cli_env", + ) + + cv_file = ConfigValue( + key="account", + value="file_account", + source_name="connections_toml", + ) + + assert cv_cli.source_name == "cli_arguments" + assert cv_env.source_name == "cli_env" + assert cv_file.source_name == "connections_toml" diff --git a/tests/config_ng/test_configuration.py b/tests/config_ng/test_configuration.py new file mode 100644 index 0000000000..ad4d6c953f --- /dev/null +++ b/tests/config_ng/test_configuration.py @@ -0,0 +1,1038 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simplified tests for config_ng using minimal mocking. + +This test file demonstrates a simpler approach to testing configuration +resolution by: +1. Setting up temporary SNOWFLAKE_HOME with config files +2. Setting environment variables directly +3. Calling get_connection_dict() to test the actual public API +4. Minimal mocking - only using the real config_ng system + +Uses the config_ng_setup fixture from conftest.py. +""" + + +# Tests for all 7 precedence levels + + +def test_level1_snowsql_config(config_ng_setup): + """Base level: SnowSQL config provides values""" + snowsql_config = """ + [connections.test] + accountname = from-snowsql + user = test-user + password = test-password + """ + + with config_ng_setup(snowsql_config=snowsql_config): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + assert conn["account"] == "from-snowsql" + assert conn["user"] == "test-user" + assert conn["password"] == "test-password" + + +def test_level2_cli_config_overrides_snowsql(config_ng_setup): + """CLI config.toml overrides SnowSQL config""" + snowsql_config = """ + [connections.test] + accountname = from-snowsql + user = snowsql-user + """ + + cli_config = """ + [connections.test] + account = "from-cli-config" + user = "cli-config-user" + """ + + with config_ng_setup(snowsql_config=snowsql_config, cli_config=cli_config): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + assert conn["account"] == "from-cli-config" + assert conn["user"] == "cli-config-user" + + +def test_level3_connections_toml_overrides_cli_config(config_ng_setup): + """connections.toml overrides cli config.toml""" + cli_config = """ + [connections.test] + account = "from-cli-config" + warehouse = "cli-warehouse" + """ + + connections_toml = """ + [connections.test] + account = "from-connections-toml" + warehouse = "connections-warehouse" + """ + + with config_ng_setup(cli_config=cli_config, connections_toml=connections_toml): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + assert conn["account"] == "from-connections-toml" + assert conn["warehouse"] == "connections-warehouse" + + +def test_level4_snowsql_env_overrides_connections_toml(config_ng_setup): + """SNOWSQL_* env vars override connections.toml""" + connections_toml = """ + [connections.test] + account = "from-connections-toml" + database = "connections-db" + """ + + env_vars = {"SNOWSQL_ACCOUNT": "from-snowsql-env", "SNOWSQL_DATABASE": "env-db"} + + with config_ng_setup(connections_toml=connections_toml, env_vars=env_vars): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + assert conn["account"] == "from-snowsql-env" + assert conn["database"] == "env-db" + + +def test_level5_connection_specific_env_overrides_snowsql_env(config_ng_setup): + """SNOWFLAKE_CONNECTIONS_* overrides SNOWSQL_*""" + env_vars = { + "SNOWSQL_ACCOUNT": "from-snowsql-env", + "SNOWFLAKE_CONNECTIONS_TEST_ACCOUNT": "from-conn-specific-env", + "SNOWFLAKE_CONNECTIONS_TEST_ROLE": "conn-specific-role", + } + + with config_ng_setup(env_vars=env_vars): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + assert conn["account"] == "from-conn-specific-env" + assert conn["role"] == "conn-specific-role" + + +def test_level6_general_env_overrides_connection_specific(config_ng_setup): + """SNOWFLAKE_* overrides SNOWFLAKE_CONNECTIONS_*""" + env_vars = { + "SNOWFLAKE_CONNECTIONS_TEST_ACCOUNT": "from-conn-specific", + "SNOWFLAKE_ACCOUNT": "from-general-env", + "SNOWFLAKE_SCHEMA": "general-schema", + } + + with config_ng_setup(env_vars=env_vars): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + assert conn["account"] == "from-general-env" + assert conn["schema"] == "general-schema" + + +def test_complete_7_level_chain(config_ng_setup): + """All 7 levels with different keys showing complete precedence""" + snowsql_config = """ + [connections.test] + accountname = level1 + user = level1-user + """ + + cli_config = """ + [connections.test] + account = "level2" + password = "level2-pass" + """ + + connections_toml = """ + [connections.test] + account = "level3" + warehouse = "level3-wh" + """ + + env_vars = { + "SNOWSQL_ACCOUNT": "level4", + "SNOWSQL_DATABASE": "level4-db", + "SNOWFLAKE_CONNECTIONS_TEST_ACCOUNT": "level5", + "SNOWFLAKE_CONNECTIONS_TEST_ROLE": "level5-role", + "SNOWFLAKE_ACCOUNT": "level6", + "SNOWFLAKE_SCHEMA": "level6-schema", + } + + with config_ng_setup( + snowsql_config=snowsql_config, + cli_config=cli_config, + connections_toml=connections_toml, + env_vars=env_vars, + ): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + # Level 6 should win for account (general env) + assert conn["account"] == "level6" + + # Level 6 provides schema (only level with it) + assert conn["schema"] == "level6-schema" + + # Level 5 provides role (highest level with it) + assert conn["role"] == "level5-role" + + # Level 4 provides database + assert conn["database"] == "level4-db" + + # Level 3 provides warehouse + assert conn["warehouse"] == "level3-wh" + + # Level 2 (cli_config) is skipped because connections.toml defines this connection + # password from cli_config is NOT present + + # user NOT in connection - connections_toml (level 3) replaced entire connection + # and didn't include user field + assert "user" not in conn + + +def test_get_connection_dict_uses_config_ng_when_enabled(config_ng_setup): + """Validate that get_connection_dict delegates to config_ng when flag is set""" + + cli_config = """ + [connections.test] + account = "test-account" + user = "test-user" + """ + + with config_ng_setup(cli_config=cli_config): + from snowflake.cli.api.config import get_connection_dict + from snowflake.cli.api.config_provider import ( + AlternativeConfigProvider, + get_config_provider_singleton, + ) + + # Verify we're using AlternativeConfigProvider + provider = get_config_provider_singleton() + assert isinstance(provider, AlternativeConfigProvider) + + # Verify resolution works + conn = get_connection_dict("test") + + assert conn["account"] == "test-account" + assert conn["user"] == "test-user" + + +def test_precedence_with_multiple_connections(config_ng_setup): + """Test that precedence works correctly for multiple connections""" + cli_config = """ + [connections.conn1] + account = "conn1-account" + user = "conn1-user" + + [connections.conn2] + account = "conn2-account" + user = "conn2-user" + """ + + env_vars = { + "SNOWFLAKE_CONNECTIONS_CONN1_ACCOUNT": "conn1-env", + "SNOWFLAKE_SCHEMA": "common-schema", + } + + with config_ng_setup(cli_config=cli_config, env_vars=env_vars): + from snowflake.cli.api.config import get_connection_dict + + # conn1 should have env override + conn1 = get_connection_dict("conn1") + assert conn1["account"] == "conn1-env" # From connection-specific env + assert conn1["user"] == "conn1-user" # From config file + assert conn1["schema"] == "common-schema" # From general env + + # conn2 should use config values + conn2 = get_connection_dict("conn2") + assert conn2["account"] == "conn2-account" # From config file + assert conn2["user"] == "conn2-user" # From config file + assert conn2["schema"] == "common-schema" # From general env + + +def test_snowsql_key_mapping(config_ng_setup): + """Test that SnowSQL key names are properly mapped to CLI names""" + snowsql_config = """ + [connections.test] + accountname = test-account + username = test-user + dbname = test-db + schemaname = test-schema + rolename = test-role + warehousename = test-warehouse + pwd = test-password + """ + + with config_ng_setup(snowsql_config=snowsql_config): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + # All SnowSQL names should be mapped to CLI names + assert conn["account"] == "test-account" + assert conn["user"] == "test-user" + assert conn["database"] == "test-db" + assert conn["schema"] == "test-schema" + assert conn["role"] == "test-role" + assert conn["warehouse"] == "test-warehouse" + assert conn["password"] == "test-password" + + +def test_empty_config_files(config_ng_setup): + """Test behavior with empty/missing config files""" + # Only set env vars, no config files + env_vars = { + "SNOWFLAKE_ACCOUNT": "env-only-account", + "SNOWFLAKE_USER": "env-only-user", + } + + with config_ng_setup(env_vars=env_vars): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("default") + assert conn["account"] == "env-only-account" + assert conn["user"] == "env-only-user" + + +# Group 1: Non-Adjacent 2-Source Tests + + +def test_snowsql_config_with_snowsql_env_direct(config_ng_setup): + """Test SnowSQL env overrides SnowSQL config when intermediate sources absent""" + snowsql_config = """ + [connections.test] + accountname = from-config + user = config-user + database = config-db + """ + + env_vars = {"SNOWSQL_ACCOUNT": "from-env", "SNOWSQL_DATABASE": "env-db"} + + with config_ng_setup(snowsql_config=snowsql_config, env_vars=env_vars): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + # Level 4 (SnowSQL env) wins for account and database + # Level 1 (SnowSQL config) wins for user + expected = { + "account": "from-env", + "user": "config-user", + "database": "env-db", + } + assert conn == expected + + +def test_snowsql_config_with_general_env_direct(config_ng_setup): + """Test general env overrides SnowSQL config across all intermediate levels""" + snowsql_config = """ + [connections.test] + accountname = from-config + user = config-user + warehouse = config-warehouse + """ + + env_vars = {"SNOWFLAKE_ACCOUNT": "from-env", "SNOWFLAKE_WAREHOUSE": "env-warehouse"} + + with config_ng_setup(snowsql_config=snowsql_config, env_vars=env_vars): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + expected = { + "account": "from-env", + "user": "config-user", + "warehouse": "env-warehouse", + } + assert conn == expected + + +def test_cli_config_with_general_env_direct(config_ng_setup): + """Test general env overrides CLI config when intermediate sources absent""" + cli_config = """ + [connections.test] + account = "from-cli" + user = "cli-user" + role = "cli-role" + """ + + env_vars = {"SNOWFLAKE_ACCOUNT": "from-env", "SNOWFLAKE_ROLE": "env-role"} + + with config_ng_setup(cli_config=cli_config, env_vars=env_vars): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + expected = { + "account": "from-env", + "user": "cli-user", + "role": "env-role", + } + assert conn == expected + + +def test_connections_toml_with_general_env_direct(config_ng_setup): + """Test general env overrides Connections TOML directly""" + connections_toml = """ + [connections.test] + account = "from-toml" + user = "toml-user" + schema = "toml-schema" + """ + + env_vars = {"SNOWFLAKE_ACCOUNT": "from-env", "SNOWFLAKE_SCHEMA": "env-schema"} + + with config_ng_setup(connections_toml=connections_toml, env_vars=env_vars): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + expected = { + "account": "from-env", + "user": "toml-user", + "schema": "env-schema", + } + assert conn == expected + + +# Group 2: Strategic 3-Source Tests + + +def test_all_file_sources_precedence(config_ng_setup): + """Test precedence among all three file-based sources""" + snowsql_config = """ + [connections.test] + accountname = from-snowsql + user = snowsql-user + warehouse = snowsql-warehouse + password = snowsql-pass + """ + + cli_config = """ + [connections.test] + account = "from-cli" + user = "cli-user" + password = "cli-pass" + """ + + connections_toml = """ + [connections.test] + account = "from-connections" + password = "connections-pass" + """ + + with config_ng_setup( + snowsql_config=snowsql_config, + cli_config=cli_config, + connections_toml=connections_toml, + ): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + expected = { + "account": "from-connections", # connections_toml (Level 3) wins + # user and warehouse NOT inherited - connections_toml replaced entire connection + "password": "connections-pass", # From connections_toml (FILE) + } + assert conn == expected + + +def test_all_env_sources_precedence(config_ng_setup): + """Test precedence among all three environment variable types""" + env_vars = { + "SNOWSQL_ACCOUNT": "snowsql-env", + "SNOWSQL_DATABASE": "snowsql-db", + "SNOWFLAKE_CONNECTIONS_TEST_ACCOUNT": "conn-specific", + "SNOWFLAKE_CONNECTIONS_TEST_ROLE": "conn-role", + "SNOWFLAKE_ACCOUNT": "general-env", + "SNOWFLAKE_SCHEMA": "general-schema", + } + + with config_ng_setup(env_vars=env_vars): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + expected = { + "account": "general-env", # Level 6 wins + "schema": "general-schema", # Level 6 only source + "role": "conn-role", # Level 5 only source + "database": "snowsql-db", # Level 4 only source + } + assert conn == expected + + +def test_file_and_env_mix_with_gaps(config_ng_setup): + """Test precedence with gaps in source chain""" + snowsql_config = """ + [connections.test] + accountname = snowsql-account + user = snowsql-user + """ + + connections_toml = """ + [connections.test] + account = "toml-account" + warehouse = "toml-warehouse" + """ + + env_vars = {"SNOWFLAKE_ACCOUNT": "env-account"} + + with config_ng_setup( + snowsql_config=snowsql_config, + connections_toml=connections_toml, + env_vars=env_vars, + ): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + expected = { + "account": "env-account", # Level 6 OVERLAY wins + # user NOT inherited - connections_toml replaced entire snowsql connection + "warehouse": "toml-warehouse", # From connections_toml (FILE) + } + assert conn == expected + + +def test_cli_config_with_two_env_types(config_ng_setup): + """Test CLI config as base with two env override types""" + cli_config = """ + [connections.test] + account = "cli-account" + user = "cli-user" + database = "cli-db" + """ + + env_vars = { + "SNOWSQL_ACCOUNT": "snowsql-env", + "SNOWFLAKE_CONNECTIONS_TEST_ACCOUNT": "conn-specific", + "SNOWFLAKE_CONNECTIONS_TEST_DATABASE": "conn-db", + } + + with config_ng_setup(cli_config=cli_config, env_vars=env_vars): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + expected = { + "account": "conn-specific", # Level 5 wins + "user": "cli-user", # Level 2 only source + "database": "conn-db", # Level 5 wins + } + assert conn == expected + + +# Group 3: 4-Source Combinations + + +def test_all_files_plus_snowsql_env(config_ng_setup): + """Test all file sources with SnowSQL environment override""" + snowsql_config = """ + [connections.test] + accountname = snowsql-config + user = snowsql-user + """ + + cli_config = """ + [connections.test] + account = "cli-config" + warehouse = "cli-warehouse" + """ + + connections_toml = """ + [connections.test] + account = "toml-account" + database = "toml-db" + """ + + env_vars = {"SNOWSQL_ACCOUNT": "env-account"} + + with config_ng_setup( + snowsql_config=snowsql_config, + cli_config=cli_config, + connections_toml=connections_toml, + env_vars=env_vars, + ): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + expected = { + "account": "env-account", # Level 4 OVERLAY wins + # user NOT inherited - connections_toml replaced entire connection chain + "database": "toml-db", # From connections_toml (FILE) + } + assert conn == expected + + +def test_all_files_plus_general_env(config_ng_setup): + """Test all file sources with general environment override""" + snowsql_config = """ + [connections.test] + accountname = snowsql-config + user = snowsql-user + """ + + cli_config = """ + [connections.test] + account = "cli-config" + role = "cli-role" + """ + + connections_toml = """ + [connections.test] + account = "toml-account" + warehouse = "toml-warehouse" + """ + + env_vars = { + "SNOWFLAKE_ACCOUNT": "env-account", + "SNOWFLAKE_WAREHOUSE": "env-warehouse", + } + + with config_ng_setup( + snowsql_config=snowsql_config, + cli_config=cli_config, + connections_toml=connections_toml, + env_vars=env_vars, + ): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + expected = { + "account": "env-account", # Level 6 OVERLAY wins + # user NOT inherited - connections_toml replaced entire connection chain + # role from cli_config is skipped (connections.toml replaces cli_config) + "warehouse": "env-warehouse", # Level 6 OVERLAY wins + } + assert conn == expected + + +def test_cli_config_with_all_env_types(config_ng_setup): + """Test single file source with all three environment types""" + cli_config = """ + [connections.test] + account = "cli-account" + user = "cli-user" + """ + + env_vars = { + "SNOWSQL_ACCOUNT": "snowsql-env", + "SNOWSQL_DATABASE": "snowsql-db", + "SNOWFLAKE_CONNECTIONS_TEST_ACCOUNT": "conn-specific", + "SNOWFLAKE_CONNECTIONS_TEST_ROLE": "conn-role", + "SNOWFLAKE_ACCOUNT": "general-env", + "SNOWFLAKE_SCHEMA": "general-schema", + } + + with config_ng_setup(cli_config=cli_config, env_vars=env_vars): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + expected = { + "account": "general-env", # Level 6 wins + "user": "cli-user", # Level 2 only + "database": "snowsql-db", # Level 4 only + "role": "conn-role", # Level 5 only + "schema": "general-schema", # Level 6 only + } + assert conn == expected + + +def test_two_files_two_envs_with_gap(config_ng_setup): + """Test non-adjacent file sources with non-adjacent env sources""" + snowsql_config = """ + [connections.test] + accountname = snowsql-config + user = snowsql-user + """ + + connections_toml = """ + [connections.test] + account = "toml-account" + warehouse = "toml-warehouse" + """ + + env_vars = { + "SNOWFLAKE_CONNECTIONS_TEST_ACCOUNT": "conn-specific", + "SNOWFLAKE_CONNECTIONS_TEST_DATABASE": "conn-db", + "SNOWFLAKE_ACCOUNT": "general-env", + "SNOWFLAKE_SCHEMA": "general-schema", + } + + with config_ng_setup( + snowsql_config=snowsql_config, + connections_toml=connections_toml, + env_vars=env_vars, + ): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + expected = { + "account": "general-env", # Level 6 OVERLAY wins + # user NOT inherited - connections_toml replaced entire connection + "warehouse": "toml-warehouse", # From connections_toml (FILE) + "database": "conn-db", # Level 5 OVERLAY + "schema": "general-schema", # Level 6 OVERLAY + } + assert conn == expected + + +# Group 4: 5-Source Combinations + + +def test_all_files_plus_two_env_types(config_ng_setup): + """Test all file sources with two environment override types""" + snowsql_config = """ + [connections.test] + accountname = snowsql-config + user = snowsql-user + """ + + cli_config = """ + [connections.test] + account = "cli-config" + password = "cli-password" + """ + + connections_toml = """ + [connections.test] + account = "toml-account" + warehouse = "toml-warehouse" + """ + + env_vars = { + "SNOWSQL_ACCOUNT": "snowsql-env", + "SNOWFLAKE_CONNECTIONS_TEST_ACCOUNT": "conn-specific", + "SNOWFLAKE_CONNECTIONS_TEST_WAREHOUSE": "conn-warehouse", + } + + with config_ng_setup( + snowsql_config=snowsql_config, + cli_config=cli_config, + connections_toml=connections_toml, + env_vars=env_vars, + ): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + expected = { + "account": "conn-specific", # Level 5 OVERLAY wins + # user NOT inherited - connections_toml replaced entire connection + "warehouse": "conn-warehouse", # Level 5 OVERLAY (on top of toml) + } + assert conn == expected + + +def test_two_files_all_envs(config_ng_setup): + """Test two file sources with all three environment types""" + snowsql_config = """ + [connections.test] + accountname = snowsql-config + user = snowsql-user + """ + + cli_config = """ + [connections.test] + account = "cli-config" + password = "cli-password" + """ + + env_vars = { + "SNOWSQL_ACCOUNT": "snowsql-env", + "SNOWSQL_DATABASE": "snowsql-db", + "SNOWFLAKE_CONNECTIONS_TEST_ACCOUNT": "conn-specific", + "SNOWFLAKE_CONNECTIONS_TEST_ROLE": "conn-role", + "SNOWFLAKE_ACCOUNT": "general-env", + "SNOWFLAKE_WAREHOUSE": "general-warehouse", + } + + with config_ng_setup( + snowsql_config=snowsql_config, + cli_config=cli_config, + env_vars=env_vars, + ): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + expected = { + "account": "general-env", # Level 6 OVERLAY wins + # user NOT inherited - cli_config (level 2) replaced snowsql connection + "password": "cli-password", # From cli_config (FILE) + "database": "snowsql-db", # Level 4 OVERLAY + "role": "conn-role", # Level 5 OVERLAY + "warehouse": "general-warehouse", # Level 6 OVERLAY + } + assert conn == expected + + +def test_connections_toml_with_all_env_types(config_ng_setup): + """Test Connections TOML with all environment override types""" + connections_toml = """ + [connections.test] + account = "toml-account" + user = "toml-user" + warehouse = "toml-warehouse" + """ + + env_vars = { + "SNOWSQL_ACCOUNT": "snowsql-env", + "SNOWSQL_DATABASE": "snowsql-db", + "SNOWFLAKE_CONNECTIONS_TEST_WAREHOUSE": "conn-warehouse", + "SNOWFLAKE_CONNECTIONS_TEST_ROLE": "conn-role", + "SNOWFLAKE_ACCOUNT": "general-env", + "SNOWFLAKE_SCHEMA": "general-schema", + } + + with config_ng_setup(connections_toml=connections_toml, env_vars=env_vars): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + expected = { + "account": "general-env", # Level 6 wins + "user": "toml-user", # Level 3 only + "warehouse": "conn-warehouse", # Level 5 wins + "database": "snowsql-db", # Level 4 only + "role": "conn-role", # Level 5 only + "schema": "general-schema", # Level 6 only + } + assert conn == expected + + +def test_snowsql_and_connections_with_all_envs(config_ng_setup): + """Test two non-adjacent file sources with all environment types""" + snowsql_config = """ + [connections.test] + accountname = snowsql-config + user = snowsql-user + password = snowsql-password + """ + + connections_toml = """ + [connections.test] + account = "toml-account" + warehouse = "toml-warehouse" + """ + + env_vars = { + "SNOWSQL_ACCOUNT": "snowsql-env", + "SNOWSQL_WAREHOUSE": "snowsql-warehouse", + "SNOWFLAKE_CONNECTIONS_TEST_PASSWORD": "conn-password", + "SNOWFLAKE_CONNECTIONS_TEST_ROLE": "conn-role", + "SNOWFLAKE_ACCOUNT": "general-env", + "SNOWFLAKE_DATABASE": "general-db", + } + + with config_ng_setup( + snowsql_config=snowsql_config, + connections_toml=connections_toml, + env_vars=env_vars, + ): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + expected = { + "account": "general-env", # Level 6 OVERLAY wins + # user NOT inherited - connections_toml replaced snowsql connection + "password": "conn-password", # Level 5 OVERLAY + "warehouse": "snowsql-warehouse", # Level 4 OVERLAY (on top of toml FILE) + "role": "conn-role", # Level 5 OVERLAY + "database": "general-db", # Level 6 OVERLAY + } + assert conn == expected + + +# Group 5: Edge Cases + + +def test_multiple_connections_different_source_patterns(config_ng_setup): + """Test that different connections can have different active sources""" + cli_config = """ + [connections.conn1] + account = "conn1-cli" + user = "conn1-user" + + [connections.conn2] + account = "conn2-cli" + user = "conn2-user" + """ + + connections_toml = """ + [connections.conn1] + warehouse = "conn1-warehouse" + + [connections.conn3] + account = "conn3-toml" + user = "conn3-user" + """ + + env_vars = { + "SNOWFLAKE_CONNECTIONS_CONN1_ACCOUNT": "conn1-env", + "SNOWFLAKE_CONNECTIONS_CONN2_DATABASE": "conn2-db", + "SNOWFLAKE_SCHEMA": "common-schema", + } + + with config_ng_setup( + cli_config=cli_config, connections_toml=connections_toml, env_vars=env_vars + ): + from snowflake.cli.api.config import get_connection_dict + + conn1 = get_connection_dict("conn1") + expected1 = { + "account": "conn1-env", # Connection-specific env wins + # user from cli_config is skipped (connections.toml replaces cli_config) + "warehouse": "conn1-warehouse", # Connections TOML + "schema": "common-schema", # General env + } + assert conn1 == expected1 + + conn2 = get_connection_dict("conn2") + expected2 = { + "account": "conn2-cli", # CLI config + "user": "conn2-user", # CLI config + "database": "conn2-db", # Connection-specific env + "schema": "common-schema", # General env + } + assert conn2 == expected2 + + conn3 = get_connection_dict("conn3") + expected3 = { + "account": "conn3-toml", # Connections TOML + "user": "conn3-user", # Connections TOML + "schema": "common-schema", # General env + } + assert conn3 == expected3 + + +def test_snowsql_key_mapping_with_precedence(config_ng_setup): + """Test SnowSQL legacy key names work correctly across precedence levels""" + snowsql_config = """ + [connections.test] + accountname = snowsql-account + username = snowsql-user + dbname = snowsql-db + schemaname = snowsql-schema + rolename = snowsql-role + warehousename = snowsql-warehouse + """ + + cli_config = """ + [connections.test] + account = "cli-account" + database = "cli-db" + """ + + env_vars = { + "SNOWFLAKE_ACCOUNT": "env-account", + "SNOWFLAKE_SCHEMA": "env-schema", + } + + with config_ng_setup( + snowsql_config=snowsql_config, + cli_config=cli_config, + env_vars=env_vars, + ): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + expected = { + "account": "env-account", # Level 6 OVERLAY wins + # user, role, warehouse NOT inherited - cli_config replaced snowsql connection + "database": "cli-db", # From cli_config (FILE) + "schema": "env-schema", # Level 6 OVERLAY + } + assert conn == expected + + +def test_empty_intermediate_sources_dont_break_chain(config_ng_setup): + """Test that empty config files don't prevent higher sources from working""" + snowsql_config = """ + [connections.test] + accountname = snowsql-account + user = snowsql-user + """ + + # Empty CLI config and connections.toml + cli_config = "" + connections_toml = "" + + env_vars = {"SNOWFLAKE_ACCOUNT": "env-account"} + + with config_ng_setup( + snowsql_config=snowsql_config, + cli_config=cli_config, + connections_toml=connections_toml, + env_vars=env_vars, + ): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + expected = { + "account": "env-account", # Level 6 wins + "user": "snowsql-user", # Level 1 only + } + assert conn == expected + + +def test_account_parameter_across_all_sources(config_ng_setup): + """Test account parameter defined in all sources follows precedence""" + snowsql_config = """ + [connections.test] + accountname = level1-account + """ + + cli_config = """ + [connections.test] + account = "level2-account" + """ + + connections_toml = """ + [connections.test] + account = "level3-account" + """ + + env_vars = { + "SNOWSQL_ACCOUNT": "level4-account", + "SNOWFLAKE_CONNECTIONS_TEST_ACCOUNT": "level5-account", + "SNOWFLAKE_ACCOUNT": "level6-account", + } + + with config_ng_setup( + snowsql_config=snowsql_config, + cli_config=cli_config, + connections_toml=connections_toml, + env_vars=env_vars, + ): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + # Only account should be present since all sources only provide account + expected = { + "account": "level6-account", # Level 6 (general env) wins + } + assert conn == expected diff --git a/tests/config_ng/test_connection_replacement.py b/tests/config_ng/test_connection_replacement.py new file mode 100644 index 0000000000..9f572d8efb --- /dev/null +++ b/tests/config_ng/test_connection_replacement.py @@ -0,0 +1,462 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for connection-level replacement behavior in config_ng. + +These tests verify that: +1. FILE sources (snowsql_config, cli_config_toml, connections_toml) use + connection-level replacement (later file replaces entire connection) +2. OVERLAY sources (env vars, CLI args) use field-level overlay +3. SnowSQL's multi-file merge acts as a single FILE source +""" + + +def test_file_replacement_basic(config_ng_setup): + """ + Test basic file replacement: later file replaces entire connection. + Fields from earlier file should NOT be inherited. + """ + snowsql_config = """ + [connections.test] + accountname = snowsql-account + user = snowsql-user + warehouse = snowsql-warehouse + database = snowsql-database + """ + + cli_config = """ + [connections.test] + account = "cli-account" + user = "cli-user" + # Note: warehouse and database are NOT included + """ + + with config_ng_setup(snowsql_config=snowsql_config, cli_config=cli_config): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + # Values from cli_config (later FILE source) + assert conn["account"] == "cli-account" + assert conn["user"] == "cli-user" + + # warehouse and database from snowsql NOT inherited (connection replaced) + assert "warehouse" not in conn + assert "database" not in conn + + +def test_file_replacement_connections_toml_replaces_cli_config(config_ng_setup): + """ + Test that connections.toml replaces cli_config.toml entirely. + """ + cli_config = """ + [connections.prod] + account = "cli-account" + user = "cli-user" + warehouse = "cli-warehouse" + database = "cli-database" + schema = "cli-schema" + """ + + connections_toml = """ + [connections.prod] + account = "conn-account" + database = "conn-database" + # Note: user, warehouse, schema are NOT included + """ + + with config_ng_setup(cli_config=cli_config, connections_toml=connections_toml): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("prod") + + # Values from connections.toml + assert conn["account"] == "conn-account" + assert conn["database"] == "conn-database" + + # user, warehouse, schema from cli_config NOT inherited + assert "user" not in conn + assert "warehouse" not in conn + assert "schema" not in conn + + +def test_file_replacement_three_levels(config_ng_setup): + """ + Test replacement across three file sources: snowsql -> cli_config -> connections.toml + """ + snowsql_config = """ + [connections.dev] + accountname = snowsql-account + user = snowsql-user + warehouse = snowsql-warehouse + database = snowsql-database + schema = snowsql-schema + """ + + cli_config = """ + [connections.dev] + account = "cli-account" + user = "cli-user" + warehouse = "cli-warehouse" + # database and schema not included + """ + + connections_toml = """ + [connections.dev] + account = "conn-account" + # Only account specified + """ + + with config_ng_setup( + snowsql_config=snowsql_config, + cli_config=cli_config, + connections_toml=connections_toml, + ): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("dev") + + # Only value from connections.toml (highest FILE source) + assert conn["account"] == "conn-account" + + # All other fields NOT inherited from earlier FILE sources + assert "user" not in conn + assert "warehouse" not in conn + assert "database" not in conn + assert "schema" not in conn + + +def test_overlay_adds_fields_without_replacing_connection(config_ng_setup): + """ + Test that OVERLAY sources (env vars) add/override individual fields + without replacing the entire connection from FILE sources. + """ + cli_config = """ + [connections.test] + account = "cli-account" + database = "cli-database" + schema = "cli-schema" + """ + + env_vars = { + "SNOWFLAKE_CONNECTIONS_TEST_USER": "env-user", + "SNOWFLAKE_CONNECTIONS_TEST_WAREHOUSE": "env-warehouse", + } + + with config_ng_setup(cli_config=cli_config, env_vars=env_vars): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + # Values from FILE source (cli_config) + assert conn["account"] == "cli-account" + assert conn["database"] == "cli-database" + assert conn["schema"] == "cli-schema" + + # Values from OVERLAY source (env vars) - added without replacing + assert conn["user"] == "env-user" + assert conn["warehouse"] == "env-warehouse" + + +def test_overlay_overrides_file_field_without_replacing_connection(config_ng_setup): + """ + Test that OVERLAY sources can override individual fields from FILE sources + without replacing the entire connection. + """ + connections_toml = """ + [connections.prod] + account = "file-account" + user = "file-user" + warehouse = "file-warehouse" + database = "file-database" + """ + + env_vars = { + "SNOWFLAKE_CONNECTIONS_PROD_USER": "env-user", + "SNOWFLAKE_CONNECTIONS_PROD_WAREHOUSE": "env-warehouse", + } + + with config_ng_setup(connections_toml=connections_toml, env_vars=env_vars): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("prod") + + # account and database from FILE source (not overridden) + assert conn["account"] == "file-account" + assert conn["database"] == "file-database" + + # user and warehouse overridden by OVERLAY source + assert conn["user"] == "env-user" + assert conn["warehouse"] == "env-warehouse" + + +def test_snowsql_env_overlay_on_replaced_connection(config_ng_setup): + """ + Test that SNOWSQL_* env vars (OVERLAY) overlay on replaced connections. + """ + snowsql_config = """ + [connections.test] + accountname = snowsql-account + user = snowsql-user + warehouse = snowsql-warehouse + database = snowsql-database + """ + + connections_toml = """ + [connections.test] + account = "conn-account" + user = "conn-user" + # warehouse and database not included (connection replaced) + """ + + env_vars = { + "SNOWSQL_WAREHOUSE": "env-warehouse", + } + + with config_ng_setup( + snowsql_config=snowsql_config, + connections_toml=connections_toml, + env_vars=env_vars, + ): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + # Values from connections.toml (FILE source) + assert conn["account"] == "conn-account" + assert conn["user"] == "conn-user" + + # database from snowsql NOT inherited (connection replaced by connections.toml) + assert "database" not in conn + + # warehouse from SNOWSQL_* env (OVERLAY source) + assert conn["warehouse"] == "env-warehouse" + + +def test_cli_env_overlay_on_file_connection(config_ng_setup): + """ + Test that SNOWFLAKE_* env vars (OVERLAY) add fields to file connections. + """ + cli_config = """ + [connections.dev] + account = "cli-account" + database = "cli-database" + """ + + env_vars = { + "SNOWFLAKE_USER": "global-env-user", + "SNOWFLAKE_WAREHOUSE": "global-env-warehouse", + } + + with config_ng_setup(cli_config=cli_config, env_vars=env_vars): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("dev") + + # Values from FILE source + assert conn["account"] == "cli-account" + assert conn["database"] == "cli-database" + + # Values from global OVERLAY source (apply to all connections) + assert conn["user"] == "global-env-user" + assert conn["warehouse"] == "global-env-warehouse" + + +def test_multiple_connections_independent_replacement(config_ng_setup): + """ + Test that replacement is per-connection: different connections can be + replaced independently. + """ + snowsql_config = """ + [connections.conn1] + accountname = snowsql-account1 + user = snowsql-user1 + warehouse = snowsql-warehouse1 + + [connections.conn2] + accountname = snowsql-account2 + user = snowsql-user2 + database = snowsql-database2 + """ + + connections_toml = """ + [connections.conn1] + account = "conn-account1" + # Only account specified for conn1 - warehouse NOT inherited + + # conn2 NOT defined in connections.toml + """ + + with config_ng_setup( + snowsql_config=snowsql_config, connections_toml=connections_toml + ): + from snowflake.cli.api.config import get_connection_dict + + # conn1: replaced by connections.toml + conn1 = get_connection_dict("conn1") + assert conn1["account"] == "conn-account1" + assert "user" not in conn1 # Not inherited + assert "warehouse" not in conn1 # Not inherited + + # conn2: NOT replaced, uses snowsql_config values + conn2 = get_connection_dict("conn2") + assert conn2["account"] == "snowsql-account2" + assert conn2["user"] == "snowsql-user2" + assert conn2["database"] == "snowsql-database2" + + +def test_empty_connection_replacement(config_ng_setup): + """ + Test that an empty connection in a later FILE source still replaces + the entire connection from earlier sources. + + The empty connection is considered valid (it exists in config), but has + no parameters. Validation of required fields happens when actually using + the connection to connect to Snowflake. + """ + cli_config = """ + [connections.test] + account = "cli-account" + user = "cli-user" + warehouse = "cli-warehouse" + """ + + connections_toml = """ + [connections.test] + # Empty connection section + """ + + with config_ng_setup(cli_config=cli_config, connections_toml=connections_toml): + from snowflake.cli.api.config import get_connection_dict + + # Empty connection replacement: connection exists but has no parameters + conn = get_connection_dict("test") + assert conn == {} # Connection exists but is empty + + # No parameters from cli_config are inherited (connection was replaced) + assert "account" not in conn + assert "user" not in conn + assert "warehouse" not in conn + + +def test_overlay_precedence_connection_specific_over_global(config_ng_setup): + """ + Test OVERLAY precedence: global env (SNOWFLAKE_*) overrides connection-specific env. + Source order: connection_specific_env (#5) < cli_env (#6) + """ + cli_config = """ + [connections.test] + account = "cli-account" + """ + + env_vars = { + "SNOWFLAKE_USER": "global-user", + "SNOWFLAKE_CONNECTIONS_TEST_USER": "specific-user", + "SNOWFLAKE_WAREHOUSE": "global-warehouse", + } + + with config_ng_setup(cli_config=cli_config, env_vars=env_vars): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("test") + + assert conn["account"] == "cli-account" + # Global env (later OVERLAY source) overrides connection-specific env + assert conn["user"] == "global-user" + # Global env applies when no specific override exists + assert conn["warehouse"] == "global-warehouse" + + +def test_resolution_history_shows_replacement(config_ng_setup): + """ + Test that resolution history correctly shows file replacement behavior. + """ + snowsql_config = """ + [connections.test] + accountname = snowsql-account + user = snowsql-user + warehouse = snowsql-warehouse + """ + + cli_config = """ + [connections.test] + account = "cli-account" + user = "cli-user" + # warehouse not included + """ + + with config_ng_setup(snowsql_config=snowsql_config, cli_config=cli_config): + from snowflake.cli.api.config import get_connection_dict + from snowflake.cli.api.config_ng import get_resolver + + # Trigger resolution to populate history + conn = get_connection_dict("test") + assert conn["account"] == "cli-account" + + resolver = get_resolver() + assert resolver is not None + + # account: both sources provide, cli_config wins + account_history = resolver.get_resolution_history("connections.test.account") + assert account_history is not None + assert len(account_history.entries) == 2 + assert account_history.final_value == "cli-account" + + # user: both sources provide, cli_config wins + user_history = resolver.get_resolution_history("connections.test.user") + assert user_history is not None + assert len(user_history.entries) == 2 + assert user_history.final_value == "cli-user" + + # warehouse: only snowsql provides, but NOT in final config (connection replaced) + # Since the connection was replaced and warehouse wasn't in the new connection, + # it was discovered but never made it to the final resolution, so no history entry + warehouse_history = resolver.get_resolution_history( + "connections.test.warehouse" + ) + # Note: warehouse was discovered but since connection was replaced by cli_config, + # and cli_config didn't include warehouse, it's not in the final resolved values + # The history tracking only marks selected values, so warehouse has no marked entry + if warehouse_history: + # If history exists, it should show discovery but no selection + assert warehouse_history.selected_entry is None + + +def test_flat_keys_still_use_simple_override(config_ng_setup): + """ + Test that flat keys (non-connection) still use simple override behavior. + """ + snowsql_config = """ + [connections] + some_global = snowsql-global + """ + + cli_config = """ + [connections] + some_global = "cli-global" + """ + + # Note: This test is somewhat artificial as flat keys in connections sections + # are not commonly used, but verifies the logic handles them correctly + + with config_ng_setup(snowsql_config=snowsql_config, cli_config=cli_config): + from snowflake.cli.api.config_ng import get_resolver + + resolver = get_resolver() + if resolver: + resolved = resolver.resolve() + # This would need actual flat key support in sources to fully test + # For now, just verify no errors occur + assert resolved is not None diff --git a/tests/config_ng/test_constants.py b/tests/config_ng/test_constants.py new file mode 100644 index 0000000000..fdb0639eaf --- /dev/null +++ b/tests/config_ng/test_constants.py @@ -0,0 +1,80 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for configuration constants.""" + +from snowflake.cli.api.config_ng.constants import ( + FILE_SOURCE_NAMES, + INTERNAL_CLI_PARAMETERS, + SNOWFLAKE_HOME_ENV, + ConfigSection, +) + + +class TestConfigSection: + """Test ConfigSection enum.""" + + def test_enum_values(self): + """Test that enum has expected values.""" + assert ConfigSection.CONNECTIONS.value == "connections" + assert ConfigSection.VARIABLES.value == "variables" + assert ConfigSection.CLI.value == "cli" + assert ConfigSection.CLI_LOGS.value == "cli.logs" + assert ConfigSection.CLI_FEATURES.value == "cli.features" + + def test_enum_string_representation(self): + """Test that enum converts to string correctly.""" + assert str(ConfigSection.CONNECTIONS) == "connections" + assert str(ConfigSection.VARIABLES) == "variables" + assert str(ConfigSection.CLI) == "cli" + + def test_enum_is_string(self): + """Test that enum instances are strings.""" + assert isinstance(ConfigSection.CONNECTIONS, str) + assert isinstance(ConfigSection.VARIABLES, str) + + def test_enum_comparison(self): + """Test that enum can be compared with strings.""" + assert ConfigSection.CONNECTIONS == "connections" + assert ConfigSection.VARIABLES == "variables" + + +class TestConstants: + """Test other constants.""" + + def test_snowflake_home_env(self): + """Test SNOWFLAKE_HOME environment variable constant.""" + assert SNOWFLAKE_HOME_ENV == "SNOWFLAKE_HOME" + + def test_internal_cli_parameters(self): + """Test INTERNAL_CLI_PARAMETERS set.""" + expected_params = { + "enable_diag", + "temporary_connection", + "default_connection_name", + "connection_name", + "diag_log_path", + "diag_allowlist_path", + "mfa_passcode", + } + assert INTERNAL_CLI_PARAMETERS == expected_params + + def test_file_source_names(self): + """Test FILE_SOURCE_NAMES set.""" + expected_sources = { + "snowsql_config", + "cli_config_toml", + "connections_toml", + } + assert FILE_SOURCE_NAMES == expected_sources diff --git a/tests/config_ng/test_env_parsing.py b/tests/config_ng/test_env_parsing.py new file mode 100644 index 0000000000..f5bb085e53 --- /dev/null +++ b/tests/config_ng/test_env_parsing.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024 Snowflake Inc. + +"""Focused tests for environment variable parsing in config_ng.""" + + +def test_connection_specific_env_with_underscores(config_ng_setup): + """Connection names containing underscores should parse correctly. + + Also validate keys that themselves contain underscores (e.g., PRIVATE_KEY_PATH). + """ + + env_vars = { + # Connection-specific variables for connection name with underscores + "SNOWFLAKE_CONNECTIONS_DEV_US_EAST_ACCOUNT": "from-specific", + "SNOWFLAKE_CONNECTIONS_DEV_US_EAST_PRIVATE_KEY_PATH": "/tmp/example_key.pem", + # General env remains available for other flat keys + "SNOWFLAKE_SCHEMA": "general-schema", + } + + with config_ng_setup(env_vars=env_vars): + from snowflake.cli.api.config import get_connection_dict + + conn = get_connection_dict("dev_us_east") + + assert conn["account"] == "from-specific" + assert conn["private_key_path"] == "/tmp/example_key.pem" + # Ensure general env still contributes flat keys + assert conn["schema"] == "general-schema" diff --git a/tests/config_ng/test_merge_operations.py b/tests/config_ng/test_merge_operations.py new file mode 100644 index 0000000000..b3115e4f6a --- /dev/null +++ b/tests/config_ng/test_merge_operations.py @@ -0,0 +1,269 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for configuration merge operations.""" + +from snowflake.cli.api.config_ng.merge_operations import ( + create_default_connection_from_params, + extract_root_level_connection_params, + merge_params_into_connections, +) + + +class TestExtractRootLevelConnectionParams: + """Test extract_root_level_connection_params function.""" + + def test_extract_connection_params_from_mixed_config(self): + """Test extracting connection params from config with sections.""" + config = { + "account": "test_account", + "user": "test_user", + "connections": {"dev": {"database": "db"}}, + "cli": {"enable_diag": True}, + } + + conn_params, remaining = extract_root_level_connection_params(config) + + assert conn_params == {"account": "test_account", "user": "test_user"} + assert "connections" in remaining + assert "cli" in remaining + assert "account" not in remaining + assert "user" not in remaining + + def test_extract_with_no_connection_params(self): + """Test extraction when no root-level connection params exist.""" + config = { + "connections": {"dev": {"account": "acc"}}, + "variables": {"key": "value"}, + } + + conn_params, remaining = extract_root_level_connection_params(config) + + assert conn_params == {} + assert remaining == config + + def test_extract_with_only_connection_params(self): + """Test extraction when only connection params exist.""" + config = {"account": "acc", "user": "usr", "password": "pwd"} + + conn_params, remaining = extract_root_level_connection_params(config) + + assert conn_params == config + assert remaining == {} + + def test_extract_ignores_internal_cli_parameters(self): + """Test that internal CLI parameters are not treated as connection params.""" + config = { + "account": "acc", + "enable_diag": True, + "temporary_connection": True, + "default_connection_name": "dev", + } + + conn_params, remaining = extract_root_level_connection_params(config) + + assert conn_params == {"account": "acc"} + assert "enable_diag" in remaining + assert "temporary_connection" in remaining + assert "default_connection_name" in remaining + + def test_extract_recognizes_all_sections(self): + """Test that all ConfigSection values are recognized as sections.""" + config = { + "account": "acc", + "connections": {}, + "variables": {}, + "cli": {}, + } + + conn_params, remaining = extract_root_level_connection_params(config) + + assert conn_params == {"account": "acc"} + assert "connections" in remaining + assert "variables" in remaining + assert "cli" in remaining + + def test_extract_with_nested_section_names(self): + """Test extraction with nested section names like cli.logs.""" + config = { + "account": "acc", + "cli.logs": {"save_logs": True}, + "cli.features": {"feature1": True}, + } + + conn_params, remaining = extract_root_level_connection_params(config) + + assert conn_params == {"account": "acc"} + assert "cli.logs" in remaining + assert "cli.features" in remaining + + def test_extract_empty_config(self): + """Test extraction with empty config.""" + conn_params, remaining = extract_root_level_connection_params({}) + + assert conn_params == {} + assert remaining == {} + + def test_extract_preserves_nested_structures(self): + """Test that nested structures in sections are preserved.""" + config = { + "account": "acc", + "connections": {"dev": {"nested": {"deep": "value"}}}, + } + + conn_params, remaining = extract_root_level_connection_params(config) + + assert conn_params == {"account": "acc"} + assert remaining["connections"]["dev"]["nested"]["deep"] == "value" + + +class TestMergeParamsIntoConnections: + """Test merge_params_into_connections function.""" + + def test_merge_params_into_single_connection(self): + """Test merging params into a single connection.""" + connections = {"dev": {"account": "dev_acc", "user": "dev_user"}} + params = {"password": "new_pass"} + + result = merge_params_into_connections(connections, params) + + assert result["dev"]["account"] == "dev_acc" + assert result["dev"]["user"] == "dev_user" + assert result["dev"]["password"] == "new_pass" + + def test_merge_params_into_multiple_connections(self): + """Test merging params into multiple connections.""" + connections = { + "dev": {"account": "dev_acc"}, + "prod": {"account": "prod_acc"}, + } + params = {"user": "global_user", "password": "global_pass"} + + result = merge_params_into_connections(connections, params) + + assert result["dev"]["user"] == "global_user" + assert result["dev"]["password"] == "global_pass" + assert result["prod"]["user"] == "global_user" + assert result["prod"]["password"] == "global_pass" + + def test_merge_params_override_connection_values(self): + """Test that params override existing connection values.""" + connections = {"dev": {"account": "old_acc", "user": "old_user"}} + params = {"user": "new_user"} + + result = merge_params_into_connections(connections, params) + + assert result["dev"]["account"] == "old_acc" + assert result["dev"]["user"] == "new_user" + + def test_merge_empty_params(self): + """Test merging with empty params.""" + connections = {"dev": {"account": "acc"}} + params = {} + + result = merge_params_into_connections(connections, params) + + assert result == connections + + def test_merge_into_empty_connections(self): + """Test merging params into empty connections dict.""" + connections = {} + params = {"account": "acc"} + + result = merge_params_into_connections(connections, params) + + assert result == {} + + def test_merge_preserves_original_connections(self): + """Test that original connections dict is not modified.""" + connections = {"dev": {"account": "acc"}} + params = {"user": "usr"} + + result = merge_params_into_connections(connections, params) + + # Original should be unchanged + assert "user" not in connections["dev"] + # Result should have merged values + assert result["dev"]["user"] == "usr" + + def test_merge_nested_connection_values(self): + """Test merging with nested connection structures.""" + connections = {"dev": {"account": "acc", "nested": {"key": "value"}}} + params = {"nested": {"key": "new_value", "new_key": "new"}} + + result = merge_params_into_connections(connections, params) + + assert result["dev"]["nested"]["key"] == "new_value" + assert result["dev"]["nested"]["new_key"] == "new" + + def test_merge_handles_non_dict_connection(self): + """Test that non-dict connection values are preserved.""" + connections = {"dev": {"account": "acc"}, "invalid": "not_a_dict"} + params = {"user": "usr"} + + result = merge_params_into_connections(connections, params) + + assert result["dev"]["user"] == "usr" + assert result["invalid"] == "not_a_dict" + + +class TestCreateDefaultConnectionFromParams: + """Test create_default_connection_from_params function.""" + + def test_create_default_connection(self): + """Test creating default connection from params.""" + params = {"account": "test_acc", "user": "test_user"} + + result = create_default_connection_from_params(params) + + assert "default" in result + assert result["default"] == params + + def test_create_default_with_single_param(self): + """Test creating default connection with single param.""" + params = {"account": "test_acc"} + + result = create_default_connection_from_params(params) + + assert result == {"default": {"account": "test_acc"}} + + def test_create_default_with_empty_params(self): + """Test creating default connection with empty params.""" + result = create_default_connection_from_params({}) + + assert result == {} + + def test_create_default_preserves_original_params(self): + """Test that original params dict is not modified.""" + params = {"account": "acc"} + + result = create_default_connection_from_params(params) + + # Modify result + result["default"]["user"] = "usr" + + # Original should be unchanged + assert "user" not in params + + def test_create_default_with_complex_params(self): + """Test creating default connection with nested params.""" + params = { + "account": "acc", + "user": "usr", + "nested": {"key": "value"}, + } + + result = create_default_connection_from_params(params) + + assert result["default"]["nested"]["key"] == "value" diff --git a/tests/config_ng/test_parsers.py b/tests/config_ng/test_parsers.py new file mode 100644 index 0000000000..493280434f --- /dev/null +++ b/tests/config_ng/test_parsers.py @@ -0,0 +1,345 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for configuration parsers.""" + +import pytest +from snowflake.cli.api.config_ng.parsers import SnowSQLParser, TOMLParser + + +class TestSnowSQLParser: + """Test SnowSQL INI parser.""" + + def test_parse_single_connection(self): + """Test parsing a single connection.""" + content = """ +[connections.dev] +accountname = myaccount +username = myuser +password = mypass +""" + result = SnowSQLParser.parse(content) + + assert "connections" in result + assert "dev" in result["connections"] + assert result["connections"]["dev"] == { + "account": "myaccount", + "user": "myuser", + "password": "mypass", + } + + def test_parse_multiple_connections(self): + """Test parsing multiple connections.""" + content = """ +[connections.dev] +accountname = dev_account +username = dev_user + +[connections.prod] +accountname = prod_account +username = prod_user +""" + result = SnowSQLParser.parse(content) + + assert "connections" in result + assert len(result["connections"]) == 2 + assert result["connections"]["dev"]["account"] == "dev_account" + assert result["connections"]["prod"]["account"] == "prod_account" + + def test_parse_default_connection(self): + """Test parsing default connection (no name suffix).""" + content = """ +[connections] +accountname = default_account +username = default_user +""" + result = SnowSQLParser.parse(content) + + assert "connections" in result + assert "default" in result["connections"] + assert result["connections"]["default"]["account"] == "default_account" + + def test_key_mapping_accountname_to_account(self): + """Test that accountname is mapped to account.""" + content = """ +[connections.test] +accountname = test_account +""" + result = SnowSQLParser.parse(content) + + assert "account" in result["connections"]["test"] + assert result["connections"]["test"]["account"] == "test_account" + + def test_key_mapping_username_to_user(self): + """Test that username is mapped to user.""" + content = """ +[connections.test] +username = test_user +""" + result = SnowSQLParser.parse(content) + + assert "user" in result["connections"]["test"] + assert result["connections"]["test"]["user"] == "test_user" + + def test_key_mapping_dbname_to_database(self): + """Test that dbname is mapped to database.""" + content = """ +[connections.test] +dbname = test_db +""" + result = SnowSQLParser.parse(content) + + assert "database" in result["connections"]["test"] + assert result["connections"]["test"]["database"] == "test_db" + + def test_key_mapping_pwd_to_password(self): + """Test that pwd is mapped to password.""" + content = """ +[connections.test] +pwd = test_pass +""" + result = SnowSQLParser.parse(content) + + assert "password" in result["connections"]["test"] + assert result["connections"]["test"]["password"] == "test_pass" + + def test_key_mapping_multiple_keys(self): + """Test mapping multiple keys at once.""" + content = """ +[connections.test] +accountname = acc +username = usr +dbname = db +schemaname = sch +warehousename = wh +rolename = rol +""" + result = SnowSQLParser.parse(content) + + conn = result["connections"]["test"] + assert conn["account"] == "acc" + assert conn["user"] == "usr" + assert conn["database"] == "db" + assert conn["schema"] == "sch" + assert conn["warehouse"] == "wh" + assert conn["role"] == "rol" + + def test_parse_variables_section(self): + """Test parsing variables section.""" + content = """ +[variables] +stage = mystage +table = mytable +schema = myschema +""" + result = SnowSQLParser.parse(content) + + assert "variables" in result + assert result["variables"]["stage"] == "mystage" + assert result["variables"]["table"] == "mytable" + assert result["variables"]["schema"] == "myschema" + + def test_parse_connections_and_variables(self): + """Test parsing both connections and variables.""" + content = """ +[connections.dev] +accountname = dev_account + +[variables] +env = development +""" + result = SnowSQLParser.parse(content) + + assert "connections" in result + assert "variables" in result + assert result["connections"]["dev"]["account"] == "dev_account" + assert result["variables"]["env"] == "development" + + def test_parse_empty_content(self): + """Test parsing empty content.""" + result = SnowSQLParser.parse("") + + assert result == {} + + def test_parse_no_connections_section(self): + """Test parsing config without connections section.""" + content = """ +[variables] +key = value +""" + result = SnowSQLParser.parse(content) + + assert "connections" not in result + assert "variables" in result + + def test_parse_preserves_unmapped_keys(self): + """Test that unmapped keys are preserved as-is.""" + content = """ +[connections.test] +custom_key = custom_value +another_key = another_value +""" + result = SnowSQLParser.parse(content) + + conn = result["connections"]["test"] + assert conn["custom_key"] == "custom_value" + assert conn["another_key"] == "another_value" + + def test_parse_connection_with_special_characters(self): + """Test parsing connection with special characters in name.""" + content = """ +[connections.my-test_conn] +accountname = test +""" + result = SnowSQLParser.parse(content) + + assert "my-test_conn" in result["connections"] + + def test_parse_values_with_spaces(self): + """Test parsing values that contain spaces.""" + content = """ +[connections.test] +accountname = my account name +""" + result = SnowSQLParser.parse(content) + + assert result["connections"]["test"]["account"] == "my account name" + + +class TestTOMLParser: + """Test TOML parser.""" + + def test_parse_simple_toml(self): + """Test parsing simple TOML.""" + content = """ +[connections.test] +account = "test_account" +user = "test_user" +""" + result = TOMLParser.parse(content) + + assert "connections" in result + assert "test" in result["connections"] + assert result["connections"]["test"]["account"] == "test_account" + assert result["connections"]["test"]["user"] == "test_user" + + def test_parse_nested_toml(self): + """Test parsing nested TOML structure.""" + content = """ +[cli] +enable_diag = true + +[cli.logs] +save_logs = true + +[connections.prod] +account = "prod_account" +""" + result = TOMLParser.parse(content) + + assert "cli" in result + assert result["cli"]["enable_diag"] is True + assert result["cli"]["logs"]["save_logs"] is True + assert result["connections"]["prod"]["account"] == "prod_account" + + def test_parse_multiple_connections(self): + """Test parsing multiple connections in TOML.""" + content = """ +[connections.dev] +account = "dev_account" + +[connections.prod] +account = "prod_account" +""" + result = TOMLParser.parse(content) + + assert len(result["connections"]) == 2 + assert result["connections"]["dev"]["account"] == "dev_account" + assert result["connections"]["prod"]["account"] == "prod_account" + + def test_parse_variables(self): + """Test parsing variables section.""" + content = """ +[variables] +stage = "mystage" +table = "mytable" +""" + result = TOMLParser.parse(content) + + assert "variables" in result + assert result["variables"]["stage"] == "mystage" + assert result["variables"]["table"] == "mytable" + + def test_parse_empty_content(self): + """Test parsing empty TOML.""" + result = TOMLParser.parse("") + + assert result == {} + + def test_parse_toml_with_types(self): + """Test parsing TOML with different value types.""" + content = """ +[test] +string_val = "text" +int_val = 42 +float_val = 3.14 +bool_val = true +array_val = ["a", "b", "c"] +""" + result = TOMLParser.parse(content) + + assert result["test"]["string_val"] == "text" + assert result["test"]["int_val"] == 42 + assert result["test"]["float_val"] == 3.14 + assert result["test"]["bool_val"] is True + assert result["test"]["array_val"] == ["a", "b", "c"] + + def test_parse_malformed_toml_raises_error(self): + """Test that malformed TOML raises an error.""" + content = """ +[connections.test +account = "broken +""" + with pytest.raises(Exception): # tomllib raises TOMLDecodeError + TOMLParser.parse(content) + + def test_parse_toml_with_inline_table(self): + """Test parsing TOML with inline tables.""" + content = """ +[connections] +dev = { account = "dev_acc", user = "dev_user" } +""" + result = TOMLParser.parse(content) + + assert result["connections"]["dev"]["account"] == "dev_acc" + assert result["connections"]["dev"]["user"] == "dev_user" + + def test_parse_legacy_connections_format(self): + """Test parsing legacy connections.toml format (direct sections).""" + content = """ +[dev] +account = "dev_account" +user = "dev_user" + +[prod] +account = "prod_account" +user = "prod_user" +""" + result = TOMLParser.parse(content) + + # Note: TOMLParser just parses, doesn't normalize + assert "dev" in result + assert "prod" in result + assert result["dev"]["account"] == "dev_account" + assert result["prod"]["account"] == "prod_account" diff --git a/tests/config_ng/test_private_key_cleanup.py b/tests/config_ng/test_private_key_cleanup.py new file mode 100644 index 0000000000..99ad720596 --- /dev/null +++ b/tests/config_ng/test_private_key_cleanup.py @@ -0,0 +1,39 @@ +"""Tests for temporary private_key_raw file lifecycle and cleanup.""" + +from pathlib import Path + + +def test_private_key_raw_creates_and_cleans_temp_file(config_ng_setup, tmp_path): + priv_key_content = ( + """-----BEGIN PRIVATE KEY-----\nABC\n-----END PRIVATE KEY-----\n""" + ) + + cli_config = """ + [connections.test] + user = "cli-user" + """ + + env_vars = { + # Provide private_key_raw via env to trigger transformation + "SNOWFLAKE_CONNECTIONS_TEST_PRIVATE_KEY_RAW": priv_key_content, + } + + with config_ng_setup(cli_config=cli_config, env_vars=env_vars): + from snowflake.cli.api.config import get_connection_dict + from snowflake.cli.api.config_provider import ( + get_config_provider_singleton, + reset_config_provider, + ) + + provider = get_config_provider_singleton() + + conn = get_connection_dict("test") + temp_path = Path(conn["private_key_file"]) # should exist now + assert temp_path.exists() + assert temp_path.read_text() == priv_key_content + + # Reset provider triggers cleanup + reset_config_provider() + + # File should be gone after cleanup + assert not temp_path.exists() diff --git a/tests/config_ng/test_resolution_logger.py b/tests/config_ng/test_resolution_logger.py new file mode 100644 index 0000000000..1349d82681 --- /dev/null +++ b/tests/config_ng/test_resolution_logger.py @@ -0,0 +1,317 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: SLF001 +""" +Tests for configuration resolution logger module. + +This tests the internal resolution logging utilities that are independent +of CLI commands and can be used in any context. +""" + +import os +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest import mock + +from snowflake.cli.api.config_ng.resolution_logger import ( + check_value_source, + explain_configuration, + export_resolution_history, + format_summary_for_display, + get_resolution_summary, + get_resolver, + is_resolution_logging_available, + show_all_resolution_chains, + show_resolution_chain, +) +from snowflake.cli.api.config_provider import ( + ALTERNATIVE_CONFIG_ENV_VAR, + AlternativeConfigProvider, + reset_config_provider, +) + + +class TestResolutionLoggingAvailability: + """Tests for checking if resolution logging is available.""" + + def test_logging_not_available_with_legacy_provider(self): + """Test that logging is not available with legacy provider.""" + with mock.patch.dict(os.environ, {}, clear=False): + if ALTERNATIVE_CONFIG_ENV_VAR in os.environ: + del os.environ[ALTERNATIVE_CONFIG_ENV_VAR] + reset_config_provider() + + assert not is_resolution_logging_available() + + def test_logging_available_with_alternative_provider(self): + """Test that logging is available with alternative provider.""" + with mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: "true"}): + reset_config_provider() + + assert is_resolution_logging_available() + + def test_get_resolver_returns_none_with_legacy(self): + """Test that get_resolver returns None with legacy provider.""" + with mock.patch.dict(os.environ, {}, clear=False): + if ALTERNATIVE_CONFIG_ENV_VAR in os.environ: + del os.environ[ALTERNATIVE_CONFIG_ENV_VAR] + reset_config_provider() + + resolver = get_resolver() + assert resolver is None + + def test_get_resolver_returns_instance_with_alternative(self): + """Test that get_resolver returns resolver with alternative provider.""" + with mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: "true"}): + reset_config_provider() + + resolver = get_resolver() + assert resolver is not None + + +class TestShowResolutionChain: + """Tests for showing resolution chains.""" + + def test_show_chain_with_legacy_provider_shows_warning(self, capsys): + """Test that show_resolution_chain shows warning with legacy provider.""" + with mock.patch.dict(os.environ, {}, clear=False): + if ALTERNATIVE_CONFIG_ENV_VAR in os.environ: + del os.environ[ALTERNATIVE_CONFIG_ENV_VAR] + reset_config_provider() + + show_resolution_chain("test_key") + + captured = capsys.readouterr() + assert "not available" in captured.out.lower() + + def test_show_all_chains_with_legacy_provider_shows_warning(self, capsys): + """Test that show_all_resolution_chains shows warning with legacy provider.""" + with mock.patch.dict(os.environ, {}, clear=False): + if ALTERNATIVE_CONFIG_ENV_VAR in os.environ: + del os.environ[ALTERNATIVE_CONFIG_ENV_VAR] + reset_config_provider() + + show_all_resolution_chains() + + captured = capsys.readouterr() + assert "not available" in captured.out.lower() + + +class TestResolutionSummary: + """Tests for resolution summary functionality.""" + + def test_summary_returns_none_with_legacy_provider(self): + """Test that get_resolution_summary returns None with legacy provider.""" + with mock.patch.dict(os.environ, {}, clear=False): + if ALTERNATIVE_CONFIG_ENV_VAR in os.environ: + del os.environ[ALTERNATIVE_CONFIG_ENV_VAR] + reset_config_provider() + + summary = get_resolution_summary() + assert summary is None + + def test_format_summary_returns_none_with_legacy_provider(self): + """Test that format_summary_for_display returns None with legacy provider.""" + with mock.patch.dict(os.environ, {}, clear=False): + if ALTERNATIVE_CONFIG_ENV_VAR in os.environ: + del os.environ[ALTERNATIVE_CONFIG_ENV_VAR] + reset_config_provider() + + formatted = format_summary_for_display() + assert formatted is None + + def test_format_summary_with_alternative_provider(self): + """Test that format_summary_for_display returns formatted string.""" + with mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: "true"}): + reset_config_provider() + + # Mock the resolver's tracker to have some data + provider = AlternativeConfigProvider() + provider._ensure_initialized() + + with mock.patch.object( + provider._resolver.get_tracker(), "get_summary" + ) as mock_summary: + mock_summary.return_value = { + "total_keys_resolved": 5, + "keys_with_overrides": 2, + "keys_using_defaults": 1, + "source_usage": { + "cli_arguments": 2, + "snowflake_cli_env": 3, + }, + "source_wins": { + "cli_arguments": 2, + "snowflake_cli_env": 3, + }, + } + + # Need to mock the provider singleton + with mock.patch( + "snowflake.cli.api.config_ng.resolution_logger.get_config_provider_singleton", + return_value=provider, + ): + formatted = format_summary_for_display() + + assert formatted is not None + assert "Total keys resolved: 5" in formatted + assert "Keys with overrides: 2" in formatted + assert "Keys using defaults: 1" in formatted + assert "cli_arguments" in formatted + assert "snowflake_cli_env" in formatted + + +class TestCheckValueSource: + """Tests for checking value source.""" + + def test_check_value_source_returns_none_with_legacy(self): + """Test that check_value_source returns None with legacy provider.""" + with mock.patch.dict(os.environ, {}, clear=False): + if ALTERNATIVE_CONFIG_ENV_VAR in os.environ: + del os.environ[ALTERNATIVE_CONFIG_ENV_VAR] + reset_config_provider() + + source = check_value_source("test_key") + assert source is None + + +class TestExportResolutionHistory: + """Tests for exporting resolution history.""" + + def test_export_returns_false_with_legacy_provider(self, capsys): + """Test that export_resolution_history returns False with legacy provider.""" + with mock.patch.dict(os.environ, {}, clear=False): + if ALTERNATIVE_CONFIG_ENV_VAR in os.environ: + del os.environ[ALTERNATIVE_CONFIG_ENV_VAR] + reset_config_provider() + + with TemporaryDirectory() as tmpdir: + export_path = Path(tmpdir) / "test_export.json" + success = export_resolution_history(export_path) + + assert not success + captured = capsys.readouterr() + assert "not available" in captured.out.lower() + + def test_export_succeeds_with_alternative_provider(self): + """Test that export_resolution_history succeeds with alternative provider.""" + with mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: "true"}): + reset_config_provider() + + with TemporaryDirectory() as tmpdir: + export_path = Path(tmpdir) / "test_export.json" + success = export_resolution_history(export_path) + + assert success + assert export_path.exists() + + # Verify JSON is valid + import json + + with open(export_path) as f: + data = json.load(f) + + assert "summary" in data + assert "histories" in data + + +class TestExplainConfiguration: + """Tests for explain_configuration function.""" + + def test_explain_with_legacy_provider_shows_warning(self, capsys): + """Test that explain_configuration shows warning with legacy provider.""" + with mock.patch.dict(os.environ, {}, clear=False): + if ALTERNATIVE_CONFIG_ENV_VAR in os.environ: + del os.environ[ALTERNATIVE_CONFIG_ENV_VAR] + reset_config_provider() + + explain_configuration() + + captured = capsys.readouterr() + assert "not available" in captured.out.lower() + + def test_explain_specific_key(self, capsys): + """Test explaining a specific key.""" + with mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: "true"}): + reset_config_provider() + + # Just test that it doesn't crash + # Actual display testing would require more setup + explain_configuration(key="account") + + def test_explain_all_keys_verbose(self, capsys): + """Test explaining all keys in verbose mode.""" + with mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: "true"}): + reset_config_provider() + + # Just test that it doesn't crash + explain_configuration(verbose=True) + + +class TestIntegrationWithRealConfig: + """Integration tests with actual configuration.""" + + def test_resolution_with_env_vars(self): + """Test resolution logging with actual environment variables.""" + with mock.patch.dict( + os.environ, + { + ALTERNATIVE_CONFIG_ENV_VAR: "true", + "SNOWFLAKE_ACCOUNT": "test_account", + "SNOWFLAKE_USER": "test_user", + }, + ): + reset_config_provider() + + # Verify logging is available + assert is_resolution_logging_available() + + # Get resolver and check it has data + resolver = get_resolver() + assert resolver is not None + + # Force resolution + from snowflake.cli.api.config_provider import get_config_provider_singleton + + provider = get_config_provider_singleton() + provider.read_config() + + # Check that we can get summary + summary = get_resolution_summary() + assert summary is not None + assert summary["total_keys_resolved"] > 0 + + def test_check_value_source_for_env_var(self): + """Test checking the source of an environment variable.""" + with mock.patch.dict( + os.environ, + { + ALTERNATIVE_CONFIG_ENV_VAR: "true", + "SNOWFLAKE_ACCOUNT": "test_account", + }, + ): + reset_config_provider() + + # Force resolution + from snowflake.cli.api.config_provider import get_config_provider_singleton + + provider = get_config_provider_singleton() + provider.read_config() + + # Check source + source = check_value_source("account") + # Should be from environment (snowflake_cli_env or similar) + assert source is not None + assert "env" in source.lower() diff --git a/tests/config_ng/test_snowsql_variables.py b/tests/config_ng/test_snowsql_variables.py new file mode 100644 index 0000000000..15bd91621c --- /dev/null +++ b/tests/config_ng/test_snowsql_variables.py @@ -0,0 +1,390 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for SnowSQL [variables] section reading and merging with -D parameters. +""" + +import tempfile +from pathlib import Path +from unittest import mock + +from snowflake.cli.api.config_ng import ( + ConfigurationResolver, + SnowSQLConfigFile, + SnowSQLSection, + get_merged_variables, +) + + +class TestSnowSQLVariablesSection: + """Tests for reading [variables] section from SnowSQL config files.""" + + def test_read_variables_section_from_snowsql_config(self): + """Test that [variables] section is correctly read from SnowSQL config.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_file = Path(temp_dir) / "config" + config_file.write_text( + """ +[connections] +accountname = test_account +username = test_user + +[variables] +var1=value1 +var2=value2 +example_variable=27 +""" + ) + + source = SnowSQLConfigFile(config_paths=[config_file]) + + discovered = source.discover() + + # Check nested structure + assert "variables" in discovered + assert "var1" in discovered["variables"] + assert "var2" in discovered["variables"] + assert "example_variable" in discovered["variables"] + + # Values are plain strings now (not ConfigValue objects) + assert discovered["variables"]["var1"] == "value1" + assert discovered["variables"]["var2"] == "value2" + assert discovered["variables"]["example_variable"] == "27" + + def test_variables_section_empty(self): + """Test that empty [variables] section doesn't cause errors.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_file = Path(temp_dir) / "config" + config_file.write_text( + """ +[connections] +accountname = test_account + +[variables] +""" + ) + + source = SnowSQLConfigFile(config_paths=[config_file]) + + discovered = source.discover() + + # Should have connections but no variables (or empty variables dict) + assert "connections" in discovered + assert not discovered.get("variables", {}) + + def test_no_variables_section(self): + """Test that config without [variables] section works correctly.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_file = Path(temp_dir) / "config" + config_file.write_text( + """ +[connections] +accountname = test_account +username = test_user +""" + ) + + source = SnowSQLConfigFile(config_paths=[config_file]) + + discovered = source.discover() + + # Should have connections but no variables key + assert "connections" in discovered + assert "variables" not in discovered + + def test_variables_merged_from_multiple_files(self): + """Test that variables from multiple SnowSQL config files are merged.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_file1 = Path(temp_dir) / "config1" + config_file1.write_text( + """ +[variables] +var1=value1 +var2=original_value2 +""" + ) + + config_file2 = Path(temp_dir) / "config2" + config_file2.write_text( + """ +[variables] +var2=overridden_value2 +var3=value3 +""" + ) + + source = SnowSQLConfigFile(config_paths=[config_file1, config_file2]) + + discovered = source.discover() + + # Check nested structure with merged values + assert discovered["variables"]["var1"] == "value1" + assert discovered["variables"]["var2"] == "overridden_value2" + assert discovered["variables"]["var3"] == "value3" + + def test_variables_with_special_characters(self): + """Test that variables with special characters in values are handled.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_file = Path(temp_dir) / "config" + config_file.write_text( + """ +[variables] +var_with_equals=key=value +var_with_spaces=value with spaces +var_with_quotes='quoted value' +""" + ) + + source = SnowSQLConfigFile(config_paths=[config_file]) + + discovered = source.discover() + + assert discovered["variables"]["var_with_equals"] == "key=value" + assert discovered["variables"]["var_with_spaces"] == "value with spaces" + assert discovered["variables"]["var_with_quotes"] == "'quoted value'" + + +class TestAlternativeConfigProviderVariables: + """Tests for getting variables section from AlternativeConfigProvider.""" + + def test_get_variables_section(self): + """Test get_section('variables') returns nested dict.""" + from snowflake.cli.api.config_provider import AlternativeConfigProvider + + provider = AlternativeConfigProvider() + + # Mock with nested structure + mock_cache = { + "variables": {"var1": "value1", "var2": "value2"}, + "connections": {"default": {"account": "test_account"}}, + } + + setattr(provider, "_initialized", True) + setattr(provider, "_config_cache", mock_cache) + + from snowflake.cli.api.cli_global_context import get_cli_context + + try: + setattr( + provider, + "_last_config_override", + get_cli_context().config_file_override, + ) + except Exception: + setattr(provider, "_last_config_override", None) + + result = provider.get_section("variables") + + # Should return nested dict under "variables" key + assert result == {"var1": "value1", "var2": "value2"} + + def test_get_variables_section_empty(self): + """Test get_section('variables') with no variables returns empty dict.""" + from snowflake.cli.api.config_provider import AlternativeConfigProvider + + provider = AlternativeConfigProvider() + + mock_cache = { + "connections": {"default": {"account": "test_account"}}, + } + + setattr(provider, "_initialized", True) + setattr(provider, "_config_cache", mock_cache) + + from snowflake.cli.api.cli_global_context import get_cli_context + + try: + setattr( + provider, + "_last_config_override", + get_cli_context().config_file_override, + ) + except Exception: + setattr(provider, "_last_config_override", None) + + result = provider.get_section("variables") + + assert result == {} + + +class TestGetMergedVariables: + """Tests for get_merged_variables() utility function.""" + + def test_get_merged_variables_no_cli_params(self): + """Test get_merged_variables with only SnowSQL variables.""" + with mock.patch( + "snowflake.cli.api.config_provider.get_config_provider_singleton" + ) as mock_provider: + mock_instance = mock.Mock() + mock_instance.get_section.return_value = { + "var1": "snowsql_value1", + "var2": "snowsql_value2", + } + mock_provider.return_value = mock_instance + + result = get_merged_variables(None) + + assert result == {"var1": "snowsql_value1", "var2": "snowsql_value2"} + mock_instance.get_section.assert_called_once_with("variables") + + def test_get_merged_variables_with_cli_params(self): + """Test get_merged_variables with both SnowSQL and CLI -D parameters.""" + with mock.patch( + "snowflake.cli.api.config_provider.get_config_provider_singleton" + ) as mock_provider: + mock_instance = mock.Mock() + mock_instance.get_section.return_value = { + "var1": "snowsql_value1", + "var2": "snowsql_value2", + } + mock_provider.return_value = mock_instance + + cli_vars = ["var2=cli_value2", "var3=cli_value3"] + result = get_merged_variables(cli_vars) + + # var1 from SnowSQL + assert result["var1"] == "snowsql_value1" + # var2 should be overridden by CLI + assert result["var2"] == "cli_value2" + # var3 from CLI + assert result["var3"] == "cli_value3" + + def test_get_merged_variables_cli_only(self): + """Test get_merged_variables with only CLI -D parameters.""" + with mock.patch( + "snowflake.cli.api.config_provider.get_config_provider_singleton" + ) as mock_provider: + mock_instance = mock.Mock() + mock_instance.get_section.return_value = {} + mock_provider.return_value = mock_instance + + cli_vars = ["var1=cli_value1", "var2=cli_value2"] + result = get_merged_variables(cli_vars) + + assert result == {"var1": "cli_value1", "var2": "cli_value2"} + + def test_get_merged_variables_precedence(self): + """Test that CLI -D parameters have higher precedence than SnowSQL variables.""" + with mock.patch( + "snowflake.cli.api.config_provider.get_config_provider_singleton" + ) as mock_provider: + mock_instance = mock.Mock() + mock_instance.get_section.return_value = { + "database": "snowsql_db", + "schema": "snowsql_schema", + "custom_var": "snowsql_value", + } + mock_provider.return_value = mock_instance + + cli_vars = ["database=cli_db", "custom_var=cli_value"] + result = get_merged_variables(cli_vars) + + # CLI should override SnowSQL + assert result["database"] == "cli_db" + assert result["custom_var"] == "cli_value" + # SnowSQL value should remain for non-overridden keys + assert result["schema"] == "snowsql_schema" + + def test_get_merged_variables_provider_error(self): + """Test get_merged_variables handles provider errors gracefully.""" + with mock.patch( + "snowflake.cli.api.config_provider.get_config_provider_singleton" + ) as mock_provider: + mock_instance = mock.Mock() + mock_instance.get_section.side_effect = Exception("Provider error") + mock_provider.return_value = mock_instance + + cli_vars = ["var1=cli_value1"] + result = get_merged_variables(cli_vars) + + # Should fall back to only CLI variables + assert result == {"var1": "cli_value1"} + + def test_get_merged_variables_empty(self): + """Test get_merged_variables with no variables at all.""" + with mock.patch( + "snowflake.cli.api.config_provider.get_config_provider_singleton" + ) as mock_provider: + mock_instance = mock.Mock() + mock_instance.get_section.return_value = {} + mock_provider.return_value = mock_instance + + result = get_merged_variables(None) + + assert result == {} + + +class TestConfigurationResolverVariables: + """Integration tests for variables in ConfigurationResolver.""" + + def test_resolver_with_variables(self): + """Test that resolver correctly processes variables from SnowSQL config.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_file = Path(temp_dir) / "config" + config_file.write_text( + """ +[connections] +accountname = test_account + +[variables] +var1=value1 +var2=value2 +""" + ) + + source = SnowSQLConfigFile(config_paths=[config_file]) + + resolver = ConfigurationResolver(sources=[source]) + config = resolver.resolve() + + # Check nested structure + assert "variables" in config + assert "var1" in config["variables"] + assert "var2" in config["variables"] + assert config["variables"]["var1"] == "value1" + assert config["variables"]["var2"] == "value2" + + +class TestSnowSQLSectionEnum: + """Tests for SnowSQLSection enum.""" + + def test_section_enum_values(self): + """Test that SnowSQLSection enum has correct values.""" + assert SnowSQLSection.CONNECTIONS.value == "connections" + assert SnowSQLSection.VARIABLES.value == "variables" + assert SnowSQLSection.OPTIONS.value == "options" + + def test_section_enum_in_snowsql_source(self): + """Test that SnowSQLConfigFile uses the enum.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_file = Path(temp_dir) / "config" + config_file.write_text( + """ +[connections] +accountname = test_account + +[variables] +var1=value1 +""" + ) + + source = SnowSQLConfigFile(config_paths=[config_file]) + + # Should discover both connections and variables + discovered = source.discover() + + # Check nested structure contains both sections + assert "connections" in discovered + assert "variables" in discovered diff --git a/tests/config_ng/test_sources.py b/tests/config_ng/test_sources.py new file mode 100644 index 0000000000..8ac9e4e2c5 --- /dev/null +++ b/tests/config_ng/test_sources.py @@ -0,0 +1,380 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for configuration sources with string-based testing.""" + +from snowflake.cli.api.config_ng.sources import ( + CliConfigFile, + ConnectionsConfigFile, + SnowSQLConfigFile, +) + + +class TestSnowSQLConfigFileFromString: + """Test SnowSQLConfigFile with string-based initialization.""" + + def test_from_string_single_connection(self): + """Test creating source from string with single connection.""" + content = """ +[connections.dev] +accountname = test_account +username = test_user +password = test_pass +""" + source = SnowSQLConfigFile.from_string(content) + result = source.discover() + + assert "connections" in result + assert "dev" in result["connections"] + assert result["connections"]["dev"]["account"] == "test_account" + assert result["connections"]["dev"]["user"] == "test_user" + assert result["connections"]["dev"]["password"] == "test_pass" + + def test_from_string_multiple_connections(self): + """Test creating source from string with multiple connections.""" + content = """ +[connections.dev] +accountname = dev_account + +[connections.prod] +accountname = prod_account +""" + source = SnowSQLConfigFile.from_string(content) + result = source.discover() + + assert len(result["connections"]) == 2 + assert result["connections"]["dev"]["account"] == "dev_account" + assert result["connections"]["prod"]["account"] == "prod_account" + + def test_from_string_with_variables(self): + """Test creating source from string with variables section.""" + content = """ +[connections.test] +accountname = test_account + +[variables] +stage = mystage +table = mytable +""" + source = SnowSQLConfigFile.from_string(content) + result = source.discover() + + assert "variables" in result + assert result["variables"]["stage"] == "mystage" + assert result["variables"]["table"] == "mytable" + + def test_from_string_key_mapping(self): + """Test that SnowSQL key mapping works with string source.""" + content = """ +[connections.test] +accountname = acc +username = usr +dbname = db +schemaname = sch +warehousename = wh +rolename = rol +pwd = pass +""" + source = SnowSQLConfigFile.from_string(content) + result = source.discover() + + conn = result["connections"]["test"] + assert conn["account"] == "acc" + assert conn["user"] == "usr" + assert conn["database"] == "db" + assert conn["schema"] == "sch" + assert conn["warehouse"] == "wh" + assert conn["role"] == "rol" + assert conn["password"] == "pass" + + def test_from_string_empty_content(self): + """Test creating source from empty string.""" + source = SnowSQLConfigFile.from_string("") + result = source.discover() + + assert result == {} + + def test_from_string_default_connection(self): + """Test creating source with default connection (no name).""" + content = """ +[connections] +accountname = default_account +""" + source = SnowSQLConfigFile.from_string(content) + result = source.discover() + + assert "default" in result["connections"] + assert result["connections"]["default"]["account"] == "default_account" + + +class TestCliConfigFileFromString: + """Test CliConfigFile with string-based initialization.""" + + def test_from_string_single_connection(self): + """Test creating CLI config source from string.""" + content = """ +[connections.dev] +account = "test_account" +user = "test_user" +""" + source = CliConfigFile.from_string(content) + result = source.discover() + + assert "connections" in result + assert "dev" in result["connections"] + assert result["connections"]["dev"]["account"] == "test_account" + assert result["connections"]["dev"]["user"] == "test_user" + + def test_from_string_multiple_connections(self): + """Test creating CLI config with multiple connections.""" + content = """ +[connections.dev] +account = "dev_acc" + +[connections.prod] +account = "prod_acc" +""" + source = CliConfigFile.from_string(content) + result = source.discover() + + assert len(result["connections"]) == 2 + assert result["connections"]["dev"]["account"] == "dev_acc" + assert result["connections"]["prod"]["account"] == "prod_acc" + + def test_from_string_with_cli_section(self): + """Test creating CLI config with cli section.""" + content = """ +[cli] +enable_diag = true + +[cli.logs] +save_logs = true + +[connections.test] +account = "test_account" +""" + source = CliConfigFile.from_string(content) + result = source.discover() + + assert "cli" in result + assert result["cli"]["enable_diag"] is True + assert result["cli"]["logs"]["save_logs"] is True + assert result["connections"]["test"]["account"] == "test_account" + + def test_from_string_with_variables(self): + """Test creating CLI config with variables.""" + content = """ +[variables] +stage = "mystage" +env = "dev" + +[connections.test] +account = "test_account" +""" + source = CliConfigFile.from_string(content) + result = source.discover() + + assert "variables" in result + assert result["variables"]["stage"] == "mystage" + assert result["variables"]["env"] == "dev" + + def test_from_string_empty_content(self): + """Test creating CLI config from empty string.""" + source = CliConfigFile.from_string("") + result = source.discover() + + assert result == {} + + def test_from_string_nested_structure(self): + """Test creating CLI config with deeply nested structure.""" + content = """ +[cli.features] +feature1 = true +feature2 = false + +[cli.logs] +level = "INFO" +path = "/var/log" +""" + source = CliConfigFile.from_string(content) + result = source.discover() + + assert result["cli"]["features"]["feature1"] is True + assert result["cli"]["logs"]["level"] == "INFO" + + +class TestConnectionsConfigFileFromString: + """Test ConnectionsConfigFile with string-based initialization.""" + + def test_from_string_nested_format(self): + """Test creating connections file with nested format.""" + content = """ +[connections.dev] +account = "dev_account" +user = "dev_user" + +[connections.prod] +account = "prod_account" +user = "prod_user" +""" + source = ConnectionsConfigFile.from_string(content) + result = source.discover() + + assert "connections" in result + assert len(result["connections"]) == 2 + assert result["connections"]["dev"]["account"] == "dev_account" + assert result["connections"]["prod"]["account"] == "prod_account" + + def test_from_string_legacy_format(self): + """Test creating connections file with legacy format (direct sections).""" + content = """ +[dev] +account = "dev_account" +user = "dev_user" + +[prod] +account = "prod_account" +user = "prod_user" +""" + source = ConnectionsConfigFile.from_string(content) + result = source.discover() + + # Legacy format should be normalized to nested format + assert "connections" in result + assert len(result["connections"]) == 2 + assert result["connections"]["dev"]["account"] == "dev_account" + assert result["connections"]["prod"]["account"] == "prod_account" + + def test_from_string_mixed_format(self): + """Test creating connections file with mixed legacy and nested format.""" + content = """ +[legacy_conn] +account = "legacy_account" + +[connections.new_conn] +account = "new_account" +""" + source = ConnectionsConfigFile.from_string(content) + result = source.discover() + + # Both should be normalized to nested format + assert "connections" in result + assert len(result["connections"]) == 2 + assert result["connections"]["legacy_conn"]["account"] == "legacy_account" + assert result["connections"]["new_conn"]["account"] == "new_account" + + def test_from_string_nested_takes_precedence(self): + """Test that nested format takes precedence over legacy format.""" + content = """ +[test] +account = "legacy_account" + +[connections.test] +account = "new_account" +""" + source = ConnectionsConfigFile.from_string(content) + result = source.discover() + + # Nested format should win + assert result["connections"]["test"]["account"] == "new_account" + + def test_from_string_empty_content(self): + """Test creating connections file from empty string.""" + source = ConnectionsConfigFile.from_string("") + result = source.discover() + + # Empty TOML should return empty dict (no connections) + assert result == {} + + def test_from_string_single_connection(self): + """Test creating connections file with single connection.""" + content = """ +[connections.default] +account = "test_account" +user = "test_user" +password = "test_pass" +""" + source = ConnectionsConfigFile.from_string(content) + result = source.discover() + + assert "default" in result["connections"] + assert result["connections"]["default"]["account"] == "test_account" + + def test_get_defined_connections(self): + """Test getting defined connection names.""" + content = """ +[connections.dev] +account = "dev_acc" + +[connections.prod] +account = "prod_acc" +""" + source = ConnectionsConfigFile.from_string(content) + defined_connections = source.get_defined_connections() + + assert defined_connections == {"dev", "prod"} + + def test_get_defined_connections_legacy_format(self): + """Test getting defined connections with legacy format.""" + content = """ +[dev] +account = "dev_acc" + +[prod] +account = "prod_acc" +""" + source = ConnectionsConfigFile.from_string(content) + defined_connections = source.get_defined_connections() + + assert defined_connections == {"dev", "prod"} + + +class TestSourceProperties: + """Test source properties and metadata.""" + + def test_snowsql_config_source_name(self): + """Test SnowSQLConfigFile source name.""" + source = SnowSQLConfigFile.from_string("") + assert source.source_name == "snowsql_config" + + def test_cli_config_source_name(self): + """Test CliConfigFile source name.""" + source = CliConfigFile.from_string("") + assert source.source_name == "cli_config_toml" + + def test_connections_config_source_name(self): + """Test ConnectionsConfigFile source name.""" + source = ConnectionsConfigFile.from_string("") + assert source.source_name == "connections_toml" + + def test_connections_file_marker(self): + """Test that ConnectionsConfigFile is marked as connections file.""" + source = ConnectionsConfigFile.from_string("") + assert source.is_connections_file is True + + def test_non_connections_file_marker(self): + """Test that other sources don't have is_connections_file property.""" + cli_source = CliConfigFile.from_string("") + snowsql_source = SnowSQLConfigFile.from_string("") + + # These should not have the is_connections_file property + # or it should be False (default) + assert ( + not hasattr(cli_source, "is_connections_file") + or not cli_source.is_connections_file + ) + assert ( + not hasattr(snowsql_source, "is_connections_file") + or not snowsql_source.is_connections_file + ) diff --git a/tests/config_ng/test_telemetry_integration.py b/tests/config_ng/test_telemetry_integration.py new file mode 100644 index 0000000000..129ad19c1a --- /dev/null +++ b/tests/config_ng/test_telemetry_integration.py @@ -0,0 +1,204 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for config_ng telemetry integration.""" + +from unittest.mock import MagicMock, patch + +from snowflake.cli.api.config_ng import ( + CliParameters, + ConfigurationResolver, + get_config_telemetry_payload, + record_config_source_usage, +) +from snowflake.cli.api.config_ng.sources import CliConfigFile + + +class TestRecordConfigSourceUsage: + """Tests for record_config_source_usage function.""" + + def test_records_winning_sources(self): + """Test that winning sources are recorded as counters.""" + # Create resolver with some sources + cli_config = CliConfigFile.from_string( + """ + [connections.test] + account = "test_account" + user = "test_user" + """ + ) + cli_params = CliParameters(cli_context={"password": "secret"}) + resolver = ConfigurationResolver(sources=[cli_config, cli_params]) + + # Resolve to populate history + resolver.resolve() + + # Mock CLI context + mock_context = MagicMock() + mock_metrics = MagicMock() + mock_context.metrics = mock_metrics + + with patch( + "snowflake.cli.api.cli_global_context.get_cli_context", + return_value=mock_context, + ): + record_config_source_usage(resolver) + + # Verify counters were set + assert mock_metrics.set_counter.called + # Should have calls for all sources + call_args = [call[0] for call in mock_metrics.set_counter.call_args_list] + counter_fields = [arg[0] for arg in call_args] + + # Verify some expected counter fields are present + assert any("config_source" in str(field) for field in counter_fields) + + def test_handles_no_cli_context_gracefully(self): + """Test that function doesn't fail if CLI context unavailable.""" + cli_config = CliConfigFile.from_string( + """ + [connections.test] + account = "test_account" + """ + ) + resolver = ConfigurationResolver(sources=[cli_config]) + resolver.resolve() + + with patch( + "snowflake.cli.api.cli_global_context.get_cli_context", + side_effect=Exception("No context"), + ): + # Should not raise + record_config_source_usage(resolver) + + def test_sets_counter_to_zero_for_unused_sources(self): + """Test that unused sources get counter value 0.""" + # Only use CLI config, not CLI params + cli_config = CliConfigFile.from_string( + """ + [connections.test] + account = "test_account" + """ + ) + resolver = ConfigurationResolver(sources=[cli_config]) + resolver.resolve() + + mock_context = MagicMock() + mock_metrics = MagicMock() + mock_context.metrics = mock_metrics + + with patch( + "snowflake.cli.api.cli_global_context.get_cli_context", + return_value=mock_context, + ): + record_config_source_usage(resolver) + + # Check that at least one source was set to 0 + call_args = mock_metrics.set_counter.call_args_list + values = [call[0][1] for call in call_args] + assert 0 in values + + +class TestGetConfigTelemetryPayload: + """Tests for get_config_telemetry_payload function.""" + + def test_returns_empty_dict_for_none_resolver(self): + """Test that None resolver returns empty dict.""" + result = get_config_telemetry_payload(None) + assert result == {} + + def test_returns_summary_data(self): + """Test that function returns summary data from resolver.""" + cli_config = CliConfigFile.from_string( + """ + [connections.test] + account = "test_account" + user = "test_user" + """ + ) + cli_params = CliParameters(cli_context={"password": "secret"}) + resolver = ConfigurationResolver(sources=[cli_config, cli_params]) + + # Resolve to populate history + resolver.resolve() + + result = get_config_telemetry_payload(resolver) + + # Verify expected keys are present + assert "config_sources_used" in result + assert "config_source_wins" in result + assert "config_total_keys_resolved" in result + assert "config_keys_with_overrides" in result + + # Verify data types + assert isinstance(result["config_sources_used"], list) + assert isinstance(result["config_source_wins"], dict) + assert isinstance(result["config_total_keys_resolved"], int) + assert isinstance(result["config_keys_with_overrides"], int) + + def test_handles_resolver_errors_gracefully(self): + """Test that function handles resolver errors gracefully.""" + mock_resolver = MagicMock() + mock_resolver.get_tracker.side_effect = Exception("Tracker error") + + result = get_config_telemetry_payload(mock_resolver) + assert result == {} + + def test_tracks_source_wins_correctly(self): + """Test that source wins are tracked correctly.""" + cli_config = CliConfigFile.from_string( + """ + [connections.test] + account = "test_account" + """ + ) + # CLI params should win for password + cli_params = CliParameters(cli_context={"password": "override"}) + resolver = ConfigurationResolver(sources=[cli_config, cli_params]) + + resolver.resolve() + + result = get_config_telemetry_payload(resolver) + + # Verify that cli_arguments won for password + source_wins = result["config_source_wins"] + assert "cli_arguments" in source_wins + assert source_wins["cli_arguments"] > 0 + + +class TestTelemetryIntegration: + """Integration tests for telemetry system.""" + + def test_telemetry_records_from_config_provider(self): + """Test that config provider records telemetry on initialization.""" + from snowflake.cli.api.config_provider import AlternativeConfigProvider + + mock_context = MagicMock() + mock_metrics = MagicMock() + mock_context.metrics = mock_metrics + mock_context.connection_context.present_values_as_dict.return_value = {} + mock_context.config_file_override = None + + def mock_getter(): + return mock_context + + with patch( + "snowflake.cli.api.cli_global_context.get_cli_context", + return_value=mock_context, + ): + provider = AlternativeConfigProvider(cli_context_getter=mock_getter) + provider._ensure_initialized() # noqa: SLF001 + + # Verify that telemetry was recorded + assert mock_metrics.set_counter.called diff --git a/tests/conftest.py b/tests/conftest.py index c36de67ae3..65073bf5af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -104,8 +104,23 @@ def os_agnostic_snapshot(snapshot): # In addition to its own CliContextManager, each test gets its own OpenConnectionCache # which is cleared after the test completes. def reset_global_context_and_setup_config_and_logging_levels( - request, test_snowcli_config + request, test_snowcli_config, monkeypatch ): + # Reset config provider singleton to prevent test interference + from snowflake.cli.api.config_provider import reset_config_provider + + reset_config_provider() + + # Clear SNOWFLAKE_CONNECTIONS_* env vars for test isolation with config_ng + # These may be set in CI/dev environments and interfere with tests + import os + + for key in list(os.environ.keys()): + if key.startswith("SNOWFLAKE_CONNECTIONS_") or key.startswith( + "SNOWSQL_CONNECTIONS_" + ): + monkeypatch.delenv(key, raising=False) + with fork_cli_context(): connection_cache = OpenConnectionCache() cli_context_manager = get_cli_context_manager() diff --git a/tests/helpers/test_show_config_sources.py b/tests/helpers/test_show_config_sources.py new file mode 100644 index 0000000000..a2e5484531 --- /dev/null +++ b/tests/helpers/test_show_config_sources.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest import mock + +from snowflake.cli.api.config_provider import ALTERNATIVE_CONFIG_ENV_VAR + +COMMAND = "show-config-sources" + + +class TestHiddenLogic: + """ + Test the logic that determines if the command should be hidden. + + Note: The 'hidden' parameter in Typer decorators is evaluated at module import time, + so we test the logic itself rather than the runtime visibility in help output. + """ + + def test_hidden_logic_with_truthy_values(self): + """Test that the hidden logic correctly identifies truthy values.""" + truthy_values = ["1", "true", "yes", "on", "TRUE", "Yes", "ON"] + for value in truthy_values: + # This is the logic used in the command decorator + is_hidden = value.lower() not in ("1", "true", "yes", "on") + assert ( + not is_hidden + ), f"Value '{value}' should make command visible (not hidden)" + + def test_hidden_logic_with_falsy_values(self): + """Test that the hidden logic correctly identifies falsy values.""" + falsy_values = ["", "0", "false", "no", "off", "random"] + for value in falsy_values: + # This is the logic used in the command decorator + is_hidden = value.lower() not in ("1", "true", "yes", "on") + assert is_hidden, f"Value '{value}' should make command hidden" + + +class TestCommandFunctionality: + """Test that the command functions correctly when called.""" + + @mock.patch.dict(os.environ, {}, clear=True) + def test_command_unavailable_without_env_var(self, runner): + """Command should indicate resolution logging is unavailable without env var.""" + result = runner.invoke(["helpers", COMMAND]) + assert result.exit_code == 0 + assert "Configuration resolution logging is not available" in result.output + assert ALTERNATIVE_CONFIG_ENV_VAR in result.output + + @mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: "1"}, clear=True) + @mock.patch("snowflake.cli.api.config_ng.is_resolution_logging_available") + def test_command_unavailable_message_when_logging_not_available( + self, mock_is_available, runner + ): + """Command should show unavailable message when resolution logging is not available.""" + mock_is_available.return_value = False + result = runner.invoke(["helpers", COMMAND]) + assert result.exit_code == 0 + assert "Configuration resolution logging is not available" in result.output + + @mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: "1"}, clear=True) + @mock.patch("snowflake.cli.api.config_ng.is_resolution_logging_available") + @mock.patch( + "snowflake.cli.api.config_ng.resolution_logger.get_configuration_explanation_results" + ) + def test_command_shows_summary_without_arguments( + self, mock_get_results, mock_is_available, runner + ): + """Command should show configuration summary when called without arguments.""" + from snowflake.cli.api.output.types import CollectionResult + + mock_is_available.return_value = True + mock_get_results.return_value = CollectionResult([]) + result = runner.invoke(["helpers", COMMAND]) + assert result.exit_code == 0 + mock_get_results.assert_called_once_with(key=None, verbose=False) + + @mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: "1"}, clear=True) + @mock.patch("snowflake.cli.api.config_ng.is_resolution_logging_available") + @mock.patch( + "snowflake.cli.api.config_ng.resolution_logger.get_configuration_explanation_results" + ) + def test_command_shows_specific_key( + self, mock_get_results, mock_is_available, runner + ): + """Command should show resolution for specific key when provided.""" + from snowflake.cli.api.output.types import CollectionResult + + mock_is_available.return_value = True + mock_get_results.return_value = CollectionResult([]) + result = runner.invoke(["helpers", COMMAND, "account"]) + assert result.exit_code == 0 + mock_get_results.assert_called_once_with(key="account", verbose=False) + + @mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: "1"}, clear=True) + @mock.patch("snowflake.cli.api.config_ng.is_resolution_logging_available") + @mock.patch( + "snowflake.cli.api.config_ng.resolution_logger.get_configuration_explanation_results" + ) + def test_command_shows_details_with_flag( + self, mock_get_results, mock_is_available, runner + ): + """Command should show detailed resolution when --show-details flag is used.""" + from snowflake.cli.api.output.types import ( + CollectionResult, + MessageResult, + MultipleResults, + ) + + mock_is_available.return_value = True + mock_get_results.return_value = MultipleResults( + [CollectionResult([]), MessageResult("test history")] + ) + result = runner.invoke(["helpers", COMMAND, "--show-details"]) + assert result.exit_code == 0 + mock_get_results.assert_called_once_with(key=None, verbose=True) + + @mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: "1"}, clear=True) + @mock.patch("snowflake.cli.api.config_ng.is_resolution_logging_available") + @mock.patch( + "snowflake.cli.api.config_ng.resolution_logger.get_configuration_explanation_results" + ) + def test_command_shows_details_with_short_flag( + self, mock_get_results, mock_is_available, runner + ): + """Command should show detailed resolution when -d flag is used.""" + from snowflake.cli.api.output.types import ( + CollectionResult, + MessageResult, + MultipleResults, + ) + + mock_is_available.return_value = True + mock_get_results.return_value = MultipleResults( + [CollectionResult([]), MessageResult("test history")] + ) + result = runner.invoke(["helpers", COMMAND, "-d"]) + assert result.exit_code == 0 + mock_get_results.assert_called_once_with(key=None, verbose=True) + + @mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: "1"}, clear=True) + @mock.patch("snowflake.cli.api.config_ng.is_resolution_logging_available") + @mock.patch( + "snowflake.cli.api.config_ng.resolution_logger.get_configuration_explanation_results" + ) + def test_command_shows_key_with_details( + self, mock_get_results, mock_is_available, runner + ): + """Command should show detailed resolution for specific key.""" + from snowflake.cli.api.output.types import ( + CollectionResult, + MessageResult, + MultipleResults, + ) + + mock_is_available.return_value = True + mock_get_results.return_value = MultipleResults( + [CollectionResult([]), MessageResult("test history")] + ) + result = runner.invoke(["helpers", COMMAND, "user", "--show-details"]) + assert result.exit_code == 0 + mock_get_results.assert_called_once_with(key="user", verbose=True) + + @mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: "1"}, clear=True) + @mock.patch("snowflake.cli.api.config_ng.is_resolution_logging_available") + @mock.patch("snowflake.cli.api.config_ng.export_resolution_history") + def test_command_exports_to_file_success( + self, mock_export, mock_is_available, runner, tmp_path + ): + """Command should export resolution history to file when --export is used.""" + mock_is_available.return_value = True + mock_export.return_value = True + export_file = tmp_path / "config_debug.json" + + result = runner.invoke(["helpers", COMMAND, "--export", str(export_file)]) + assert result.exit_code == 0 + mock_export.assert_called_once_with(export_file) + assert "Resolution history exported to:" in result.output + assert str(export_file) in result.output + + @mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: "1"}, clear=True) + @mock.patch("snowflake.cli.api.config_ng.is_resolution_logging_available") + @mock.patch("snowflake.cli.api.config_ng.export_resolution_history") + def test_command_exports_to_file_with_short_flag( + self, mock_export, mock_is_available, runner, tmp_path + ): + """Command should export resolution history to file when -e is used.""" + mock_is_available.return_value = True + mock_export.return_value = True + export_file = tmp_path / "debug.json" + + result = runner.invoke(["helpers", COMMAND, "-e", str(export_file)]) + assert result.exit_code == 0 + mock_export.assert_called_once_with(export_file) + + @mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: "1"}, clear=True) + @mock.patch("snowflake.cli.api.config_ng.is_resolution_logging_available") + @mock.patch("snowflake.cli.api.config_ng.export_resolution_history") + def test_command_export_failure( + self, mock_export, mock_is_available, runner, tmp_path + ): + """Command should show error message when export fails.""" + mock_is_available.return_value = True + mock_export.return_value = False + export_file = tmp_path / "config_debug.json" + + result = runner.invoke(["helpers", COMMAND, "--export", str(export_file)]) + assert result.exit_code == 0 + assert "Failed to export resolution history" in result.output + + +class TestCommandHelp: + """Test the command help output.""" + + @mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: "1"}, clear=True) + def test_command_help_message(self, runner): + """Command help should display correctly.""" + result = runner.invoke(["helpers", COMMAND, "--help"]) + assert result.exit_code == 0 + assert "Show where configuration values come from" in result.output + assert "--show-details" in result.output + assert "--export" in result.output + assert "Examples:" in result.output + + @mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: "1"}, clear=True) + def test_command_help_shows_key_argument(self, runner): + """Command help should show the optional key argument.""" + result = runner.invoke(["helpers", COMMAND, "--help"]) + assert result.exit_code == 0 + assert "KEY" in result.output or "key" in result.output.lower() + assert "account" in result.output or "user" in result.output diff --git a/tests/test_config.py b/tests/test_config.py index b8747cfbb3..9e3c2ba165 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -394,6 +394,8 @@ def test_connections_toml_override_config_toml( ) config_init(test_snowcli_config) + # Both legacy and config_ng: Only connections from connections.toml are present + # connections.toml REPLACES config.toml connections (not merge) assert get_default_connection_dict() == {"database": "overridden_database"} assert config_manager["connections"] == { "default": {"database": "overridden_database"} diff --git a/tests/test_config_provider_integration.py b/tests/test_config_provider_integration.py new file mode 100644 index 0000000000..99cbfdd288 --- /dev/null +++ b/tests/test_config_provider_integration.py @@ -0,0 +1,496 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: SLF001 +""" +Integration tests for ConfigProvider Phase 7: Provider Integration. + +Tests the AlternativeConfigProvider implementation and its compatibility +with LegacyConfigProvider. + +Note: This file accesses private members for testing purposes, which is expected +in test code to verify internal state and behavior. +""" + +import os +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest import mock + +import pytest +from snowflake.cli.api.config_provider import ( + ALTERNATIVE_CONFIG_ENV_VAR, + AlternativeConfigProvider, + LegacyConfigProvider, + get_config_provider, + get_config_provider_singleton, + reset_config_provider, +) + + +class TestProviderSelection: + """Tests for provider selection via environment variable.""" + + def test_default_provider_is_legacy(self): + """Test that LegacyConfigProvider is used by default.""" + with mock.patch.dict(os.environ, {}, clear=False): + if ALTERNATIVE_CONFIG_ENV_VAR in os.environ: + del os.environ[ALTERNATIVE_CONFIG_ENV_VAR] + + provider = get_config_provider() + assert isinstance(provider, LegacyConfigProvider) + + @pytest.mark.parametrize( + "env_value", + ["true", "1", "yes", "on", "TRUE", "True", "Yes", "YES", "ON"], + ) + def test_alternative_provider_enabled_with_various_values(self, env_value): + """Test enabling alternative provider with various truthy values.""" + with mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: env_value}): + provider = get_config_provider() + assert isinstance(provider, AlternativeConfigProvider) + + def test_singleton_pattern(self): + """Test that singleton returns same instance.""" + with mock.patch.dict(os.environ, {}): + reset_config_provider() + + provider1 = get_config_provider_singleton() + provider2 = get_config_provider_singleton() + + assert provider1 is provider2 + + def test_reset_config_provider(self): + """Test that reset_config_provider creates new instance.""" + with mock.patch.dict(os.environ, {}): + reset_config_provider() + + provider1 = get_config_provider_singleton() + reset_config_provider() + provider2 = get_config_provider_singleton() + + assert provider1 is not provider2 + + +class TestAlternativeConfigProviderInitialization: + """Tests for AlternativeConfigProvider initialization.""" + + def test_lazy_initialization(self): + """Test that provider initializes lazily on first use.""" + provider = AlternativeConfigProvider() + assert provider._resolver is None + assert not provider._initialized + + # Accessing any method should trigger initialization + provider._ensure_initialized() + assert provider._resolver is not None + assert provider._initialized + + def test_reinitialization_clears_cache(self): + """Test that re-initialization clears cache.""" + provider = AlternativeConfigProvider() + provider._config_cache = {"old": "data"} + provider._initialized = True + + provider.read_config() + + # Cache should be cleared during re-init + assert provider._config_cache != {"old": "data"} + + +class TestAlternativeConfigProviderBasicOperations: + """Tests for basic config provider operations.""" + + def test_section_exists_root(self): + """Test section_exists for root.""" + provider = AlternativeConfigProvider() + + with mock.patch.object(provider, "_resolver") as mock_resolver: + mock_resolver.resolve.return_value = {"key": "value"} + provider._initialized = True + + assert provider.section_exists() + + def test_section_exists_with_prefix(self): + """Test section_exists for specific section.""" + provider = AlternativeConfigProvider() + + with mock.patch.object(provider, "_resolver") as mock_resolver: + mock_resolver.resolve.return_value = { + "connections": { + "default": {"account": "test_account", "user": "test_user"} + } + } + provider._initialized = True + provider._config_cache = mock_resolver.resolve.return_value + + assert provider.section_exists("connections") + assert provider.section_exists("connections", "default") + assert not provider.section_exists("nonexistent") + + def test_get_value_simple(self): + """Test get_value for simple key.""" + provider = AlternativeConfigProvider() + + with mock.patch.object(provider, "_resolver") as mock_resolver: + mock_resolver.resolve.return_value = {"account": "test_account"} + provider._initialized = True + provider._config_cache = mock_resolver.resolve.return_value + # Prevent re-initialization due to config_file_override check + from snowflake.cli.api.cli_global_context import get_cli_context + + try: + provider._last_config_override = get_cli_context().config_file_override + except Exception: + provider._last_config_override = None + + value = provider.get_value(key="account") + assert value == "test_account" + + def test_get_value_with_path(self): + """Test get_value with path.""" + provider = AlternativeConfigProvider() + + with mock.patch.object(provider, "_resolver") as mock_resolver: + mock_resolver.resolve.return_value = { + "connections": {"default": {"account": "test_account"}} + } + provider._initialized = True + provider._config_cache = mock_resolver.resolve.return_value + # Prevent re-initialization due to config_file_override check + from snowflake.cli.api.cli_global_context import get_cli_context + + try: + provider._last_config_override = get_cli_context().config_file_override + except Exception: + provider._last_config_override = None + + value = provider.get_value("connections", "default", key="account") + assert value == "test_account" + + def test_get_value_with_default(self): + """Test get_value returns default when key not found.""" + provider = AlternativeConfigProvider() + + with mock.patch.object(provider, "_resolver") as mock_resolver: + mock_resolver.resolve.return_value = {} + provider._initialized = True + + value = provider.get_value(key="nonexistent", default="default_value") + assert value == "default_value" + + def test_get_section_root(self): + """Test get_section for root.""" + provider = AlternativeConfigProvider() + + with mock.patch.object(provider, "_resolver") as mock_resolver: + config_data = {"key1": "value1", "key2": "value2"} + mock_resolver.resolve.return_value = config_data + provider._initialized = True + provider._config_cache = config_data + # Prevent re-initialization due to config_file_override check + from snowflake.cli.api.cli_global_context import get_cli_context + + try: + provider._last_config_override = get_cli_context().config_file_override + except Exception: + provider._last_config_override = None + + section = provider.get_section() + assert section == config_data + + def test_get_section_connections(self): + """Test get_section for connections.""" + provider = AlternativeConfigProvider() + + with mock.patch.object(provider, "_resolver") as mock_resolver: + mock_resolver.resolve.return_value = { + "connections": { + "default": {"account": "test_account", "user": "test_user"}, + "prod": {"account": "prod_account"}, + } + } + provider._initialized = True + provider._config_cache = mock_resolver.resolve.return_value + # Prevent re-initialization due to config_file_override check + from snowflake.cli.api.cli_global_context import get_cli_context + + try: + provider._last_config_override = get_cli_context().config_file_override + except Exception: + provider._last_config_override = None + + section = provider.get_section("connections") + assert "default" in section + assert "prod" in section + assert section["default"]["account"] == "test_account" + assert section["prod"]["account"] == "prod_account" + + def test_get_section_specific_connection(self): + """Test get_section for specific connection.""" + provider = AlternativeConfigProvider() + + with mock.patch.object(provider, "_resolver") as mock_resolver: + mock_resolver.resolve.return_value = { + "connections": { + "default": {"account": "test_account", "user": "test_user"} + } + } + provider._initialized = True + provider._config_cache = mock_resolver.resolve.return_value + # Prevent re-initialization due to config_file_override check + from snowflake.cli.api.cli_global_context import get_cli_context + + try: + provider._last_config_override = get_cli_context().config_file_override + except Exception: + provider._last_config_override = None + + section = provider.get_section("connections", "default") + assert section == {"account": "test_account", "user": "test_user"} + + +class TestAlternativeConfigProviderConnectionOperations: + """Tests for connection-specific operations.""" + + def test_get_connection_dict(self): + """Test get_connection_dict retrieves connection config.""" + provider = AlternativeConfigProvider() + + with mock.patch.object(provider, "_resolver") as mock_resolver: + mock_resolver.resolve.return_value = { + "connections": { + "default": { + "account": "test_account", + "user": "test_user", + "password": "secret", + } + } + } + provider._initialized = True + provider._config_cache = mock_resolver.resolve.return_value + # Prevent re-initialization due to config_file_override check + from snowflake.cli.api.cli_global_context import get_cli_context + + try: + provider._last_config_override = get_cli_context().config_file_override + except Exception: + provider._last_config_override = None + + conn_dict = provider.get_connection_dict("default") + assert conn_dict == { + "account": "test_account", + "user": "test_user", + "password": "secret", + } + + def test_get_connection_dict_not_found(self): + """Test get_connection_dict raises error for missing connection.""" + provider = AlternativeConfigProvider() + + with mock.patch.object(provider, "_resolver") as mock_resolver: + mock_resolver.resolve.return_value = {} + provider._initialized = True + + with pytest.raises(Exception): # MissingConfigurationError + provider.get_connection_dict("nonexistent") + + def test_get_all_connections_dict(self): + """Test _get_all_connections_dict returns nested dict.""" + provider = AlternativeConfigProvider() + + with mock.patch.object(provider, "_resolver") as mock_resolver: + mock_resolver.resolve.return_value = { + "connections": { + "default": {"account": "test_account", "user": "test_user"}, + "prod": {"account": "prod_account", "user": "prod_user"}, + } + } + provider._initialized = True + provider._config_cache = mock_resolver.resolve.return_value + # Prevent re-initialization due to config_file_override check + from snowflake.cli.api.cli_global_context import get_cli_context + + try: + provider._last_config_override = get_cli_context().config_file_override + except Exception: + provider._last_config_override = None + + all_conns = provider._get_all_connections_dict() + assert "default" in all_conns + assert "prod" in all_conns + assert all_conns["default"] == { + "account": "test_account", + "user": "test_user", + } + assert all_conns["prod"] == { + "account": "prod_account", + "user": "prod_user", + } + + @mock.patch("snowflake.cli.api.config.ConnectionConfig") + def test_get_all_connections(self, mock_connection_config): + """Test get_all_connections returns ConnectionConfig objects.""" + provider = AlternativeConfigProvider() + + # Mock ConnectionConfig.from_dict + mock_config_instance = mock.Mock() + mock_connection_config.from_dict.return_value = mock_config_instance + + # Mock _get_file_based_connections to avoid resolver._sources access + with mock.patch.object( + provider, "_get_file_based_connections" + ) as mock_get_file_based: + mock_get_file_based.return_value = {"default": mock_config_instance} + + all_conns = provider.get_all_connections() + + assert "default" in all_conns + assert all_conns["default"] == mock_config_instance + mock_get_file_based.assert_called_once() + + +class TestAlternativeConfigProviderWriteOperations: + """Tests for write operations that delegate to legacy system.""" + + @mock.patch("snowflake.cli.api.config.set_config_value") + def test_set_value_delegates_to_legacy(self, mock_set_value): + """Test that set_value delegates to legacy system.""" + provider = AlternativeConfigProvider() + provider._initialized = True + + provider.set_value(["test", "path"], "value") + + mock_set_value.assert_called_once_with(["test", "path"], "value") + assert not provider._initialized # Should reset + assert not provider._config_cache # Should clear cache + + @mock.patch("snowflake.cli.api.config.unset_config_value") + def test_unset_value_delegates_to_legacy(self, mock_unset_value): + """Test that unset_value delegates to legacy system.""" + provider = AlternativeConfigProvider() + provider._initialized = True + + provider.unset_value(["test", "path"]) + + mock_unset_value.assert_called_once_with(["test", "path"]) + assert not provider._initialized # Should reset + assert not provider._config_cache # Should clear cache + + +class TestProviderIntegrationEndToEnd: + """End-to-end integration tests with real config files.""" + + def test_alternative_provider_with_toml_file(self): + """Test alternative provider reads from TOML file.""" + with TemporaryDirectory() as tmpdir: + # Create a test config file + config_file = Path(tmpdir) / "connections.toml" + config_file.write_text( + """ +[default] +account = "test_account" +user = "test_user" +password = "test_password" +""" + ) + + # Create provider and test + # Note: This requires mocking the config manager to use our temp file + # Full integration testing would be done in separate test suite + + def test_provider_switching_via_environment(self): + """Test switching between providers via environment variable.""" + # Test legacy provider (default) + reset_config_provider() + with mock.patch.dict(os.environ, {}, clear=False): + if ALTERNATIVE_CONFIG_ENV_VAR in os.environ: + del os.environ[ALTERNATIVE_CONFIG_ENV_VAR] + + provider = get_config_provider_singleton() + assert isinstance(provider, LegacyConfigProvider) + + # Test alternative provider (enabled) + reset_config_provider() + with mock.patch.dict(os.environ, {ALTERNATIVE_CONFIG_ENV_VAR: "true"}): + provider = get_config_provider_singleton() + assert isinstance(provider, AlternativeConfigProvider) + + +class TestAlternativeConfigProviderConnections: + """Tests for AlternativeConfigProvider connection filtering.""" + + def test_get_all_connections_excludes_env_by_default(self, monkeypatch): + """Test that get_all_connections excludes env-only connections by default.""" + monkeypatch.setenv(ALTERNATIVE_CONFIG_ENV_VAR, "1") + + # Set up environment variable for connection + monkeypatch.setenv("SNOWFLAKE_CONNECTIONS_ENVONLY_ACCOUNT", "test_account") + monkeypatch.setenv("SNOWFLAKE_CONNECTIONS_ENVONLY_USER", "test_user") + + reset_config_provider() + provider = get_config_provider_singleton() + + # Default: should not include env-only connection + connections = provider.get_all_connections(include_env_connections=False) + assert "envonly" not in connections + + # With flag: should include env-only connection + reset_config_provider() + all_connections = provider.get_all_connections(include_env_connections=True) + assert "envonly" in all_connections + assert all_connections["envonly"].account == "test_account" + assert all_connections["envonly"].user == "test_user" + + def test_get_all_connections_with_mixed_sources(self, monkeypatch): + """Test that file-based connections are included but env-only excluded by default.""" + monkeypatch.setenv(ALTERNATIVE_CONFIG_ENV_VAR, "1") + + # Set env variable for env-only connection + monkeypatch.setenv("SNOWFLAKE_CONNECTIONS_ENVCONN_ACCOUNT", "env_account") + + reset_config_provider() + provider = get_config_provider_singleton() + + # Without flag: should have file connections but not env-only connection + connections = provider.get_all_connections(include_env_connections=False) + # Test fixture connections should be present (from test.toml) + assert len(connections) > 0 + assert "envconn" not in connections + + # With flag: should have both file and env connections + reset_config_provider() + all_connections = provider.get_all_connections(include_env_connections=True) + assert "envconn" in all_connections + # Should have more connections when including env + assert len(all_connections) >= len(connections) + + def test_legacy_provider_ignores_include_env_flag(self, monkeypatch): + """Test that LegacyConfigProvider ignores the include_env_connections flag.""" + # Ensure legacy provider is used + monkeypatch.delenv(ALTERNATIVE_CONFIG_ENV_VAR, raising=False) + + reset_config_provider() + provider = get_config_provider_singleton() + + assert isinstance(provider, LegacyConfigProvider) + + # Both calls should return the same result (flag is ignored) + connections_default = provider.get_all_connections( + include_env_connections=False + ) + connections_all = provider.get_all_connections(include_env_connections=True) + + # Should be same connections (legacy doesn't filter) + assert set(connections_default.keys()) == set(connections_all.keys()) diff --git a/tests/test_connection.py b/tests/test_connection.py index 81cfaebd10..5bbf32d928 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -293,6 +293,7 @@ def test_fails_if_existing_connection(runner): @mock.patch("snowflake.cli._plugins.connection.commands.get_default_connection_name") +@mock.patch.dict(os.environ, {}, clear=True) def test_lists_connection_information(mock_get_default_conn_name, runner): mock_get_default_conn_name.return_value = "empty" result = runner.invoke(["connection", "list", "--format", "json"]) @@ -450,6 +451,52 @@ def test_connection_list_does_not_print_too_many_env_variables( ] +@mock.patch.dict( + os.environ, + { + "SNOWFLAKE_CLI_CONFIG_V2_ENABLED": "1", + "SNOWFLAKE_CONNECTIONS_INTEGRATION_ACCOUNT": "test_account", + "SNOWFLAKE_CONNECTIONS_INTEGRATION_USER": "test_user", + }, + clear=True, +) +@mock.patch("snowflake.cli._plugins.connection.commands.get_default_connection_name") +def test_connection_list_all_flag_includes_env_connections( + mock_get_default_conn_name, runner +): + """Test that --all flag shows environment-based connections in config_ng mode.""" + from snowflake.cli.api.config_provider import reset_config_provider + + mock_get_default_conn_name.return_value = "empty" + + # Reset config provider to pick up new environment + reset_config_provider() + + # Without --all: should not show env-only connections + result = runner.invoke(["connection", "list", "--format", "json"]) + assert result.exit_code == 0, result.output + connections = json.loads(result.output) + connection_names = {c["connection_name"] for c in connections} + assert "integration" not in connection_names + + # Reset for second call + reset_config_provider() + + # With --all: should show env-based connections + result_all = runner.invoke(["connection", "list", "--all", "--format", "json"]) + assert result_all.exit_code == 0, result_all.output + connections_all = json.loads(result_all.output) + connection_names_all = {c["connection_name"] for c in connections_all} + assert "integration" in connection_names_all + + # Verify integration connection has expected parameters + integration_conn = next( + c for c in connections_all if c["connection_name"] == "integration" + ) + assert integration_conn["parameters"]["account"] == "test_account" + assert integration_conn["parameters"]["user"] == "test_user" + + def test_second_connection_not_update_default_connection(runner, os_agnostic_snapshot): with NamedTemporaryFile("w+", suffix=".toml") as tmp_file: tmp_file.write( diff --git a/tests_common/__init__.py b/tests_common/__init__.py index b975752952..eb42887326 100644 --- a/tests_common/__init__.py +++ b/tests_common/__init__.py @@ -17,3 +17,5 @@ from tests_common.path_utils import * IS_WINDOWS = platform.system() == "Windows" + +__all__ = ["IS_WINDOWS", "ConfigModeSnapshotExtension", "config_snapshot"] diff --git a/tests_common/conftest.py b/tests_common/conftest.py index 1b889f02ed..734506a470 100644 --- a/tests_common/conftest.py +++ b/tests_common/conftest.py @@ -23,6 +23,7 @@ import pytest import yaml +from syrupy.extensions.amber import AmberSnapshotExtension from snowflake.cli._plugins.streamlit.streamlit_entity import StreamlitEntity from snowflake.cli._plugins.streamlit.streamlit_entity_model import StreamlitEntityModel @@ -144,3 +145,23 @@ def _update(snowflake_yml_path: Path, parameter_path: str, value=None): sys.version_info >= PYTHON_3_12, reason="requires python3.11 or lower", ) + + +class ConfigModeSnapshotExtension(AmberSnapshotExtension): + """Snapshot extension that includes config mode in snapshot file name.""" + + @classmethod + def _get_file_basename(cls, *, test_location, index): + """Generate snapshot filename with config mode suffix.""" + config_mode = ( + "config_ng" if os.getenv("SNOWFLAKE_CLI_CONFIG_V2_ENABLED") else "legacy" + ) + basename = super()._get_file_basename(test_location=test_location, index=index) + # Insert config mode before .ambr extension + return f"{basename}_{config_mode}" + + +@pytest.fixture() +def config_snapshot(snapshot): + """Config-mode-aware snapshot fixture for tests that differ between legacy and config_ng.""" + return snapshot.use_extension(ConfigModeSnapshotExtension) diff --git a/tests_e2e/__snapshots__/test_import_snowsql_connections.ambr b/tests_e2e/__snapshots__/test_import_snowsql_connections.ambr deleted file mode 100644 index bf4d0ea7a7..0000000000 --- a/tests_e2e/__snapshots__/test_import_snowsql_connections.ambr +++ /dev/null @@ -1,25 +0,0 @@ -# serializer version: 1 -# name: test_import_confirm_on_conflict_with_existing_cli_connection - '[{"connection_name": "example", "parameters": {"user": "u1", "schema": "public", "authenticator": "SNOWFLAKE_JWT"}, "is_default": false}]' -# --- -# name: test_import_confirm_on_conflict_with_existing_cli_connection.1 - '[{"connection_name": "example", "parameters": {"account": "accountname", "user": "username"}, "is_default": false}, {"connection_name": "snowsql1", "parameters": {"account": "a1", "user": "u1", "host": "h1_override", "database": "d1", "schema": "public", "warehouse": "w1", "role": "r1"}, "is_default": false}, {"connection_name": "snowsql2", "parameters": {"account": "a2", "user": "u2", "host": "h2", "port": 1234, "database": "d2", "schema": "public", "warehouse": "w2", "role": "r2"}, "is_default": false}, {"connection_name": "snowsql3", "parameters": {"account": "a3", "user": "u3", "password": "****", "host": "h3", "database": "d3", "schema": "public", "warehouse": "w3", "role": "r3"}, "is_default": false}, {"connection_name": "default", "parameters": {"account": "default_connection_account", "user": "default_connection_user", "host": "localhost", "database": "default_connection_database_override", "schema": "public", "warehouse": "default_connection_warehouse", "role": "accountadmin"}, "is_default": true}]' -# --- -# name: test_import_of_snowsql_connections - '[]' -# --- -# name: test_import_of_snowsql_connections.1 - '[{"connection_name": "snowsql1", "parameters": {"account": "a1", "user": "u1", "host": "h1_override", "database": "d1", "schema": "public", "warehouse": "w1", "role": "r1"}, "is_default": false}, {"connection_name": "snowsql2", "parameters": {"account": "a2", "user": "u2", "host": "h2", "port": 1234, "database": "d2", "schema": "public", "warehouse": "w2", "role": "r2"}, "is_default": false}, {"connection_name": "example", "parameters": {"account": "accountname", "user": "username"}, "is_default": false}, {"connection_name": "snowsql3", "parameters": {"account": "a3", "user": "u3", "password": "****", "host": "h3", "database": "d3", "schema": "public", "warehouse": "w3", "role": "r3"}, "is_default": false}, {"connection_name": "default", "parameters": {"account": "default_connection_account", "user": "default_connection_user", "host": "localhost", "database": "default_connection_database_override", "schema": "public", "warehouse": "default_connection_warehouse", "role": "accountadmin"}, "is_default": true}]' -# --- -# name: test_import_prompt_for_different_default_connection_name_on_conflict - '[]' -# --- -# name: test_import_prompt_for_different_default_connection_name_on_conflict.1 - '[{"connection_name": "snowsql1", "parameters": {"account": "a1", "user": "u1", "host": "h1_override", "database": "d1", "schema": "public", "warehouse": "w1", "role": "r1"}, "is_default": false}, {"connection_name": "snowsql2", "parameters": {"account": "a2", "user": "u2", "host": "h2", "port": 1234, "database": "d2", "schema": "public", "warehouse": "w2", "role": "r2"}, "is_default": true}, {"connection_name": "example", "parameters": {"account": "accountname", "user": "username"}, "is_default": false}, {"connection_name": "snowsql3", "parameters": {"account": "a3", "user": "u3", "password": "****", "host": "h3", "database": "d3", "schema": "public", "warehouse": "w3", "role": "r3"}, "is_default": false}, {"connection_name": "default", "parameters": {"account": "default_connection_account", "user": "default_connection_user", "host": "localhost", "database": "default_connection_database_override", "schema": "public", "warehouse": "default_connection_warehouse", "role": "accountadmin"}, "is_default": false}]' -# --- -# name: test_import_reject_on_conflict_with_existing_cli_connection - '[{"connection_name": "example", "parameters": {"user": "u1", "schema": "public", "authenticator": "SNOWFLAKE_JWT"}, "is_default": false}]' -# --- -# name: test_import_reject_on_conflict_with_existing_cli_connection.1 - '[{"connection_name": "example", "parameters": {"user": "u1", "schema": "public", "authenticator": "SNOWFLAKE_JWT"}, "is_default": false}, {"connection_name": "snowsql1", "parameters": {"account": "a1", "user": "u1", "host": "h1_override", "database": "d1", "schema": "public", "warehouse": "w1", "role": "r1"}, "is_default": false}, {"connection_name": "snowsql2", "parameters": {"account": "a2", "user": "u2", "host": "h2", "port": 1234, "database": "d2", "schema": "public", "warehouse": "w2", "role": "r2"}, "is_default": false}, {"connection_name": "snowsql3", "parameters": {"account": "a3", "user": "u3", "password": "****", "host": "h3", "database": "d3", "schema": "public", "warehouse": "w3", "role": "r3"}, "is_default": false}, {"connection_name": "default", "parameters": {"account": "default_connection_account", "user": "default_connection_user", "host": "localhost", "database": "default_connection_database_override", "schema": "public", "warehouse": "default_connection_warehouse", "role": "accountadmin"}, "is_default": true}]' -# --- diff --git a/tests_e2e/conftest.py b/tests_e2e/conftest.py index 8286f0276d..d0190e7fb9 100644 --- a/tests_e2e/conftest.py +++ b/tests_e2e/conftest.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import shutil import subprocess import sys @@ -88,9 +89,12 @@ def test_root_path(): def disable_colors_and_styles_in_output(monkeypatch): """ Colors and styles in output cause mismatches in asserts, - this environment variable turn off styling + this environment variable turn off styling. + Also set consistent terminal width to avoid snapshot mismatches. """ monkeypatch.setenv("TERM", "unknown") + width = 81 if IS_WINDOWS else 80 + monkeypatch.setenv("COLUMNS", str(width)) @pytest.fixture(scope="session") @@ -111,6 +115,21 @@ def isolate_default_config_location(monkeypatch, temporary_directory): monkeypatch.setenv("SNOWFLAKE_HOME", temporary_directory) +@pytest.fixture(autouse=True) +def isolate_environment_variables(monkeypatch): + """ + Clear Snowflake-specific environment variables that could interfere with e2e tests. + This ensures tests run in a clean environment and only use the config files they specify. + Exception: Keep INTEGRATION connection vars for e2e testing. + """ + # Clear all SNOWFLAKE_CONNECTIONS_* environment variables except INTEGRATION + for env_var in list(os.environ.keys()): + if env_var.startswith(("SNOWFLAKE_CONNECTIONS_", "SNOWSQL_")): + # Preserve all INTEGRATION connection environment variables + if not env_var.startswith("SNOWFLAKE_CONNECTIONS_INTEGRATION_"): + monkeypatch.delenv(env_var, raising=False) + + def _create_venv(tmp_dir: Path) -> None: subprocess_check_output(["python", "-m", "venv", tmp_dir]) diff --git a/tests_e2e/test_import_snowsql_connections.py b/tests_e2e/test_import_snowsql_connections.py index 49ebc30554..b254c41ec4 100644 --- a/tests_e2e/test_import_snowsql_connections.py +++ b/tests_e2e/test_import_snowsql_connections.py @@ -1,37 +1,94 @@ import json -from typing import Optional import pytest from tests_e2e.conftest import subprocess_check_output, subprocess_run -@pytest.fixture() -def _assert_json_output_matches_snapshot(snapshot): - def f(cmd, stdin: Optional[str] = None): - output = subprocess_check_output(cmd, stdin) - parsed_json = json.loads(output) - snapshot.assert_match(json.dumps(parsed_json)) - - return f - - -@pytest.mark.e2e -def test_import_of_snowsql_connections( - snowcli, test_root_path, empty_config_file, _assert_json_output_matches_snapshot -): - _assert_json_output_matches_snapshot( +def _get_connections_list_output(snowcli, config_file) -> str: + """Helper function to get connections list output as string.""" + return subprocess_check_output( [ snowcli, "--config-file", - empty_config_file, + config_file, "connection", "list", "--format", "json", - ], + ] ) + +def _parse_connections(output: str) -> list: + """Parse connection list JSON output.""" + return json.loads(output) + + +def _assert_connection_structure(connection: dict) -> None: + """Assert that a connection has the expected structure.""" + assert "connection_name" in connection + assert "parameters" in connection + assert "is_default" in connection + assert isinstance(connection["parameters"], dict) + assert isinstance(connection["is_default"], bool) + + +def _assert_connections_present(connections: list, expected_names: set) -> None: + """Assert that specific connections are present in the list.""" + actual_names = {conn["connection_name"] for conn in connections} + assert expected_names.issubset( + actual_names + ), f"Expected connections {expected_names} not found. Got: {actual_names}" + + +def _assert_connection_parameters( + connections: list, connection_name: str, expected_params: dict +) -> None: + """Assert that a specific connection has expected parameters.""" + conn = next( + (c for c in connections if c["connection_name"] == connection_name), None + ) + assert conn is not None, f"Connection '{connection_name}' not found" + + # Check each expected parameter + for key, value in expected_params.items(): + assert ( + key in conn["parameters"] + ), f"Parameter '{key}' not found in connection '{connection_name}'" + assert conn["parameters"][key] == value, ( + f"Parameter '{key}' mismatch in connection '{connection_name}': " + f"expected {value}, got {conn['parameters'][key]}" + ) + + +def _assert_default_connection(connections: list, expected_name: str) -> None: + """Assert which connection is marked as default.""" + default_connections = [c for c in connections if c["is_default"]] + assert ( + len(default_connections) == 1 + ), f"Expected exactly one default connection, found {len(default_connections)}" + assert default_connections[0]["connection_name"] == expected_name + + +@pytest.mark.e2e +def test_import_of_snowsql_connections(snowcli, test_root_path, empty_config_file): + """Test connection import. + + Verifies that connections are imported from SnowSQL config files and + appear in the connection list. Environment-based connections are not + shown by default (matching legacy behavior). + """ + # Initially should have empty or minimal connections list + initial_output = _get_connections_list_output(snowcli, empty_config_file) + initial_connections = _parse_connections(initial_output) + + # In isolated e2e tests, should start empty + assert ( + len(initial_connections) == 0 + ), f"Expected no file-based connections initially, found: {initial_connections}" + + # Import snowsql connections result = subprocess_run( [ snowcli, @@ -47,35 +104,63 @@ def test_import_of_snowsql_connections( ) assert result.returncode == 0 - _assert_json_output_matches_snapshot( - [ - snowcli, - "--config-file", - empty_config_file, - "connection", - "list", - "--format", - "json", - ] + # After import, should have multiple connections + final_output = _get_connections_list_output(snowcli, empty_config_file) + final_connections = _parse_connections(final_output) + + # Validate all connections have proper structure + for conn in final_connections: + _assert_connection_structure(conn) + + # Assert expected connections are present + expected_connections = {"snowsql1", "snowsql2", "example", "snowsql3", "default"} + _assert_connections_present(final_connections, expected_connections) + + # Validate default connection + _assert_default_connection(final_connections, "default") + + # Validate specific connection parameters (from snowsql config files) + _assert_connection_parameters( + final_connections, + "snowsql1", + { + "account": "a1", + "user": "u1", + "host": "h1_override", # overridden in overriding_config + "database": "d1", + "schema": "public", + "warehouse": "w1", + "role": "r1", + }, + ) + + _assert_connection_parameters( + final_connections, + "default", + { + "account": "default_connection_account", + "user": "default_connection_user", + "host": "localhost", + "database": "default_connection_database_override", # overridden + "schema": "public", + "warehouse": "default_connection_warehouse", + "role": "accountadmin", + }, ) @pytest.mark.e2e def test_import_prompt_for_different_default_connection_name_on_conflict( - snowcli, test_root_path, empty_config_file, _assert_json_output_matches_snapshot + snowcli, test_root_path, empty_config_file ): - _assert_json_output_matches_snapshot( - [ - snowcli, - "--config-file", - empty_config_file, - "connection", - "list", - "--format", - "json", - ], - ) + """Test importing with different default connection name.""" + # Initially should have empty or minimal connections list + initial_output = _get_connections_list_output(snowcli, empty_config_file) + initial_connections = _parse_connections(initial_output) + assert len(initial_connections) == 0 + + # Import with different default connection name result = subprocess_run( [ snowcli, @@ -94,16 +179,35 @@ def test_import_prompt_for_different_default_connection_name_on_conflict( ) assert result.returncode == 0 - _assert_json_output_matches_snapshot( - [ - snowcli, - "--config-file", - empty_config_file, - "connection", - "list", - "--format", - "json", - ] + # After import, snowsql2 should be the default + final_output = _get_connections_list_output(snowcli, empty_config_file) + final_connections = _parse_connections(final_output) + + # Validate all connections have proper structure + for conn in final_connections: + _assert_connection_structure(conn) + + # Assert expected connections are present + expected_connections = {"snowsql1", "snowsql2", "example", "snowsql3", "default"} + _assert_connections_present(final_connections, expected_connections) + + # Validate that snowsql2 is the default (not "default") + _assert_default_connection(final_connections, "snowsql2") + + # Validate snowsql2 parameters + _assert_connection_parameters( + final_connections, + "snowsql2", + { + "account": "a2", + "user": "u2", + "host": "h2", + "port": 1234, + "database": "d2", + "schema": "public", + "warehouse": "w2", + "role": "r2", + }, ) @@ -112,20 +216,18 @@ def test_import_confirm_on_conflict_with_existing_cli_connection( snowcli, test_root_path, example_connection_config_file, - _assert_json_output_matches_snapshot, ): - _assert_json_output_matches_snapshot( - [ - snowcli, - "--config-file", - example_connection_config_file, - "connection", - "list", - "--format", - "json", - ], + """Test import with confirmation on conflict.""" + # Initially should have example connection + initial_output = _get_connections_list_output( + snowcli, example_connection_config_file ) + initial_connections = _parse_connections(initial_output) + # Should have the example connection + _assert_connections_present(initial_connections, {"example"}) + + # Import with confirmation (y) - this will overwrite "example" connection result = subprocess_run( [ snowcli, @@ -142,16 +244,29 @@ def test_import_confirm_on_conflict_with_existing_cli_connection( ) assert result.returncode == 0 - _assert_json_output_matches_snapshot( - [ - snowcli, - "--config-file", - example_connection_config_file, - "connection", - "list", - "--format", - "json", - ], + # After import, example connection should be overwritten with snowsql data + final_output = _get_connections_list_output(snowcli, example_connection_config_file) + final_connections = _parse_connections(final_output) + + # Validate all connections have proper structure + for conn in final_connections: + _assert_connection_structure(conn) + + # Assert all expected connections are present (including overwritten example) + expected_connections = {"example", "snowsql1", "snowsql2", "snowsql3", "default"} + _assert_connections_present(final_connections, expected_connections) + + # Validate default connection + _assert_default_connection(final_connections, "default") + + # Validate that "example" was overwritten with snowsql config values + _assert_connection_parameters( + final_connections, + "example", + { + "account": "accountname", + "user": "username", + }, ) @@ -160,20 +275,24 @@ def test_import_reject_on_conflict_with_existing_cli_connection( snowcli, test_root_path, example_connection_config_file, - _assert_json_output_matches_snapshot, ): - _assert_json_output_matches_snapshot( - [ - snowcli, - "--config-file", - example_connection_config_file, - "connection", - "list", - "--format", - "json", - ], + """Test import with rejection on conflict.""" + # Initially should have example connection + initial_output = _get_connections_list_output( + snowcli, example_connection_config_file + ) + initial_connections = _parse_connections(initial_output) + + # Should have the example connection with original values + _assert_connections_present(initial_connections, {"example"}) + + # Get initial example connection parameters + initial_example = next( + c for c in initial_connections if c["connection_name"] == "example" ) + initial_example_params = initial_example["parameters"].copy() + # Import with rejection (n) - should NOT overwrite "example" connection result = subprocess_run( [ snowcli, @@ -190,21 +309,37 @@ def test_import_reject_on_conflict_with_existing_cli_connection( ) assert result.returncode == 0 - _assert_json_output_matches_snapshot( - [ - snowcli, - "--config-file", - example_connection_config_file, - "connection", - "list", - "--format", - "json", - ], + # After import, example connection should remain unchanged + # But other connections should still be imported + final_output = _get_connections_list_output(snowcli, example_connection_config_file) + final_connections = _parse_connections(final_output) + + # Validate all connections have proper structure + for conn in final_connections: + _assert_connection_structure(conn) + + # Assert all expected connections are present + expected_connections = {"example", "snowsql1", "snowsql2", "snowsql3", "default"} + _assert_connections_present(final_connections, expected_connections) + + # Validate default connection + _assert_default_connection(final_connections, "default") + + # Validate that "example" connection was NOT overwritten (kept original values) + final_example = next( + c for c in final_connections if c["connection_name"] == "example" ) + assert ( + final_example["parameters"] == initial_example_params + ), "Example connection should not have been overwritten after rejection" @pytest.mark.e2e def test_connection_imported_from_snowsql(snowcli, test_root_path, empty_config_file): + """Test that imported connection works.""" + # Always provide confirmation to avoid interactive abort. + stdin = "y\n" + result = subprocess_run( [ snowcli, @@ -215,9 +350,11 @@ def test_connection_imported_from_snowsql(snowcli, test_root_path, empty_config_ "--snowsql-config-file", test_root_path / "config" / "snowsql" / "integration_config", ], + stdin=stdin, ) assert result.returncode == 0 + # Test that the imported integration connection works result = subprocess_run( [ snowcli, diff --git a/tests_integration/conftest.py b/tests_integration/conftest.py index 362033205a..9eeb3788aa 100644 --- a/tests_integration/conftest.py +++ b/tests_integration/conftest.py @@ -54,6 +54,7 @@ "tests_integration.snowflake_connector", ] + TEST_DIR = Path(__file__).parent DEFAULT_TEST_CONFIG = "connection_configs.toml" WORLD_READABLE_CONFIG = "world_readable.toml" @@ -120,6 +121,12 @@ def invoke(self, *a, **kw): kw.update(catch_exceptions=False) kw = self._with_env_vars(kw) + # Reset config provider to ensure fresh config resolution + # This is critical for tests that set environment variables + from snowflake.cli.api.config_provider import reset_config_provider + + reset_config_provider() + # between every invocation, we need to reset the CLI context # and ensure no connections are cached going forward (to prevent # test cases from impacting each other / align with CLI usage) diff --git a/tests_integration/nativeapp/test_metrics.py b/tests_integration/nativeapp/test_metrics.py index fb794bf418..f86d15d893 100644 --- a/tests_integration/nativeapp/test_metrics.py +++ b/tests_integration/nativeapp/test_metrics.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from shlex import split -from typing import Dict, Callable +from typing import Callable, Dict, List from unittest import mock from snowflake.cli._app.telemetry import TelemetryEvent, CLITelemetryField @@ -111,13 +111,12 @@ def test_feature_counters_v1_post_deploy_set_and_package_scripts_available( mock_telemetry, TelemetryEvent.CMD_EXECUTION_RESULT.value ) - assert message[CLITelemetryField.COUNTERS.value] == { - CLICounterField.SNOWPARK_PROCESSOR: 0, - CLICounterField.TEMPLATES_PROCESSOR: 0, - CLICounterField.PDF_TEMPLATES: 0, - CLICounterField.POST_DEPLOY_SCRIPTS: 1, - CLICounterField.PACKAGE_SCRIPTS: 0, - } + counters = message[CLITelemetryField.COUNTERS.value] + assert counters[CLICounterField.SNOWPARK_PROCESSOR] == 0 + assert counters[CLICounterField.TEMPLATES_PROCESSOR] == 0 + assert counters[CLICounterField.PDF_TEMPLATES] == 0 + assert counters[CLICounterField.POST_DEPLOY_SCRIPTS] == 1 + assert counters[CLICounterField.PACKAGE_SCRIPTS] == 0 @pytest.mark.integration @@ -161,11 +160,10 @@ def test_feature_counters_v2_post_deploy_not_available_in_bundle( mock_telemetry, TelemetryEvent.CMD_EXECUTION_RESULT.value ) - assert message[CLITelemetryField.COUNTERS.value] == { - CLICounterField.SNOWPARK_PROCESSOR: 0, - CLICounterField.TEMPLATES_PROCESSOR: 0, - CLICounterField.PDF_TEMPLATES: 1, - } + counters = message[CLITelemetryField.COUNTERS.value] + assert counters[CLICounterField.SNOWPARK_PROCESSOR] == 0 + assert counters[CLICounterField.TEMPLATES_PROCESSOR] == 0 + assert counters[CLICounterField.PDF_TEMPLATES] == 1 @pytest.mark.integration @@ -207,15 +205,14 @@ def test_feature_counter_v2_templates_processor_set( mock_telemetry, TelemetryEvent.CMD_EXECUTION_RESULT.value ) - assert message[CLITelemetryField.COUNTERS.value] == { - CLICounterField.SNOWPARK_PROCESSOR: 0, - CLICounterField.TEMPLATES_PROCESSOR: 1, - CLICounterField.PDF_TEMPLATES: 0, - CLICounterField.POST_DEPLOY_SCRIPTS: 0, - CLICounterField.EVENT_SHARING: 0, - CLICounterField.EVENT_SHARING_ERROR: 0, - CLICounterField.EVENT_SHARING_WARNING: 0, - } + counters = message[CLITelemetryField.COUNTERS.value] + assert counters[CLICounterField.SNOWPARK_PROCESSOR] == 0 + assert counters[CLICounterField.TEMPLATES_PROCESSOR] == 1 + assert counters[CLICounterField.PDF_TEMPLATES] == 0 + assert counters[CLICounterField.POST_DEPLOY_SCRIPTS] == 0 + assert counters[CLICounterField.EVENT_SHARING] == 0 + assert counters[CLICounterField.EVENT_SHARING_ERROR] == 0 + assert counters[CLICounterField.EVENT_SHARING_WARNING] == 0 @pytest.mark.integration @@ -244,13 +241,12 @@ def test_feature_counter_v1_package_scripts_converted_to_post_deploy_and_both_se mock_telemetry, TelemetryEvent.CMD_EXECUTION_RESULT.value ) - assert message[CLITelemetryField.COUNTERS.value] == { - CLICounterField.SNOWPARK_PROCESSOR: 0, - CLICounterField.TEMPLATES_PROCESSOR: 0, - CLICounterField.PDF_TEMPLATES: 0, - CLICounterField.POST_DEPLOY_SCRIPTS: 1, - CLICounterField.PACKAGE_SCRIPTS: 1, - } + counters = message[CLITelemetryField.COUNTERS.value] + assert counters[CLICounterField.SNOWPARK_PROCESSOR] == 0 + assert counters[CLICounterField.TEMPLATES_PROCESSOR] == 0 + assert counters[CLICounterField.PDF_TEMPLATES] == 0 + assert counters[CLICounterField.POST_DEPLOY_SCRIPTS] == 1 + assert counters[CLICounterField.PACKAGE_SCRIPTS] == 1 @pytest.mark.integration @@ -289,12 +285,11 @@ def test_feature_counter_v2_post_deploy_set_and_package_scripts_not_available( mock_telemetry, TelemetryEvent.CMD_EXECUTION_RESULT.value ) - assert message[CLITelemetryField.COUNTERS.value] == { - CLICounterField.SNOWPARK_PROCESSOR: 0, - CLICounterField.TEMPLATES_PROCESSOR: 0, - CLICounterField.PDF_TEMPLATES: 1, - CLICounterField.POST_DEPLOY_SCRIPTS: 1, - } + counters = message[CLITelemetryField.COUNTERS.value] + assert counters[CLICounterField.SNOWPARK_PROCESSOR] == 0 + assert counters[CLICounterField.TEMPLATES_PROCESSOR] == 0 + assert counters[CLICounterField.PDF_TEMPLATES] == 1 + assert counters[CLICounterField.POST_DEPLOY_SCRIPTS] == 1 @pytest.mark.integration diff --git a/tests_integration/plugin/__snapshots__/test_broken_plugin_config_ng.ambr b/tests_integration/plugin/__snapshots__/test_broken_plugin_config_ng.ambr new file mode 100644 index 0000000000..87d2a1f15c --- /dev/null +++ b/tests_integration/plugin/__snapshots__/test_broken_plugin_config_ng.ambr @@ -0,0 +1,15 @@ +# serializer version: 1 +# name: test_broken_command_path_plugin + ''' + [ + { + "connection_name": "test", + "parameters": { + "account": "test" + }, + "is_default": false + } + ] + + ''' +# --- diff --git a/tests_integration/plugin/__snapshots__/test_broken_plugin_legacy.ambr b/tests_integration/plugin/__snapshots__/test_broken_plugin_legacy.ambr new file mode 100644 index 0000000000..87d2a1f15c --- /dev/null +++ b/tests_integration/plugin/__snapshots__/test_broken_plugin_legacy.ambr @@ -0,0 +1,15 @@ +# serializer version: 1 +# name: test_broken_command_path_plugin + ''' + [ + { + "connection_name": "test", + "parameters": { + "account": "test" + }, + "is_default": false + } + ] + + ''' +# --- diff --git a/tests_integration/plugin/__snapshots__/test_failing_plugin_config_ng.ambr b/tests_integration/plugin/__snapshots__/test_failing_plugin_config_ng.ambr new file mode 100644 index 0000000000..144e641e4c --- /dev/null +++ b/tests_integration/plugin/__snapshots__/test_failing_plugin_config_ng.ambr @@ -0,0 +1,15 @@ +# serializer version: 1 +# name: test_failing_plugin + ''' + [ + { + "connection_name": "test", + "parameters": { + "account": "test" + }, + "is_default": false + } + ] + + ''' +# --- diff --git a/tests_integration/plugin/__snapshots__/test_failing_plugin_legacy.ambr b/tests_integration/plugin/__snapshots__/test_failing_plugin_legacy.ambr new file mode 100644 index 0000000000..144e641e4c --- /dev/null +++ b/tests_integration/plugin/__snapshots__/test_failing_plugin_legacy.ambr @@ -0,0 +1,15 @@ +# serializer version: 1 +# name: test_failing_plugin + ''' + [ + { + "connection_name": "test", + "parameters": { + "account": "test" + }, + "is_default": false + } + ] + + ''' +# --- diff --git a/tests_integration/plugin/__snapshots__/test_override_by_external_plugins_config_ng.ambr b/tests_integration/plugin/__snapshots__/test_override_by_external_plugins_config_ng.ambr new file mode 100644 index 0000000000..947fa54d21 --- /dev/null +++ b/tests_integration/plugin/__snapshots__/test_override_by_external_plugins_config_ng.ambr @@ -0,0 +1,30 @@ +# serializer version: 1 +# name: test_disabled_plugin_is_not_executed + ''' + [ + { + "connection_name": "test", + "parameters": { + "account": "test" + }, + "is_default": false + } + ] + + ''' +# --- +# name: test_override_build_in_commands + ''' + Outside command code + [ + { + "connection_name": "test", + "parameters": { + "account": "test" + }, + "is_default": false + } + ] + + ''' +# --- diff --git a/tests_integration/plugin/__snapshots__/test_override_by_external_plugins_legacy.ambr b/tests_integration/plugin/__snapshots__/test_override_by_external_plugins_legacy.ambr new file mode 100644 index 0000000000..947fa54d21 --- /dev/null +++ b/tests_integration/plugin/__snapshots__/test_override_by_external_plugins_legacy.ambr @@ -0,0 +1,30 @@ +# serializer version: 1 +# name: test_disabled_plugin_is_not_executed + ''' + [ + { + "connection_name": "test", + "parameters": { + "account": "test" + }, + "is_default": false + } + ] + + ''' +# --- +# name: test_override_build_in_commands + ''' + Outside command code + [ + { + "connection_name": "test", + "parameters": { + "account": "test" + }, + "is_default": false + } + ] + + ''' +# --- diff --git a/tests_integration/plugin/test_broken_plugin.py b/tests_integration/plugin/test_broken_plugin.py index d60ff6d197..9d9c8dc02c 100644 --- a/tests_integration/plugin/test_broken_plugin.py +++ b/tests_integration/plugin/test_broken_plugin.py @@ -12,18 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from textwrap import dedent - import pytest @pytest.mark.integration -def test_broken_command_path_plugin(runner, test_root_path, _install_plugin, caplog): +def test_broken_command_path_plugin( + runner, test_root_path, _install_plugin, caplog, config_snapshot +): + """Test broken plugin.""" config_path = ( test_root_path / "config" / "plugin_tests" / "broken_plugin_config.toml" ) - result = runner.invoke(["--config-file", config_path, "connection", "list"]) + result = runner.invoke( + ["--config-file", config_path, "connection", "list", "--format", "JSON"] + ) assert result.exit_code == 0, result.output assert "Loaded external plugin: broken_plugin" in caplog.messages @@ -31,15 +34,9 @@ def test_broken_command_path_plugin(runner, test_root_path, _install_plugin, cap "Cannot register plugin [broken_plugin]: Invalid command path [snow broken run]. Command group [broken] does not exist." in caplog.messages ) - assert result.output == dedent( - """\ - +----------------------------------------------------+ - | connection_name | parameters | is_default | - |-----------------+---------------------+------------| - | test | {'account': 'test'} | False | - +----------------------------------------------------+ - """ - ) + + # Use snapshot to capture the output + assert result.output == config_snapshot @pytest.fixture(scope="module") diff --git a/tests_integration/plugin/test_failing_plugin.py b/tests_integration/plugin/test_failing_plugin.py index a7e8af70e2..32f0eebde1 100644 --- a/tests_integration/plugin/test_failing_plugin.py +++ b/tests_integration/plugin/test_failing_plugin.py @@ -12,32 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -from textwrap import dedent - import pytest @pytest.mark.integration -def test_failing_plugin(runner, test_root_path, _install_plugin, caplog): +def test_failing_plugin( + runner, test_root_path, _install_plugin, caplog, config_snapshot +): + """Test failing plugin.""" config_path = ( test_root_path / "config" / "plugin_tests" / "failing_plugin_config.toml" ) - result = runner.invoke(["--config-file", config_path, "connection", "list"]) + result = runner.invoke( + ["--config-file", config_path, "connection", "list", "--format", "JSON"] + ) assert ( "Cannot register plugin [failing_plugin]: Some error in plugin" in caplog.messages ) - assert result.output == dedent( - """\ - +----------------------------------------------------+ - | connection_name | parameters | is_default | - |-----------------+---------------------+------------| - | test | {'account': 'test'} | False | - +----------------------------------------------------+ - """ - ) + + # Use snapshot to capture the output + assert result.output == config_snapshot @pytest.fixture(scope="module") diff --git a/tests_integration/plugin/test_override_by_external_plugins.py b/tests_integration/plugin/test_override_by_external_plugins.py index 2bf1043afa..951012fa29 100644 --- a/tests_integration/plugin/test_override_by_external_plugins.py +++ b/tests_integration/plugin/test_override_by_external_plugins.py @@ -12,39 +12,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -from textwrap import dedent - import pytest @pytest.mark.integration -def test_override_build_in_commands(runner, test_root_path, _install_plugin, caplog): +def test_override_build_in_commands( + runner, test_root_path, _install_plugin, caplog, config_snapshot +): + """Test plugin override attempt.""" config_path = ( test_root_path / "config" / "plugin_tests" / "override_plugin_config.toml" ) - result = runner.invoke(["--config-file", config_path, "connection", "list"]) + result = runner.invoke( + ["--config-file", config_path, "connection", "list", "--format", "JSON"] + ) assert ( "Cannot register plugin [override]: Cannot add command [snow connection list] because it already exists." in caplog.messages ) - assert result.output == dedent( - """\ - Outside command code - +----------------------------------------------------+ - | connection_name | parameters | is_default | - |-----------------+---------------------+------------| - | test | {'account': 'test'} | False | - +----------------------------------------------------+ - """ - ) + + # Use snapshot to capture the output + assert result.output == config_snapshot @pytest.mark.integration def test_disabled_plugin_is_not_executed( - runner, test_root_path, _install_plugin, caplog + runner, test_root_path, _install_plugin, caplog, config_snapshot ): + """Test disabled plugin.""" config_path = ( test_root_path / "config" @@ -52,18 +49,13 @@ def test_disabled_plugin_is_not_executed( / "disabled_override_plugin_config.toml" ) - result = runner.invoke(["--config-file", config_path, "connection", "list"]) - - assert result.output == dedent( - """\ - +----------------------------------------------------+ - | connection_name | parameters | is_default | - |-----------------+---------------------+------------| - | test | {'account': 'test'} | False | - +----------------------------------------------------+ - """ + result = runner.invoke( + ["--config-file", config_path, "connection", "list", "--format", "JSON"] ) + # Use snapshot to capture the output + assert result.output == config_snapshot + @pytest.fixture(scope="module") def _install_plugin(test_root_path):