|
10 | 10 | # See the License for the specific language governing permissions and
|
11 | 11 | # limitations under the License.
|
12 | 12 | import json
|
13 |
| -import re |
14 | 13 | import threading
|
15 | 14 | import time
|
16 | 15 | import uuid
|
17 |
| -from collections import namedtuple |
18 | 16 | from unittest import mock
|
19 | 17 | from urllib.parse import urlparse
|
20 | 18 |
|
|
25 | 23 | from requests_kerberos.exceptions import KerberosExchangeError
|
26 | 24 |
|
27 | 25 | import trino.exceptions
|
| 26 | +from tests.unit.oauth_test_utils import RedirectHandler, GetTokenCallback, PostStatementCallback, \ |
| 27 | + MultithreadedTokenServer, _post_statement_requests, _get_token_requests, REDIRECT_RESOURCE, TOKEN_RESOURCE, \ |
| 28 | + SERVER_ADDRESS |
28 | 29 | from trino import constants
|
29 | 30 | from trino.auth import KerberosAuthentication, _OAuth2TokenBearer
|
30 | 31 | from trino.client import TrinoQuery, TrinoRequest, TrinoResult
|
@@ -259,52 +260,6 @@ def long_call(request, uri, headers):
|
259 | 260 | httpretty.reset()
|
260 | 261 |
|
261 | 262 |
|
262 |
| -SERVER_ADDRESS = "https://coordinator" |
263 |
| -REDIRECT_PATH = "oauth2/initiate" |
264 |
| -TOKEN_PATH = "oauth2/token" |
265 |
| -REDIRECT_RESOURCE = f"{SERVER_ADDRESS}/{REDIRECT_PATH}" |
266 |
| -TOKEN_RESOURCE = f"{SERVER_ADDRESS}/{TOKEN_PATH}" |
267 |
| - |
268 |
| - |
269 |
| -class RedirectHandler: |
270 |
| - def __init__(self): |
271 |
| - self.redirect_server = "" |
272 |
| - |
273 |
| - def __call__(self, url): |
274 |
| - self.redirect_server += url |
275 |
| - |
276 |
| - |
277 |
| -class PostStatementCallback: |
278 |
| - def __init__(self, redirect_server, token_server, tokens, sample_post_response_data): |
279 |
| - self.redirect_server = redirect_server |
280 |
| - self.token_server = token_server |
281 |
| - self.tokens = tokens |
282 |
| - self.sample_post_response_data = sample_post_response_data |
283 |
| - |
284 |
| - def __call__(self, request, uri, response_headers): |
285 |
| - authorization = request.headers.get("Authorization") |
286 |
| - if authorization and authorization.replace("Bearer ", "") in self.tokens: |
287 |
| - return [200, response_headers, json.dumps(self.sample_post_response_data)] |
288 |
| - return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{self.redirect_server}", ' |
289 |
| - f'x_token_server="{self.token_server}"', |
290 |
| - 'Basic realm': '"Trino"'}, ""] |
291 |
| - |
292 |
| - |
293 |
| -class GetTokenCallback: |
294 |
| - def __init__(self, token_server, token, attempts=1): |
295 |
| - self.token_server = token_server |
296 |
| - self.token = token |
297 |
| - self.attempts = attempts |
298 |
| - |
299 |
| - def __call__(self, request, uri, response_headers): |
300 |
| - self.attempts -= 1 |
301 |
| - if self.attempts < 0: |
302 |
| - return [404, response_headers, "{}"] |
303 |
| - if self.attempts == 0: |
304 |
| - return [200, response_headers, f'{{"token": "{self.token}"}}'] |
305 |
| - return [200, response_headers, f'{{"nextUri": "{self.token_server}"}}'] |
306 |
| - |
307 |
| - |
308 | 263 | @pytest.mark.parametrize("attempts", [1, 3, 5])
|
309 | 264 | @httprettified
|
310 | 265 | def test_oauth2_authentication_flow(attempts, sample_post_response_data):
|
@@ -511,57 +466,6 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon
|
511 | 466 | assert len(_get_token_requests(challenge_id)) == 1
|
512 | 467 |
|
513 | 468 |
|
514 |
| -class MultithreadedTokenServer: |
515 |
| - Challenge = namedtuple('Challenge', ['token', 'attempts']) |
516 |
| - |
517 |
| - def __init__(self, sample_post_response_data, attempts=1): |
518 |
| - self.tokens = set() |
519 |
| - self.challenges = {} |
520 |
| - self.sample_post_response_data = sample_post_response_data |
521 |
| - self.attempts = attempts |
522 |
| - |
523 |
| - # bind post statement |
524 |
| - httpretty.register_uri( |
525 |
| - method=httpretty.POST, |
526 |
| - uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", |
527 |
| - body=self.post_statement_callback) |
528 |
| - |
529 |
| - # bind get token |
530 |
| - httpretty.register_uri( |
531 |
| - method=httpretty.GET, |
532 |
| - uri=re.compile(rf"{TOKEN_RESOURCE}/.*"), |
533 |
| - body=self.get_token_callback) |
534 |
| - |
535 |
| - # noinspection PyUnusedLocal |
536 |
| - def post_statement_callback(self, request, uri, response_headers): |
537 |
| - authorization = request.headers.get("Authorization") |
538 |
| - |
539 |
| - if authorization and authorization.replace("Bearer ", "") in self.tokens: |
540 |
| - return [200, response_headers, json.dumps(self.sample_post_response_data)] |
541 |
| - |
542 |
| - challenge_id = str(uuid.uuid4()) |
543 |
| - token = str(uuid.uuid4()) |
544 |
| - self.tokens.add(token) |
545 |
| - self.challenges[challenge_id] = MultithreadedTokenServer.Challenge(token, self.attempts) |
546 |
| - redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" |
547 |
| - token_server = f"{TOKEN_RESOURCE}/{challenge_id}" |
548 |
| - return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{redirect_server}", ' |
549 |
| - f'x_token_server="{token_server}"', |
550 |
| - 'Basic realm': '"Trino"'}, ""] |
551 |
| - |
552 |
| - # noinspection PyUnusedLocal |
553 |
| - def get_token_callback(self, request, uri, response_headers): |
554 |
| - challenge_id = uri.replace(f"{TOKEN_RESOURCE}/", "") |
555 |
| - challenge = self.challenges[challenge_id] |
556 |
| - challenge = challenge._replace(attempts=challenge.attempts - 1) |
557 |
| - self.challenges[challenge_id] = challenge |
558 |
| - if challenge.attempts < 0: |
559 |
| - return [404, response_headers, "{}"] |
560 |
| - if challenge.attempts == 0: |
561 |
| - return [200, response_headers, f'{{"token": "{challenge.token}"}}'] |
562 |
| - return [200, response_headers, f'{{"nextUri": "{uri}"}}'] |
563 |
| - |
564 |
| - |
565 | 469 | @httprettified
|
566 | 470 | def test_multithreaded_oauth2_authentication_flow(sample_post_response_data):
|
567 | 471 | redirect_handler = RedirectHandler()
|
@@ -598,31 +502,19 @@ def run(self) -> None:
|
598 | 502 | for thread in threads:
|
599 | 503 | thread.join()
|
600 | 504 |
|
601 |
| - # should issue only 3 tokens and each thread should get one |
602 |
| - assert len(token_server.tokens) == 3 |
| 505 | + # should issue only 1 token and each thread should reuse it |
| 506 | + assert len(token_server.tokens) == 1 |
603 | 507 | for thread in threads:
|
604 | 508 | assert thread.token in token_server.tokens
|
605 | 509 |
|
606 |
| - # should start only 3 challenges and every token should be obtained |
607 |
| - assert len(token_server.challenges.keys()) == 3 |
| 510 | + # should start only 1 challenge |
| 511 | + assert len(token_server.challenges.keys()) == 1 |
608 | 512 | for challenge_id, challenge in token_server.challenges.items():
|
609 | 513 | assert f"{REDIRECT_RESOURCE}/{challenge_id}" in redirect_handler.redirect_server
|
610 | 514 | assert challenge.attempts == 0
|
611 | 515 | assert len(_get_token_requests(challenge_id)) == 1
|
612 | 516 | # 3 threads * (10 POST /statement each + 1 replied request by authentication)
|
613 |
| - assert len(_post_statement_requests()) == 33 |
614 |
| - |
615 |
| - |
616 |
| -def _get_token_requests(challenge_id): |
617 |
| - return list(filter( |
618 |
| - lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}", |
619 |
| - httpretty.latest_requests())) |
620 |
| - |
621 |
| - |
622 |
| -def _post_statement_requests(): |
623 |
| - return list(filter( |
624 |
| - lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH, |
625 |
| - httpretty.latest_requests())) |
| 517 | + assert len(_post_statement_requests()) == 31 |
626 | 518 |
|
627 | 519 |
|
628 | 520 | @mock.patch("trino.client.TrinoRequest.http")
|
|
0 commit comments