Skip to content
This repository was archived by the owner on Jun 29, 2019. It is now read-only.

Commit 33bda0a

Browse files
committed
Fix error in ClientCredentialsGrant when no expiration of a token is set in the handler
1 parent 9485b8d commit 33bda0a

File tree

3 files changed

+38
-27
lines changed

3 files changed

+38
-27
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
## 1.1.0 (unreleased)
22

3+
## 1.0.1
4+
5+
Bugfixes:
6+
7+
- Fix error in `ClientCredentialsGrant` when no expiration of a token is set in the handler ([@wndhydrnt][])
8+
39
## 1.0.0
410

511
Features:

oauth2/grant.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -343,14 +343,14 @@ def create_token(self, client_id, data, grant_type, scopes, user_id):
343343
raise UserIdentifierMissingError
344344

345345
try:
346-
access_token = self.access_token_store.\
346+
access_token = self.access_token_store. \
347347
fetch_existing_token_of_user(
348-
client_id,
349-
grant_type,
350-
user_id)
348+
client_id,
349+
grant_type,
350+
user_id)
351351

352352
if (access_token.scopes == scopes
353-
and access_token.is_expired() is False):
353+
and access_token.is_expired() is False):
354354
token_data = {"access_token": access_token.token,
355355
"token_type": "Bearer"}
356356

@@ -396,7 +396,7 @@ def __init__(self, site_adapter, **kwargs):
396396
if isinstance(site_adapter, self.site_adapter_class) is False:
397397
raise InvalidSiteAdapter(
398398
"Site adapter must inherit from class '{0}'"
399-
.format(self.site_adapter_class.__name__)
399+
.format(self.site_adapter_class.__name__)
400400
)
401401

402402
self.site_adapter = site_adapter
@@ -621,8 +621,8 @@ def __call__(self, request, server):
621621
Check the HTTP method of a request
622622
"""
623623
if (request.method == "POST"
624-
and request.post_param("grant_type") == "authorization_code"
625-
and request.path == server.token_path):
624+
and request.post_param("grant_type") == "authorization_code"
625+
and request.path == server.token_path):
626626
return AuthorizationCodeTokenHandler(
627627
access_token_store=server.access_token_store,
628628
auth_token_store=server.auth_code_store,
@@ -631,8 +631,8 @@ def __call__(self, request, server):
631631
unique_token=self.unique_token)
632632

633633
if (request.method == "GET"
634-
and request.get_param("response_type") == "code"
635-
and request.path == server.authorize_path):
634+
and request.get_param("response_type") == "code"
635+
and request.path == server.authorize_path):
636636
scope_handler = self._create_scope_handler()
637637

638638
return AuthorizationCodeAuthHandler(
@@ -670,7 +670,7 @@ def __call__(self, request, server):
670670
response_type = request.get_param("response_type")
671671

672672
if (response_type == "token"
673-
and request.path == server.authorize_path):
673+
and request.path == server.authorize_path):
674674
return ImplicitGrantHandler(
675675
access_token_store=server.access_token_store,
676676
client_authenticator=server.client_authenticator,
@@ -1002,7 +1002,7 @@ def read_validate_params(self, request):
10021002
self.refresh_grant_type = access_token.grant_type
10031003

10041004
if refresh_token_expires_at != 0 and \
1005-
refresh_token_expires_at < int(time.time()):
1005+
refresh_token_expires_at < int(time.time()):
10061006
raise OAuthInvalidError(error="invalid_request",
10071007
explanation="Invalid refresh token")
10081008

@@ -1048,8 +1048,11 @@ def process(self, request, response, environ):
10481048
body = {"token_type": "Bearer"}
10491049

10501050
token = self.token_generator.generate()
1051-
expires_in = self.token_generator.expires_in[ClientCredentialsGrant.grant_type]
1052-
expires_at = int(time.time()) + expires_in
1051+
expires_in = self.token_generator.expires_in.get(ClientCredentialsGrant.grant_type, None)
1052+
if expires_in is None:
1053+
expires_at = None
1054+
else:
1055+
expires_at = int(time.time()) + expires_in
10531056

10541057
access_token = AccessToken(
10551058
client_id=self.client.identifier,
@@ -1061,7 +1064,7 @@ def process(self, request, response, environ):
10611064

10621065
body["access_token"] = token
10631066

1064-
if expires_in > 0:
1067+
if expires_in is not None:
10651068
body["expires_in"] = expires_in
10661069

10671070
if self.scope_handler.send_back:

oauth2/test/test_grant.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ def test_process_no_refresh_token(self):
579579
call("Cache-Control",
580580
"no-store"),
581581
call("Pragma", "no-cache")])
582+
582583
@patch("time.time", mock_time)
583584
def test_process_with_refresh_token(self):
584585
token_data = {"access_token": "abcd", "token_type": "Bearer",
@@ -670,12 +671,12 @@ def test_process_with_unique_access_token_not_found(self):
670671
response = Response()
671672

672673
access_token_store_mock = Mock(spec=AccessTokenStore)
673-
access_token_store_mock.fetch_existing_token_of_user.\
674+
access_token_store_mock.fetch_existing_token_of_user. \
674675
side_effect = AccessTokenNotFound
675676

676677
token_generator_mock = Mock(spec=TokenGenerator)
677678
token_generator_mock.refresh_expires_in = 10000
678-
token_generator_mock.create_access_token_data.\
679+
token_generator_mock.create_access_token_data. \
679680
return_value = token_data
680681

681682
handler = AuthorizationCodeTokenHandler(
@@ -866,7 +867,6 @@ def test_process_redirect_with_token(self):
866867
self.assertEqual(responseMock.content, "")
867868
self.assertEqual(result_response, responseMock)
868869

869-
870870
def test_process_redirect_with_state(self):
871871
"""
872872
ImplicitGrantHandler should include the value of the "state" query parameter from request in redirect
@@ -916,7 +916,8 @@ def test_process_with_scope(self):
916916
state = "XHGFI"
917917
token = "tokencode"
918918

