Skip to content

Commit 5871def

Browse files
committed
feat: 添加登录请求速率限制中间件,防止暴力破解密码
1 parent b36f6e7 commit 5871def

File tree

2 files changed

+183
-1
lines changed

2 files changed

+183
-1
lines changed

server/main.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
import asyncio
2+
import time
3+
from collections import defaultdict, deque
4+
15
import uvicorn
2-
from fastapi import FastAPI, Request
6+
from fastapi import FastAPI, Request, status
37
from fastapi.middleware.cors import CORSMiddleware
8+
from fastapi.responses import JSONResponse
49
from starlette.middleware.base import BaseHTTPMiddleware
510

611
from server.routers import router
@@ -11,6 +16,14 @@
1116
# 设置日志配置
1217
setup_logging()
1318

19+
RATE_LIMIT_MAX_ATTEMPTS = 10
20+
RATE_LIMIT_WINDOW_SECONDS = 60
21+
RATE_LIMIT_ENDPOINTS = {("/api/auth/token", "POST")}
22+
23+
# In-memory login attempt tracker to reduce brute-force exposure per worker
24+
_login_attempts: defaultdict[str, deque[float]] = defaultdict(deque)
25+
_attempt_lock = asyncio.Lock()
26+
1427
app = FastAPI()
1528
app.include_router(router, prefix="/api")
1629

@@ -24,6 +37,51 @@
2437
)
2538

2639

40+
def _extract_client_ip(request: Request) -> str:
41+
forwarded_for = request.headers.get("x-forwarded-for")
42+
if forwarded_for:
43+
return forwarded_for.split(",")[0].strip()
44+
if request.client:
45+
return request.client.host
46+
return "unknown"
47+
48+
49+
class LoginRateLimitMiddleware(BaseHTTPMiddleware):
50+
async def dispatch(self, request: Request, call_next):
51+
normalized_path = request.url.path.rstrip("/") or "/"
52+
request_signature = (normalized_path, request.method.upper())
53+
54+
if request_signature in RATE_LIMIT_ENDPOINTS:
55+
client_ip = _extract_client_ip(request)
56+
now = time.monotonic()
57+
58+
async with _attempt_lock:
59+
attempt_history = _login_attempts[client_ip]
60+
61+
while attempt_history and now - attempt_history[0] > RATE_LIMIT_WINDOW_SECONDS:
62+
attempt_history.popleft()
63+
64+
if len(attempt_history) >= RATE_LIMIT_MAX_ATTEMPTS:
65+
retry_after = int(max(1, RATE_LIMIT_WINDOW_SECONDS - (now - attempt_history[0])))
66+
return JSONResponse(
67+
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
68+
content={"detail": "登录尝试过于频繁,请稍后再试"},
69+
headers={"Retry-After": str(retry_after)},
70+
)
71+
72+
attempt_history.append(now)
73+
74+
response = await call_next(request)
75+
76+
if response.status_code < 400:
77+
async with _attempt_lock:
78+
_login_attempts.pop(client_ip, None)
79+
80+
return response
81+
82+
return await call_next(request)
83+
84+
2785
# 鉴权中间件
2886
class AuthMiddleware(BaseHTTPMiddleware):
2987
async def dispatch(self, request: Request, call_next):
@@ -58,6 +116,7 @@ async def dispatch(self, request: Request, call_next):
58116

59117

60118
# 添加鉴权中间件
119+
app.add_middleware(LoginRateLimitMiddleware)
61120
app.add_middleware(AuthMiddleware)
62121

63122

