Skip to content

Commit 86c217d

Browse files
committed
Rewrite test codes to replace httpretty with responses
1 parent f5e8f5c commit 86c217d

File tree

3 files changed

+115
-115
lines changed

3 files changed

+115
-115
lines changed

tests/unit/oauth_test_utils.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import uuid
1515
from collections import namedtuple
1616

17-
import httpretty
17+
import responses
1818

1919
from trino import constants
2020

@@ -91,14 +91,14 @@ def __call__(self, request, uri, response_headers):
9191

9292
def _get_token_requests(challenge_id):
9393
return list(filter(
94-
lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}",
95-
httpretty.latest_requests()))
94+
lambda r: r.method == "GET" and r.url == f"{TOKEN_RESOURCE}/{challenge_id}",
95+
responses.calls))
9696

9797

9898
def _post_statement_requests():
9999
return list(filter(
100-
lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH,
101-
httpretty.latest_requests()))
100+
lambda r: r.method == "POST" and r.url == f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
101+
responses.calls))
102102

103103

104104
class MultithreadedTokenServer:
@@ -111,15 +111,15 @@ def __init__(self, sample_post_response_data, attempts=1):
111111
self.attempts = attempts
112112

113113
# bind post statement
114-
httpretty.register_uri(
115-
method=httpretty.POST,
116-
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
114+
responses.add(
115+
method=responses.POST,
116+
url=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
117117
body=self.post_statement_callback)
118118

119119
# bind get token
120-
httpretty.register_uri(
121-
method=httpretty.GET,
122-
uri=re.compile(rf"{TOKEN_RESOURCE}/.*"),
120+
responses.add(
121+
method=responses.GET,
122+
url=re.compile(rf"{TOKEN_RESOURCE}/.*"),
123123
body=self.get_token_callback)
124124

125125
# noinspection PyUnusedLocal

tests/unit/test_client.py

Lines changed: 67 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,10 @@
2424
from zoneinfo import ZoneInfoNotFoundError
2525

2626
import gssapi
27-
import httpretty
2827
import keyring
2928
import pytest
3029
import requests
31-
from httpretty import httprettified
30+
import responses
3231
from requests_gssapi.exceptions import SPNEGOExchangeError
3332
from requests_kerberos.exceptions import KerberosExchangeError
3433
from tzlocal import get_localzone_name # type: ignore
@@ -347,9 +346,9 @@ def long_call(request, uri, headers):
347346
time.sleep(timeout * 2)
348347
return (200, headers, "delayed success")
349348

350-
httpretty.enable()
351-
for method in [httpretty.POST, httpretty.GET]:
352-
httpretty.register_uri(method, url, body=long_call)
349+
responses.start()
350+
for method in [responses.POST, responses.GET]:
351+
responses.add_callback(method, url, callback=long_call)
353352

354353
# timeout without retry
355354
for request_timeout in [timeout, (timeout, timeout)]:
@@ -370,12 +369,12 @@ def long_call(request, uri, headers):
370369
with pytest.raises(requests.exceptions.Timeout):
371370
req.post("select 1")
372371

373-
httpretty.disable()
374-
httpretty.reset()
372+
responses.stop()
373+
responses.reset()
375374

376375

377376
@pytest.mark.parametrize("attempts", [1, 3, 5])
378-
@httprettified
377+
@responses.activate
379378
def test_oauth2_authentication_flow(attempts, sample_post_response_data):
380379
token = str(uuid.uuid4())
381380
challenge_id = str(uuid.uuid4())
@@ -386,17 +385,17 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data):
386385
post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)
387386

388387
# bind post statement
389-
httpretty.register_uri(
390-
method=httpretty.POST,
391-
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
392-
body=post_statement_callback)
388+
responses.add_callback(
389+
method=responses.POST,
390+
url=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
391+
callback=post_statement_callback)
393392

394393
# bind get token
395394
get_token_callback = GetTokenCallback(token_server, token, attempts)
396-
httpretty.register_uri(
397-
method=httpretty.GET,
398-
uri=token_server,
399-
body=get_token_callback)
395+
responses.add_callback(
396+
method=responses.GET,
397+
url=token_server,
398+
callback=get_token_callback)
400399

401400
redirect_handler = RedirectHandler()
402401

@@ -417,7 +416,7 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data):
417416
assert len(_get_token_requests(challenge_id)) == attempts
418417

419418

420-
@httprettified
419+
@responses.activate
421420
def test_oauth2_refresh_token_flow(sample_post_response_data):
422421
token = str(uuid.uuid4())
423422
challenge_id = str(uuid.uuid4())
@@ -427,17 +426,17 @@ def test_oauth2_refresh_token_flow(sample_post_response_data):
427426
post_statement_callback = PostStatementCallback(None, token_server, [token], sample_post_response_data)
428427

429428
# bind post statement
430-
httpretty.register_uri(
431-
method=httpretty.POST,
432-
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
433-
body=post_statement_callback)
429+
responses.add_callback(
430+
method=responses.POST,
431+
url=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
432+
callback=post_statement_callback)
434433

435434
# bind get token
436435
get_token_callback = GetTokenCallback(token_server, token)
437-
httpretty.register_uri(
438-
method=httpretty.GET,
439-
uri=token_server,
440-
body=get_token_callback)
436+
responses.add_callback(
437+
method=responses.GET,
438+
url=token_server,
439+
callback=get_token_callback)
441440

442441
redirect_handler = RedirectHandlerWithException(
443442
trino.exceptions.TrinoAuthError(
@@ -460,7 +459,7 @@ def test_oauth2_refresh_token_flow(sample_post_response_data):
460459

461460

462461
@pytest.mark.parametrize("attempts", [6, 10])
463-
@httprettified
462+
@responses.activate
464463
def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):
465464
token = str(uuid.uuid4())
466465
challenge_id = str(uuid.uuid4())
@@ -471,17 +470,17 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):
471470
post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)
472471

