Skip to content

Commit 3769b43

Browse files
sfc-gh-jszczerbinskisfc-gh-jrosesfc-gh-aling
authored
SNOW-1902019: Python CVEs january batch (#2154)
Co-authored-by: Jamison <[email protected]> Co-authored-by: Adam Ling <[email protected]>
1 parent ec3002d commit 3769b43

File tree

9 files changed

+354
-47
lines changed

9 files changed

+354
-47
lines changed

src/snowflake/connector/auth/_auth.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
ProgrammingError,
5353
ServiceUnavailableError,
5454
)
55+
from ..file_util import owner_rw_opener
5556
from ..network import (
5657
ACCEPT_TYPE_APPLICATION_SNOWFLAKE,
5758
CONTENT_TYPE_APPLICATION_JSON,
@@ -625,7 +626,11 @@ def flush_temporary_credentials() -> None:
625626
)
626627
try:
627628
with open(
628-
TEMPORARY_CREDENTIAL_FILE, "w", encoding="utf-8", errors="ignore"
629+
TEMPORARY_CREDENTIAL_FILE,
630+
"w",
631+
encoding="utf-8",
632+
errors="ignore",
633+
opener=owner_rw_opener,
629634
) as f:
630635
json.dump(TEMPORARY_CREDENTIAL, f)
631636
except Exception as ex:

