Skip to content

Commit 5681fb7

Browse files
mdesmethashhar
authored andcommitted
Launch webbrowser for oauth2 authentication
1 parent 1714795 commit 5681fb7

File tree

5 files changed

+452
-138
lines changed

5 files changed

+452
-138
lines changed

README.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,14 @@ the [`JWT` authentication type](https://trino.io/docs/current/security/jwt.html)
167167

168168
### OAuth2 Authentication
169169

170-
- `OAuth2Authentication` class can be used to connect to a Trino cluster configured with
170+
The `OAuth2Authentication` class can be used to connect to a Trino cluster configured with
171171
the [OAuth2 authentication type](https://trino.io/docs/current/security/oauth2.html).
172-
- A callback to handle the redirect url can be provided via param `redirect_auth_url_handler`, by default it just outputs the redirect url to stdout.
173172

174-
* DBAPI
173+
A callback to handle the redirect url can be provided via param `redirect_auth_url_handler` of the `trino.auth.OAuth2Authentication` class. By default, it will try to launch a web browser (`trino.auth.WebBrowserRedirectHandler`) to go through the authentication flow and output the redirect url to stdout (`trino.auth.ConsoleRedirectHandler`). Multiple redirect handlers are combined using the `trino.auth.CompositeRedirectHandler` class.
174+
175+
The OAuth2 token will be cached either per `trino.auth.OAuth2Authentication` instance.
176+
177+
- DBAPI
175178

176179
```python
177180
from trino.dbapi import connect
@@ -185,7 +188,7 @@ the [OAuth2 authentication type](https://trino.io/docs/current/security/oauth2.h
185188
)
186189
```
187190

188-
* SQLAlchemy
191+
- SQLAlchemy
189192

190193
```python
191194
from sqlalchemy import create_engine

tests/unit/oauth_test_utils.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
13+
import json
14+
import re
15+
import uuid
16+
from collections import namedtuple
17+
18+
import httpretty
19+
20+
from trino import constants
21+
22+
SERVER_ADDRESS = "https://coordinator"
23+
REDIRECT_PATH = "oauth2/initiate"
24+
TOKEN_PATH = "oauth2/token"
25+
REDIRECT_RESOURCE = f"{SERVER_ADDRESS}/{REDIRECT_PATH}"
26+
TOKEN_RESOURCE = f"{SERVER_ADDRESS}/{TOKEN_PATH}"
27+
28+
29+
class RedirectHandler:
30+
def __init__(self):
31+
self.redirect_server = ""
32+
33+
def __call__(self, url):
34+
self.redirect_server += url
35+
36+
37+
class PostStatementCallback:
38+
def __init__(self, redirect_server, token_server, tokens, sample_post_response_data):
39+
self.redirect_server = redirect_server
40+
self.token_server = token_server
41+
self.tokens = tokens
42+
self.sample_post_response_data = sample_post_response_data
43+
44+
def __call__(self, request, uri, response_headers):
45+
authorization = request.headers.get("Authorization")
46+
if authorization and authorization.replace("Bearer ", "") in self.tokens:
47+
return [200, response_headers, json.dumps(self.sample_post_response_data)]
48+
return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{self.redirect_server}", '
49+
f'x_token_server="{self.token_server}"',
50+
'Basic realm': '"Trino"'}, ""]
51+
52+
53+
class GetTokenCallback:
54+
def __init__(self, token_server, token, attempts=1):
55+
self.token_server = token_server
56+
self.token = token
57+
self.attempts = attempts
58+
59+
def __call__(self, request, uri, response_headers):
60+
self.attempts -= 1
61+
if self.attempts < 0:
62+
return [404, response_headers, "{}"]
63+
if self.attempts == 0:
64+
return [200, response_headers, f'{{"token": "{self.token}"}}']
65+
return [200, response_headers, f'{{"nextUri": "{self.token_server}"}}']
66+
67+
68+
def _get_token_requests(challenge_id):
69+
return list(filter(
70+
lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}",
71+
httpretty.latest_requests()))
72+
73+
74+
def _post_statement_requests():
75+
return list(filter(
76+
lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH,
77+
httpretty.latest_requests()))
78+
79+
80+
class MultithreadedTokenServer:
81+
Challenge = namedtuple('Challenge', ['token', 'attempts'])
82+
83+
def __init__(self, sample_post_response_data, attempts=1):
84+
self.tokens = set()
85+
self.challenges = {}
86+
self.sample_post_response_data = sample_post_response_data
87+
self.attempts = attempts
88+
89+
# bind post statement
90+
httpretty.register_uri(
91+
method=httpretty.POST,
92+
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
93+
body=self.post_statement_callback)
94+
95+
# bind get token
96+
httpretty.register_uri(
97+
method=httpretty.GET,
98+
uri=re.compile(rf"{TOKEN_RESOURCE}/.*"),
99+
body=self.get_token_callback)
100+
101+
# noinspection PyUnusedLocal
102+
def post_statement_callback(self, request, uri, response_headers):
103+
authorization = request.headers.get("Authorization")
104+
105+
if authorization and authorization.replace("Bearer ", "") in self.tokens:
106+
return [200, response_headers, json.dumps(self.sample_post_response_data)]
107+
108+
challenge_id = str(uuid.uuid4())
109+
token = str(uuid.uuid4())
110+
self.tokens.add(token)
111+
self.challenges[challenge_id] = MultithreadedTokenServer.Challenge(token, self.attempts)
112+
redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
113+
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
114+
return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{redirect_server}", '
115+
f'x_token_server="{token_server}"',
116+
'Basic realm': '"Trino"'}, ""]
117+
118+
# noinspection PyUnusedLocal
119+
def get_token_callback(self, request, uri, response_headers):
120+
challenge_id = uri.replace(f"{TOKEN_RESOURCE}/", "")
121+
challenge = self.challenges[challenge_id]
122+
challenge = challenge._replace(attempts=challenge.attempts - 1)
123+
self.challenges[challenge_id] = challenge
124+
if challenge.attempts < 0:
125+
return [404, response_headers, "{}"]
126+
if challenge.attempts == 0:
127+
return [200, response_headers, f'{{"token": "{challenge.token}"}}']
128+
return [200, response_headers, f'{{"nextUri": "{uri}"}}']

tests/unit/test_client.py

Lines changed: 8 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,9 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212
import json
13-
import re
1413
import threading
1514
import time
1615
import uuid
17-
from collections import namedtuple
1816
from unittest import mock
1917
from urllib.parse import urlparse
2018

@@ -25,6 +23,9 @@
2523
from requests_kerberos.exceptions import KerberosExchangeError
2624

2725
import trino.exceptions
26+
from tests.unit.oauth_test_utils import RedirectHandler, GetTokenCallback, PostStatementCallback, \
27+
MultithreadedTokenServer, _post_statement_requests, _get_token_requests, REDIRECT_RESOURCE, TOKEN_RESOURCE, \
28+
SERVER_ADDRESS
2829
from trino import constants
2930
from trino.auth import KerberosAuthentication, _OAuth2TokenBearer
3031
from trino.client import TrinoQuery, TrinoRequest, TrinoResult
@@ -259,52 +260,6 @@ def long_call(request, uri, headers):
259260
httpretty.reset()
260261

261262

262-
SERVER_ADDRESS = "https://coordinator"
263-
REDIRECT_PATH = "oauth2/initiate"
264-
TOKEN_PATH = "oauth2/token"
265-
REDIRECT_RESOURCE = f"{SERVER_ADDRESS}/{REDIRECT_PATH}"
266-
TOKEN_RESOURCE = f"{SERVER_ADDRESS}/{TOKEN_PATH}"
267-
268-
269-
class RedirectHandler:
270-
def __init__(self):
271-
self.redirect_server = ""
272-
273-
def __call__(self, url):
274-
self.redirect_server += url
275-
276-
277-
class PostStatementCallback:
278-
def __init__(self, redirect_server, token_server, tokens, sample_post_response_data):
279-
self.redirect_server = redirect_server
280-
self.token_server = token_server
281-
self.tokens = tokens
282-
self.sample_post_response_data = sample_post_response_data
283-
284-
def __call__(self, request, uri, response_headers):
285-
authorization = request.headers.get("Authorization")
286-
if authorization and authorization.replace("Bearer ", "") in self.tokens:
287-
return [200, response_headers, json.dumps(self.sample_post_response_data)]
288-
return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{self.redirect_server}", '
289-
f'x_token_server="{self.token_server}"',
290-
'Basic realm': '"Trino"'}, ""]
291-
292-
293-
class GetTokenCallback:
294-
def __init__(self, token_server, token, attempts=1):
295-
self.token_server = token_server
296-
self.token = token
297-
self.attempts = attempts
298-
299-
def __call__(self, request, uri, response_headers):
300-
self.attempts -= 1
301-
if self.attempts < 0:
302-
return [404, response_headers, "{}"]
303-
if self.attempts == 0:
304-
return [200, response_headers, f'{{"token": "{self.token}"}}']
305-
return [200, response_headers, f'{{"nextUri": "{self.token_server}"}}']
306-
307-
308263
@pytest.mark.parametrize("attempts", [1, 3, 5])
309264
@httprettified
310265
def test_oauth2_authentication_flow(attempts, sample_post_response_data):
@@ -511,57 +466,6 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon
511466
assert len(_get_token_requests(challenge_id)) == 1
512467

513468

514-
class MultithreadedTokenServer:
515-
Challenge = namedtuple('Challenge', ['token', 'attempts'])
516-
517-
def __init__(self, sample_post_response_data, attempts=1):
518-
self.tokens = set()
519-
self.challenges = {}
520-
self.sample_post_response_data = sample_post_response_data
521-
self.attempts = attempts
522-
523-
# bind post statement
524-
httpretty.register_uri(
525-
method=httpretty.POST,
526-
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
527-
body=self.post_statement_callback)
528-
529-
# bind get token
530-
httpretty.register_uri(
531-
method=httpretty.GET,
532-
uri=re.compile(rf"{TOKEN_RESOURCE}/.*"),
533-
body=self.get_token_callback)
534-
535-
# noinspection PyUnusedLocal
536-
def post_statement_callback(self, request, uri, response_headers):
537-
authorization = request.headers.get("Authorization")
538-
539-
if authorization and authorization.replace("Bearer ", "") in self.tokens:
540-
return [200, response_headers, json.dumps(self.sample_post_response_data)]
541-
542-
challenge_id = str(uuid.uuid4())
543-
token = str(uuid.uuid4())
544-
self.tokens.add(token)
545-
self.challenges[challenge_id] = MultithreadedTokenServer.Challenge(token, self.attempts)
546-
redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
547-
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
548-
return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{redirect_server}", '
549-
f'x_token_server="{token_server}"',
550-
'Basic realm': '"Trino"'}, ""]
551-
552-
# noinspection PyUnusedLocal
553-
def get_token_callback(self, request, uri, response_headers):
554-
challenge_id = uri.replace(f"{TOKEN_RESOURCE}/", "")
555-
challenge = self.challenges[challenge_id]
556-
challenge = challenge._replace(attempts=challenge.attempts - 1)
557-
self.challenges[challenge_id] = challenge
558-
if challenge.attempts < 0:
559-
return [404, response_headers, "{}"]
560-
if challenge.attempts == 0:
561-
return [200, response_headers, f'{{"token": "{challenge.token}"}}']
562-
return [200, response_headers, f'{{"nextUri": "{uri}"}}']
563-
564-
565469
@httprettified
566470
def test_multithreaded_oauth2_authentication_flow(sample_post_response_data):
567471
redirect_handler = RedirectHandler()
@@ -598,31 +502,19 @@ def run(self) -> None:
598502
for thread in threads:
599503
thread.join()
600504

601-
# should issue only 3 tokens and each thread should get one
602-
assert len(token_server.tokens) == 3
505+
# should issue only 1 token and each thread should reuse it
506+
assert len(token_server.tokens) == 1
603507
for thread in threads:
604508
assert thread.token in token_server.tokens
605509

606-
# should start only 3 challenges and every token should be obtained
607-
assert len(token_server.challenges.keys()) == 3
510+
# should start only 1 challenge
511+
assert len(token_server.challenges.keys()) == 1
608512
for challenge_id, challenge in token_server.challenges.items():
609513
assert f"{REDIRECT_RESOURCE}/{challenge_id}" in redirect_handler.redirect_server
610514
assert challenge.attempts == 0
611515
assert len(_get_token_requests(challenge_id)) == 1
612516
# 3 threads * (10 POST /statement each + 1 replied request by authentication)
613-
assert len(_post_statement_requests()) == 33
614-
615-
616-
def _get_token_requests(challenge_id):
617-
return list(filter(
618-
lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}",
619-
httpretty.latest_requests()))
620-
621-
622-
def _post_statement_requests():
623-
return list(filter(
624-
lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH,
625-
httpretty.latest_requests()))
517+
assert len(_post_statement_requests()) == 31
626518

627519

628520
@mock.patch("trino.client.TrinoRequest.http")

0 commit comments

Comments
 (0)