Skip to content

Commit 8cc898b

Browse files
authored
Merge pull request #379 from sfu-db/oauth2
feat(connector): implement authorization code
2 parents 3019027 + e6838ca commit 8cc898b

File tree

8 files changed

+199
-55
lines changed

8 files changed

+199
-55
lines changed

dataprep/connector/schema/defs.py

Lines changed: 162 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,35 @@
11
"""Strong typed schema definition."""
22
from __future__ import annotations
33

4+
import http.server
5+
import random
6+
import socketserver
7+
import string
48
from base64 import b64encode
59
from enum import Enum
10+
from pathlib import Path
11+
from threading import Thread
612
from time import time
7-
from typing import Any, Dict, Optional, Union
8-
13+
from typing import Any, Dict, List, Optional, Union
14+
from urllib.parse import parse_qs, urlparse
15+
import socket
916
import requests
1017
from pydantic import Field
1118

19+
from ...utils import is_notebook
1220
from .base import BaseDef, BaseDefT
1321

14-
1522
# pylint: disable=missing-class-docstring,missing-function-docstring
23+
FILE_PATH: Path = Path(__file__).resolve().parent
24+
25+
with open(f"{FILE_PATH}/oauth2.html", "rb") as f:
26+
OAUTH2_TEMPLATE = f.read()
27+
28+
29+
def get_random_string(length: int) -> str:
30+
letters = string.ascii_lowercase
31+
result_str = "".join(random.choice(letters) for _ in range(length))
32+
return result_str
1633

1734

1835
class OffsetPaginationDef(BaseDef):
@@ -67,9 +84,126 @@ class FieldDef(BaseDef):
6784
FieldDefUnion = Union[FieldDef, bool, str] # Put bool before str
6885

6986