src/snowflake/connector/cache.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ def __init__(
388388
file_path: str | dict[str, str],
389389
entry_lifetime: int = constants.DAY_IN_SECONDS,
390390
file_timeout: int = 0,
391+
load_if_file_exists: bool = True,
391392
) -> None:
392393
"""Inits an SFDictFileCache with path, lifetime.
393394
@@ -445,7 +446,7 @@ def __init__(
445446
self._file_lock_path = f"{self.file_path}.lock"
446447
self._file_lock = FileLock(self._file_lock_path, timeout=self.file_timeout)
447448
self.last_loaded: datetime.datetime | None = None
448-
if os.path.exists(self.file_path):
449+
if os.path.exists(self.file_path) and load_if_file_exists:
449450
with self._lock:
450451
self._load()
451452
# indicate whether the cache is modified or not, this variable is for
@@ -498,7 +499,7 @@ def _load(self) -> bool:
498499
"""Load cache from disk if possible, returns whether it was able to load."""
499500
try:
500501
with open(self.file_path, "rb") as r_file:
501-
other: SFDictFileCache = pickle.load(r_file)
502+
other: SFDictFileCache = self._deserialize(r_file)
502503
# Since we want to know whether we are dirty after loading
503504
# we have to know whether the file could learn anything from self
504505
# so instead of calling self.update we call other.update and swap
@@ -529,6 +530,13 @@ def load(self) -> bool:
529530
with self._lock:
530531
return self._load()
531532

533+
def _serialize(self):
534+
return pickle.dumps(self)
535+
536+
@classmethod
537+
def _deserialize(cls, r_file):
538+
return pickle.load(r_file)
539+
532540
def _save(self, load_first: bool = True, force_flush: bool = False) -> bool:
533541
"""Save cache to disk if possible, returns whether it was able to save.
534542
@@ -559,7 +567,7 @@ def _save(self, load_first: bool = True, force_flush: bool = False) -> bool:
559567
# python program.
560568
# thus we fall back to the approach using the normal open() method to open a file and write.
561569
with open(tmp_file, "wb") as w_file:
562-
w_file.write(pickle.dumps(self))
570+
w_file.write(self._serialize())
563571
# We write to a tmp file and then move it to have atomic write
564572
os.replace(tmp_file_path, self.file_path)
565573
self.last_loaded = datetime.datetime.fromtimestamp(

src/snowflake/connector/encryption_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from .compat import PKCS5_OFFSET, PKCS5_PAD, PKCS5_UNPAD
1919
from .constants import UTF8, EncryptionMetadata, MaterialDescriptor, kilobyte
20+
from .file_util import owner_rw_opener
2021
from .util_text import random_string
2122

2223
block_size = int(algorithms.AES.block_size / 8) # in bytes
@@ -213,7 +214,7 @@ def decrypt_file(
213214

214215
logger.debug("encrypted file: %s, tmp file: %s", in_filename, temp_output_file)
215216
with open(in_filename, "rb") as infile:
216-
with open(temp_output_file, "wb") as outfile:
217+
with open(temp_output_file, "wb", opener=owner_rw_opener) as outfile:
217218
SnowflakeEncryptionUtil.decrypt_stream(
218219
metadata, encryption_material, infile, outfile, chunk_size
219220
)

src/snowflake/connector/file_util.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
logger = getLogger(__name__)
2222

2323

24+
def owner_rw_opener(path, flags) -> int:
25+
return os.open(path, flags, mode=0o600)
26+
27+
2428
class SnowflakeFileUtil:
2529
@staticmethod
2630
def get_digest_and_size(src: IO[bytes]) -> tuple[str, int]:

src/snowflake/connector/ocsp_snowflake.py

Lines changed: 161 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from __future__ import annotations
77

88
import codecs
9+
import importlib
910
import json
1011
import os
1112
import platform
@@ -30,6 +31,7 @@
3031
from asn1crypto.x509 import Certificate
3132
from OpenSSL.SSL import Connection
3233

34+
from snowflake.connector import SNOWFLAKE_CONNECTOR_VERSION
3335
from snowflake.connector.compat import OK, urlsplit, urlunparse
3436
from snowflake.connector.constants import HTTP_HEADER_USER_AGENT
3537
from snowflake.connector.errorcode import (
@@ -58,9 +60,10 @@
5860

5961
from . import constants
6062
from .backoff_policies import exponential_backoff
61-
from .cache import SFDictCache, SFDictFileCache
63+
from .cache import CacheEntry, SFDictCache, SFDictFileCache
6264
from .telemetry import TelemetryField, generate_telemetry_data_dict
6365
from .url_util import extract_top_level_domain_from_hostname, url_encode_str
66+
from .util_text import _base64_bytes_to_str
6467

6568

6669
class OCSPResponseValidationResult(NamedTuple):
@@ -72,27 +75,180 @@ class OCSPResponseValidationResult(NamedTuple):
7275
ts: int | None = None
7376
validated: bool = False
7477

78+
def _serialize(self):
79+
def serialize_exception(exc):
80+
# serialization exception is not supported for all exceptions
81+
# in the ocsp_snowflake.py, most exceptions are RevocationCheckError which is easy to serialize.
82+
# however, it would require non-trivial effort to serialize other exceptions especially 3rd part errors
83+
# as there can be un-serializable members and nondeterministic constructor arguments.
84+
# here we do a general best efforts serialization for other exceptions recording only the error message.
85+
if not exc:
86+
return None
87+
88+
exc_type = type(exc)
89+
ret = {"class": exc_type.__name__, "module": exc_type.__module__}
90+
if isinstance(exc, RevocationCheckError):
91+
ret.update({"errno": exc.errno, "msg": exc.raw_msg})
92+
else:
93+
ret.update({"msg": str(exc)})
94+
return ret
95+
96+
return json.dumps(
97+
{
98+
"exception": serialize_exception(self.exception),
99+
"issuer": (
100+
_base64_bytes_to_str(self.issuer.dump()) if self.issuer else None
101+
),
102+
"subject": (
103+
_base64_bytes_to_str(self.subject.dump()) if self.subject else None
104+
),
105+
"cert_id": (
106+
_base64_bytes_to_str(self.cert_id.dump()) if self.cert_id else None
107+
),
108+
"ocsp_response": _base64_bytes_to_str(self.ocsp_response),
109+
"ts": self.ts,
110+
"validated": self.validated,
111+
}
112+
)
113+
114+
@classmethod
115+
def _deserialize(cls, json_str: str) -> OCSPResponseValidationResult:
116+
json_obj = json.loads(json_str)
117+
118+
def deserialize_exception(exception_dict: dict | None) -> Exception | None:
119+
# as pointed out in the serialization method, here we do the best effort deserialization
120+
# for non-RevocationCheckError exceptions. If we can not deserialize the exception, we will
121+
# return a RevocationCheckError with a message indicating the failure.
122+
if not exception_dict:
123+
return
124+
exc_class = exception_dict.get("class")
125+
exc_module = exception_dict.get("module")
126+
try:
127+
if (
128+
exc_class == "RevocationCheckError"
129+
and exc_module == "snowflake.connector.errors"
130+
):
131+
return RevocationCheckError(
132+
msg=exception_dict["msg"],
133+
errno=exception_dict["errno"],
134+
)
135+
else:
136+
module = importlib.import_module(exc_module)
137+
exc_cls = getattr(module, exc_class)
138+
return exc_cls(exception_dict["msg"])
139+
except Exception as deserialize_exc:
140+
logger.debug(
141+
f"hitting error {str(deserialize_exc)} while deserializing exception,"
142+
f" the original error error class and message are {exc_class} and {exception_dict['msg']}"
143+
)
144+
return RevocationCheckError(
145+
f"Got error {str(deserialize_exc)} while deserializing ocsp cache, please try "
146+
f"cleaning up the "
147+
f"OCSP cache under directory {OCSP_RESPONSE_VALIDATION_CACHE.file_path}",
148+
errno=ER_OCSP_RESPONSE_LOAD_FAILURE,
149+
)
150+
151+
return OCSPResponseValidationResult(
152+
exception=deserialize_exception(json_obj.get("exception")),
153+
issuer=(
154+
Certificate.load(b64decode(json_obj.get("issuer")))
155+
if json_obj.get("issuer")
156+
else None
157+
),
158+
subject=(
159+
Certificate.load(b64decode(json_obj.get("subject")))
160+
if json_obj.get("subject")
161+
else None
162+
),
163+
cert_id=(
164+
CertId.load(b64decode(json_obj.get("cert_id")))
165+
if json_obj.get("cert_id")
166+
else None
167+
),
168+
ocsp_response=(
169+
b64decode(json_obj.get("ocsp_response"))
170+
if json_obj.get("ocsp_response")
171+
else None
172+
),
173+
ts=json_obj.get("ts"),
174+
validated=json_obj.get("validated"),
175+
)
176+
177+
178+
class _OCSPResponseValidationResultCache(SFDictFileCache):
179+
def _serialize(self) -> bytes:
180+
entries = {
181+
(
182+
_base64_bytes_to_str(k[0]),
183+
_base64_bytes_to_str(k[1]),
184+
_base64_bytes_to_str(k[2]),
185+
): (v.expiry.isoformat(), v.entry._serialize())
186+
for k, v in self._cache.items()
187+
}
188+
189+
return json.dumps(
190+
{
191+
"cache_keys": list(entries.keys()),
192+
"cache_items": list(entries.values()),
193+
"entry_lifetime": self._entry_lifetime.total_seconds(),
194+
"file_path": str(self.file_path),
195+
"file_timeout": self.file_timeout,
196+
"last_loaded": (
197+
self.last_loaded.isoformat() if self.last_loaded else None
198+
),
199+
"telemetry": self.telemetry,
200+
"connector_version": SNOWFLAKE_CONNECTOR_VERSION, # reserved for schema version control
201+
}
202+
).encode()
203+
204+
@classmethod
205+
def _deserialize(cls, opened_fd) -> _OCSPResponseValidationResultCache:
206+
data = json.loads(opened_fd.read().decode())
207+
cache_instance = cls(
208+
file_path=data["file_path"],
209+
entry_lifetime=int(data["entry_lifetime"]),
210+
file_timeout=data["file_timeout"],
211+
load_if_file_exists=False,
212+
)
213+
cache_instance.file_path = os.path.expanduser(data["file_path"])
214+
cache_instance.telemetry = data["telemetry"]
215+
cache_instance.last_loaded = (
216+
datetime.fromisoformat(data["last_loaded"]) if data["last_loaded"] else None
217+
)
218+
for k, v in zip(data["cache_keys"], data["cache_items"]):
219+
cache_instance._cache[
220+
(b64decode(k[0]), b64decode(k[1]), b64decode(k[2]))
221+
] = CacheEntry(
222+
datetime.fromisoformat(v[0]),
223+
OCSPResponseValidationResult._deserialize(v[1]),
224+
)
225+
return cache_instance
226+
75227

76228
try:
77229
OCSP_RESPONSE_VALIDATION_CACHE: SFDictFileCache[
78230
tuple[bytes, bytes, bytes],
79231
OCSPResponseValidationResult,
80-
] = SFDictFileCache(
232+
] = _OCSPResponseValidationResultCache(
81233
entry_lifetime=constants.DAY_IN_SECONDS,
82234
file_path={
83235
"linux": os.path.join(
84-
"~", ".cache", "snowflake", "ocsp_response_validation_cache"
236+
"~", ".cache", "snowflake", "ocsp_response_validation_cache.json"
85237
),
86238
"darwin": os.path.join(
87-
"~", "Library", "Caches", "Snowflake", "ocsp_response_validation_cache"
239+
"~",
240+
"Library",
241+
"Caches",
242+
"Snowflake",
243+
"ocsp_response_validation_cache.json",
88244
),
89245
"windows": os.path.join(
90246
"~",
91247
"AppData",
92248
"Local",
93249
"Snowflake",
94250
"Caches",
95-
"ocsp_response_validation_cache",
251+
"ocsp_response_validation_cache.json",
96252
),
97253
},
98254
)

src/snowflake/connector/storage_client.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,11 @@ def _send_request_with_retry(
329329
f"{verb} with url {url} failed for exceeding maximum retries."
330330
)
331331

332+
def _open_intermediate_dst_path(self, mode):
333+
if not self.intermediate_dst_path.exists():
334+
self.intermediate_dst_path.touch(mode=0o600)
335+
return self.intermediate_dst_path.open(mode)
336+
332337
def prepare_download(self) -> None:
333338
# TODO: add nicer error message for when target directory is not writeable
334339
# but this should be done before we get here
@@ -352,13 +357,13 @@ def prepare_download(self) -> None:
352357
self.num_of_chunks = ceil(file_header.content_length / self.chunk_size)
353358

354359
# Preallocate encrypted file.
355-
with self.intermediate_dst_path.open("wb+") as fd:
360+
with self._open_intermediate_dst_path("wb+") as fd:
356361
fd.truncate(self.meta.src_file_size)
357362

358363
def write_downloaded_chunk(self, chunk_id: int, data: bytes) -> None:
359364
"""Writes given data to the temp location starting at chunk_id * chunk_size."""
360365
# TODO: should we use chunking and write content in smaller chunks?
361-
with self.intermediate_dst_path.open("rb+") as fd:
366+
with self._open_intermediate_dst_path("rb+") as fd:
362367
fd.seek(self.chunk_size * chunk_id)
363368
fd.write(data)
364369

src/snowflake/connector/util_text.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
import base64
89
import hashlib
910
import logging
1011
import random
@@ -292,6 +293,10 @@ def random_string(
292293
return "".join([prefix, random_part, suffix])
293294

294295

296+
def _base64_bytes_to_str(x) -> str | None:
297+
return base64.b64encode(x).decode("utf-8") if x else None
298+
299+
295300
def get_md5(text: str | bytes) -> bytes:
296301
if isinstance(text, str):
297302
text = text.encode("utf-8")

test/extras/run.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,18 @@
3535
assert (
3636
cache_files
3737
== {
38-
"ocsp_response_validation_cache.lock",
39-
"ocsp_response_validation_cache",
38+
"ocsp_response_validation_cache.json.lock",
39+
"ocsp_response_validation_cache.json",
4040
"ocsp_response_cache.json",
4141
}
4242
and not platform.system() == "Windows"
4343
) or (
4444
cache_files
4545
== {
46-
"ocsp_response_validation_cache",
46+
"ocsp_response_validation_cache.json",
4747
"ocsp_response_cache.json",
4848
}
4949
and platform.system() == "Windows"
50+
), str(
51+
cache_files
5052
)

0 commit comments

Comments
 (0)