test/bruteforce_simulation.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
Quick script to hammer the login endpoint with invalid credentials and observe rate limiting.
3+
4+
Usage:
5+
uv run python test/bruteforce_simulation.py --username demo --attempts 20
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import argparse
11+
import asyncio
12+
import os
13+
import random
14+
import string
15+
import sys
16+
import time
17+
from collections import Counter
18+
19+
import httpx
20+
21+
22+
def parse_args() -> argparse.Namespace:
23+
parser = argparse.ArgumentParser(description="Simulate brute-force login attempts.")
24+
parser.add_argument("--base-url", default=os.getenv("TEST_BASE_URL", "http://localhost:5050"), help="API base URL")
25+
parser.add_argument("--username", default=os.getenv("TEST_USERNAME", "admin"), help="Login identifier to attack")
26+
parser.add_argument(
27+
"--attempts", type=int, default=20, help="Total number of attempts to issue (default: 20)"
28+
)
29+
parser.add_argument(
30+
"--concurrency",
31+
type=int,
32+
default=4,
33+
help="Concurrent request limit (default: 4)",
34+
)
35+
parser.add_argument(
36+
"--delay",
37+
type=float,
38+
default=0.05,
39+
help="Delay in seconds between scheduling attempts (default: 0.05)",
40+
)
41+
parser.add_argument(
42+
"--password",
43+
default=None,
44+
help="Explicit password to reuse for each request; random values are used when omitted",
45+
)
46+
return parser.parse_args()
47+
48+
49+
def build_payload(username: str, password: str) -> dict[str, str]:
50+
return {"username": username, "password": password}
51+
52+
53+
def random_password() -> str:
54+
base = string.ascii_letters + string.digits + "!@#$%^&*"
55+
return "".join(random.choices(base, k=12))
56+
57+
58+
async def attempt_login(
59+
client: httpx.AsyncClient,
60+
semaphore: asyncio.Semaphore,
61+
attempt_no: int,
62+
username: str,
63+
static_password: str | None,
64+
) -> tuple[int, float]:
65+
async with semaphore:
66+
password = static_password or random_password()
67+
payload = build_payload(username, password)
68+
started = time.perf_counter()
69+
response = await client.post("/api/auth/token", data=payload)
70+
elapsed = time.perf_counter() - started
71+
detail = response.json().get("detail") if response.headers.get("content-type", "").startswith("application/json") else response.text
72+
print(
73+
f"[{attempt_no:02d}] {response.status_code} in {elapsed*1000:.1f} ms "
74+
f"(pwd={password!r}) detail={detail!r}"
75+
)
76+
return response.status_code, elapsed
77+
78+
79+
async def run_simulation(args: argparse.Namespace) -> int:
80+
timeout = httpx.Timeout(10.0, connect=3.0)
81+
limits = httpx.Limits(max_connections=args.concurrency, max_keepalive_connections=args.concurrency)
82+
semaphore = asyncio.Semaphore(args.concurrency)
83+
status_counts: Counter[int] = Counter()
84+
85+
async with httpx.AsyncClient(base_url=args.base_url.rstrip("/"), timeout=timeout, limits=limits) as client:
86+
tasks = []
87+
for attempt_no in range(1, args.attempts + 1):
88+
tasks.append(
89+
asyncio.create_task(
90+
attempt_login(client, semaphore, attempt_no, args.username, args.password)
91+
)
92+
)
93+
if args.delay:
94+
await asyncio.sleep(args.delay)
95+
96+
for task in asyncio.as_completed(tasks):
97+
status_code, _ = await task
98+
status_counts[status_code] += 1
99+
100+
print("\nSummary:")
101+
for code, total in sorted(status_counts.items()):
102+
print(f" HTTP {code}: {total} hits")
103+
104+
if 429 in status_counts:
105+
print("Rate limiting engaged (HTTP 429 encountered).")
106+
return 0
107+
108+
print("No rate limiting observed; consider increasing attempt count or reducing delay.")
109+
return 1
110+
111+
112+
def main() -> None:
113+
args = parse_args()
114+
try:
115+
exit_code = asyncio.run(run_simulation(args))
116+
except KeyboardInterrupt:
117+
print("\nInterrupted.")
118+
exit_code = 130
119+
sys.exit(exit_code)
120+
121+
122+
if __name__ == "__main__":
123+
main()

0 commit comments

Comments
 (0)