Skip to content

Commit bf960d7

Browse files
na-ka-nagcarvelli
andauthored
Avoid decoding jwt twice (#440)
* Avoid decoding jwt twice Currently the Session::authenticate() function (which runs on every request and consumes CPU cycles) is decoding the jwt twice unnecessarily. This small refactor fixes that * Fix typo * remove unnecessary mock --------- Co-authored-by: Giovanni Carvelli <[email protected]>
1 parent 3731319 commit bf960d7

File tree

2 files changed

+38
-55
lines changed

2 files changed

+38
-55
lines changed

tests/test_session.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,7 @@ def test_authenticate_success(self, session_constants, mock_user_management):
236236
"entitlements": ["feature_1"],
237237
}
238238

239-
with patch.object(
240-
Session, "unseal_data", return_value=mock_session
241-
), patch.object(session, "_is_valid_jwt", return_value=True), patch(
239+
with patch.object(Session, "unseal_data", return_value=mock_session), patch(
242240
"jwt.decode", return_value=mock_jwt_payload
243241
), patch.object(
244242
session.jwks,
@@ -324,22 +322,21 @@ def test_refresh_success(self, session_constants, mock_user_management):
324322
cookie_password=session_constants["COOKIE_PASSWORD"],
325323
)
326324

327-
with patch.object(session, "_is_valid_jwt", return_value=True) as _:
328-
with patch(
329-
"jwt.decode",
330-
return_value={
331-
"sid": session_constants["SESSION_ID"],
332-
"org_id": session_constants["ORGANIZATION_ID"],
333-
"role": "admin",
334-
"permissions": ["read"],
335-
"entitlements": ["feature_1"],
336-
},
337-
):
338-
response = session.refresh()
325+
with patch(
326+
"jwt.decode",
327+
return_value={
328+
"sid": session_constants["SESSION_ID"],
329+
"org_id": session_constants["ORGANIZATION_ID"],
330+
"role": "admin",
331+
"permissions": ["read"],
332+
"entitlements": ["feature_1"],
333+
},
334+
):
335+
response = session.refresh()
339336

340-
assert isinstance(response, RefreshWithSessionCookieSuccessResponse)
341-
assert response.authenticated is True
342-
assert response.user.id == session_constants["TEST_USER"]["id"]
337+
assert isinstance(response, RefreshWithSessionCookieSuccessResponse)
338+
assert response.authenticated is True
339+
assert response.user.id == session_constants["TEST_USER"]["id"]
343340

344341
# Verify the refresh token was used correctly
345342
mock_user_management.authenticate_with_refresh_token.assert_called_once_with(
@@ -425,22 +422,21 @@ async def test_refresh_success(self, session_constants, mock_user_management):
425422
cookie_password=session_constants["COOKIE_PASSWORD"],
426423
)
427424

428-
with patch.object(session, "_is_valid_jwt", return_value=True) as _:
429-
with patch(
430-
"jwt.decode",
431-
return_value={
432-
"sid": session_constants["SESSION_ID"],
433-
"org_id": session_constants["ORGANIZATION_ID"],
434-
"role": "admin",
435-
"permissions": ["read"],
436-
"entitlements": ["feature_1"],
437-
},
438-
):
439-
response = await session.refresh()
425+
with patch(
426+
"jwt.decode",
427+
return_value={
428+
"sid": session_constants["SESSION_ID"],
429+
"org_id": session_constants["ORGANIZATION_ID"],
430+
"role": "admin",
431+
"permissions": ["read"],
432+
"entitlements": ["feature_1"],
433+
},
434+
):
435+
response = await session.refresh()
440436

441-
assert isinstance(response, RefreshWithSessionCookieSuccessResponse)
442-
assert response.authenticated is True
443-
assert response.user.id == session_constants["TEST_USER"]["id"]
437+
assert isinstance(response, RefreshWithSessionCookieSuccessResponse)
438+
assert response.authenticated is True
439+
assert response.user.id == session_constants["TEST_USER"]["id"]
444440

445441
# Verify the refresh token was used correctly
446442
mock_user_management.authenticate_with_refresh_token.assert_called_once_with(

workos/session.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -77,20 +77,20 @@ def authenticate(
7777
reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE,
7878
)
7979

80-
if not self._is_valid_jwt(session["access_token"]):
80+
try:
81+
signing_key = self.jwks.get_signing_key_from_jwt(session["access_token"])
82+
decoded = jwt.decode(
83+
session["access_token"],
84+
signing_key.key,
85+
algorithms=self.jwk_algorithms,
86+
options={"verify_aud": False},
87+
)
88+
except jwt.exceptions.InvalidTokenError:
8189
return AuthenticateWithSessionCookieErrorResponse(
8290
authenticated=False,
8391
reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT,
8492
)
8593

86-
signing_key = self.jwks.get_signing_key_from_jwt(session["access_token"])
87-
decoded = jwt.decode(
88-
session["access_token"],
89-
signing_key.key,
90-
algorithms=self.jwk_algorithms,
91-
options={"verify_aud": False},
92-
)
93-
9494
return AuthenticateWithSessionCookieSuccessResponse(
9595
authenticated=True,
9696
session_id=decoded["sid"],
@@ -128,19 +128,6 @@ def get_logout_url(self, return_to: Optional[str] = None) -> str:
128128
)
129129
return str(result)
130130

131-
def _is_valid_jwt(self, token: str) -> bool:
132-
try:
133-
signing_key = self.jwks.get_signing_key_from_jwt(token)
134-
jwt.decode(
135-
token,
136-
signing_key.key,
137-
algorithms=self.jwk_algorithms,
138-
options={"verify_aud": False},
139-
)
140-
return True
141-
except jwt.exceptions.InvalidTokenError:
142-
return False
143-
144131
@staticmethod
145132
def seal_data(data: Dict[str, Any], key: str) -> str:
146133
fernet = Fernet(key)

0 commit comments

Comments
 (0)