Skip to content

Commit 2628044

Browse files
committed
Add tests for rate limiting
1 parent 72757af commit 2628044

File tree

2 files changed

+287
-0
lines changed

2 files changed

+287
-0
lines changed

supertokens/querier.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ var (
4646
querierHostLock sync.Mutex
4747
)
4848

49+
func SetQuerierApiVersionForTests(version string) {
50+
querierAPIVersion = version
51+
}
52+
4953
func (q *Querier) GetQuerierAPIVersion() (string, error) {
5054
querierLock.Lock()
5155
defer querierLock.Unlock()

test/querier_test.go

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
package test
2+
3+
import (
4+
"encoding/json"
5+
"errors"
6+
"github.com/stretchr/testify/assert"
7+
"github.com/supertokens/supertokens-golang/recipe/session"
8+
"github.com/supertokens/supertokens-golang/supertokens"
9+
"net/http"
10+
"net/http/httptest"
11+
"strings"
12+
"sync"
13+
"testing"
14+
)
15+
16+
func resetQuerier() {
17+
supertokens.SetQuerierApiVersionForTests("")
18+
}
19+
20+
func TestThatNetworkCallIsRetried(t *testing.T) {
21+
mux := http.NewServeMux()
22+
23+
numberOfTimesCalled := 0
24+
numberOfTimesSecondCalled := 0
25+
numberOfTimesThirdCalled := 0
26+
27+
mux.HandleFunc("/testing", func(rw http.ResponseWriter, r *http.Request) {
28+
numberOfTimesCalled++
29+
rw.WriteHeader(supertokens.RateLimitStatusCode)
30+
rw.Header().Set("Content-Type", "application/json")
31+
response, err := json.Marshal(map[string]interface{}{})
32+
if err != nil {
33+
t.Error(err.Error())
34+
}
35+
rw.Write(response)
36+
})
37+
38+
mux.HandleFunc("/testing2", func(rw http.ResponseWriter, r *http.Request) {
39+
numberOfTimesSecondCalled++
40+
rw.Header().Set("Content-Type", "application/json")
41+
42+
if numberOfTimesSecondCalled == 3 {
43+
rw.WriteHeader(200)
44+
} else {
45+
rw.WriteHeader(supertokens.RateLimitStatusCode)
46+
}
47+
48+
response, err := json.Marshal(map[string]interface{}{})
49+
if err != nil {
50+
t.Error(err.Error())
51+
}
52+
rw.Write(response)
53+
})
54+
55+
mux.HandleFunc("/testing3", func(rw http.ResponseWriter, r *http.Request) {
56+
numberOfTimesThirdCalled++
57+
rw.Header().Set("Content-Type", "application/json")
58+
rw.WriteHeader(200)
59+
response, err := json.Marshal(map[string]interface{}{})
60+
if err != nil {
61+
t.Error(err.Error())
62+
}
63+
rw.Write(response)
64+
})
65+
66+
testServer := httptest.NewServer(mux)
67+
68+
defer func() {
69+
testServer.Close()
70+
}()
71+
72+
config := supertokens.TypeInput{
73+
Supertokens: &supertokens.ConnectionInfo{
74+
// We need the querier to call the test server and not the core
75+
ConnectionURI: testServer.URL,
76+
},
77+
AppInfo: supertokens.AppInfo{
78+
AppName: "SuperTokens",
79+
WebsiteDomain: "supertokens.io",
80+
APIDomain: "api.supertokens.io",
81+
},
82+
RecipeList: []supertokens.Recipe{
83+
session.Init(nil),
84+
},
85+
}
86+
87+
err := supertokens.Init(config)
88+
89+
if err != nil {
90+
t.Error(err.Error())
91+
}
92+
93+
q, err := supertokens.GetNewQuerierInstanceOrThrowError("")
94+
supertokens.SetQuerierApiVersionForTests("3.0")
95+
defer resetQuerier()
96+
97+
if err != nil {
98+
t.Error(err.Error())
99+
}
100+
101+
_, err = q.SendGetRequest("/testing", map[string]string{})
102+
if err == nil {
103+
t.Error(errors.New("request should have failed but didnt").Error())
104+
} else {
105+
if !strings.Contains(err.Error(), "with status code: 429") {
106+
t.Error(errors.New("request failed with an unexpected error").Error())
107+
}
108+
}
109+
110+
_, err = q.SendGetRequest("/testing2", map[string]string{})
111+
if err != nil {
112+
t.Error(err.Error())
113+
}
114+
115+
_, err = q.SendGetRequest("/testing3", map[string]string{})
116+
if err != nil {
117+
t.Error(err.Error())
118+
}
119+
120+
// One initial call + 5 retries
121+
assert.Equal(t, numberOfTimesCalled, 6)
122+
assert.Equal(t, numberOfTimesSecondCalled, 3)
123+
assert.Equal(t, numberOfTimesThirdCalled, 1)
124+
}
125+
126+
func TestThatRateLimitErrorsAreThrownBackToTheUser(t *testing.T) {
127+
mux := http.NewServeMux()
128+
129+
mux.HandleFunc("/testing", func(rw http.ResponseWriter, r *http.Request) {
130+
rw.WriteHeader(supertokens.RateLimitStatusCode)
131+
rw.Header().Set("Content-Type", "application/json")
132+
response, err := json.Marshal(map[string]interface{}{
133+
"status": "RATE_LIMIT_ERROR",
134+
})
135+
if err != nil {
136+
t.Error(err.Error())
137+
}
138+
rw.Write(response)
139+
})
140+
141+
testServer := httptest.NewServer(mux)
142+
143+
defer func() {
144+
testServer.Close()
145+
}()
146+
147+
config := supertokens.TypeInput{
148+
Supertokens: &supertokens.ConnectionInfo{
149+
// We need the querier to call the test server and not the core
150+
ConnectionURI: testServer.URL,
151+
},
152+
AppInfo: supertokens.AppInfo{
153+
AppName: "SuperTokens",
154+
WebsiteDomain: "supertokens.io",
155+
APIDomain: "api.supertokens.io",
156+
},
157+
RecipeList: []supertokens.Recipe{
158+
session.Init(nil),
159+
},
160+
}
161+
162+
err := supertokens.Init(config)
163+
164+
if err != nil {
165+
t.Error(err.Error())
166+
}
167+
168+
q, err := supertokens.GetNewQuerierInstanceOrThrowError("")
169+
supertokens.SetQuerierApiVersionForTests("3.0")
170+
defer resetQuerier()
171+
172+
if err != nil {
173+
t.Error(err.Error())
174+
}
175+
176+
_, err = q.SendGetRequest("/testing", map[string]string{})
177+
if err == nil {
178+
t.Error(errors.New("request should have failed but didnt").Error())
179+
} else {
180+
if !strings.Contains(err.Error(), "with status code: 429") {
181+
t.Error(errors.New("request failed with an unexpected error").Error())
182+
}
183+
184+
assert.True(t, strings.Contains(err.Error(), "message: {\"status\":\"RATE_LIMIT_ERROR\"}"))
185+
}
186+
}
187+
188+
func TestThatParallelCallsHaveIndependentRetryCounters(t *testing.T) {
189+
mux := http.NewServeMux()
190+
191+
numberOfTimesFirstCalled := 0
192+
numberOfTimesSecondCalled := 0
193+
194+
mux.HandleFunc("/testing", func(rw http.ResponseWriter, r *http.Request) {
195+
if r.URL.Query().Get("id") == "1" {
196+
numberOfTimesFirstCalled++
197+
} else {
198+
numberOfTimesSecondCalled++
199+
}
200+
201+
rw.WriteHeader(supertokens.RateLimitStatusCode)
202+
rw.Header().Set("Content-Type", "application/json")
203+
response, err := json.Marshal(map[string]interface{}{})
204+
if err != nil {
205+
t.Error(err.Error())
206+
}
207+
rw.Write(response)
208+
})
209+
210+
testServer := httptest.NewServer(mux)
211+
212+
defer func() {
213+
testServer.Close()
214+
}()
215+
216+
config := supertokens.TypeInput{
217+
Supertokens: &supertokens.ConnectionInfo{
218+
// We need the querier to call the test server and not the core
219+
ConnectionURI: testServer.URL,
220+
},
221+
AppInfo: supertokens.AppInfo{
222+
AppName: "SuperTokens",
223+
WebsiteDomain: "supertokens.io",
224+
APIDomain: "api.supertokens.io",
225+
},
226+
RecipeList: []supertokens.Recipe{
227+
session.Init(nil),
228+
},
229+
}
230+
231+
err := supertokens.Init(config)
232+
233+
if err != nil {
234+
t.Error(err.Error())
235+
}
236+
237+
q, err := supertokens.GetNewQuerierInstanceOrThrowError("")
238+
supertokens.SetQuerierApiVersionForTests("3.0")
239+
defer resetQuerier()
240+
241+
if err != nil {
242+
t.Error(err.Error())
243+
}
244+
245+
var wg sync.WaitGroup
246+
247+
wg.Add(2)
248+
249+
go func() {
250+
_, err = q.SendGetRequest("/testing", map[string]string{
251+
"id": "1",
252+
})
253+
if err == nil {
254+
t.Error(errors.New("request should have failed but didnt").Error())
255+
} else {
256+
if !strings.Contains(err.Error(), "with status code: 429") {
257+
t.Error(errors.New("request failed with an unexpected error").Error())
258+
}
259+
}
260+
261+
wg.Done()
262+
}()
263+
264+
go func() {
265+
_, err = q.SendGetRequest("/testing", map[string]string{
266+
"id": "2",
267+
})
268+
if err == nil {
269+
t.Error(errors.New("request should have failed but didnt").Error())
270+
} else {
271+
if !strings.Contains(err.Error(), "with status code: 429") {
272+
t.Error(errors.New("request failed with an unexpected error").Error())
273+
}
274+
}
275+
276+
wg.Done()
277+
}()
278+
279+
wg.Wait()
280+
281+
assert.Equal(t, numberOfTimesFirstCalled, 6)
282+
assert.Equal(t, numberOfTimesSecondCalled, 6)
283+
}

0 commit comments

Comments
 (0)