diff --git a/src/snowflake/cli/_app/cli_app.py b/src/snowflake/cli/_app/cli_app.py index e57d6cdf95..d0041c6855 100644 --- a/src/snowflake/cli/_app/cli_app.py +++ b/src/snowflake/cli/_app/cli_app.py @@ -40,11 +40,14 @@ get_new_version_msg, show_new_version_banner_callback, ) -from snowflake.cli.api.config import config_init, get_feature_flags_section +from snowflake.cli.api.config import ( + config_init, + get_config_manager, + get_feature_flags_section, +) from snowflake.cli.api.output.formats import OutputFormat from snowflake.cli.api.output.types import CollectionResult from snowflake.cli.api.secure_path import SecurePath -from snowflake.connector.config_manager import CONFIG_MANAGER log = logging.getLogger(__name__) @@ -160,7 +163,7 @@ def callback(value: bool): {"key": "version", "value": __about__.VERSION}, { "key": "default_config_file_path", - "value": str(CONFIG_MANAGER.file_path), + "value": str(get_config_manager().file_path), }, {"key": "python_version", "value": sys.version}, {"key": "system_info", "value": platform.platform()}, diff --git a/src/snowflake/cli/_app/loggers.py b/src/snowflake/cli/_app/loggers.py index 0d9254bf49..efa53077e9 100644 --- a/src/snowflake/cli/_app/loggers.py +++ b/src/snowflake/cli/_app/loggers.py @@ -17,6 +17,7 @@ import logging import logging.config from dataclasses import asdict, dataclass, field +from pathlib import Path from typing import Any, Dict, List import typer @@ -136,7 +137,11 @@ def _check_log_level(self, config): @property def filename(self): - return self.path.path / _DEFAULT_LOG_FILENAME + from snowflake.cli.api.utils.path_utils import path_resolver + + # Ensure Windows short paths are resolved to prevent cleanup issues + resolved_path = path_resolver(str(self.path.path)) + return Path(resolved_path) / _DEFAULT_LOG_FILENAME def create_initial_loggers(): diff --git a/src/snowflake/cli/_app/version_check.py b/src/snowflake/cli/_app/version_check.py index 3e4a055dc3..b25cd570a6 100644 --- a/src/snowflake/cli/_app/version_check.py +++ b/src/snowflake/cli/_app/version_check.py @@ -10,9 +10,9 @@ CLI_SECTION, IGNORE_NEW_VERSION_WARNING_KEY, get_config_bool_value, + get_config_manager, ) from snowflake.cli.api.secure_path import SecurePath -from snowflake.connector.config_manager import CONFIG_MANAGER REPOSITORY_URL_PIP = "https://pypi.org/pypi/snowflake-cli/json" REPOSITORY_URL_BREW = "https://formulae.brew.sh/api/formula/snowflake-cli.json" @@ -69,12 +69,14 @@ class _VersionCache: _last_time = "last_time_check" _version = "version" _last_time_shown = "last_time_shown" - _version_cache_file = SecurePath( - CONFIG_MANAGER.file_path.parent / ".cli_version.cache" - ) + + @property + def _version_cache_file(self): + """Get version cache file path with lazy evaluation.""" + return SecurePath(get_config_manager().file_path.parent / ".cli_version.cache") def __init__(self): - self._cache_file = _VersionCache._version_cache_file + self._cache_file = self._version_cache_file def _save_latest_version(self, version: str): data = { diff --git a/src/snowflake/cli/_plugins/sql/repl.py b/src/snowflake/cli/_plugins/sql/repl.py index 509c3cab55..8cc6e897b6 100644 --- a/src/snowflake/cli/_plugins/sql/repl.py +++ b/src/snowflake/cli/_plugins/sql/repl.py @@ -13,21 +13,27 @@ from snowflake.cli._plugins.sql.manager import SqlManager from snowflake.cli._plugins.sql.repl_commands import detect_command from snowflake.cli.api.cli_global_context import get_cli_context_manager +from snowflake.cli.api.config import get_config_manager from snowflake.cli.api.console import cli_console from snowflake.cli.api.output.types import MultipleResults, QueryResult from snowflake.cli.api.rendering.sql_templates import SQLTemplateSyntaxConfig from snowflake.cli.api.secure_path import SecurePath -from snowflake.connector.config_manager import CONFIG_MANAGER from snowflake.connector.cursor import SnowflakeCursor log = getLogger(__name__) -HISTORY_FILE = SecurePath( - CONFIG_MANAGER.file_path.parent / "repl_history" -).path.expanduser() + +def _get_history_file(): + """Get history file path with lazy evaluation to avoid circular imports.""" + return SecurePath( + get_config_manager().file_path.parent / "repl_history" + ).path.expanduser() + + +HISTORY_FILE = None # Will be set lazily EXIT_KEYWORDS = ("exit", "quit") -log.debug("setting history file to: %s", HISTORY_FILE.as_posix()) +# History file path will be set when REPL is initialized @contextmanager @@ -65,7 +71,7 @@ def __init__( self._data = data or {} self._retain_comments = retain_comments self._template_syntax_config = template_syntax_config - self._history = FileHistory(HISTORY_FILE) + self._history = FileHistory(_get_history_file()) self._lexer = PygmentsLexer(CliLexer) self._completer = cli_completer self._repl_key_bindings = self._setup_key_bindings() diff --git a/src/snowflake/cli/api/cli_global_context.py b/src/snowflake/cli/api/cli_global_context.py index bbf491f732..4960a577ce 100644 --- a/src/snowflake/cli/api/cli_global_context.py +++ b/src/snowflake/cli/api/cli_global_context.py @@ -21,12 +21,19 @@ from pathlib import Path from typing import TYPE_CHECKING, Iterator +import tomlkit from snowflake.cli.api.connections import ConnectionContext, OpenConnectionCache from snowflake.cli.api.exceptions import MissingConfigurationError from snowflake.cli.api.metrics import CLIMetrics from snowflake.cli.api.output.formats import OutputFormat from snowflake.cli.api.rendering.jinja import CONTEXT_KEY from snowflake.connector import SnowflakeConnection +from snowflake.connector.config_manager import ( + ConfigManager, + ConfigSlice, + ConfigSliceOptions, +) +from snowflake.connector.constants import CONFIG_FILE if TYPE_CHECKING: from snowflake.cli._plugins.sql.repl import Repl @@ -66,6 +73,10 @@ class _CliGlobalContextManager: _definition_manager: DefinitionManager | None = None enhanced_exit_codes: bool = False + _config_manager: ConfigManager | None = None + config_file_override: Path | None = None + connections_file_override: Path | None = None + # which properties invalidate our current DefinitionManager? DEFINITION_MANAGER_DEPENDENCIES = [ "project_path_arg", @@ -73,6 +84,8 @@ class _CliGlobalContextManager: "project_env_overrides_args", ] + CONFIG_MANAGER_DEPENDENCIES = ["config_file_override", "connections_file_override"] + def reset(self): self.__init__() @@ -88,6 +101,9 @@ def __setattr__(self, prop, val): if prop in self.DEFINITION_MANAGER_DEPENDENCIES: self._clear_definition_manager() + if prop in self.CONFIG_MANAGER_DEPENDENCIES: + self._clear_config_manager() + super().__setattr__(prop, val) @property @@ -144,6 +160,63 @@ def _clear_definition_manager(self): """ self._definition_manager = None + @property + def config_manager(self) -> ConfigManager: + """ + Get or create the configuration manager instance. + Follows the same lazy initialization pattern as DefinitionManager. + """ + if self._config_manager is None: + self._config_manager = self._create_config_manager() + return self._config_manager + + def _create_config_manager(self) -> ConfigManager: + """ + Factory method to create ConfigManager instance with CLI-specific options. + Replicates the behavior of the imported CONFIG_MANAGER singleton. + """ + from snowflake.cli.api.config import get_connections_file + + connections_file = get_connections_file() + + connections_slice = ConfigSlice( + path=connections_file, + options=ConfigSliceOptions(check_permissions=True, only_in_slice=False), + section="connections", + ) + + manager = ConfigManager( + name="CONFIG_MANAGER", + file_path=self.config_file_override or CONFIG_FILE, + _slices=[connections_slice], + ) + + manager.add_option( + name="connections", + parse_str=tomlkit.parse, + default=dict(), + ) + + manager.add_option( + name="default_connection_name", parse_str=str, default="default" + ) + + from snowflake.cli.api.config import CLI_SECTION + + manager.add_option( + name=CLI_SECTION, + parse_str=tomlkit.parse, + default=dict(), + ) + + return manager + + def _clear_config_manager(self): + """ + Force re-creation of config manager when dependencies change. + """ + self._config_manager = None + class _CliGlobalContextAccess: def __init__(self, manager: _CliGlobalContextManager): @@ -216,6 +289,21 @@ def repl(self) -> Repl | None: """Get the current REPL instance if running in REPL mode.""" return self._manager.repl_instance + @property + def config_manager(self) -> ConfigManager: + """Get the current configuration manager.""" + return self._manager.config_manager + + @property + def config_file_override(self) -> Path | None: + """Get the current config file override path.""" + return self._manager.config_file_override + + @config_file_override.setter + def config_file_override(self, value: Path | None) -> None: + """Set the config file override path.""" + self._manager.config_file_override = value + _CLI_CONTEXT_MANAGER: ContextVar[_CliGlobalContextManager | None] = ContextVar( "cli_context", default=None diff --git a/src/snowflake/cli/api/config.py b/src/snowflake/cli/api/config.py index 93734ee120..5be189c2fb 100644 --- a/src/snowflake/cli/api/config.py +++ b/src/snowflake/cli/api/config.py @@ -35,10 +35,10 @@ windows_get_not_whitelisted_users_with_access, ) from snowflake.cli.api.utils.dict_utils import remove_key_from_nested_dict_if_exists +from snowflake.cli.api.utils.path_utils import path_resolver from snowflake.cli.api.utils.types import try_cast_to_bool from snowflake.connector.compat import IS_WINDOWS -from snowflake.connector.config_manager import CONFIG_MANAGER -from snowflake.connector.constants import CONFIG_FILE, CONNECTIONS_FILE +from snowflake.connector.constants import CONFIG_FILE from snowflake.connector.errors import ConfigSourceError, MissingConfigOptionError from tomlkit import TOMLDocument, dump from tomlkit.container import Container @@ -48,6 +48,26 @@ log = logging.getLogger(__name__) +def get_connections_file(): + """ + Dynamically get the current CONNECTIONS_FILE path. + This ensures we get the updated value after module reloads in tests. + """ + from snowflake.connector.constants import CONNECTIONS_FILE as _CONNECTIONS_FILE + + return _CONNECTIONS_FILE + + +def get_config_manager(): + """ + Get the current configuration manager from CLI context. + This replaces direct CONFIG_MANAGER access throughout the codebase. + """ + from snowflake.cli.api.cli_global_context import get_cli_context_manager + + return get_cli_context_manager().config_manager + + class Empty: pass @@ -63,12 +83,6 @@ class Empty: PLUGIN_ENABLED_KEY = "enabled" FEATURE_FLAGS_SECTION_PATH = [CLI_SECTION, "features"] -CONFIG_MANAGER.add_option( - name=CLI_SECTION, - parse_str=tomlkit.parse, - default=dict(), -) - @dataclass class ConnectionConfig: @@ -133,64 +147,79 @@ def config_init(config_file: Optional[Path]): If config file does not exist we create an empty one. """ from snowflake.cli._app.loggers import create_initial_loggers + from snowflake.cli.api.cli_global_context import get_cli_context_manager if config_file: - CONFIG_MANAGER.file_path = config_file + get_cli_context_manager().config_file_override = config_file else: _check_default_config_files_permissions() - if not CONFIG_MANAGER.file_path.exists(): - _initialise_config(CONFIG_MANAGER.file_path) + + config_manager = get_config_manager() + if not config_manager.file_path.exists(): + _initialise_config(config_manager.file_path) _read_config_file() create_initial_loggers() def add_connection_to_proper_file(name: str, connection_config: ConnectionConfig): - if CONNECTIONS_FILE.exists(): + connections_file = get_connections_file() + if connections_file.exists(): existing_connections = _read_connections_toml() existing_connections.update( {name: connection_config.to_dict_of_all_non_empty_values()} ) _update_connections_toml(existing_connections) - return CONNECTIONS_FILE + return connections_file else: set_config_value( path=[CONNECTIONS_SECTION, name], value=connection_config.to_dict_of_all_non_empty_values(), ) - return CONFIG_MANAGER.file_path + return get_config_manager().file_path def remove_connection_from_proper_file(name: str): - if CONNECTIONS_FILE.exists(): + connections_file = get_connections_file() + if connections_file.exists(): existing_connections = _read_connections_toml() if name not in existing_connections: raise MissingConfigurationError(f"Connection {name} is not configured") del existing_connections[name] _update_connections_toml(existing_connections) - return CONNECTIONS_FILE + return connections_file else: unset_config_value(path=[CONNECTIONS_SECTION, name]) - return CONFIG_MANAGER.file_path + return get_config_manager().file_path -_DEFAULT_LOGS_CONFIG = { - "save_logs": True, - "path": str(CONFIG_MANAGER.file_path.parent / "logs"), - "level": "info", -} +def _get_default_logs_config() -> dict: + """Get default logs configuration with lazy evaluation to avoid circular imports.""" + config_parent_path = get_config_manager().file_path.parent + resolved_parent_path = path_resolver(str(config_parent_path)) + + return { + "save_logs": True, + "path": str(Path(resolved_parent_path) / "logs"), + "level": "info", + } + -_DEFAULT_CLI_CONFIG = {LOGS_SECTION: _DEFAULT_LOGS_CONFIG} +def _get_default_cli_config() -> dict: + """Get default CLI configuration with lazy evaluation.""" + return {LOGS_SECTION: _get_default_logs_config()} @contextmanager def _config_file(): _read_config_file() - conf_file_cache = CONFIG_MANAGER.conf_file_cache + config_manager = get_config_manager() + conf_file_cache = config_manager.conf_file_cache yield conf_file_cache _dump_config(conf_file_cache) def _read_config_file(): + config_manager = get_config_manager() with warnings.catch_warnings(): if IS_WINDOWS: warnings.filterwarnings( @@ -199,19 +228,19 @@ def _read_config_file(): module="snowflake.connector.config_manager", ) - if not file_permissions_are_strict(CONFIG_MANAGER.file_path): + if not file_permissions_are_strict(config_manager.file_path): users = ", ".join( windows_get_not_whitelisted_users_with_access( - CONFIG_MANAGER.file_path + config_manager.file_path ) ) warnings.warn( - f"Unauthorized users ({users}) have access to configuration file {CONFIG_MANAGER.file_path}.\n" - f'Run `icacls "{CONFIG_MANAGER.file_path}" /remove:g ` on those users to restrict permissions.' + f"Unauthorized users ({users}) have access to configuration file {config_manager.file_path}.\n" + f'Run `icacls "{config_manager.file_path}" /remove:g ` on those users to restrict permissions.' ) try: - CONFIG_MANAGER.read_config() + config_manager.read_config() except ConfigSourceError as exception: raise ClickException( f"Configuration file seems to be corrupted. {str(exception.__cause__)}" @@ -220,7 +249,7 @@ def _read_config_file(): def _initialise_logs_section(): with _config_file() as conf_file_cache: - conf_file_cache[CLI_SECTION][LOGS_SECTION] = _DEFAULT_LOGS_CONFIG + conf_file_cache[CLI_SECTION][LOGS_SECTION] = _get_default_logs_config() def _initialise_cli_section(): @@ -253,7 +282,7 @@ def unset_config_value(path: List[str]) -> None: def get_logs_config() -> dict: - logs_config = _DEFAULT_LOGS_CONFIG.copy() + logs_config = _get_default_logs_config().copy() if config_section_exists(*LOGS_SECTION_PATH): logs_config.update(**get_config_section(*LOGS_SECTION_PATH)) return logs_config @@ -295,7 +324,7 @@ def get_connection_dict(connection_name: str) -> dict: def get_default_connection_name() -> str: - return CONFIG_MANAGER["default_connection_name"] + return get_config_manager()["default_connection_name"] def get_default_connection_dict() -> dict: @@ -350,7 +379,9 @@ def _initialise_config(config_file: Path) -> None: config_file.touch() _initialise_cli_section() _initialise_logs_section() - log.info("Created Snowflake configuration file at %s", CONFIG_MANAGER.file_path) + log.info( + "Created Snowflake configuration file at %s", get_config_manager().file_path + ) def get_env_variable_name(*path, key: str) -> str: @@ -362,7 +393,7 @@ def get_env_value(*path, key: str) -> str | None: def _find_section(*path) -> TOMLDocument: - section = CONFIG_MANAGER + section = get_config_manager() idx = 0 while idx < len(path): section = section[path[idx]] @@ -392,7 +423,8 @@ def _get_envs_for_path(*path) -> dict: def _dump_config(config_and_connections: Dict): config_toml_dict = config_and_connections.copy() - if CONNECTIONS_FILE.exists(): + connections_file = get_connections_file() + if connections_file.exists(): # update connections in connections.toml # it will add only connections (maybe updated) which were originally read from connections.toml # it won't add connections from config.toml @@ -405,16 +437,17 @@ def _dump_config(config_and_connections: Dict): else: config_toml_dict.pop("connections", None) - with SecurePath(CONFIG_MANAGER.file_path).open("w+") as fh: + with SecurePath(get_config_manager().file_path).open("w+") as fh: dump(config_toml_dict, fh) def _check_default_config_files_permissions() -> None: if not IS_WINDOWS: - if CONNECTIONS_FILE.exists() and not file_permissions_are_strict( - CONNECTIONS_FILE + connections_file = get_connections_file() + if connections_file.exists() and not file_permissions_are_strict( + connections_file ): - raise ConfigFileTooWidePermissionsError(CONNECTIONS_FILE) + raise ConfigFileTooWidePermissionsError(connections_file) if CONFIG_FILE.exists() and not file_permissions_are_strict(CONFIG_FILE): raise ConfigFileTooWidePermissionsError(CONFIG_FILE) @@ -438,13 +471,13 @@ def _bool_or_unknown(value): def _read_config_file_toml() -> dict: - return tomlkit.loads(CONFIG_MANAGER.file_path.read_text()).unwrap() + return tomlkit.loads(get_config_manager().file_path.read_text()).unwrap() def _read_connections_toml() -> dict: - return tomlkit.loads(CONNECTIONS_FILE.read_text()).unwrap() + return tomlkit.loads(get_connections_file().read_text()).unwrap() def _update_connections_toml(connections: dict): - with open(CONNECTIONS_FILE, "w") as f: + with open(get_connections_file(), "w") as f: f.write(tomlkit.dumps(connections)) diff --git a/tests/conftest.py b/tests/conftest.py index edfb41af64..c36de67ae3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -113,7 +113,10 @@ def reset_global_context_and_setup_config_and_logging_levels( cli_context_manager.verbose = False cli_context_manager.enable_tracebacks = False cli_context_manager.connection_cache = connection_cache - config_init(test_snowcli_config) + + cli_context_manager.config_file_override = test_snowcli_config + + config_init(None) loggers.create_loggers(verbose=False, debug=False) try: yield @@ -234,6 +237,96 @@ def app_zip(temporary_directory) -> Generator: yield create_temp_file(".zip", temporary_directory, []) +@pytest.fixture +def config_manager(): + """ + Direct access to current config manager. + Returns a proxy that always fetches the current manager from context, + ensuring tests see config changes after config_init() calls. + """ + from snowflake.cli.api.config import get_config_manager + + class ConfigManagerProxy: + def __getitem__(self, key): + return get_config_manager()[key] + + def __setitem__(self, key, value): + get_config_manager()[key] = value + + def __contains__(self, key): + return key in get_config_manager() + + def __getattr__(self, name): + return getattr(get_config_manager(), name) + + def __repr__(self): + return repr(get_config_manager()) + + def get(self, key, default=None): + return get_config_manager().get(key, default) + + return ConfigManagerProxy() + + +@pytest.fixture +def with_custom_config(): + """Context manager for custom config testing""" + + @contextmanager + def _with_custom_config(config_data: dict): + import tempfile + from pathlib import Path + + import tomlkit + + with tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False) as f: + tomlkit.dump(config_data, f) + config_file = Path(f.name) + + try: + with fork_cli_context() as ctx: + ctx.config_file_override = config_file + yield ctx.config_manager + finally: + config_file.unlink(missing_ok=True) + + return _with_custom_config + + +@pytest.fixture +def config_manager_factory(): + """Factory for creating config managers with specific settings""" + import tempfile + from pathlib import Path + + import tomlkit + + created_files = [] + + def _create_manager( + config_data: Optional[dict] = None, + config_file: Optional[Path] = None, + ): + with fork_cli_context() as ctx: + if config_file: + ctx.config_file_override = config_file + elif config_data: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".toml", delete=False + ) as f: + tomlkit.dump(config_data, f) + config_file = Path(f.name) + created_files.append(config_file) + ctx.config_file_override = config_file + + return ctx.config_manager + + yield _create_manager + + for file in created_files: + file.unlink(missing_ok=True) + + @pytest.fixture def correct_requirements_txt(temporary_directory) -> Generator: req_txt = create_named_file( @@ -423,11 +516,18 @@ def func(): @contextmanager def _named_temporary_file(suffix=None, prefix=None): with tempfile.TemporaryDirectory() as tmp_dir: + from snowflake.cli.api.utils.path_utils import path_resolver + + resolved_tmp_dir = path_resolver(tmp_dir) + suffix = suffix or "" prefix = prefix or "" - f = Path(tmp_dir) / f"{prefix}tmp_file{suffix}" + f = Path(resolved_tmp_dir) / f"{prefix}tmp_file{suffix}" f.touch() - yield f + try: + yield f + finally: + clean_logging_handlers() @pytest.fixture() diff --git a/tests/logs/snowflake-cli.log b/tests/logs/snowflake-cli.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/nativeapp/codegen/test_compiler.py b/tests/nativeapp/codegen/test_compiler.py index 6f703e3d9e..aa58e54291 100644 --- a/tests/nativeapp/codegen/test_compiler.py +++ b/tests/nativeapp/codegen/test_compiler.py @@ -122,7 +122,7 @@ def test_find_and_execute_processors_exception(test_proj_def, test_compiler): test_compiler.compile_artifacts() -class TestProcessor(ArtifactProcessor): +class AcmeProcessor(ArtifactProcessor): NAME = "test_processor" def __init__(self, *args, **kwargs): @@ -148,7 +148,7 @@ def test_skips_disabled_processors(test_proj_def, test_compiler): {"dest": "./", "src": "app/*", "processors": ["test_processor"]} ] test_compiler = NativeAppCompiler(_get_bundle_context(pkg_model)) - test_compiler.register(TestProcessor) + test_compiler.register(AcmeProcessor) - # TestProcessor is never invoked, otherwise calling its methods will make the test fail + # AcmeProcessor is never invoked, otherwise calling its methods will make the test fail test_compiler.compile_artifacts() diff --git a/tests/spcs/test_services.py b/tests/spcs/test_services.py index 5efdf13fb2..c9ee4915ac 100644 --- a/tests/spcs/test_services.py +++ b/tests/spcs/test_services.py @@ -68,14 +68,22 @@ @pytest.fixture() def enable_events_and_metrics_config(): + from snowflake.cli.api.utils.path_utils import path_resolver + + from tests.conftest import clean_logging_handlers + with TemporaryDirectory() as tempdir: - config_toml = Path(tempdir) / "config.toml" + resolved_tempdir = path_resolver(tempdir) + config_toml = Path(resolved_tempdir) / "config.toml" config_toml.write_text( "[cli.features]\n" "enable_spcs_service_events = true\n" "enable_spcs_service_metrics = true\n" ) - yield config_toml + try: + yield config_toml + finally: + clean_logging_handlers() @patch(EXECUTE_QUERY) diff --git a/tests/test_config.py b/tests/test_config.py index 475f28a9f8..b8747cfbb3 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -18,6 +18,7 @@ from unittest import mock import pytest +from snowflake.cli.api.cli_global_context import fork_cli_context from snowflake.cli.api.config import ( ConfigFileTooWidePermissionsError, config_init, @@ -34,14 +35,21 @@ def test_empty_config_file_is_created_if_not_present(): + from snowflake.cli.api.utils.path_utils import path_resolver + + from tests.conftest import clean_logging_handlers + with TemporaryDirectory() as tmp_dir: - config_file = Path(tmp_dir) / "sub" / "config.toml" + resolved_tmp_dir = path_resolver(tmp_dir) + config_file = Path(resolved_tmp_dir) / "sub" / "config.toml" assert config_file.exists() is False - config_init(config_file) - assert config_file.exists() is True - - assert_file_permissions_are_strict(config_file) + try: + config_init(config_file) + assert config_file.exists() is True + assert_file_permissions_are_strict(config_file) + finally: + clean_logging_handlers() @mock.patch.dict(os.environ, {}, clear=True) @@ -156,23 +164,27 @@ def test_get_all_connections(test_snowcli_config): } -@mock.patch("snowflake.cli.api.config.CONFIG_MANAGER") @mock.patch("snowflake.cli.api.config.get_config_section") def test_create_default_config_if_not_exists_with_proper_permissions( mock_get_config_section, - mock_config_manager, ): + from snowflake.cli.api.utils.path_utils import path_resolver + + from tests.conftest import clean_logging_handlers + mock_get_config_section.return_value = {} with TemporaryDirectory() as tmp_dir: - config_path = Path(f"{tmp_dir}/snowflake/config.toml") - mock_config_manager.file_path = config_path - mock_config_manager.conf_file_cache = {} + resolved_tmp_dir = path_resolver(tmp_dir) + config_path = Path(f"{resolved_tmp_dir}/snowflake/config.toml") - config_init(None) + try: + config_init(config_path) - assert config_path.exists() - assert_file_permissions_are_strict(config_path.parent) - assert_file_permissions_are_strict(config_path) + assert config_path.exists() + assert_file_permissions_are_strict(config_path.parent) + assert_file_permissions_are_strict(config_path) + finally: + clean_logging_handlers() @mock.patch.dict( @@ -227,82 +239,82 @@ def test_not_found_default_connection_from_evn_variable(test_root_path): def test_correct_updates_of_connections_on_setting_default_connection( test_snowcli_config, snowflake_home ): - from snowflake.cli.api.config import CONFIG_MANAGER + with fork_cli_context() as ctx: + config = test_snowcli_config + connections_toml = snowflake_home / "connections.toml" + connections_toml.write_text( + """[asdf_a] + database = "asdf_a_database" + user = "asdf_a" + account = "asdf_a" + + [asdf_b] + database = "asdf_b_database" + user = "asdf_b" + account = "asdf_b" + """ + ) - config = test_snowcli_config - connections_toml = snowflake_home / "connections.toml" - connections_toml.write_text( - """[asdf_a] - database = "asdf_a_database" - user = "asdf_a" - account = "asdf_a" - - [asdf_b] - database = "asdf_b_database" - user = "asdf_b" - account = "asdf_b" - """ - ) - config_init(config) - set_config_value(path=["default_connection_name"], value="asdf_b") - - def assert_correct_connections_loaded(): - assert CONFIG_MANAGER["default_connection_name"] == "asdf_b" - assert CONFIG_MANAGER["connections"] == { - "asdf_a": { - "database": "asdf_a_database", - "user": "asdf_a", - "account": "asdf_a", - }, - "asdf_b": { - "database": "asdf_b_database", - "user": "asdf_b", - "account": "asdf_b", - }, - } - - # assert correct connections in memory after setting default connection name - assert_correct_connections_loaded() - - with open(connections_toml) as f: - connection_toml_content = f.read() - assert ( - connection_toml_content.count("asdf_a") == 4 - ) # connection still exists in connections.toml - assert ( - connection_toml_content.count("asdf_b") == 4 - ) # connection still exists in connections.toml - assert ( - connection_toml_content.count("jwt") == 0 - ) # connection from config.toml isn't copied to connections.toml - with open(config) as f: - config_toml_content = f.read() - assert ( - config_toml_content.count("asdf_a") == 0 - ) # connection from connections.toml isn't copied to config.toml - assert ( - config_toml_content.count("asdf_b") == 1 - ) # only default_config_name setting, connection from connections.toml isn't copied to config.toml - assert ( - config_toml_content.count("connections.full") == 1 - ) # connection wasn't erased from config.toml - assert ( - config_toml_content.count("connections.jwt") == 1 - ) # connection wasn't erased from config.toml - assert ( - config_toml_content.count("dummy_flag = true") == 1 - ) # other settings are not erased - - # reinit config file and recheck loaded connections - config_init(config) - assert_correct_connections_loaded() + ctx.config_file_override = config + config_init(None) + set_config_value(path=["default_connection_name"], value="asdf_b") + + config_manager = ctx.config_manager + + def assert_correct_connections_loaded(): + assert config_manager["default_connection_name"] == "asdf_b" + assert config_manager["connections"] == { + "asdf_a": { + "database": "asdf_a_database", + "user": "asdf_a", + "account": "asdf_a", + }, + "asdf_b": { + "database": "asdf_b_database", + "user": "asdf_b", + "account": "asdf_b", + }, + } + + # assert correct connections in memory after setting default connection name + assert_correct_connections_loaded() + + with open(connections_toml) as f: + connection_toml_content = f.read() + assert ( + connection_toml_content.count("asdf_a") == 4 + ) # connection still exists in connections.toml + assert ( + connection_toml_content.count("asdf_b") == 4 + ) # connection still exists in connections.toml + assert ( + connection_toml_content.count("jwt") == 0 + ) # connection from config.toml isn't copied to connections.toml + with open(config) as f: + config_toml_content = f.read() + assert ( + config_toml_content.count("asdf_a") == 0 + ) # connection from connections.toml isn't copied to config.toml + assert ( + config_toml_content.count("asdf_b") == 1 + ) # only default_config_name setting, connection from connections.toml isn't copied to config.toml + assert ( + config_toml_content.count("connections.full") == 1 + ) # connection wasn't erased from config.toml + assert ( + config_toml_content.count("connections.jwt") == 1 + ) # connection wasn't erased from config.toml + assert ( + config_toml_content.count("dummy_flag = true") == 1 + ) # other settings are not erased + + config_init(None) + assert_correct_connections_loaded() def test_correct_updates_of_connections_on_setting_default_connection_for_empty_config_file( - config_file, snowflake_home + config_file, snowflake_home, config_manager ): - from snowflake.cli.api.config import CONFIG_MANAGER - with config_file() as config: connections_toml = snowflake_home / "connections.toml" connections_toml.write_text( @@ -321,8 +333,8 @@ def test_correct_updates_of_connections_on_setting_default_connection_for_empty_ set_config_value(path=["default_connection_name"], value="asdf_b") def assert_correct_connections_loaded(): - assert CONFIG_MANAGER["default_connection_name"] == "asdf_b" - assert CONFIG_MANAGER["connections"] == { + assert config_manager["default_connection_name"] == "asdf_b" + assert config_manager["connections"] == { "asdf_a": { "database": "asdf_a_database", "user": "asdf_a", @@ -367,14 +379,13 @@ def assert_correct_connections_loaded(): config_toml_content.count("dummy_flag = true") == 0 ) # other settings are not erased - # reinit config file and recheck loaded connections config_init(config) assert_correct_connections_loaded() -def test_connections_toml_override_config_toml(test_snowcli_config, snowflake_home): - from snowflake.cli.api.config import CONFIG_MANAGER - +def test_connections_toml_override_config_toml( + test_snowcli_config, snowflake_home, config_manager +): connections_toml = snowflake_home / "connections.toml" connections_toml.write_text( """[default] @@ -384,7 +395,7 @@ def test_connections_toml_override_config_toml(test_snowcli_config, snowflake_ho config_init(test_snowcli_config) assert get_default_connection_dict() == {"database": "overridden_database"} - assert CONFIG_MANAGER["connections"] == { + assert config_manager["connections"] == { "default": {"database": "overridden_database"} } @@ -490,8 +501,13 @@ def test_too_wide_permissions_on_custom_config_file_causes_warning( def test_too_wide_permissions_on_custom_config_file_causes_warning_windows(permissions): import subprocess + from snowflake.cli.api.utils.path_utils import path_resolver + + from tests.conftest import clean_logging_handlers + with TemporaryDirectory() as tmp_dir: - config_path = Path(tmp_dir) / "config.toml" + resolved_tmp_dir = path_resolver(tmp_dir) + config_path = Path(resolved_tmp_dir) / "config.toml" config_path.touch() result = subprocess.run( ["icacls", str(config_path), "/GRANT", f"Everyone:{permissions}"], @@ -500,11 +516,14 @@ def test_too_wide_permissions_on_custom_config_file_causes_warning_windows(permi ) assert result.returncode == 0, result.stdout + result.stderr - with pytest.warns( - UserWarning, - match=r"Unauthorized users \(.*\) have access to configuration file .*", - ): - config_init(config_file=config_path) + try: + with pytest.warns( + UserWarning, + match=r"Unauthorized users \(.*\) have access to configuration file .*", + ): + config_init(config_file=config_path) + finally: + clean_logging_handlers() @parametrize_chmod diff --git a/tests/test_main.py b/tests/test_main.py index 3fa50be806..a4052c2b35 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -26,7 +26,6 @@ import pytest from click import Command -from snowflake.connector.config_manager import CONFIG_MANAGER from typer.core import TyperArgument, TyperOption @@ -67,13 +66,13 @@ def test_custom_config_path(mock_conn, runner, mock_cursor): @mock.patch.dict(os.environ, {"SNOWFLAKE_HOME": "FooBar"}, clear=True) -def test_info_callback(runner): +def test_info_callback(runner, config_manager): result = runner.invoke(["--info"]) assert result.exit_code == 0, result.output payload = json.loads(result.output) assert payload == [ {"key": "version", "value": "0.0.0-test_patched"}, - {"key": "default_config_file_path", "value": str(CONFIG_MANAGER.file_path)}, + {"key": "default_config_file_path", "value": str(config_manager.file_path)}, {"key": "python_version", "value": sys.version}, {"key": "system_info", "value": platform.platform()}, { diff --git a/tests_common/conftest.py b/tests_common/conftest.py index 19f43bde09..1b889f02ed 100644 --- a/tests_common/conftest.py +++ b/tests_common/conftest.py @@ -35,14 +35,19 @@ @pytest.fixture def temporary_directory(): + from snowflake.cli.api.utils.path_utils import path_resolver + from tests.conftest import clean_logging_handlers + initial_dir = os.getcwd() with tempfile.TemporaryDirectory() as tmp_dir: + resolved_tmp_dir = path_resolver(tmp_dir) try: - os.chdir(tmp_dir) - yield tmp_dir + os.chdir(resolved_tmp_dir) + yield resolved_tmp_dir finally: os.chdir(initial_dir) + clean_logging_handlers() # Borrowed from tests_integration/test_utils.py @@ -100,6 +105,7 @@ def snowflake_home(monkeypatch): sys.modules["snowflake.connector.config_manager"], sys.modules["snowflake.connector.log_configuration"], sys.modules["snowflake.cli.api.config"], + sys.modules["snowflake.cli.api.cli_global_context"], ]: importlib.reload(module)