24
24
from zoneinfo import ZoneInfoNotFoundError
25
25
26
26
import gssapi
27
- import httpretty
28
27
import keyring
29
28
import pytest
30
29
import requests
31
- from httpretty import httprettified
30
+ import responses
32
31
from requests_gssapi .exceptions import SPNEGOExchangeError
33
32
from requests_kerberos .exceptions import KerberosExchangeError
34
33
from tzlocal import get_localzone_name # type: ignore
@@ -347,9 +346,9 @@ def long_call(request, uri, headers):
347
346
time .sleep (timeout * 2 )
348
347
return (200 , headers , "delayed success" )
349
348
350
- httpretty . enable ()
351
- for method in [httpretty .POST , httpretty .GET ]:
352
- httpretty . register_uri (method , url , body = long_call )
349
+ responses . start ()
350
+ for method in [responses .POST , responses .GET ]:
351
+ responses . add_callback (method , url , callback = long_call )
353
352
354
353
# timeout without retry
355
354
for request_timeout in [timeout , (timeout , timeout )]:
@@ -370,12 +369,11 @@ def long_call(request, uri, headers):
370
369
with pytest .raises (requests .exceptions .Timeout ):
371
370
req .post ("select 1" )
372
371
373
- httpretty .disable ()
374
- httpretty .reset ()
375
-
372
+ responses .stop ()
373
+ responses .reset ()
376
374
377
375
@pytest .mark .parametrize ("attempts" , [1 , 3 , 5 ])
378
- @httprettified
376
+ @responses . activate
379
377
def test_oauth2_authentication_flow (attempts , sample_post_response_data ):
380
378
token = str (uuid .uuid4 ())
381
379
challenge_id = str (uuid .uuid4 ())
@@ -386,17 +384,17 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data):
386
384
post_statement_callback = PostStatementCallback (redirect_server , token_server , [token ], sample_post_response_data )
387
385
388
386
# bind post statement
389
- httpretty . register_uri (
390
- method = httpretty .POST ,
391
- uri = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
392
- body = post_statement_callback )
387
+ responses . add_callback (
388
+ method = responses .POST ,
389
+ url = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
390
+ callback = post_statement_callback )
393
391
394
392
# bind get token
395
393
get_token_callback = GetTokenCallback (token_server , token , attempts )
396
- httpretty . register_uri (
397
- method = httpretty .GET ,
398
- uri = token_server ,
399
- body = get_token_callback )
394
+ responses . add_callback (
395
+ method = responses .GET ,
396
+ url = token_server ,
397
+ callback = get_token_callback )
400
398
401
399
redirect_handler = RedirectHandler ()
402
400
@@ -417,7 +415,7 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data):
417
415
assert len (_get_token_requests (challenge_id )) == attempts
418
416
419
417
420
- @httprettified
418
+ @responses . activate
421
419
def test_oauth2_refresh_token_flow (sample_post_response_data ):
422
420
token = str (uuid .uuid4 ())
423
421
challenge_id = str (uuid .uuid4 ())
@@ -427,17 +425,17 @@ def test_oauth2_refresh_token_flow(sample_post_response_data):
427
425
post_statement_callback = PostStatementCallback (None , token_server , [token ], sample_post_response_data )
428
426
429
427
# bind post statement
430
- httpretty . register_uri (
431
- method = httpretty .POST ,
432
- uri = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
433
- body = post_statement_callback )
428
+ responses . add_callback (
429
+ method = responses .POST ,
430
+ url = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
431
+ callback = post_statement_callback )
434
432
435
433
# bind get token
436
434
get_token_callback = GetTokenCallback (token_server , token )
437
- httpretty . register_uri (
438
- method = httpretty .GET ,
439
- uri = token_server ,
440
- body = get_token_callback )
435
+ responses . add_callback (
436
+ method = responses .GET ,
437
+ url = token_server ,
438
+ callback = get_token_callback )
441
439
442
440
redirect_handler = RedirectHandlerWithException (
443
441
trino .exceptions .TrinoAuthError (
@@ -460,7 +458,7 @@ def test_oauth2_refresh_token_flow(sample_post_response_data):
460
458
461
459
462
460
@pytest .mark .parametrize ("attempts" , [6 , 10 ])
463
- @httprettified
461
+ @responses . activate
464
462
def test_oauth2_exceed_max_attempts (attempts , sample_post_response_data ):
465
463
token = str (uuid .uuid4 ())
466
464
challenge_id = str (uuid .uuid4 ())
@@ -471,17 +469,17 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):
471
469
post_statement_callback = PostStatementCallback (redirect_server , token_server , [token ], sample_post_response_data )
472
470
473
471
# bind post statement
474
- httpretty . register_uri (
475
- method = httpretty .POST ,
476
- uri = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
477
- body = post_statement_callback )
472
+ responses . add_callback (
473
+ method = responses .POST ,
474
+ url = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
475
+ callback = post_statement_callback )
478
476
479
477
# bind get token
480
478
get_token_callback = GetTokenCallback (token_server , token , attempts )
481
- httpretty . register_uri (
482
- method = httpretty .GET ,
483
- uri = f"{ TOKEN_RESOURCE } /{ challenge_id } " ,
484
- body = get_token_callback )
479
+ responses . add_callback (
480
+ method = responses .GET ,
481
+ url = f"{ TOKEN_RESOURCE } /{ challenge_id } " ,
482
+ callback = get_token_callback )
485
483
486
484
redirect_handler = RedirectHandler ()
487
485
@@ -509,13 +507,13 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):
509
507
('x_redirect_server="redirect_server", x_token_server="token_server"' , 'Error: header info didn\' t match x_redirect_server="redirect_server", x_token_server="token_server"' ), # noqa: E501
510
508
('Bearer x_redirect_server="redirect_server"' , 'Error: header info didn\' t have x_token_server' ),
511
509
])
512
- @httprettified
510
+ @responses . activate
513
511
def test_oauth2_authentication_missing_headers (header , error ):
514
512
# bind post statement
515
- httpretty . register_uri (
516
- method = httpretty .POST ,
517
- uri = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
518
- adding_headers = {'WWW-Authenticate' : header },
513
+ responses . add (
514
+ method = responses .POST ,
515
+ url = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
516
+ headers = {'WWW-Authenticate' : header },
519
517
status = 401 )
520
518
521
519
request = TrinoRequest (
@@ -543,7 +541,7 @@ def test_oauth2_authentication_missing_headers(header, error):
543
541
'x_token_server="{token_server}"'
544
542
'Bearer x_redirect_server="{redirect_server}",x_token_server="{token_server}",additional_challenge' ,
545
543
])
546
- @httprettified
544
+ @responses . activate
547
545
def test_oauth2_header_parsing (header , sample_post_response_data ):
548
546
token = str (uuid .uuid4 ())
549
547
challenge_id = str (uuid .uuid4 ())
@@ -552,25 +550,26 @@ def test_oauth2_header_parsing(header, sample_post_response_data):
552
550
token_server = f"{ TOKEN_RESOURCE } /{ challenge_id } "
553
551
554
552
# noinspection PyUnusedLocal
555
- def post_statement (request , uri , response_headers ):
553
+ # def post_statement(request, uri, response_headers): # FIXME
554
+ def post_statement (request ):
556
555
authorization = request .headers .get ("Authorization" )
557
556
if authorization and authorization .replace ("Bearer " , "" ) in token :
558
- return [ 200 , response_headers , json .dumps (sample_post_response_data )]
559
- return [ 401 , {'Www-Authenticate' : header .format (redirect_server = redirect_server , token_server = token_server ),
560
- 'Basic realm' : '"Trino"' }, "" ]
557
+ return ( 200 , {} , json .dumps (sample_post_response_data ))
558
+ return ( 401 , {'Www-Authenticate' : header .format (redirect_server = redirect_server , token_server = token_server ),
559
+ 'Basic realm' : '"Trino"' }, "" )
561
560
562
561
# bind post statement
563
- httpretty . register_uri (
564
- method = httpretty .POST ,
565
- uri = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
566
- body = post_statement )
562
+ responses . add_callback (
563
+ method = responses .POST ,
564
+ url = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
565
+ callback = post_statement )
567
566
568
567
# bind get token
569
568
get_token_callback = GetTokenCallback (token_server , token )
570
- httpretty . register_uri (
571
- method = httpretty .GET ,
572
- uri = token_server ,
573
- body = get_token_callback )
569
+ responses . add_callback (
570
+ method = responses .GET ,
571
+ url = token_server ,
572
+ callback = get_token_callback )
574
573
575
574
redirect_handler = RedirectHandler ()
576
575
@@ -590,9 +589,7 @@ def post_statement(request, uri, response_headers):
590
589
assert len (_post_statement_requests ()) == 2
591
590
assert len (_get_token_requests (challenge_id )) == 1
592
591
593
-
594
- @pytest .mark .parametrize ("http_status" , [400 , 401 , 500 ])
595
- @httprettified
592
+ @responses .activate
596
593
def test_oauth2_authentication_fail_token_server (http_status , sample_post_response_data ):
597
594
token = str (uuid .uuid4 ())
598
595
challenge_id = str (uuid .uuid4 ())
@@ -603,16 +600,18 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon
603
600
post_statement_callback = PostStatementCallback (redirect_server , token_server , [token ], sample_post_response_data )
604
601
605
602
# bind post statement
606
- httpretty .register_uri (
607
- method = httpretty .POST ,
608
- uri = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
609
- body = post_statement_callback )
610
-
611
- httpretty .register_uri (
612
- method = httpretty .GET ,
613
- uri = f"{ TOKEN_RESOURCE } /{ challenge_id } " ,
603
+ responses .add (
604
+ method = responses .POST ,
605
+ url = f"{ SERVER_ADDRESS } { constants .URL_STATEMENT_PATH } " ,
606
+ json = post_statement_callback ,
607
+ )
608
+
609
+ responses .add (
610
+ method = responses .GET ,
611
+ url = f"{ TOKEN_RESOURCE } /{ challenge_id } " ,
614
612
status = http_status ,
615
- body = "error" )
613
+ body = "error" ,
614
+ )
616
615
617
616
redirect_handler = RedirectHandler ()
618
617
@@ -623,7 +622,8 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon
623
622
user = "test" ,
624
623
),
625
624
http_scheme = constants .HTTPS ,
626
- auth = trino .auth .OAuth2Authentication (redirect_auth_url_handler = redirect_handler ))
625
+ auth = trino .auth .OAuth2Authentication (redirect_auth_url_handler = redirect_handler ),
626
+ )
627
627
628
628
with pytest .raises (trino .exceptions .TrinoAuthError ) as exp :
629
629
request .post ("select 1" )
@@ -633,8 +633,7 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon
633
633
assert len (_post_statement_requests ()) == 1
634
634
assert len (_get_token_requests (challenge_id )) == 1
635
635
636
-
637
- @httprettified
636
+ @responses .activate
638
637
def test_multithreaded_oauth2_authentication_flow (sample_post_response_data ):
639
638
redirect_handler = RedirectHandler ()
640
639
auth = trino .auth .OAuth2Authentication (redirect_auth_url_handler = redirect_handler )
0 commit comments