70-
class OAuth2AuthorizationDef(BaseDef):
87+
class TCPServer(socketserver.TCPServer):
88+
def server_bind(self) -> None:
89+
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
90+
self.socket.bind(self.server_address)
91+
92+
93+
class HTTPServer(http.server.BaseHTTPRequestHandler):
94+
def do_GET(self) -> None: # pylint: disable=invalid-name
95+
# pylint: disable=protected-access
96+
query = urlparse(self.path).query
97+
parsed = parse_qs(query)
98+
99+
(code,) = parsed["code"]
100+
(state,) = parsed["state"]
101+
102+
self.send_response(200)
103+
self.send_header("Content-type", "text/html")
104+
self.end_headers()
105+
self.wfile.write(OAUTH2_TEMPLATE)
106+
107+
Thread(target=self.server.shutdown).start()
108+
109+
# Hacky way to pass data out
110+
self.server._oauth2_code = code # type: ignore
111+
self.server._oauth2_state = state # type: ignore
112+
113+
def log_request(
114+
self, code: Union[str, int] = "-", size: Union[str, int] = "-"
115+
) -> None:
116+
pass
117+
118+
119+
class OAuth2AuthorizationCodeAuthorizationDef(BaseDef):
120+
type: str = Field("OAuth2", const=True)
121+
grant_type: str = Field("AuthorizationCode", const=True)
122+
scopes: List[str]
123+
auth_server_url: str
124+
token_server_url: str
125+
126+
def build(
127+
self,
128+
req_data: Dict[str, Any],
129+
params: Dict[str, Any],
130+
storage: Optional[Dict[str, Any]] = None,
131+
) -> None:
132+
if storage is None:
133+
raise ValueError("storage is required for OAuth2")
134+
135+
if "access_token" not in storage or storage.get("expires_at", 0) < time():
136+
port = params.get("port", 9999)
137+
code = self._auth(params["client_id"], port)
138+
139+
ckey = params["client_id"]
140+
csecret = params["client_secret"]
141+
b64cred = b64encode(f"{ckey}:{csecret}".encode("ascii")).decode()
142+
143+
resp: Dict[str, Any] = requests.post(
144+
self.token_server_url,
145+
headers={"Authorization": f"Basic {b64cred}"},
146+
data={
147+
"grant_type": "authorization_code",
148+
"code": code,
149+
"redirect_uri": f"http://localhost:{port}/",
150+
},
151+
).json()
152+
153+
if resp["token_type"].lower() != "bearer":
154+
raise RuntimeError("token_type is not bearer")
155+
156+
access_token = resp["access_token"]
157+
storage["access_token"] = access_token
158+
if "expires_in" in resp:
159+
storage["expires_at"] = (
160+
time() + resp["expires_in"] - 60
161+
) # 60 seconds grace period to avoid clock lag
162+
163+
req_data["headers"]["Authorization"] = f"Bearer {storage['access_token']}"
164+
165+
def _auth(self, client_id: str, port: int = 9999) -> str:
166+
# pylint: disable=protected-access
167+
168+
state = get_random_string(23)
169+
scope = ",".join(self.scopes)
170+
authurl = (
171+
f"{self.auth_server_url}?"
172+
f"response_type=code&client_id={client_id}&"
173+
f"redirect_uri=http%3A%2F%2Flocalhost:{port}/&scope={scope}&"
174+
f"state={state}"
175+
)
176+
if is_notebook():
177+
from IPython.display import ( # pylint: disable=import-outside-toplevel
178+
Javascript,
179+
display,
180+
)
181+
182+
display(Javascript(f"window.open('{authurl}');"))
183+
else:
184+
import webbrowser # pylint: disable=import-outside-toplevel
185+
186+
webbrowser.open_new_tab(authurl)
187+
188+
with TCPServer(("", 9999), HTTPServer) as httpd:
189+
try:
190+
httpd.serve_forever()
191+
finally:
192+
httpd.server_close()
193+
194+
if httpd._oauth2_state != state: # type: ignore
195+
raise RuntimeError("OAuth2 state does not match")
196+
197+
if httpd._oauth2_code is None: # type: ignore
198+
raise RuntimeError(
199+
"OAuth2 authorization code auth failed, no code acquired."
200+
)
201+
return httpd._oauth2_code # type: ignore
202+
203+
204+
class OAuth2ClientCredentialsAuthorizationDef(BaseDef):
71205
type: str = Field("OAuth2", const=True)
72-
grant_type: str
206+
grant_type: str = Field("ClientCredentials", const=True)
73207
token_server_url: str
74208

