Skip to content

Commit 293df07

Browse files
committed
rate limit
1 parent 2d6cb4d commit 293df07

File tree

8 files changed

+124
-3
lines changed

8 files changed

+124
-3
lines changed

docker-compose.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ services:
2222
- ES_USE_SSL=false
2323
- ES_VERIFY_CERTS=false
2424
- BACKEND=elasticsearch
25+
- STAC_FASTAPI_RATE_LIMIT=2/5second
2526
ports:
2627
- "8080:8080"
2728
volumes:
@@ -54,6 +55,7 @@ services:
5455
- ES_USE_SSL=false
5556
- ES_VERIFY_CERTS=false
5657
- BACKEND=opensearch
58+
- STAC_FASTAPI_RATE_LIMIT=200/minute
5759
ports:
5860
- "8082:8082"
5961
volumes:

stac_fastapi/core/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"pygeofilter==0.2.1",
2020
"typing_extensions==4.8.0",
2121
"jsonschema",
22+
"slowapi==0.1.9",
2223
]
2324

2425
setup(
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import os
2+
from fastapi import FastAPI, Request
3+
from slowapi import Limiter, _rate_limit_exceeded_handler
4+
from slowapi.util import get_remote_address
5+
from slowapi.errors import RateLimitExceeded
6+
from slowapi.middleware import SlowAPIMiddleware
7+
import logging
8+
9+
logger = logging.getLogger(__name__)
10+
11+
limiter = Limiter(key_func=get_remote_address)
12+
13+
def setup_rate_limit(app: FastAPI):
14+
RATE_LIMIT = os.getenv("STAC_FASTAPI_RATE_LIMIT")
15+
logger.info(f"Setting up rate limit with RATE_LIMIT={RATE_LIMIT}")
16+
if RATE_LIMIT:
17+
app.state.limiter = limiter
18+
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
19+
app.add_middleware(SlowAPIMiddleware)
20+
21+
@app.middleware("http")
22+
@limiter.limit(RATE_LIMIT)
23+
async def rate_limit_middleware(request: Request, call_next):
24+
response = await call_next(request)
25+
return response
26+
27+
logger.info("Rate limit setup complete")

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""FastAPI application."""
22

33
import os
4+
from stac_fastapi.core.rate_limit import setup_rate_limit
45

56
from stac_fastapi.api.app import StacApi
67
from stac_fastapi.api.models import create_get_request_model, create_post_request_model
@@ -97,6 +98,8 @@
9798
app = api.app
9899
app.root_path = os.getenv("STAC_FASTAPI_ROOT_PATH", "")
99100

101+
# Add rate limit
102+
setup_rate_limit(app)
100103

101104
@app.on_event("startup")
102105
async def _startup_event() -> None:

stac_fastapi/opensearch/stac_fastapi/opensearch/app.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""FastAPI application."""
22

33
import os
4+
from stac_fastapi.core.rate_limit import setup_rate_limit
45

56
from stac_fastapi.api.app import StacApi
67
from stac_fastapi.api.models import create_get_request_model, create_post_request_model
@@ -97,6 +98,8 @@
9798
app = api.app
9899
app.root_path = os.getenv("STAC_FASTAPI_ROOT_PATH", "")
99100

101+
# Add rate limit
102+
setup_rate_limit(app)
100103

101104
@app.on_event("startup")
102105
async def _startup_event() -> None:

stac_fastapi/tests/basic_auth/test_basic_auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ async def test_get_search_not_authenticated(app_client_basic_auth, ctx):
1818

1919
@pytest.mark.asyncio
2020
async def test_post_search_authenticated(app_client_basic_auth, ctx):
21-
"""Test protected endpoint [POST /search] with reader auhtentication"""
21+
"""Test protected endpoint [POST /search] with reader authentication"""
2222
if not os.getenv("BASIC_AUTH"):
2323
pytest.skip()
2424
params = {"id": ctx.item["id"]}
@@ -34,7 +34,7 @@ async def test_post_search_authenticated(app_client_basic_auth, ctx):
3434
async def test_delete_resource_anonymous(
3535
app_client_basic_auth,
3636
):
37-
"""Test protected endpoint [DELETE /collections/{collection_id}] without auhtentication"""
37+
"""Test protected endpoint [DELETE /collections/{collection_id}] without authentication"""
3838
if not os.getenv("BASIC_AUTH"):
3939
pytest.skip()
4040

stac_fastapi/tests/conftest.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
EsAsyncAggregationClient,
2626
)
2727
from stac_fastapi.core.route_dependencies import get_route_dependencies
28+
from stac_fastapi.core.rate_limit import setup_rate_limit
2829

