1616
1717from _pytest .fixtures import fixture
1818from fastapi import FastAPI
19+ from typing import Optional , Dict , Any
1920from pytest import mark
2021from starlette .requests import Request
2122from starlette .testclient import TestClient
2223from supertokens_python import InputAppInfo , SupertokensConfig , init
2324from supertokens_python .framework .fastapi import get_middleware
2425from supertokens_python .recipe import jwt
25- from supertokens_python .recipe .jwt .interfaces import APIInterface
26+ from supertokens_python .recipe .jwt .interfaces import APIInterface , RecipeInterface
2627from supertokens_python .recipe .session .asyncio import create_new_session
2728from tests .utils import clean_st , reset , setup_st , start_st
2829
@@ -83,6 +84,20 @@ async def test_that_default_getJWKS_api_does_not_work_when_disabled(
8384
8485
8586async def test_that_default_getJWKS_works_fine (driver_config_client : TestClient ):
87+ custom_validity : Optional [int ] = - 1 # -1 means no override
88+
89+ def func_override (oi : RecipeInterface ):
90+ oi_get_jwks = oi .get_jwks
91+
92+ async def get_jwks (user_context : Dict [str , Any ]):
93+ res = await oi_get_jwks (user_context )
94+ if custom_validity != - 1 :
95+ res .validity_in_secs = custom_validity
96+ return res
97+
98+ oi .get_jwks = get_jwks
99+ return oi
100+
86101 init (
87102 supertokens_config = SupertokensConfig ("http://localhost:3567" ),
88103 app_info = InputAppInfo (
@@ -91,12 +106,37 @@ async def test_that_default_getJWKS_works_fine(driver_config_client: TestClient)
91106 website_domain = "supertokens.io" ,
92107 ),
93108 framework = "fastapi" ,
94- recipe_list = [jwt .init ()],
109+ recipe_list = [jwt .init (override = jwt . OverrideConfig ( functions = func_override ) )],
95110 )
96111 start_st ()
97112
98113 response = driver_config_client .get (url = "/auth/jwt/jwks.json" )
99114
115+ # Default:
100116 assert response .status_code == 200
101117 data = response .json ()
118+ assert data .keys () == {"keys" }
102119 assert len (data ["keys" ]) > 0
120+ assert data ["keys" ][0 ].keys () == {"kty" , "kid" , "n" , "e" , "alg" , "use" }
121+
122+ assert response .headers ["cache-control" ] == "max-age=60, must-revalidate"
123+
124+ # Override cache control:
125+ custom_validity = 1
126+ response = driver_config_client .get (url = "/auth/jwt/jwks.json" )
127+
128+ assert response .status_code == 200
129+ data = response .json ()
130+ assert len (data ["keys" ]) > 0
131+
132+ assert response .headers ["cache-control" ] == "max-age=1, must-revalidate"
133+
134+ # Disable cache control:
135+ custom_validity = None
136+ response = driver_config_client .get (url = "/auth/jwt/jwks.json" )
137+
138+ assert response .status_code == 200
139+ data = response .json ()
140+ assert len (data ["keys" ]) > 0
141+
142+ assert "cache-control" not in response .headers
0 commit comments