Skip to content

Commit c0eae8e

Browse files
Fix error in ssl_wrap_socket_with_ocsp
1 parent c208c4b commit c0eae8e

File tree

2 files changed

+38
-12
lines changed

2 files changed

+38
-12
lines changed

src/snowflake/connector/ssl_wrap_socket.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -104,19 +104,45 @@ def ssl_wrap_socket_with_ocsp(*args: Any, **kwargs: Any) -> WrappedSocket:
104104
ssl_context_index = get_args(ssl_.ssl_wrap_socket).args.index("ssl_context")
105105
context_in_args = len(args) > ssl_context_index
106106
ssl_context = (
107-
args[hostname_index] if context_in_args else kwargs.get("ssl_context", None)
107+
args[ssl_context_index] if context_in_args else kwargs.get("ssl_context", None)
108108
)
109-
if not isinstance(ssl_context, PyOpenSSLContext):
110-
# Create new default context
111-
if context_in_args:
112-
new_args = list(args)
113-
new_args[ssl_context_index] = None
114-
args = tuple(new_args)
115-
else:
116-
del kwargs["ssl_context"]
109+
117110
# Fix ca certs location
118111
ca_certs_index = get_args(ssl_.ssl_wrap_socket).args.index("ca_certs")
119112
ca_certs_in_args = len(args) > ca_certs_index
113+
114+
if not isinstance(ssl_context, PyOpenSSLContext):
115+
if FEATURE_OCSP_MODE != OCSPMode.DISABLE_OCSP_CHECKS:
116+
# Create PyOpenSSL context for OCSP validation
117+
ssl_context = PyOpenSSLContext(ssl_.ssl.PROTOCOL_TLS_CLIENT)
118+
ssl_context.check_hostname = False
119+
ssl_context.verify_mode = ssl_.ssl.CERT_REQUIRED
120+
121+
# Load CA certificates
122+
ca_certs = (
123+
kwargs.get("ca_certs")
124+
or (args[ca_certs_index] if ca_certs_in_args else None)
125+
or certifi.where()
126+
)
127+
ssl_context.load_verify_locations(ca_certs)
128+
129+
# Set the PyOpenSSL context in arguments
130+
if context_in_args:
131+
new_args = list(args)
132+
new_args[ssl_context_index] = ssl_context
133+
args = tuple(new_args)
134+
else:
135+
kwargs["ssl_context"] = ssl_context
136+
else:
137+
# Create new default context
138+
if context_in_args:
139+
new_args = list(args)
140+
new_args[ssl_context_index] = None
141+
args = tuple(new_args)
142+
else:
143+
if "ssl_context" in kwargs:
144+
del kwargs["ssl_context"]
145+
120146
if not ca_certs_in_args and not kwargs.get("ca_certs"):
121147
kwargs["ca_certs"] = certifi.where()
122148

src/snowflake/connector/vendored/urllib3/contrib/emscripten/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
import urllib3.connection
3+
import snowflake.connector.vendored.urllib3.connection
44

55
from ...connectionpool import HTTPConnectionPool, HTTPSConnectionPool
66
from .connection import EmscriptenHTTPConnection, EmscriptenHTTPSConnection
@@ -12,5 +12,5 @@ def inject_into_urllib3() -> None:
1212
# if it isn't ignored
1313
HTTPConnectionPool.ConnectionCls = EmscriptenHTTPConnection
1414
HTTPSConnectionPool.ConnectionCls = EmscriptenHTTPSConnection
15-
urllib3.connection.HTTPConnection = EmscriptenHTTPConnection # type: ignore[misc,assignment]
16-
urllib3.connection.HTTPSConnection = EmscriptenHTTPSConnection # type: ignore[misc,assignment]
15+
snowflake.connector.vendored.urllib3.connection.HTTPConnection = EmscriptenHTTPConnection # type: ignore[misc,assignment]
16+
snowflake.connector.vendored.urllib3.connection.HTTPSConnection = EmscriptenHTTPSConnection # type: ignore[misc,assignment]

0 commit comments

Comments
 (0)