Skip to content

Commit 51f341f

Browse files
committed
feat: Add 429 rate limting from SaaS
1 parent a3dd752 commit 51f341f

File tree

3 files changed

+180
-3
lines changed

3 files changed

+180
-3
lines changed

supertokens_python/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@
2626
API_VERSION = "/apiversion"
2727
API_VERSION_HEADER = "cdi-version"
2828
DASHBOARD_VERSION = "0.3"
29+
RATE_LIMIT_STATUS_CODE = 429

supertokens_python/querier.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# under the License.
1414
from __future__ import annotations
1515

16+
import asyncio
17+
1618
from json import JSONDecodeError
1719
from os import environ
18-
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict
20+
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional
1921

2022
from httpx import AsyncClient, ConnectTimeout, NetworkError, Response
2123

@@ -25,6 +27,7 @@
2527
API_VERSION_HEADER,
2628
RID_KEY_HEADER,
2729
SUPPORTED_CDI_VERSIONS,
30+
RATE_LIMIT_STATUS_CODE,
2831
)
2932
from .normalised_url_path import NormalisedURLPath
3033

@@ -196,6 +199,7 @@ async def __send_request_helper(
196199
method: str,
197200
http_function: Callable[[str], Awaitable[Response]],
198201
no_of_tries: int,
202+
retry_info_map: Optional[Dict[str, int]] = None,
199203
) -> Any:
200204
if no_of_tries == 0:
201205
raise_general_exception("No SuperTokens core available to query")
@@ -212,6 +216,14 @@ async def __send_request_helper(
212216
Querier.__last_tried_index %= len(self.__hosts)
213217
url = current_host + path.get_as_string_dangerous()
214218

219+
max_retries = 5
220+
221+
if retry_info_map is None:
222+
retry_info_map = {}
223+
224+
if retry_info_map.get(url) is None:
225+
retry_info_map[url] = max_retries
226+
215227
ProcessState.get_instance().add_state(
216228
AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER
217229
)
@@ -221,6 +233,20 @@ async def __send_request_helper(
221233
):
222234
Querier.__hosts_alive_for_testing.add(current_host)
223235

236+
if response.status_code == RATE_LIMIT_STATUS_CODE:
237+
retries_left = retry_info_map[url]
238+
239+
if retries_left > 0:
240+
retry_info_map[url] = retries_left - 1
241+
242+
attempts_made = max_retries - retries_left
243+
delay = (10 + attempts_made * 250) / 1000
244+
245+
await asyncio.sleep(delay)
246+
return await self.__send_request_helper(
247+
path, method, http_function, no_of_tries, retry_info_map
248+
)
249+
224250
if is_4xx_error(response.status_code) or is_5xx_error(response.status_code): # type: ignore
225251
raise_general_exception(
226252
"SuperTokens core threw an error for a "
@@ -238,9 +264,9 @@ async def __send_request_helper(
238264
except JSONDecodeError:
239265
return response.text
240266

241-
except (ConnectionError, NetworkError, ConnectTimeout):
267+
except (ConnectionError, NetworkError, ConnectTimeout) as _:
242268
return await self.__send_request_helper(
243-
path, method, http_function, no_of_tries - 1
269+
path, method, http_function, no_of_tries - 1, retry_info_map
244270
)
245271
except Exception as e:
246272
raise_general_exception(e)

tests/test_querier.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
2+
#
3+
# This software is licensed under the Apache License, Version 2.0 (the
4+
# "License") as published by the Apache Software Foundation.
5+
#
6+
# You may not use this file except in compliance with the License. You may
7+
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
from pytest import mark
15+
from supertokens_python.recipe import (
16+
session,
17+
emailpassword,
18+
emailverification,
19+
dashboard,
20+
)
21+
import asyncio
22+
import respx
23+
import httpx
24+
from supertokens_python import init, SupertokensConfig
25+
from supertokens_python.querier import Querier, NormalisedURLPath
26+
27+
from tests.utils import get_st_init_args
28+
from tests.utils import (
29+
setup_function,
30+
teardown_function,
31+
start_st,
32+
)
33+
34+
_ = setup_function
35+
_ = teardown_function
36+
37+
pytestmark = mark.asyncio
38+
respx_mock = respx.MockRouter
39+
40+
41+
async def test_network_call_is_retried_as_expected():
42+
# Test that network call is retried properly
43+
# Test that rate limiting errors are thrown back to the user
44+
args = get_st_init_args(
45+
[
46+
session.init(),
47+
emailpassword.init(),
48+
emailverification.init(mode="OPTIONAL"),
49+
dashboard.init(),
50+
]
51+
)
52+
args["supertokens_config"] = SupertokensConfig("http://localhost:6789")
53+
init(**args) # type: ignore
54+
start_st()
55+
56+
Querier.api_version = "3.0"
57+
q = Querier.get_instance()
58+
59+
api2_call_count = 0
60+
61+
def api2_side_effect(_: httpx.Request):
62+
nonlocal api2_call_count
63+
api2_call_count += 1
64+
65+
if api2_call_count == 3:
66+
return httpx.Response(200)
67+
68+
return httpx.Response(429, json={})
69+
70+
with respx_mock() as mocker:
71+
api1 = mocker.get("http://localhost:6789/api1").mock(
72+
httpx.Response(429, json={"status": "RATE_ERROR"})
73+
)
74+
api2 = mocker.get("http://localhost:6789/api2").mock(
75+
side_effect=api2_side_effect
76+
)
77+
api3 = mocker.get("http://localhost:6789/api3").mock(httpx.Response(200))
78+
79+
try:
80+
await q.send_get_request(NormalisedURLPath("/api1"), {})
81+
except Exception as e:
82+
if "with status code: 429" in str(
83+
e
84+
) and 'message: {"status": "RATE_ERROR"}' in str(e):
85+
pass
86+
else:
87+
raise e
88+
89+
await q.send_get_request(NormalisedURLPath("/api2"), {})
90+
await q.send_get_request(NormalisedURLPath("/api3"), {})
91+
92+
# 1 initial request + 5 retries
93+
assert api1.call_count == 6
94+
# 2 403 and 1 200
95+
assert api2.call_count == 3
96+
# 200 in the first attempt
97+
assert api3.call_count == 1
98+
99+
100+
async def test_parallel_calls_have_independent_counters():
101+
args = get_st_init_args(
102+
[
103+
session.init(),
104+
emailpassword.init(),
105+
emailverification.init(mode="OPTIONAL"),
106+
dashboard.init(),
107+
]
108+
)
109+
init(**args) # type: ignore
110+
start_st()
111+
112+
Querier.api_version = "3.0"
113+
q = Querier.get_instance()
114+
115+
call_count1 = 0
116+
call_count2 = 0
117+
118+
def api_side_effect(r: httpx.Request):
119+
nonlocal call_count1, call_count2
120+
121+
id_ = int(r.url.params.get("id"))
122+
if id_ == 1:
123+
call_count1 += 1
124+
elif id_ == 2:
125+
call_count2 += 1
126+
127+
return httpx.Response(429, json={})
128+
129+
with respx_mock() as mocker:
130+
api = mocker.get("http://localhost:3567/api").mock(side_effect=api_side_effect)
131+
132+
async def call_api(id_: int):
133+
try:
134+
await q.send_get_request(NormalisedURLPath("/api"), {"id": id_})
135+
except Exception as e:
136+
if "with status code: 429" in str(e):
137+
pass
138+
else:
139+
raise e
140+
141+
_ = await asyncio.gather(
142+
call_api(1),
143+
call_api(2),
144+
)
145+
146+
# 1 initial request + 5 retries
147+
assert call_count1 == 6
148+
assert call_count2 == 6
149+
150+
assert api.call_count == 12

0 commit comments

Comments
 (0)