473472
# bind post statement
474-
httpretty.register_uri(
475-
method=httpretty.POST,
476-
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
477-
body=post_statement_callback)
473+
responses.add_callback(
474+
method=responses.POST,
475+
url=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
476+
callback=post_statement_callback)
478477

479478
# bind get token
480479
get_token_callback = GetTokenCallback(token_server, token, attempts)
481-
httpretty.register_uri(
482-
method=httpretty.GET,
483-
uri=f"{TOKEN_RESOURCE}/{challenge_id}",
484-
body=get_token_callback)
480+
responses.add_callback(
481+
method=responses.GET,
482+
url=f"{TOKEN_RESOURCE}/{challenge_id}",
483+
callback=get_token_callback)
485484

486485
redirect_handler = RedirectHandler()
487486

@@ -509,13 +508,13 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):
509508
('x_redirect_server="redirect_server", x_token_server="token_server"', 'Error: header info didn\'t match x_redirect_server="redirect_server", x_token_server="token_server"'), # noqa: E501
510509
('Bearer x_redirect_server="redirect_server"', 'Error: header info didn\'t have x_token_server'),
511510
])
512-
@httprettified
511+
@responses.activate
513512
def test_oauth2_authentication_missing_headers(header, error):
514513
# bind post statement
515-
httpretty.register_uri(
516-
method=httpretty.POST,
517-
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
518-
adding_headers={'WWW-Authenticate': header},
514+
responses.add(
515+
method=responses.POST,
516+
url=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
517+
headers={'WWW-Authenticate': header},
519518
status=401)
520519

521520
request = TrinoRequest(
@@ -543,7 +542,7 @@ def test_oauth2_authentication_missing_headers(header, error):
543542
'x_token_server="{token_server}"'
544543
'Bearer x_redirect_server="{redirect_server}",x_token_server="{token_server}",additional_challenge',
545544
])
546-
@httprettified
545+
@responses.activate
547546
def test_oauth2_header_parsing(header, sample_post_response_data):
548547
token = str(uuid.uuid4())
549548
challenge_id = str(uuid.uuid4())
@@ -552,25 +551,26 @@ def test_oauth2_header_parsing(header, sample_post_response_data):
552551
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
553552

554553
# noinspection PyUnusedLocal
555-
def post_statement(request, uri, response_headers):
554+
# def post_statement(request, uri, response_headers): # FIXME
555+
def post_statement(request):
556556
authorization = request.headers.get("Authorization")
557557
if authorization and authorization.replace("Bearer ", "") in token:
558-
return [200, response_headers, json.dumps(sample_post_response_data)]
559-
return [401, {'Www-Authenticate': header.format(redirect_server=redirect_server, token_server=token_server),
560-
'Basic realm': '"Trino"'}, ""]
558+
return (200, {}, json.dumps(sample_post_response_data))
559+
return (401, {'Www-Authenticate': header.format(redirect_server=redirect_server, token_server=token_server),
560+
'Basic realm': '"Trino"'}, "")
561561

562562
# bind post statement
563-
httpretty.register_uri(
564-
method=httpretty.POST,
565-
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
566-
body=post_statement)
563+
responses.add_callback(
564+
method=responses.POST,
565+
url=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
566+
callback=post_statement)
567567

568568
# bind get token
569569
get_token_callback = GetTokenCallback(token_server, token)
570-
httpretty.register_uri(
571-
method=httpretty.GET,
572-
uri=token_server,
573-
body=get_token_callback)
570+
responses.add_callback(
571+
method=responses.GET,
572+
url=token_server,
573+
callback=get_token_callback)
574574

575575
redirect_handler = RedirectHandler()
576576

@@ -591,8 +591,7 @@ def post_statement(request, uri, response_headers):
591591
assert len(_get_token_requests(challenge_id)) == 1
592592

593593

594-
@pytest.mark.parametrize("http_status", [400, 401, 500])
595-
@httprettified
594+
@responses.activate
596595
def test_oauth2_authentication_fail_token_server(http_status, sample_post_response_data):
597596
token = str(uuid.uuid4())
598597
challenge_id = str(uuid.uuid4())
@@ -603,16 +602,18 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon
603602
post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)
604603

605604
# bind post statement
606-
httpretty.register_uri(
607-
method=httpretty.POST,
608-
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
609-
body=post_statement_callback)
610-
611-
httpretty.register_uri(
612-
method=httpretty.GET,
613-
uri=f"{TOKEN_RESOURCE}/{challenge_id}",
605+
responses.add(
606+
method=responses.POST,
607+
url=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
608+
json=post_statement_callback,
609+
)
610+
611+
responses.add(
612+
method=responses.GET,
613+
url=f"{TOKEN_RESOURCE}/{challenge_id}",
614614
status=http_status,
615-
body="error")
615+
body="error",
616+
)
616617

617618
redirect_handler = RedirectHandler()
618619

@@ -623,7 +624,8 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon
623624
user="test",
624625
),
625626
http_scheme=constants.HTTPS,
626-
auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler))
627+
auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler),
628+
)
627629

628630
with pytest.raises(trino.exceptions.TrinoAuthError) as exp:
629631
request.post("select 1")
@@ -634,7 +636,7 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon
634636
assert len(_get_token_requests(challenge_id)) == 1
635637

636638

637-
@httprettified
639+
@responses.activate
638640
def test_multithreaded_oauth2_authentication_flow(sample_post_response_data):
639641
redirect_handler = RedirectHandler()
640642
auth = trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler)

0 commit comments

Comments
 (0)