24
24
from zoneinfo import ZoneInfoNotFoundError
25
25
26
26
import gssapi
27
- import httpretty
28
27
import keyring
29
28
import pytest
30
29
import requests
31
- from httpretty import httprettified
30
+ import responses
32
31
from requests_gssapi .exceptions import SPNEGOExchangeError
33
32
from requests_kerberos .exceptions import KerberosExchangeError
34
33
from tzlocal import get_localzone_name # type: ignore
@@ -347,9 +346,9 @@ def long_call(request, uri, headers):
347
346
time .sleep (timeout * 2 )
348
347
return (200 , headers , "delayed success" )
349
348
350
- httpretty . enable ()
351
- for method in [httpretty .POST , httpretty .GET ]:
352
- httpretty . register_uri (method , url , body = long_call )
349
+ responses . start ()
350
+ for method in [responses .POST , responses .GET ]:
351
+ responses . add_callback (method , url , callback = long_call )
353
352
354
353
# timeout without retry
355
354
for request_timeout in [timeout , (timeout , timeout )]:
@@ -370,12 +369,12 @@ def long_call(request, uri, headers):
370
369
with pytest .raises (requests .exceptions .Timeout ):
371
370
req .post ("select 1" )
372
371
373
- httpretty . disable ()
374
- httpretty .reset ()
372
+ responses . stop ()
373
+ responses .reset ()
375
374
376
375
377
376
@pytest .mark .parametrize ("attempts" , [1 , 3 , 5 ])
378
- @httprettified
377
+ @responses . activate
379
378
def test_oauth2_authentication_flow (attempts , sample_post_response_data ):
380
379
token = str (uuid .uuid4 ())
381
380
challenge_id = str (uuid .uuid4 ())
@@ -386,17 +385,17 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data):
386
385
post_statement_callback = PostStatementCallback (redirect_server , token_server , [token ], sample_post_response_data )
387
386
388
387
# bind post statement
389
- httpretty . register_uri (
390
- method = httpretty .POST ,
391
- uri = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
392
- body = post_statement_callback )
388
+ responses . add_callback (
389
+ method = responses .POST ,
390
+ url = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
391
+ callback = post_statement_callback )
393
392
394
393
# bind get token
395
394
get_token_callback = GetTokenCallback (token_server , token , attempts )
396
- httpretty . register_uri (
397
- method = httpretty .GET ,
398
- uri = token_server ,
399
- body = get_token_callback )
395
+ responses . add_callback (
396
+ method = responses .GET ,
397
+ url = token_server ,
398
+ callback = get_token_callback )
400
399
401
400
redirect_handler = RedirectHandler ()
402
401
@@ -417,7 +416,7 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data):
417
416
assert len (_get_token_requests (challenge_id )) == attempts
418
417
419
418
420
- @httprettified
419
+ @responses . activate
421
420
def test_oauth2_refresh_token_flow (sample_post_response_data ):
422
421
token = str (uuid .uuid4 ())
423
422
challenge_id = str (uuid .uuid4 ())
@@ -427,17 +426,17 @@ def test_oauth2_refresh_token_flow(sample_post_response_data):
427
426
post_statement_callback = PostStatementCallback (None , token_server , [token ], sample_post_response_data )
428
427
429
428
# bind post statement
430
- httpretty . register_uri (
431
- method = httpretty .POST ,
432
- uri = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
433
- body = post_statement_callback )
429
+ responses . add_callback (
430
+ method = responses .POST ,
431
+ url = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
432
+ callback = post_statement_callback )
434
433
435
434
# bind get token
436
435
get_token_callback = GetTokenCallback (token_server , token )
437
- httpretty . register_uri (
438
- method = httpretty .GET ,
439
- uri = token_server ,
440
- body = get_token_callback )
436
+ responses . add_callback (
437
+ method = responses .GET ,
438
+ url = token_server ,
439
+ callback = get_token_callback )
441
440
442
441
redirect_handler = RedirectHandlerWithException (
443
442
trino .exceptions .TrinoAuthError (
@@ -460,7 +459,7 @@ def test_oauth2_refresh_token_flow(sample_post_response_data):
460
459
461
460
462
461
@pytest .mark .parametrize ("attempts" , [6 , 10 ])
463
- @httprettified
462
+ @responses . activate
464
463
def test_oauth2_exceed_max_attempts (attempts , sample_post_response_data ):
465
464
token = str (uuid .uuid4 ())
466
465
challenge_id = str (uuid .uuid4 ())
@@ -471,17 +470,17 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):
471
470
post_statement_callback = PostStatementCallback (redirect_server , token_server , [token ], sample_post_response_data )
472
471
473
472
# bind post statement
474
- httpretty . register_uri (
475
- method = httpretty .POST ,
476
- uri = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
477
- body = post_statement_callback )
473
+ responses . add_callback (
474
+ method = responses .POST ,
475
+ url = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
476
+ callback = post_statement_callback )
478
477
479
478
# bind get token
480
479
get_token_callback = GetTokenCallback (token_server , token , attempts )
481
- httpretty . register_uri (
482
- method = httpretty .GET ,
483
- uri = f"{ TOKEN_RESOURCE } /{ challenge_id } " ,
484
- body = get_token_callback )
480
+ responses . add_callback (
481
+ method = responses .GET ,
482
+ url = f"{ TOKEN_RESOURCE } /{ challenge_id } " ,
483
+ callback = get_token_callback )
485
484
486
485
redirect_handler = RedirectHandler ()
487
486
@@ -509,13 +508,13 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):
509
508
('x_redirect_server="redirect_server", x_token_server="token_server"' , 'Error: header info didn\' t match x_redirect_server="redirect_server", x_token_server="token_server"' ), # noqa: E501
510
509
('Bearer x_redirect_server="redirect_server"' , 'Error: header info didn\' t have x_token_server' ),
511
510
])
512
- @httprettified
511
+ @responses . activate
513
512
def test_oauth2_authentication_missing_headers (header , error ):
514
513
# bind post statement
515
- httpretty . register_uri (
516
- method = httpretty .POST ,
517
- uri = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
518
- adding_headers = {'WWW-Authenticate' : header },
514
+ responses . add (
515
+ method = responses .POST ,
516
+ url = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
517
+ headers = {'WWW-Authenticate' : header },
519
518
status = 401 )
520
519
521
520
request = TrinoRequest (
@@ -543,7 +542,7 @@ def test_oauth2_authentication_missing_headers(header, error):
543
542
'x_token_server="{token_server}"'
544
543
'Bearer x_redirect_server="{redirect_server}",x_token_server="{token_server}",additional_challenge' ,
545
544
])
546
- @httprettified
545
+ @responses . activate
547
546
def test_oauth2_header_parsing (header , sample_post_response_data ):
548
547
token = str (uuid .uuid4 ())
549
548
challenge_id = str (uuid .uuid4 ())
@@ -552,25 +551,26 @@ def test_oauth2_header_parsing(header, sample_post_response_data):
552
551
token_server = f"{ TOKEN_RESOURCE } /{ challenge_id } "
553
552
554
553
# noinspection PyUnusedLocal
555
- def post_statement (request , uri , response_headers ):
554
+ # def post_statement(request, uri, response_headers): # FIXME
555
+ def post_statement (request ):
556
556
authorization = request .headers .get ("Authorization" )
557
557
if authorization and authorization .replace ("Bearer " , "" ) in token :
558
- return [ 200 , response_headers , json .dumps (sample_post_response_data )]
559
- return [ 401 , {'Www-Authenticate' : header .format (redirect_server = redirect_server , token_server = token_server ),
560
- 'Basic realm' : '"Trino"' }, "" ]
558
+ return ( 200 , {} , json .dumps (sample_post_response_data ))
559
+ return ( 401 , {'Www-Authenticate' : header .format (redirect_server = redirect_server , token_server = token_server ),
560
+ 'Basic realm' : '"Trino"' }, "" )
561
561
562
562
# bind post statement
563
- httpretty . register_uri (
564
- method = httpretty .POST ,
565
- uri = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
566
- body = post_statement )
563
+ responses . add_callback (
564
+ method = responses .POST ,
565
+ url = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
566
+ callback = post_statement )
567
567
568
568
# bind get token
569
569
get_token_callback = GetTokenCallback (token_server , token )
570
- httpretty . register_uri (
571
- method = httpretty .GET ,
572
- uri = token_server ,
573
- body = get_token_callback )
570
+ responses . add_callback (
571
+ method = responses .GET ,
572
+ url = token_server ,
573
+ callback = get_token_callback )
574
574
575
575
redirect_handler = RedirectHandler ()
576
576
@@ -591,8 +591,7 @@ def post_statement(request, uri, response_headers):
591
591
assert len (_get_token_requests (challenge_id )) == 1
592
592
593
593
594
- @pytest .mark .parametrize ("http_status" , [400 , 401 , 500 ])
595
- @httprettified
594
+ @responses .activate
596
595
def test_oauth2_authentication_fail_token_server (http_status , sample_post_response_data ):
597
596
token = str (uuid .uuid4 ())
598
597
challenge_id = str (uuid .uuid4 ())
@@ -603,16 +602,18 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon
603
602
post_statement_callback = PostStatementCallback (redirect_server , token_server , [token ], sample_post_response_data )
604
603
605
604
# bind post statement
606
- httpretty .register_uri (
607
- method = httpretty .POST ,
608
- uri = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
609
- body = post_statement_callback )
610
-
611
- httpretty .register_uri (
612
- method = httpretty .GET ,
613
- uri = f"{ TOKEN_RESOURCE } /{ challenge_id } " ,
605
+ responses .add (
606
+ method = responses .POST ,
607
+ url = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
608
+ json = post_statement_callback ,
609
+ )
610
+
611
+ responses .add (
612
+ method = responses .GET ,
613
+ url = f"{ TOKEN_RESOURCE } /{ challenge_id } " ,
614
614
status = http_status ,
615
- body = "error" )
615
+ body = "error" ,
616
+ )
616
617
617
618
redirect_handler = RedirectHandler ()
618
619
@@ -623,7 +624,8 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon
623
624
user = "test" ,
624
625
),
625
626
http_scheme = constants .HTTPS ,
626
- auth = trino .auth .OAuth2Authentication (redirect_auth_url_handler = redirect_handler ))
627
+ auth = trino .auth .OAuth2Authentication (redirect_auth_url_handler = redirect_handler ),
628
+ )
627
629
628
630
with pytest .raises (trino .exceptions .TrinoAuthError ) as exp :
629
631
request .post ("select 1" )
@@ -634,7 +636,7 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon
634
636
assert len (_get_token_requests (challenge_id )) == 1
635
637
636
638
637
- @httprettified
639
+ @responses . activate
638
640
def test_multithreaded_oauth2_authentication_flow (sample_post_response_data ):
639
641
redirect_handler = RedirectHandler ()
640
642
auth = trino .auth .OAuth2Authentication (redirect_auth_url_handler = redirect_handler )
0 commit comments