Skip to content

Commit d4475fc

Browse files
committed
Rewrite test codes to replace httpretty with responses
1 parent f5e8f5c commit d4475fc

File tree

3 files changed

+115
-122
lines changed

3 files changed

+115
-122
lines changed

tests/unit/oauth_test_utils.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import uuid
1515
from collections import namedtuple
1616

17-
import httpretty
17+
import responses
1818

1919
from trino import constants
2020

@@ -91,15 +91,14 @@ def __call__(self, request, uri, response_headers):
9191

9292
def _get_token_requests(challenge_id):
9393
return list(filter(
94-
lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}",
95-
httpretty.latest_requests()))
94+
lambda r: r.method == "GET" and r.url == f"{TOKEN_RESOURCE}/{challenge_id}",
95+
responses.calls))
9696

9797

9898
def _post_statement_requests():
9999
return list(filter(
100-
lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH,
101-
httpretty.latest_requests()))
102-
100+
lambda r: r.method == "POST" and r.url == f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
101+
responses.calls))
103102

104103
class MultithreadedTokenServer:
105104
Challenge = namedtuple('Challenge', ['token', 'attempts'])
@@ -111,15 +110,15 @@ def __init__(self, sample_post_response_data, attempts=1):
111110
self.attempts = attempts
112111

113112
# bind post statement
114-
httpretty.register_uri(
115-
method=httpretty.POST,
116-
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
113+
responses.add(
114+
method=responses.POST,
115+
url=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
117116
body=self.post_statement_callback)
118117

119118
# bind get token
120-
httpretty.register_uri(
121-
method=httpretty.GET,
122-
uri=re.compile(rf"{TOKEN_RESOURCE}/.*"),
119+
responses.add(
120+
method=responses.GET,
121+
url=re.compile(rf"{TOKEN_RESOURCE}/.*"),
123122
body=self.get_token_callback)
124123

125124
# noinspection PyUnusedLocal

tests/unit/test_client.py

Lines changed: 67 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,10 @@
2424
from zoneinfo import ZoneInfoNotFoundError
2525

2626
import gssapi
27-
import httpretty
2827
import keyring
2928
import pytest
3029
import requests
31-
from httpretty import httprettified
30+
import responses
3231
from requests_gssapi.exceptions import SPNEGOExchangeError
3332
from requests_kerberos.exceptions import KerberosExchangeError
3433
from tzlocal import get_localzone_name # type: ignore
@@ -347,9 +346,9 @@ def long_call(request, uri, headers):
347346
time.sleep(timeout * 2)
348347
return (200, headers, "delayed success")
349348

350-
httpretty.enable()
351-
for method in [httpretty.POST, httpretty.GET]:
352-
httpretty.register_uri(method, url, body=long_call)
349+
responses.start()
350+
for method in [responses.POST, responses.GET]:
351+
responses.add_callback(method, url, callback=long_call)
353352

354353
# timeout without retry
355354
for request_timeout in [timeout, (timeout, timeout)]:
@@ -370,12 +369,11 @@ def long_call(request, uri, headers):
370369
with pytest.raises(requests.exceptions.Timeout):
371370
req.post("select 1")
372371

373-
httpretty.disable()
374-
httpretty.reset()
375-
372+
responses.stop()
373+
responses.reset()
376374

377375
@pytest.mark.parametrize("attempts", [1, 3, 5])
378-
@httprettified
376+
@responses.activate
379377
def test_oauth2_authentication_flow(attempts, sample_post_response_data):
380378
token = str(uuid.uuid4())
381379
challenge_id = str(uuid.uuid4())
@@ -386,17 +384,17 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data):
386384
post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)
387385

388386
# bind post statement
389-
httpretty.register_uri(
390-
method=httpretty.POST,
391-
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
392-
body=post_statement_callback)
387+
responses.add_callback(
388+
method=responses.POST,
389+
url=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
390+
callback=post_statement_callback)
393391

