Skip to content

Commit 92a2413

Browse files
committed
pre-commit
1 parent 293df07 commit 92a2413

File tree

5 files changed

+23
-13
lines changed

5 files changed

+23
-13
lines changed

stac_fastapi/core/stac_fastapi/core/rate_limit.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
1+
"""Rate limiting middleware."""
2+
3+
import logging
14
import os
5+
26
from fastapi import FastAPI, Request
37
from slowapi import Limiter, _rate_limit_exceeded_handler
4-
from slowapi.util import get_remote_address
58
from slowapi.errors import RateLimitExceeded
69
from slowapi.middleware import SlowAPIMiddleware
7-
import logging
10+
from slowapi.util import get_remote_address
811

912
logger = logging.getLogger(__name__)
1013

1114
limiter = Limiter(key_func=get_remote_address)
1215

16+
1317
def setup_rate_limit(app: FastAPI):
18+
"""Set up rate limiting middleware."""
1419
RATE_LIMIT = os.getenv("STAC_FASTAPI_RATE_LIMIT")
1520
logger.info(f"Setting up rate limit with RATE_LIMIT={RATE_LIMIT}")
1621
if RATE_LIMIT:
@@ -24,4 +29,4 @@ async def rate_limit_middleware(request: Request, call_next):
2429
response = await call_next(request)
2530
return response
2631

27-
logger.info("Rate limit setup complete")
32+
logger.info("Rate limit setup complete")

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""FastAPI application."""
22

33
import os
4-
from stac_fastapi.core.rate_limit import setup_rate_limit
54

65
from stac_fastapi.api.app import StacApi
76
from stac_fastapi.api.models import create_get_request_model, create_post_request_model
@@ -18,6 +17,7 @@
1817
EsAsyncAggregationClient,
1918
)
2019
from stac_fastapi.core.extensions.fields import FieldsExtension
20+
from stac_fastapi.core.rate_limit import setup_rate_limit
2121
from stac_fastapi.core.route_dependencies import get_route_dependencies
2222
from stac_fastapi.core.session import Session
2323
from stac_fastapi.elasticsearch.config import ElasticsearchSettings
@@ -101,6 +101,7 @@
101101
# Add rate limit
102102
setup_rate_limit(app)
103103

104+
104105
@app.on_event("startup")
105106
async def _startup_event() -> None:
106107
await create_index_templates()

stac_fastapi/opensearch/stac_fastapi/opensearch/app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""FastAPI application."""
22

33
import os
4-
from stac_fastapi.core.rate_limit import setup_rate_limit
54

65
from stac_fastapi.api.app import StacApi
76
from stac_fastapi.api.models import create_get_request_model, create_post_request_model
@@ -18,6 +17,7 @@
1817
EsAsyncAggregationClient,
1918
)
2019
from stac_fastapi.core.extensions.fields import FieldsExtension
20+
from stac_fastapi.core.rate_limit import setup_rate_limit
2121
from stac_fastapi.core.route_dependencies import get_route_dependencies
2222
from stac_fastapi.core.session import Session
2323
from stac_fastapi.extensions.core import (
@@ -101,6 +101,7 @@
101101
# Add rate limit
102102
setup_rate_limit(app)
103103

104+
104105
@app.on_event("startup")
105106
async def _startup_event() -> None:
106107
await create_index_templates()

stac_fastapi/tests/conftest.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
EsAggregationExtensionPostRequest,
2525
EsAsyncAggregationClient,
2626
)
27-
from stac_fastapi.core.route_dependencies import get_route_dependencies
2827
from stac_fastapi.core.rate_limit import setup_rate_limit
28+
from stac_fastapi.core.route_dependencies import get_route_dependencies
2929

3030
if os.getenv("BACKEND", "elasticsearch").lower() == "opensearch":
3131
from stac_fastapi.opensearch.config import AsyncOpensearchSettings as AsyncSettings
@@ -243,9 +243,9 @@ async def app():
243243
@pytest_asyncio.fixture(scope="function")
244244
async def app_rate_limit(monkeypatch):
245245
monkeypatch.setenv("STAC_FASTAPI_RATE_LIMIT", "2/minute")
246-
246+
247247
settings = AsyncSettings()
248-
248+
249249
aggregation_extension = AggregationExtension(
250250
client=EsAsyncAggregationClient(
251251
database=database, session=None, settings=settings
@@ -292,7 +292,6 @@ async def app_rate_limit(monkeypatch):
292292
return app
293293

294294

295-
296295
@pytest_asyncio.fixture(scope="session")
297296
async def app_client(app):
298297
await create_index_templates()
Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,24 @@
1+
import logging
2+
13
import pytest
24
from httpx import AsyncClient
3-
import logging
45
from slowapi.errors import RateLimitExceeded
56

67
logger = logging.getLogger(__name__)
78

9+
810
@pytest.mark.asyncio
911
async def test_rate_limit(app_client_rate_limit: AsyncClient):
1012
expected_status_codes = [200, 200, 429, 429, 429]
11-
13+
1214
for i, expected_status_code in enumerate(expected_status_codes):
1315
try:
1416
response = await app_client_rate_limit.get("/collections")
1517
status_code = response.status_code
1618
except RateLimitExceeded:
1719
status_code = 429
18-
20+
1921
logger.info(f"Request {i+1}: Status code {status_code}")
20-
assert status_code == expected_status_code, f"Expected status code {expected_status_code}, but got {status_code}"
22+
assert (
23+
status_code == expected_status_code
24+
), f"Expected status code {expected_status_code}, but got {status_code}"

0 commit comments

Comments
 (0)