75209
def build(
@@ -81,32 +215,27 @@ def build(
81215
if storage is None:
82216
raise ValueError("storage is required for OAuth2")
83217

84-
if self.grant_type == "ClientCredentials":
85-
if "access_token" not in storage or storage.get("expires_at", 0) < time():
86-
# Not yet authorized
87-
ckey = params["client_id"]
88-
csecret = params["client_secret"]
89-
b64cred = b64encode(f"{ckey}:{csecret}".encode("ascii")).decode()
90-
resp: Dict[str, Any] = requests.post(
91-
self.token_server_url,
92-
headers={"Authorization": f"Basic {b64cred}"},
93-
data={"grant_type": "client_credentials"},
94-
).json()
95-
if resp["token_type"].lower() != "bearer":
96-
raise RuntimeError("token_type is not bearer")
97-
98-
access_token = resp["access_token"]
99-
storage["access_token"] = access_token
100-
if "expires_in" in resp:
101-
storage["expires_at"] = (
102-
time() + resp["expires_in"] - 60
103-
) # 60 seconds grace period to avoid clock lag
104-
105-
req_data["headers"]["Authorization"] = f"Bearer {storage['access_token']}"
106-
107-
# TODO: handle auto refresh
108-
elif self.grant_type == "AuthorizationCode":
109-
raise NotImplementedError
218+
if "access_token" not in storage or storage.get("expires_at", 0) < time():
219+
# Not yet authorized
220+
ckey = params["client_id"]
221+
csecret = params["client_secret"]
222+
b64cred = b64encode(f"{ckey}:{csecret}".encode("ascii")).decode()
223+
resp: Dict[str, Any] = requests.post(
224+
self.token_server_url,
225+
headers={"Authorization": f"Basic {b64cred}"},
226+
data={"grant_type": "client_credentials"},
227+
).json()
228+
if resp["token_type"].lower() != "bearer":
229+
raise RuntimeError("token_type is not bearer")
230+
231+
access_token = resp["access_token"]
232+
storage["access_token"] = access_token
233+
if "expires_in" in resp:
234+
storage["expires_at"] = (
235+
time() + resp["expires_in"] - 60
236+
) # 60 seconds grace period to avoid clock lag
237+
238+
req_data["headers"]["Authorization"] = f"Bearer {storage['access_token']}"
110239

111240

112241
class QueryParamAuthorizationDef(BaseDef):
@@ -156,7 +285,8 @@ def build(
156285

157286

158287
AuthorizationDef = Union[
159-
OAuth2AuthorizationDef,
288+
OAuth2ClientCredentialsAuthorizationDef,
289+
OAuth2AuthorizationCodeAuthorizationDef,
160290
QueryParamAuthorizationDef,
161291
BearerAuthorizationDef,
162292
HeaderAuthorizationDef,
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
<html>
2+
3+
<head>
4+
<title>OAuth2 Success</title>
5+
</head>
6+
7+
<body>
8+
<p>OAuth2 Success. This window can be closed now.</p>
9+
<script>window.close();</script>
10+
</body>
11+
12+
</html>

dataprep/eda/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
Text,
2020
)
2121
from .missing import compute_missing, plot_missing, render_missing
22-
from .utils import is_notebook
22+
from ..utils import is_notebook
2323

2424
__all__ = [
2525
"plot_correlation",

dataprep/eda/container.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from bokeh.embed import components
1212
from bokeh.resources import INLINE
1313
from jinja2 import Environment, PackageLoader
14-
from .utils import is_notebook
14+
from ..utils import is_notebook
1515

1616
output_notebook(INLINE, hide_banner=True) # for offline usage
1717

dataprep/eda/progress_bar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from dask.callbacks import Callback
88

9-
from .utils import is_notebook
9+
from ..utils import is_notebook
1010

1111
if is_notebook():
1212
from tqdm.notebook import tqdm

dataprep/eda/report.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from bokeh.resources import CDN
1515
from jinja2 import Template
1616

17-
from .utils import is_notebook
17+
from ..utils import is_notebook
1818

1919
INLINE_TEMPLATE = Template(
2020
"""

dataprep/eda/utils.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""
33
import logging
44
from math import ceil
5-
from typing import Any, Union
5+
from typing import Union
66

77
import dask.dataframe as dd
88
import numpy as np
@@ -13,24 +13,6 @@
1313
LOGGER = logging.getLogger(__name__)
1414

1515

16-
def is_notebook() -> Any:
17-
"""
18-
:return: whether it is running in jupyter notebook
19-
"""
20-
try:
21-
# pytype: disable=import-error
22-
from IPython import get_ipython # pylint: disable=import-outside-toplevel
23-
24-
# pytype: enable=import-error
25-
26-
shell = get_ipython().__class__.__name__
27-
if shell == "ZMQInteractiveShell":
28-
return True
29-
return False
30-
except (NameError, ImportError):
31-
return False
32-
33-
3416
def to_dask(df: Union[pd.DataFrame, dd.DataFrame]) -> dd.DataFrame:
3517
"""Convert a dataframe to a dask dataframe."""
3618
if isinstance(df, dd.DataFrame):

dataprep/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Utility functions used by the whole library."""
2+
from typing import Any
3+
4+
5+
def is_notebook() -> Any:
6+
"""
7+
:return: whether it is running in jupyter notebook
8+
"""
9+
try:
10+
# pytype: disable=import-error
11+
from IPython import get_ipython # pylint: disable=import-outside-toplevel
12+
13+
# pytype: enable=import-error
14+
15+
shell = get_ipython().__class__.__name__
16+
if shell == "ZMQInteractiveShell":
17+
return True
18+
return False
19+
except (NameError, ImportError):
20+
return False

0 commit comments

Comments
 (0)