394392
# bind get token
395393
get_token_callback = GetTokenCallback(token_server, token, attempts)
396-
httpretty.register_uri(
397-
method=httpretty.GET,
398-
uri=token_server,
399-
body=get_token_callback)
394+
responses.add_callback(
395+
method=responses.GET,
396+
url=token_server,
397+
callback=get_token_callback)
400398

401399
redirect_handler = RedirectHandler()
402400

@@ -417,7 +415,7 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data):
417415
assert len(_get_token_requests(challenge_id)) == attempts
418416

419417

420-
@httprettified
418+
@responses.activate
421419
def test_oauth2_refresh_token_flow(sample_post_response_data):
422420
token = str(uuid.uuid4())
423421
challenge_id = str(uuid.uuid4())
@@ -427,17 +425,17 @@ def test_oauth2_refresh_token_flow(sample_post_response_data):
427425
post_statement_callback = PostStatementCallback(None, token_server, [token], sample_post_response_data)
428426

429427
# bind post statement
430-
httpretty.register_uri(
431-
method=httpretty.POST,
432-
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
433-
body=post_statement_callback)
428+
responses.add_callback(
429+
method=responses.POST,
430+
url=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
431+
callback=post_statement_callback)
434432

435433
# bind get token
436434
get_token_callback = GetTokenCallback(token_server, token)
437-
httpretty.register_uri(
438-
method=httpretty.GET,
439-
uri=token_server,
440-
body=get_token_callback)
435+
responses.add_callback(
436+
method=responses.GET,
437+
url=token_server,
438+
callback=get_token_callback)
441439

442440
redirect_handler = RedirectHandlerWithException(
443441
trino.exceptions.TrinoAuthError(
@@ -460,7 +458,7 @@ def test_oauth2_refresh_token_flow(sample_post_response_data):
460458

461459

462460
@pytest.mark.parametrize("attempts", [6, 10])
463-
@httprettified
461+
@responses.activate
464462
def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):
465463
token = str(uuid.uuid4())
466464
challenge_id = str(uuid.uuid4())
@@ -471,17 +469,17 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):
471469
post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)
472470

473471
# bind post statement
474-
httpretty.register_uri(
475-
method=httpretty.POST,
476-
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
477-
body=post_statement_callback)
472+
responses.add_callback(
473+
method=responses.POST,
474+
url=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
475+
callback=post_statement_callback)
478476

479477
# bind get token
480478
get_token_callback = GetTokenCallback(token_server, token, attempts)
481-
httpretty.register_uri(
482-
method=httpretty.GET,
483-
uri=f"{TOKEN_RESOURCE}/{challenge_id}",
484-
body=get_token_callback)
479+
responses.add_callback(
480+
method=responses.GET,
481+
url=f"{TOKEN_RESOURCE}/{challenge_id}",
482+
callback=get_token_callback)
485483

486484
redirect_handler = RedirectHandler()
487485

@@ -509,13 +507,13 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):
509507
('x_redirect_server="redirect_server", x_token_server="token_server"', 'Error: header info didn\'t match x_redirect_server="redirect_server", x_token_server="token_server"'), # noqa: E501
510508
('Bearer x_redirect_server="redirect_server"', 'Error: header info didn\'t have x_token_server'),
511509
])
512-
@httprettified
510+
@responses.activate
513511
def test_oauth2_authentication_missing_headers(header, error):
514512
# bind post statement
515-
httpretty.register_uri(
516-
method=httpretty.POST,
517-
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
518-
adding_headers={'WWW-Authenticate': header},
513+
responses.add(
514+
method=responses.POST,
515+
url=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
516+
headers={'WWW-Authenticate': header},
519517
status=401)
520518

521519
request = TrinoRequest(
@@ -543,7 +541,7 @@ def test_oauth2_authentication_missing_headers(header, error):
543541
'x_token_server="{token_server}"'
544542
'Bearer x_redirect_server="{redirect_server}",x_token_server="{token_server}",additional_challenge',
545543
])
546-
@httprettified
544+
@responses.activate
547545
def test_oauth2_header_parsing(header, sample_post_response_data):
548546
token = str(uuid.uuid4())
549547
challenge_id = str(uuid.uuid4())
@@ -552,25 +550,26 @@ def test_oauth2_header_parsing(header, sample_post_response_data):
552550
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
553551

