Skip to content

Commit cbe2d18

Browse files
adding a helper function for this
1 parent 3b9ea86 commit cbe2d18

File tree

4 files changed

+40
-8
lines changed

4 files changed

+40
-8
lines changed

src/snowflake/connector/auth/_auth.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import copy
44
import json
55
import logging
6-
import os
76
import uuid
87
from datetime import datetime, timezone
98
from threading import Thread
@@ -63,6 +62,7 @@
6362
from ..session_manager import SessionManagerFactory
6463
from ..sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
6564
from ..token_cache import TokenCache, TokenKey, TokenType
65+
from ..util_text import expand_tilde
6666
from ..version import VERSION
6767
from .no_auth import AuthNoAuth
6868
from .oauth import AuthByOAuth
@@ -642,8 +642,7 @@ def get_token_from_private_key(
642642

643643
def get_public_key_fingerprint(private_key_file: str, password: str) -> str:
644644
"""Helper function to generate the public key fingerprint from the private key file"""
645-
# expand tilde
646-
private_key_file = os.path.expanduser(private_key_file)
645+
private_key_file = expand_tilde(private_key_file)
647646

648647
with open(private_key_file, "rb") as key:
649648
p_key = load_pem_private_key(

src/snowflake/connector/connection.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@
142142
from .telemetry import TelemetryClient, TelemetryData, TelemetryField
143143
from .time_util import HeartBeatTimer, get_time_millis
144144
from .url_util import extract_top_level_domain_from_hostname
145-
from .util_text import construct_hostname, parse_account, split_statements
145+
from .util_text import construct_hostname, expand_tilde, parse_account, split_statements
146146
from .wif_util import AttestationProvider
147147

148148
if sys.version_info >= (3, 13) or typing.TYPE_CHECKING:
@@ -174,9 +174,7 @@ def _get_private_bytes_from_file(
174174
if private_key_file_pwd is not None and isinstance(private_key_file_pwd, str):
175175
private_key_file_pwd = private_key_file_pwd.encode("utf-8")
176176

177-
# expand tilde
178-
if isinstance(private_key_file, str):
179-
private_key_file = os.path.expanduser(private_key_file)
177+
private_key_file = expand_tilde(private_key_file)
180178

181179
with open(private_key_file, "rb") as key:
182180
private_key = serialization.load_pem_private_key(

src/snowflake/connector/util_text.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import re
99
import string
1010
from io import StringIO
11-
from typing import Sequence
11+
from pathlib import Path
12+
from typing import Any, Sequence
1213

1314
COMMENT_PATTERN_RE = re.compile(r"^\s*\-\-")
1415
EMPTY_LINE_RE = re.compile(r"^\s*$")
@@ -301,3 +302,17 @@ def get_md5_for_integrity(text: str | bytes) -> bytes:
301302
md5 = hashlib.md5(usedforsecurity=False)
302303
md5.update(text)
303304
return md5.digest()
305+
306+
307+
def expand_tilde(path_to_expand: Any) -> Any:
308+
try:
309+
(
310+
path_to_expand == str(Path(path_to_expand).expanduser())
311+
if isinstance(path_to_expand, str)
312+
else path_to_expand
313+
)
314+
except RuntimeError:
315+
# home could not be resolved
316+
_logger.debug("User home could not be determined, not expanding tilde.")
317+
318+
return path_to_expand

test/unit/test_auth_keypair.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,3 +276,23 @@ def generate_key_pair(key_length: int, *, passphrase: bytes | None = None):
276276
public_key_der_encoded = "".join(public_key_pem.split("\n")[1:-2])
277277

278278
return private_key_der, public_key_der_encoded
279+
280+
281+
@pytest.mark.skipolddriver
282+
def test_expand_tilde():
283+
from os import environ
284+
285+
from snowflake.connector.util_text import expand_tilde
286+
287+
old_home = environ["HOME"]
288+
environ["HOME"] = "/home/myuser"
289+
290+
assert expand_tilde("/path/to/key.p8") == "/path/to/key.p8"
291+
assert expand_tilde("~/key.p8") == "/home/myuser/key.p8"
292+
293+
del environ["HOME"]
294+
assert isinstance(
295+
expand_tilde("~/key.p8"), str
296+
) # should still resolve from /etc/passwd
297+
298+
environ["HOME"] = old_home

0 commit comments

Comments
 (0)