Skip to content

Commit 9257488

Browse files
Merge branch '0.15' into no-panic
2 parents aa9cddf + a311ca0 commit 9257488

File tree

7 files changed

+176
-17
lines changed

7 files changed

+176
-17
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [unreleased]
99

10+
## [0.15.0] - 2023-09-26
11+
12+
- Added a `Cache-Control` header to `/jwt/jwks.json` (`GetJWKSGET`)
13+
- Added `ValidityInSeconds` to the return value of the overrideable `GetJWKS` function.
14+
- This can be used to control the `Cache-Control` header mentioned above.
15+
- It defaults to `60` or the value set in the cache-control header returned by the core
16+
- This is optional (so you are not required to update your overrides). Returning undefined means that the header is not set.
1017
- Handle AWS Public URLs (ending with `.amazonaws.com`) separately while extracting TLDs for SameSite attribute.
1118
- Return `500` status instead of panic when `supertokens.Middleware` is used without initializing the SDK.
1219
- Updates fiber adaptor package in the fiber example.

recipe/jwt/api/implementation.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package api
1717

1818
import (
19+
"fmt"
1920
"github.com/supertokens/supertokens-golang/recipe/jwt/jwtmodels"
2021
"github.com/supertokens/supertokens-golang/supertokens"
2122
)
@@ -26,8 +27,13 @@ func MakeAPIImplementation() jwtmodels.APIInterface {
2627
if err != nil {
2728
return jwtmodels.GetJWKSAPIResponse{}, err
2829
}
30+
options.Res.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d, must-revalidate", resp.OK.ValidityInSeconds))
2931
return jwtmodels.GetJWKSAPIResponse{
30-
OK: resp.OK,
32+
OK: &struct {
33+
Keys []jwtmodels.JsonWebKeys
34+
}{
35+
Keys: resp.OK.Keys,
36+
},
3137
}, nil
3238
}
3339
return jwtmodels.APIInterface{

recipe/jwt/getJWKS_test.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,4 +127,84 @@ func TestDefaultGetJWKSWorksFine(t *testing.T) {
127127
result := *unittesting.HttpResponseToConsumableInformation(resp.Body)
128128
assert.NotNil(t, result)
129129
assert.Greater(t, len(result["keys"].([]interface{})), 0)
130+
131+
cacheControl := resp.Header.Get("Cache-Control")
132+
assert.Equal(t, cacheControl, "max-age=60, must-revalidate")
133+
}
134+
135+
func TestThatWeCanOverrideCacheControlThroughRecipeFunction(t *testing.T) {
136+
configValue := supertokens.TypeInput{
137+
Supertokens: &supertokens.ConnectionInfo{
138+
ConnectionURI: "http://localhost:8080",
139+
},
140+
AppInfo: supertokens.AppInfo{
141+
APIDomain: "api.supertokens.io",
142+
AppName: "SuperTokens",
143+
WebsiteDomain: "supertokens.io",
144+
},
145+
RecipeList: []supertokens.Recipe{
146+
Init(&jwtmodels.TypeInput{
147+
Override: &jwtmodels.OverrideStruct{
148+
Functions: func(originalImplementation jwtmodels.RecipeInterface) jwtmodels.RecipeInterface {
149+
originalGetJWKS := *originalImplementation.GetJWKS
150+
151+
getJWKs := func(userContext supertokens.UserContext) (jwtmodels.GetJWKSResponse, error) {
152+
result, err := originalGetJWKS(userContext)
153+
154+
if err != nil {
155+
return jwtmodels.GetJWKSResponse{}, err
156+
}
157+
158+
return jwtmodels.GetJWKSResponse{
159+
OK: &struct {
160+
Keys []jwtmodels.JsonWebKeys
161+
ValidityInSeconds int
162+
}{Keys: result.OK.Keys, ValidityInSeconds: 1234},
163+
}, nil
164+
}
165+
166+
*originalImplementation.GetJWKS = getJWKs
167+
168+
return originalImplementation
169+
},
170+
},
171+
}),
172+
},
173+
}
174+
175+
BeforeEach()
176+
unittesting.StartUpST("localhost", "8080")
177+
defer AfterEach()
178+
err := supertokens.Init(configValue)
179+
if err != nil {
180+
t.Error(err.Error())
181+
}
182+
183+
q, err := supertokens.GetNewQuerierInstanceOrThrowError("")
184+
if err != nil {
185+
t.Error(err.Error())
186+
}
187+
apiV, err := q.GetQuerierAPIVersion()
188+
if err != nil {
189+
t.Error(err.Error())
190+
}
191+
192+
if unittesting.MaxVersion(apiV, "2.8") == "2.8" {
193+
return
194+
}
195+
mux := http.NewServeMux()
196+
testServer := httptest.NewServer(supertokens.Middleware(mux))
197+
defer testServer.Close()
198+
199+
resp, err := http.Get(testServer.URL + "/auth/jwt/jwks.json")
200+
if err != nil {
201+
t.Error(err.Error())
202+
}
203+
204+
result := *unittesting.HttpResponseToConsumableInformation(resp.Body)
205+
assert.NotNil(t, result)
206+
assert.Greater(t, len(result["keys"].([]interface{})), 0)
207+
208+
cacheControl := resp.Header.Get("Cache-Control")
209+
assert.Equal(t, cacheControl, "max-age=1234, must-revalidate")
130210
}

recipe/jwt/jwtmodels/recipeInterface.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ type CreateJWTResponse struct {
3131

3232
type GetJWKSResponse struct {
3333
OK *struct {
34-
Keys []JsonWebKeys
34+
Keys []JsonWebKeys
35+
ValidityInSeconds int
3536
}
3637
}

recipe/jwt/recipeimplementation.go

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@ package jwt
1818
import (
1919
"github.com/supertokens/supertokens-golang/recipe/jwt/jwtmodels"
2020
"github.com/supertokens/supertokens-golang/supertokens"
21+
"regexp"
22+
"strconv"
2123
)
2224

25+
var defaultJWKSMaxAge = 60 // This corresponds to the dynamicSigningKeyOverlapMS in the core
26+
2327
func makeRecipeImplementation(querier supertokens.Querier, config jwtmodels.TypeNormalisedInput, appInfo supertokens.NormalisedAppinfo) jwtmodels.RecipeInterface {
2428
createJWT := func(payload map[string]interface{}, validitySecondsPointer *uint64, useStaticSigningKey *bool, userContext supertokens.UserContext) (jwtmodels.CreateJWTResponse, error) {
2529
validitySeconds := config.JwtValiditySeconds
@@ -61,7 +65,7 @@ func makeRecipeImplementation(querier supertokens.Querier, config jwtmodels.Type
6165
}
6266
}
6367
getJWKS := func(userContext supertokens.UserContext) (jwtmodels.GetJWKSResponse, error) {
64-
response, err := querier.SendGetRequest("/.well-known/jwks.json", map[string]string{})
68+
response, headers, err := querier.SendGetRequestWithResponseHeaders("/.well-known/jwks.json", map[string]string{})
6569
if err != nil {
6670
return jwtmodels.GetJWKSResponse{}, err
6771
}
@@ -79,9 +83,29 @@ func makeRecipeImplementation(querier supertokens.Querier, config jwtmodels.Type
7983
})
8084
}
8185

86+
validityInSeconds := defaultJWKSMaxAge
87+
cacheControlHeader := headers.Get("Cache-Control")
88+
89+
if cacheControlHeader != "" {
90+
regex := regexp.MustCompile(`/,?\s*max-age=(\d+)(?:,|$)/`)
91+
maxAgeHeader := regex.FindAllString(cacheControlHeader, -1)
92+
93+
if maxAgeHeader != nil && len(maxAgeHeader) > 0 {
94+
validityInSeconds, err = strconv.Atoi(maxAgeHeader[1])
95+
96+
if err != nil {
97+
validityInSeconds = defaultJWKSMaxAge
98+
}
99+
}
100+
}
101+
82102
return jwtmodels.GetJWKSResponse{
83-
OK: &struct{ Keys []jwtmodels.JsonWebKeys }{
84-
Keys: keys,
103+
OK: &struct {
104+
Keys []jwtmodels.JsonWebKeys
105+
ValidityInSeconds int
106+
}{
107+
Keys: keys,
108+
ValidityInSeconds: validityInSeconds,
85109
},
86110
}, nil
87111
}

supertokens/constants.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ const (
2121
)
2222

2323
// VERSION current version of the lib
24-
const VERSION = "0.14.0"
24+
const VERSION = "0.15.0"
2525

2626
var (
2727
cdiSupported = []string{"3.0"}

supertokens/querier.go

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func (q *Querier) GetQuerierAPIVersion() (string, error) {
5656
if querierAPIVersion != "" {
5757
return querierAPIVersion, nil
5858
}
59-
response, err := q.sendRequestHelper(NormalisedURLPath{value: "/apiversion"}, func(url string) (*http.Response, error) {
59+
response, _, err := q.sendRequestHelper(NormalisedURLPath{value: "/apiversion"}, func(url string) (*http.Response, error) {
6060
req, err := http.NewRequest("GET", url, nil)
6161
if err != nil {
6262
return nil, err
@@ -117,7 +117,7 @@ func (q *Querier) SendPostRequest(path string, data map[string]interface{}) (map
117117
if err != nil {
118118
return nil, err
119119
}
120-
return q.sendRequestHelper(nP, func(url string) (*http.Response, error) {
120+
resp, _, err := q.sendRequestHelper(nP, func(url string) (*http.Response, error) {
121121
if data == nil {
122122
data = map[string]interface{}{}
123123
}
@@ -147,14 +147,15 @@ func (q *Querier) SendPostRequest(path string, data map[string]interface{}) (map
147147
client := &http.Client{}
148148
return client.Do(req)
149149
}, len(QuerierHosts), nil)
150+
return resp, err
150151
}
151152

152153
func (q *Querier) SendDeleteRequest(path string, data map[string]interface{}, params map[string]string) (map[string]interface{}, error) {
153154
nP, err := NewNormalisedURLPath(path)
154155
if err != nil {
155156
return nil, err
156157
}
157-
return q.sendRequestHelper(nP, func(url string) (*http.Response, error) {
158+
resp, _, err := q.sendRequestHelper(nP, func(url string) (*http.Response, error) {
158159
jsonData, err := json.Marshal(data)
159160
if err != nil {
160161
return nil, err
@@ -188,13 +189,51 @@ func (q *Querier) SendDeleteRequest(path string, data map[string]interface{}, pa
188189
client := &http.Client{}
189190
return client.Do(req)
190191
}, len(QuerierHosts), nil)
192+
return resp, err
191193
}
192194

193195
func (q *Querier) SendGetRequest(path string, params map[string]string) (map[string]interface{}, error) {
194196
nP, err := NewNormalisedURLPath(path)
195197
if err != nil {
196198
return nil, err
197199
}
200+
resp, _, err := q.sendRequestHelper(nP, func(url string) (*http.Response, error) {
201+
req, err := http.NewRequest("GET", url, nil)
202+
if err != nil {
203+
return nil, err
204+
}
205+
206+
query := req.URL.Query()
207+
208+
for k, v := range params {
209+
query.Add(k, v)
210+
}
211+
req.URL.RawQuery = query.Encode()
212+
213+
apiVerion, querierAPIVersionError := q.GetQuerierAPIVersion()
214+
if querierAPIVersionError != nil {
215+
return nil, querierAPIVersionError
216+
}
217+
req.Header.Set("cdi-version", apiVerion)
218+
if QuerierAPIKey != nil {
219+
req.Header.Set("api-key", *QuerierAPIKey)
220+
}
221+
if nP.IsARecipePath() && q.RIDToCore != "" {
222+
req.Header.Set("rid", q.RIDToCore)
223+
}
224+
225+
client := &http.Client{}
226+
return client.Do(req)
227+
}, len(QuerierHosts), nil)
228+
return resp, err
229+
}
230+
231+
func (q *Querier) SendGetRequestWithResponseHeaders(path string, params map[string]string) (map[string]interface{}, http.Header, error) {
232+
nP, err := NewNormalisedURLPath(path)
233+
if err != nil {
234+
return nil, nil, err
235+
}
236+
198237
return q.sendRequestHelper(nP, func(url string) (*http.Response, error) {
199238
req, err := http.NewRequest("GET", url, nil)
200239
if err != nil {
@@ -230,7 +269,7 @@ func (q *Querier) SendPutRequest(path string, data map[string]interface{}) (map[
230269
if err != nil {
231270
return nil, err
232271
}
233-
return q.sendRequestHelper(nP, func(url string) (*http.Response, error) {
272+
resp, _, err := q.sendRequestHelper(nP, func(url string) (*http.Response, error) {
234273
jsonData, err := json.Marshal(data)
235274
if err != nil {
236275
return nil, err
@@ -257,6 +296,7 @@ func (q *Querier) SendPutRequest(path string, data map[string]interface{}) (map[
257296
client := &http.Client{}
258297
return client.Do(req)
259298
}, len(QuerierHosts), nil)
299+
return resp, err
260300
}
261301

262302
type httpRequestFunction func(url string) (*http.Response, error)
@@ -279,9 +319,9 @@ func GetAllCoreUrlsForPath(path string) []string {
279319
return result
280320
}
281321

282-
func (q *Querier) sendRequestHelper(path NormalisedURLPath, httpRequest httpRequestFunction, numberOfTries int, retryInfoMap *map[string]int) (map[string]interface{}, error) {
322+
func (q *Querier) sendRequestHelper(path NormalisedURLPath, httpRequest httpRequestFunction, numberOfTries int, retryInfoMap *map[string]int) (map[string]interface{}, http.Header, error) {
283323
if numberOfTries == 0 {
284-
return nil, errors.New("no SuperTokens core available to query")
324+
return nil, nil, errors.New("no SuperTokens core available to query")
285325
}
286326

287327
querierHostLock.Lock()
@@ -316,14 +356,14 @@ func (q *Querier) sendRequestHelper(path NormalisedURLPath, httpRequest httpRequ
316356
if resp != nil {
317357
resp.Body.Close()
318358
}
319-
return nil, err
359+
return nil, nil, err
320360
}
321361

322362
defer resp.Body.Close()
323363

324364
body, readErr := ioutil.ReadAll(resp.Body)
325365
if readErr != nil {
326-
return nil, readErr
366+
return nil, nil, readErr
327367
}
328368
if resp.StatusCode != 200 {
329369
if resp.StatusCode == RateLimitStatusCode {
@@ -341,17 +381,18 @@ func (q *Querier) sendRequestHelper(path NormalisedURLPath, httpRequest httpRequ
341381
}
342382
}
343383

344-
return nil, fmt.Errorf("SuperTokens core threw an error for a request to path: '%s' with status code: %v and message: %s", path.GetAsStringDangerous(), resp.StatusCode, body)
384+
return nil, nil, fmt.Errorf("SuperTokens core threw an error for a request to path: '%s' with status code: %v and message: %s", path.GetAsStringDangerous(), resp.StatusCode, body)
345385
}
346386

387+
headers := resp.Header.Clone()
347388
finalResult := make(map[string]interface{})
348389
jsonError := json.Unmarshal(body, &finalResult)
349390
if jsonError != nil {
350391
return map[string]interface{}{
351392
"result": string(body),
352-
}, nil
393+
}, headers, nil
353394
}
354-
return finalResult, nil
395+
return finalResult, headers, nil
355396
}
356397

357398
func ResetQuerierForTest() {

0 commit comments

Comments
 (0)