2930
if os.getenv("BACKEND", "elasticsearch").lower() == "opensearch":
3031
from stac_fastapi.opensearch.config import AsyncOpensearchSettings as AsyncSettings
@@ -223,7 +224,56 @@ async def app():
223224

224225
post_request_model = create_post_request_model(search_extensions)
225226

226-
return StacApi(
227+
app = StacApi(
228+
settings=settings,
229+
client=CoreClient(
230+
database=database,
231+
session=None,
232+
extensions=extensions,
233+
post_request_model=post_request_model,
234+
),
235+
extensions=extensions,
236+
search_get_request_model=create_get_request_model(search_extensions),
237+
search_post_request_model=post_request_model,
238+
).app
239+
240+
return app
241+
242+
243+
@pytest_asyncio.fixture(scope="function")
244+
async def app_rate_limit(monkeypatch):
245+
monkeypatch.setenv("STAC_FASTAPI_RATE_LIMIT", "2/minute")
246+
247+
settings = AsyncSettings()
248+
249+
aggregation_extension = AggregationExtension(
250+
client=EsAsyncAggregationClient(
251+
database=database, session=None, settings=settings
252+
)
253+
)
254+
aggregation_extension.POST = EsAggregationExtensionPostRequest
255+
aggregation_extension.GET = EsAggregationExtensionGetRequest
256+
257+
search_extensions = [
258+
TransactionExtension(
259+
client=TransactionsClient(
260+
database=database, session=None, settings=settings
261+
),
262+
settings=settings,
263+
),
264+
SortExtension(),
265+
FieldsExtension(),
266+
QueryExtension(),
267+
TokenPaginationExtension(),
268+
FilterExtension(),
269+
FreeTextExtension(),
270+
]
271+
272+
extensions = [aggregation_extension] + search_extensions
273+
274+
post_request_model = create_post_request_model(search_extensions)
275+
276+
app = StacApi(
227277
settings=settings,
228278
client=CoreClient(
229279
database=database,
@@ -236,6 +286,12 @@ async def app():
236286
search_post_request_model=post_request_model,
237287
).app
238288

289+
# Set up rate limit
290+
setup_rate_limit(app)
291+
292+
return app
293+
294+
239295

240296
@pytest_asyncio.fixture(scope="session")
241297
async def app_client(app):
@@ -246,6 +302,15 @@ async def app_client(app):
246302
yield c
247303

248304

305+
@pytest_asyncio.fixture(scope="function")
306+
async def app_client_rate_limit(app_rate_limit):
307+
await create_index_templates()
308+
await create_collection_index()
309+
310+
async with AsyncClient(app=app_rate_limit, base_url="http://test-server") as c:
311+
yield c
312+
313+
249314
@pytest_asyncio.fixture(scope="session")
250315
async def app_basic_auth():
251316

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import pytest
2+
from httpx import AsyncClient
3+
import logging
4+
from slowapi.errors import RateLimitExceeded
5+
6+
logger = logging.getLogger(__name__)
7+
8+
@pytest.mark.asyncio
9+
async def test_rate_limit(app_client_rate_limit: AsyncClient):
10+
expected_status_codes = [200, 200, 429, 429, 429]
11+
12+
for i, expected_status_code in enumerate(expected_status_codes):
13+
try:
14+
response = await app_client_rate_limit.get("/collections")
15+
status_code = response.status_code
16+
except RateLimitExceeded:
17+
status_code = 429
18+
19+
logger.info(f"Request {i+1}: Status code {status_code}")
20+
assert status_code == expected_status_code, f"Expected status code {expected_status_code}, but got {status_code}"

0 commit comments

Comments
 (0)