11"""Strong typed schema definition."""
22from __future__ import annotations
33
4+ import http .server
5+ import random
6+ import socketserver
7+ import string
48from base64 import b64encode
59from enum import Enum
10+ from pathlib import Path
11+ from threading import Thread
612from time import time
7- from typing import Any , Dict , Optional , Union
8-
13+ from typing import Any , Dict , List , Optional , Union
14+ from urllib .parse import parse_qs , urlparse
15+ import socket
916import requests
1017from pydantic import Field
1118
19+ from ...utils import is_notebook
1220from .base import BaseDef , BaseDefT
1321
14-
1522# pylint: disable=missing-class-docstring,missing-function-docstring
23+ FILE_PATH : Path = Path (__file__ ).resolve ().parent
24+
25+ with open (f"{ FILE_PATH } /oauth2.html" , "rb" ) as f :
26+ OAUTH2_TEMPLATE = f .read ()
27+
28+
29+ def get_random_string (length : int ) -> str :
30+ letters = string .ascii_lowercase
31+ result_str = "" .join (random .choice (letters ) for _ in range (length ))
32+ return result_str
1633
1734
1835class OffsetPaginationDef (BaseDef ):
@@ -67,9 +84,126 @@ class FieldDef(BaseDef):
6784FieldDefUnion = Union [FieldDef , bool , str ] # Put bool before str
6885
6986
70- class OAuth2AuthorizationDef (BaseDef ):
87+ class TCPServer (socketserver .TCPServer ):
88+ def server_bind (self ) -> None :
89+ self .socket .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEADDR , 1 )
90+ self .socket .bind (self .server_address )
91+
92+
93+ class HTTPServer (http .server .BaseHTTPRequestHandler ):
94+ def do_GET (self ) -> None : # pylint: disable=invalid-name
95+ # pylint: disable=protected-access
96+ query = urlparse (self .path ).query
97+ parsed = parse_qs (query )
98+
99+ (code ,) = parsed ["code" ]
100+ (state ,) = parsed ["state" ]
101+
102+ self .send_response (200 )
103+ self .send_header ("Content-type" , "text/html" )
104+ self .end_headers ()
105+ self .wfile .write (OAUTH2_TEMPLATE )
106+
107+ Thread (target = self .server .shutdown ).start ()
108+
109+ # Hacky way to pass data out
110+ self .server ._oauth2_code = code # type: ignore
111+ self .server ._oauth2_state = state # type: ignore
112+
113+ def log_request (
114+ self , code : Union [str , int ] = "-" , size : Union [str , int ] = "-"
115+ ) -> None :
116+ pass
117+
118+
119+ class OAuth2AuthorizationCodeAuthorizationDef (BaseDef ):
120+ type : str = Field ("OAuth2" , const = True )
121+ grant_type : str = Field ("AuthorizationCode" , const = True )
122+ scopes : List [str ]
123+ auth_server_url : str
124+ token_server_url : str
125+
126+ def build (
127+ self ,
128+ req_data : Dict [str , Any ],
129+ params : Dict [str , Any ],
130+ storage : Optional [Dict [str , Any ]] = None ,
131+ ) -> None :
132+ if storage is None :
133+ raise ValueError ("storage is required for OAuth2" )
134+
135+ if "access_token" not in storage or storage .get ("expires_at" , 0 ) < time ():
136+ port = params .get ("port" , 9999 )
137+ code = self ._auth (params ["client_id" ], port )
138+
139+ ckey = params ["client_id" ]
140+ csecret = params ["client_secret" ]
141+ b64cred = b64encode (f"{ ckey } :{ csecret } " .encode ("ascii" )).decode ()
142+
143+ resp : Dict [str , Any ] = requests .post (
144+ self .token_server_url ,
145+ headers = {"Authorization" : f"Basic { b64cred } " },
146+ data = {
147+ "grant_type" : "authorization_code" ,
148+ "code" : code ,
149+ "redirect_uri" : f"http://localhost:{ port } /" ,
150+ },
151+ ).json ()
152+
153+ if resp ["token_type" ].lower () != "bearer" :
154+ raise RuntimeError ("token_type is not bearer" )
155+
156+ access_token = resp ["access_token" ]
157+ storage ["access_token" ] = access_token
158+ if "expires_in" in resp :
159+ storage ["expires_at" ] = (
160+ time () + resp ["expires_in" ] - 60
161+ ) # 60 seconds grace period to avoid clock lag
162+
163+ req_data ["headers" ]["Authorization" ] = f"Bearer { storage ['access_token' ]} "
164+
165+ def _auth (self , client_id : str , port : int = 9999 ) -> str :
166+ # pylint: disable=protected-access
167+
168+ state = get_random_string (23 )
169+ scope = "," .join (self .scopes )
170+ authurl = (
171+ f"{ self .auth_server_url } ?"
172+ f"response_type=code&client_id={ client_id } &"
173+ f"redirect_uri=http%3A%2F%2Flocalhost:{ port } /&scope={ scope } &"
174+ f"state={ state } "
175+ )
176+ if is_notebook ():
177+ from IPython .display import ( # pylint: disable=import-outside-toplevel
178+ Javascript ,
179+ display ,
180+ )
181+
182+ display (Javascript (f"window.open('{ authurl } ');" ))
183+ else :
184+ import webbrowser # pylint: disable=import-outside-toplevel
185+
186+ webbrowser .open_new_tab (authurl )
187+
188+ with TCPServer (("" , 9999 ), HTTPServer ) as httpd :
189+ try :
190+ httpd .serve_forever ()
191+ finally :
192+ httpd .server_close ()
193+
194+ if httpd ._oauth2_state != state : # type: ignore
195+ raise RuntimeError ("OAuth2 state does not match" )
196+
197+ if httpd ._oauth2_code is None : # type: ignore
198+ raise RuntimeError (
199+ "OAuth2 authorization code auth failed, no code acquired."
200+ )
201+ return httpd ._oauth2_code # type: ignore
202+
203+
204+ class OAuth2ClientCredentialsAuthorizationDef (BaseDef ):
71205 type : str = Field ("OAuth2" , const = True )
72- grant_type : str
206+ grant_type : str = Field ( "ClientCredentials" , const = True )
73207 token_server_url : str
74208
75209 def build (
@@ -81,32 +215,27 @@ def build(
81215 if storage is None :
82216 raise ValueError ("storage is required for OAuth2" )
83217
84- if self .grant_type == "ClientCredentials" :
85- if "access_token" not in storage or storage .get ("expires_at" , 0 ) < time ():
86- # Not yet authorized
87- ckey = params ["client_id" ]
88- csecret = params ["client_secret" ]
89- b64cred = b64encode (f"{ ckey } :{ csecret } " .encode ("ascii" )).decode ()
90- resp : Dict [str , Any ] = requests .post (
91- self .token_server_url ,
92- headers = {"Authorization" : f"Basic { b64cred } " },
93- data = {"grant_type" : "client_credentials" },
94- ).json ()
95- if resp ["token_type" ].lower () != "bearer" :
96- raise RuntimeError ("token_type is not bearer" )
97-
98- access_token = resp ["access_token" ]
99- storage ["access_token" ] = access_token
100- if "expires_in" in resp :
101- storage ["expires_at" ] = (
102- time () + resp ["expires_in" ] - 60
103- ) # 60 seconds grace period to avoid clock lag
104-
105- req_data ["headers" ]["Authorization" ] = f"Bearer { storage ['access_token' ]} "
106-
107- # TODO: handle auto refresh
108- elif self .grant_type == "AuthorizationCode" :
109- raise NotImplementedError
218+ if "access_token" not in storage or storage .get ("expires_at" , 0 ) < time ():
219+ # Not yet authorized
220+ ckey = params ["client_id" ]
221+ csecret = params ["client_secret" ]
222+ b64cred = b64encode (f"{ ckey } :{ csecret } " .encode ("ascii" )).decode ()
223+ resp : Dict [str , Any ] = requests .post (
224+ self .token_server_url ,
225+ headers = {"Authorization" : f"Basic { b64cred } " },
226+ data = {"grant_type" : "client_credentials" },
227+ ).json ()
228+ if resp ["token_type" ].lower () != "bearer" :
229+ raise RuntimeError ("token_type is not bearer" )
230+
231+ access_token = resp ["access_token" ]
232+ storage ["access_token" ] = access_token
233+ if "expires_in" in resp :
234+ storage ["expires_at" ] = (
235+ time () + resp ["expires_in" ] - 60
236+ ) # 60 seconds grace period to avoid clock lag
237+
238+ req_data ["headers" ]["Authorization" ] = f"Bearer { storage ['access_token' ]} "
110239
111240
112241class QueryParamAuthorizationDef (BaseDef ):
@@ -156,7 +285,8 @@ def build(
156285
157286
158287AuthorizationDef = Union [
159- OAuth2AuthorizationDef ,
288+ OAuth2ClientCredentialsAuthorizationDef ,
289+ OAuth2AuthorizationCodeAuthorizationDef ,
160290 QueryParamAuthorizationDef ,
161291 BearerAuthorizationDef ,
162292 HeaderAuthorizationDef ,
0 commit comments