Skip to content

Commit 2de73b5

Browse files
committed
SNOW-2306184: config refactor - clean tmp files
1 parent c9388df commit 2de73b5

File tree

5 files changed

+107
-17
lines changed

5 files changed

+107
-17
lines changed

src/snowflake/cli/api/config_ng/core.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from abc import ABC, abstractmethod
2727
from dataclasses import dataclass, field
2828
from datetime import datetime
29-
from typing import Any, Callable, Dict, List, Optional
29+
from typing import Any, Callable, Dict, List, Literal, Optional
3030

3131

3232
@dataclass(frozen=True)
@@ -84,9 +84,20 @@ class ValueSource(ABC):
8484
Precedence is determined by the order sources are provided to the resolver.
8585
"""
8686

87+
# Allowed source names for config resolution
88+
SourceName = Literal[
89+
"snowsql_config",
90+
"cli_config_toml",
91+
"connections_toml",
92+
"snowsql_env",
93+
"connection_specific_env",
94+
"cli_env",
95+
"cli_arguments",
96+
]
97+
8798
@property
8899
@abstractmethod
89-
def source_name(self) -> str:
100+
def source_name(self) -> SourceName:
90101
"""
91102
Unique identifier for this source.
92103
Examples: "cli_arguments", "snowsql_config", "cli_env"

src/snowflake/cli/api/config_ng/resolver.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,13 +404,11 @@ def resolve(self, key: Optional[str] = None, default: Any = None) -> Dict[str, A
404404
# Identify sources that connections.toml replaces
405405
# connections.toml only replaces cli_config_toml, not SnowSQL config
406406
cli_config_source = "cli_config_toml"
407-
connections_file_source = None
408407
connections_to_replace: set[str] = set()
409408

410409
# First pass: find connections.toml and identify connections to replace
411410
for source in self._sources:
412411
if hasattr(source, "is_connections_file") and source.is_connections_file:
413-
connections_file_source = source
414412
connections_to_replace = source.get_defined_connections()
415413
break
416414

src/snowflake/cli/api/config_ng/sources.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import logging
3333
import os
3434
from pathlib import Path
35-
from typing import Any, Dict, Optional
35+
from typing import Any, Dict, Final, Optional
3636

3737
from snowflake.cli.api.config_ng.core import ConfigValue, ValueSource
3838

@@ -117,7 +117,7 @@ def __init__(self):
117117
]
118118

119119
@property
120-
def source_name(self) -> str:
120+
def source_name(self) -> "ValueSource.SourceName":
121121
return "snowsql_config"
122122

123123
def discover(self, key: Optional[str] = None) -> Dict[str, ConfigValue]:
@@ -225,7 +225,7 @@ def __init__(self):
225225
]
226226

227227
@property
228-
def source_name(self) -> str:
228+
def source_name(self) -> "ValueSource.SourceName":
229229
return "cli_config_toml"
230230

231231
def discover(self, key: Optional[str] = None) -> Dict[str, ConfigValue]:
@@ -309,7 +309,7 @@ def __init__(self):
309309
self._file_path = Path.home() / ".snowflake" / "connections.toml"
310310

311311
@property
312-
def source_name(self) -> str:
312+
def source_name(self) -> "ValueSource.SourceName":
313313
return "connections_toml"
314314

315315
@property
@@ -448,7 +448,7 @@ class SnowSQLEnvironment(ValueSource):
448448
}
449449

450450
@property
451-
def source_name(self) -> str:
451+
def source_name(self) -> "ValueSource.SourceName":
452452
return "snowsql_env"
453453

454454
def discover(self, key: Optional[str] = None) -> Dict[str, ConfigValue]:
@@ -485,7 +485,7 @@ def supports_key(self, key: str) -> bool:
485485

486486

