Skip to content

Commit 0f402e5

Browse files
committed
cleaning and test fixes
1 parent 66d4dfb commit 0f402e5

File tree

4 files changed

+28
-19
lines changed

4 files changed

+28
-19
lines changed

stac_fastapi/core/stac_fastapi/core/rate_limit.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import logging
44
import os
5+
from typing import Optional
56

67
from fastapi import FastAPI, Request
78
from slowapi import Limiter, _rate_limit_exceeded_handler
@@ -11,22 +12,32 @@
1112

1213
logger = logging.getLogger(__name__)
1314

14-
limiter = Limiter(key_func=get_remote_address)
15+
def get_limiter(key_func=get_remote_address):
16+
return Limiter(key_func=key_func)
1517

16-
17-
def setup_rate_limit(app: FastAPI):
18+
def setup_rate_limit(
19+
app: FastAPI,
20+
rate_limit: Optional[str] = None,
21+
key_func=get_remote_address
22+
):
1823
"""Set up rate limiting middleware."""
19-
RATE_LIMIT = os.getenv("STAC_FASTAPI_RATE_LIMIT")
24+
RATE_LIMIT = rate_limit or os.getenv("STAC_FASTAPI_RATE_LIMIT")
25+
26+
if not RATE_LIMIT:
27+
logger.info("Rate limiting is disabled")
28+
return
29+
2030
logger.info(f"Setting up rate limit with RATE_LIMIT={RATE_LIMIT}")
21-
if RATE_LIMIT:
22-
app.state.limiter = limiter
23-
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
24-
app.add_middleware(SlowAPIMiddleware)
25-
26-
@app.middleware("http")
27-
@limiter.limit(RATE_LIMIT)
28-
async def rate_limit_middleware(request: Request, call_next):
29-
response = await call_next(request)
30-
return response
31+
32+
limiter = get_limiter(key_func)
33+
app.state.limiter = limiter
34+
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
35+
app.add_middleware(SlowAPIMiddleware)
36+
37+
@app.middleware("http")
38+
@limiter.limit(RATE_LIMIT)
39+
async def rate_limit_middleware(request: Request, call_next):
40+
response = await call_next(request)
41+
return response
3142

3243
logger.info("Rate limit setup complete")

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@
9999
app.root_path = os.getenv("STAC_FASTAPI_ROOT_PATH", "")
100100

101101
# Add rate limit
102-
setup_rate_limit(app)
102+
setup_rate_limit(app, rate_limit=os.getenv("STAC_FASTAPI_RATE_LIMIT"))
103103

104104

105105
@app.on_event("startup")

stac_fastapi/opensearch/stac_fastapi/opensearch/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@
9999
app.root_path = os.getenv("STAC_FASTAPI_ROOT_PATH", "")
100100

101101
# Add rate limit
102-
setup_rate_limit(app)
102+
setup_rate_limit(app, rate_limit=os.getenv("STAC_FASTAPI_RATE_LIMIT"))
103103

104104

105105
@app.on_event("startup")

stac_fastapi/tests/conftest.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,7 @@ async def app_rate_limit():
292292
).app
293293

294294
# Set up rate limit
295-
os.environ["STAC_FASTAPI_RATE_LIMIT"] = "2/minute"
296-
setup_rate_limit(app)
297-
del os.environ["STAC_FASTAPI_RATE_LIMIT"]
295+
setup_rate_limit(app, rate_limit="2/minute")
298296

299297
return app
300298

0 commit comments

Comments
 (0)