919-
expected_redirect_uri = "%s#access_token=%s&token_type=bearer&state=%s&scope=%s" % (redirect_uri, token, state, scopes_uri)
919+
expected_redirect_uri = "%s#access_token=%s&token_type=bearer&state=%s&scope=%s" % (
920+
redirect_uri, token, state, scopes_uri)
920921

921922
response_mock = Mock(spec=Response)
922923

@@ -1516,6 +1517,7 @@ def test_call_other_grant_type(self):
15161517

15171518
self.assertEqual(grant_handler, None)
15181519

1520+
15191521
class RefreshTokenHandlerTestCase(unittest.TestCase):
15201522
@patch("time.time", mock_time)
15211523
def test_process_no_reissue(self):
@@ -1537,7 +1539,7 @@ def test_process_no_reissue(self):
15371539
scope_handler_mock = Mock(spec=Scope)
15381540
scope_handler_mock.scopes = scopes
15391541

1540-
token_data = {"access_token": token, "expires_in":expires_in, "token_type": "Bearer", "refresh_token":"gafc"}
1542+
token_data = {"access_token": token, "expires_in": expires_in, "token_type": "Bearer", "refresh_token": "gafc"}
15411543
token_generator_mock = Mock(spec=TokenGenerator)
15421544
token_generator_mock.create_access_token_data.return_value = token_data
15431545
token_generator_mock.refresh_expires_in = 1200
@@ -1588,7 +1590,8 @@ def test_process_with_reissue(self):
15881590
scope_handler_mock = Mock(spec=Scope)
15891591
scope_handler_mock.scopes = scopes
15901592

1591-
token_data = {"access_token": token, "expires_in":expires_in, "token_type": "Bearer", "refresh_token":refresh_token}
1593+
token_data = {"access_token": token, "expires_in": expires_in, "token_type": "Bearer",
1594+
"refresh_token": refresh_token}
15921595
token_generator_mock = Mock(spec=TokenGenerator)
15931596
token_generator_mock.create_access_token_data.return_value = token_data
15941597
token_generator_mock.refresh_expires_in = 1200
@@ -1618,7 +1621,6 @@ def test_process_with_reissue(self):
16181621
self.assertDictContainsSubset(expected_headers, result.headers)
16191622
self.assertDictEqual(expected_response_body, json.loads(result.body))
16201623

1621-
16221624
@patch("time.time", mock_time)
16231625
def test_read_validate_params(self):
16241626
client_id = "client"
@@ -1645,7 +1647,7 @@ def test_read_validate_params(self):
16451647
request_mock = Mock(spec=Request)
16461648
request_mock.post_param.side_effect = [refresh_token]
16471649

1648-
token_generator_mock = Mock(expires_in={'test_grant_type':600})
1650+
token_generator_mock = Mock(expires_in={'test_grant_type': 600})
16491651
token_generator_mock.refresh_expires_in = 0
16501652

16511653
scope_handler_mock = Mock(spec=Scope)
@@ -1738,7 +1740,7 @@ def test_read_validate_params_expired_refresh_token(self):
17381740
access_token_store=access_token_store_mock,
17391741
client_authenticator=client_auth_mock,
17401742
scope_handler=Mock(),
1741-
token_generator=Mock(expires_in={'test_grant_type':600}))
1743+
token_generator=Mock(expires_in={'test_grant_type': 600}))
17421744

17431745
with self.assertRaises(OAuthInvalidError) as expected:
17441746
handler.read_validate_params(request_mock)
@@ -1811,7 +1813,6 @@ def test_call_other_grant_type(self):
18111813
class ClientCredentialsHandlerTestCase(unittest.TestCase):
18121814
def test_process(self):
18131815
client_id = "abc"
1814-
expires_in = 0
18151816
token = "abcd"
18161817

18171818
expected_response_body = {"access_token": token,
@@ -1827,7 +1828,7 @@ def test_process(self):
18271828

18281829
token_generator_mock = Mock(spec=TokenGenerator)
18291830
token_generator_mock.generate.return_value = token
1830-
token_generator_mock.expires_in = {ClientCredentialsGrant.grant_type:expires_in}
1831+
token_generator_mock.expires_in = {}
18311832
handler = ClientCredentialsHandler(
18321833
access_token_store=access_token_store_mock,
18331834
client_authenticator=Mock(),
@@ -1862,7 +1863,7 @@ def test_process_with_refresh_token(self):
18621863

18631864
token_generator_mock = Mock(spec=TokenGenerator)
18641865
token_generator_mock.generate.return_value = token
1865-
token_generator_mock.expires_in = {ClientCredentialsGrant.grant_type:expires_in}
1866+
token_generator_mock.expires_in = {ClientCredentialsGrant.grant_type: expires_in}
18661867

18671868
handler = ClientCredentialsHandler(
18681869
access_token_store=access_token_store_mock,
@@ -1916,5 +1917,6 @@ def test_read_validate_params(self):
19161917
scope_handler_mock.parse.assert_called_with(request=request_mock,
19171918
source="body")
19181919

1920+
19191921
if __name__ == "__main__":
19201922
unittest.main()

0 commit comments

Comments
 (0)