554552
# noinspection PyUnusedLocal
555-
def post_statement(request, uri, response_headers):
553+
# def post_statement(request, uri, response_headers): # FIXME
554+
def post_statement(request):
556555
authorization = request.headers.get("Authorization")
557556
if authorization and authorization.replace("Bearer ", "") in token:
558-
return [200, response_headers, json.dumps(sample_post_response_data)]
559-
return [401, {'Www-Authenticate': header.format(redirect_server=redirect_server, token_server=token_server),
560-
'Basic realm': '"Trino"'}, ""]
557+
return (200, {}, json.dumps(sample_post_response_data))
558+
return (401, {'Www-Authenticate': header.format(redirect_server=redirect_server, token_server=token_server),
559+
'Basic realm': '"Trino"'}, "")
561560

562561
# bind post statement
563-
httpretty.register_uri(
564-
method=httpretty.POST,
565-
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
566-
body=post_statement)
562+
responses.add_callback(
563+
method=responses.POST,
564+
url=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
565+
callback=post_statement)
567566

568567
# bind get token
569568
get_token_callback = GetTokenCallback(token_server, token)
570-
httpretty.register_uri(
571-
method=httpretty.GET,
572-
uri=token_server,
573-
body=get_token_callback)
569+
responses.add_callback(
570+
method=responses.GET,
571+
url=token_server,
572+
callback=get_token_callback)
574573

575574
redirect_handler = RedirectHandler()
576575

@@ -590,9 +589,7 @@ def post_statement(request, uri, response_headers):
590589
assert len(_post_statement_requests()) == 2
591590
assert len(_get_token_requests(challenge_id)) == 1
592591

593-
594-
@pytest.mark.parametrize("http_status", [400, 401, 500])
595-
@httprettified
592+
@responses.activate
596593
def test_oauth2_authentication_fail_token_server(http_status, sample_post_response_data):
597594
token = str(uuid.uuid4())
598595
challenge_id = str(uuid.uuid4())
@@ -603,16 +600,18 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon
603600
post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)
604601

605602
# bind post statement
606-
httpretty.register_uri(
607-
method=httpretty.POST,
608-
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
609-
body=post_statement_callback)
610-
611-
httpretty.register_uri(
612-
method=httpretty.GET,
613-
uri=f"{TOKEN_RESOURCE}/{challenge_id}",
603+
responses.add(
604+
method=responses.POST,
605+
url=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
606+
json=post_statement_callback,
607+
)
608+
609+
responses.add(
610+
method=responses.GET,
611+
url=f"{TOKEN_RESOURCE}/{challenge_id}",
614612
status=http_status,
615-
body="error")
613+
body="error",
614+
)
616615

617616
redirect_handler = RedirectHandler()
618617

@@ -623,7 +622,8 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon
623622
user="test",
624623
),
625624
http_scheme=constants.HTTPS,
626-
auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler))
625+
auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler),
626+
)
627627

628628
with pytest.raises(trino.exceptions.TrinoAuthError) as exp:
629629
request.post("select 1")
@@ -633,8 +633,7 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon
633633
assert len(_post_statement_requests()) == 1
634634
assert len(_get_token_requests(challenge_id)) == 1
635635

636-
637-
@httprettified
636+
@responses.activate
638637
def test_multithreaded_oauth2_authentication_flow(sample_post_response_data):
639638
redirect_handler = RedirectHandler()
640639
auth = trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler)

0 commit comments

Comments
 (0)