487487
# Base configuration keys that can be set via environment
488-
_ENV_CONFIG_KEYS = [
488+
_ENV_CONFIG_KEYS: Final[list[str]] = [
489489
"account",
490490
"user",
491491
"password",
@@ -533,7 +533,7 @@ class ConnectionSpecificEnvironment(ValueSource):
533533
"""
534534

535535
@property
536-
def source_name(self) -> str:
536+
def source_name(self) -> "ValueSource.SourceName":
537537
return "connection_specific_env"
538538

539539
def discover(self, key: Optional[str] = None) -> Dict[str, ConfigValue]:
@@ -608,7 +608,7 @@ class CliEnvironment(ValueSource):
608608
"""
609609

610610
@property
611-
def source_name(self) -> str:
611+
def source_name(self) -> "ValueSource.SourceName":
612612
return "cli_env"
613613

614614
def discover(self, key: Optional[str] = None) -> Dict[str, ConfigValue]:
@@ -677,7 +677,7 @@ def __init__(self, cli_context: Optional[Dict[str, Any]] = None):
677677
self._cli_context = cli_context or {}
678678

679679
@property
680-
def source_name(self) -> str:
680+
def source_name(self) -> "ValueSource.SourceName":
681681
return "cli_arguments"
682682

683683
def discover(self, key: Optional[str] = None) -> Dict[str, ConfigValue]:

src/snowflake/cli/api/config_provider.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,17 @@
1414

1515
from __future__ import annotations
1616

17+
import atexit
1718
import os
1819
from abc import ABC, abstractmethod
1920
from pathlib import Path
20-
from typing import TYPE_CHECKING, Any, Dict, Optional
21+
from typing import TYPE_CHECKING, Any, Dict, Final, Optional
2122

2223
if TYPE_CHECKING:
24+
from snowflake.cli.api.config_ng.core import ValueSource
2325
from snowflake.cli.api.config_ng.resolver import ConfigurationResolver
2426

25-
ALTERNATIVE_CONFIG_ENV_VAR = "SNOWFLAKE_CLI_CONFIG_V2_ENABLED"
27+
ALTERNATIVE_CONFIG_ENV_VAR: Final[str] = "SNOWFLAKE_CLI_CONFIG_V2_ENABLED"
2628

2729

2830
class ConfigProvider(ABC):
@@ -97,7 +99,6 @@ def _transform_private_key_raw(self, connection_dict: dict) -> dict:
9799
if "private_key_file" in connection_dict:
98100
return connection_dict
99101

100-
import os
101102
import tempfile
102103

103104
try:
@@ -116,13 +117,35 @@ def _transform_private_key_raw(self, connection_dict: dict) -> dict:
116117
result["private_key_file"] = temp_file_path
117118
del result["private_key_raw"]
118119

120+
# Track created temp file on the provider instance for cleanup
121+
temp_files_attr = "_temp_private_key_files"
122+
existing = getattr(self, temp_files_attr, None)
123+
if existing is None:
124+
setattr(self, temp_files_attr, {temp_file_path})
125+
else:
126+
existing.add(temp_file_path)
127+
119128
return result
120129

121130
except Exception:
122131
# If transformation fails, return original dict
123132
# The error will be handled downstream
124133
return connection_dict
125134

135+
def cleanup_temp_files(self) -> None:
136+
"""Delete any temporary files created from private_key_raw transformation."""
137+
temp_files = getattr(self, "_temp_private_key_files", None)
138+
if not temp_files:
139+
return
140+
to_remove = list(temp_files)
141+
for path in to_remove:
142+
try:
143+
Path(path).unlink(missing_ok=True)
144+
except Exception:
145+
# Best-effort cleanup; ignore failures
146+
pass
147+
temp_files.clear()
148+
126149

127150
class LegacyConfigProvider(ConfigProvider):
128151
"""
@@ -409,7 +432,7 @@ def section_exists(self, *path) -> bool:
409432
)
410433

411434
# Source priority levels (higher number = higher priority)
412-
_SOURCE_PRIORITIES = {
435+
_SOURCE_PRIORITIES: Final[dict["ValueSource.SourceName", int]] = {
413436
"snowsql_config": 1,
414437
"cli_config_toml": 2,
415438
"connections_toml": 3,
@@ -683,4 +706,23 @@ def reset_config_provider():
683706
Useful for testing and when config source changes.
684707
"""
685708
global _config_provider_instance
709+
# Cleanup any temp files created by the current provider instance
710+
if _config_provider_instance is not None:
711+
try:
712+
_config_provider_instance.cleanup_temp_files()
713+
except Exception:
714+
pass
686715
_config_provider_instance = None
716+
717+
718+
def _cleanup_provider_at_exit() -> None:
719+
"""Process-exit cleanup for provider-managed temporary files."""
720+
global _config_provider_instance
721+
if _config_provider_instance is not None:
722+
try:
723+
_config_provider_instance.cleanup_temp_files()
724+
except Exception:
725+
pass
726+
727+
728+
atexit.register(_cleanup_provider_at_exit)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Tests for temporary private_key_raw file lifecycle and cleanup."""
2+
3+
from pathlib import Path
4+
5+
6+
def test_private_key_raw_creates_and_cleans_temp_file(config_ng_setup, tmp_path):
7+
priv_key_content = (
8+
"""-----BEGIN PRIVATE KEY-----\nABC\n-----END PRIVATE KEY-----\n"""
9+
)
10+
11+
cli_config = """
12+
[connections.test]
13+
user = "cli-user"
14+
"""
15+
16+
env_vars = {
17+
# Provide private_key_raw via env to trigger transformation
18+
"SNOWFLAKE_CONNECTIONS_TEST_PRIVATE_KEY_RAW": priv_key_content,
19+
}
20+
21+
with config_ng_setup(cli_config=cli_config, env_vars=env_vars):
22+
from snowflake.cli.api.config import get_connection_dict
23+
from snowflake.cli.api.config_provider import (
24+
get_config_provider_singleton,
25+
reset_config_provider,
26+
)
27+
28+
provider = get_config_provider_singleton()
29+
30+
conn = get_connection_dict("test")
31+
temp_path = Path(conn["private_key_file"]) # should exist now
32+
assert temp_path.exists()
33+
assert temp_path.read_text() == priv_key_content
34+
35+
# Reset provider triggers cleanup
36+
reset_config_provider()
37+
38+
# File should be gone after cleanup
39+
assert not temp_path.exists()

0 commit comments

Comments
 (0)