Skip to content

Commit 493efad

Browse files
SNOW-1825621 OAuth code flow PKCE support (#2137)
1 parent b85a824 commit 493efad

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

src/snowflake/connector/auth/oauth_code.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import base64
8+
import hashlib
89
import json
910
import logging
1011
import secrets
@@ -55,6 +56,7 @@ def __init__(
5556
token_request_url: str,
5657
redirect_uri: str,
5758
scope: str,
59+
pkce: bool = False,
5860
**kwargs,
5961
) -> None:
6062
super().__init__(**kwargs)
@@ -72,6 +74,10 @@ def __init__(
7274
logger.debug("chose oauth state: %s", "".join("*" for _ in self._state))
7375
self._oauth_token = None
7476
self._protocol = "http"
77+
self.pkce = pkce
78+
if pkce:
79+
logger.debug("oauth pkce is going to be used")
80+
self._verifier: str | None = None
7581

7682
def reset_secrets(self) -> None:
7783
self._oauth_token = None
@@ -104,6 +110,18 @@ def construct_url(self) -> str:
104110
}
105111
if self.scope:
106112
params["scope"] = self.scope
113+
if self.pkce:
114+
self._verifier = secrets.token_urlsafe(43)
115+
# calculate challenge and verifier
116+
challenge = (
117+
base64.urlsafe_b64encode(
118+
hashlib.sha256(self._verifier.encode("utf-8")).digest()
119+
)
120+
.decode("utf-8")
121+
.rstrip("=")
122+
)
123+
params["code_challenge"] = challenge
124+
params["code_challenge_method"] = "S256"
107125
url_params = urllib.parse.urlencode(params)
108126
url = f"{self.authentication_url}?{url_params}"
109127
return url
@@ -186,6 +204,10 @@ def prepare(
186204
}
187205
if self.client_secret:
188206
fields["client_secret"] = self.client_secret
207+
if self.pkce:
208+
assert self._verifier is not None
209+
fields["code_verifier"] = self._verifier
210+
189211
resp = urllib3.PoolManager().request_encode_body( # TODO: use network pool to gain use of proxy settings and so on
190212
"POST",
191213
self.token_request_url,

src/snowflake/connector/connection.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from __future__ import annotations
77

88
import atexit
9+
import collections.abc
910
import logging
1011
import os
1112
import pathlib
@@ -334,6 +335,11 @@ def _get_private_bytes_from_file(
334335
str,
335336
# SNOW-1825621: OAUTH implementation
336337
),
338+
"oauth_security_features": (
339+
("pkce",),
340+
collections.abc.Iterable, # of strings
341+
# SNOW-1825621: OAUTH PKCE
342+
),
337343
}
338344

339345
APPLICATION_RE = re.compile(r"[\w\d_]+")
@@ -1088,7 +1094,7 @@ def __open_connection(self):
10881094
self.auth_class = AuthByWebBrowser(
10891095
application=self.application,
10901096
protocol=self._protocol,
1091-
host=self.host,
1097+
host=self.host, # TODO: delete this?
10921098
port=self.port,
10931099
timeout=self.login_timeout,
10941100
backoff_generator=self._backoff_generator,
@@ -1125,6 +1131,8 @@ def __open_connection(self):
11251131
backoff_generator=self._backoff_generator,
11261132
)
11271133
elif self._authenticator == OAUTH_AUTHORIZATION_CODE:
1134+
pkce = "pkce" in map(lambda e: e.lower(), self._oauth_security_features)
1135+
11281136
if self._oauth_client_id is None:
11291137
Error.errorhandler_wrapper(
11301138
self,
@@ -1150,6 +1158,7 @@ def __open_connection(self):
11501158
),
11511159
redirect_uri="http://127.0.0.1:{port}/",
11521160
scope=self._oauth_scope,
1161+
pkce=pkce,
11531162
)
11541163
elif self._authenticator == USR_PWD_MFA_AUTHENTICATOR:
11551164
self._session_parameters[PARAMETER_CLIENT_REQUEST_MFA_TOKEN] = (

0 commit comments

Comments
 (0)