|
| 1 | +"""访问日志中间件 - 记录请求处理时间""" |
| 2 | + |
| 3 | +import time |
| 4 | +import logging |
| 5 | +from collections.abc import Callable |
| 6 | + |
| 7 | +from fastapi import Request, Response |
| 8 | +from starlette.middleware.base import BaseHTTPMiddleware |
| 9 | + |
| 10 | +# 创建专用的访问日志记录器 |
| 11 | +access_logger = logging.getLogger("access_logger") |
| 12 | + |
| 13 | +# 设置访问日志记录器 |
| 14 | +if not access_logger.handlers: |
| 15 | + handler = logging.StreamHandler() |
| 16 | + formatter = logging.Formatter(fmt="%(asctime)s %(levelname)s: %(message)s", datefmt="%m-%d %H:%M:%S") |
| 17 | + handler.setFormatter(formatter) |
| 18 | + access_logger.addHandler(handler) |
| 19 | + access_logger.setLevel(logging.INFO) |
| 20 | + # 避免传播到根日志记录器,防止重复日志 |
| 21 | + access_logger.propagate = False |
| 22 | + |
| 23 | + |
| 24 | +def _extract_client_ip(request: Request) -> str: |
| 25 | + """提取客户端IP地址""" |
| 26 | + forwarded_for = request.headers.get("x-forwarded-for") |
| 27 | + if forwarded_for: |
| 28 | + return forwarded_for.split(",")[0].strip() |
| 29 | + if request.client: |
| 30 | + return request.client.host |
| 31 | + return "unknown" |
| 32 | + |
| 33 | + |
| 34 | +class AccessLogMiddleware(BaseHTTPMiddleware): |
| 35 | + """访问日志中间件 - 记录请求处理时间""" |
| 36 | + |
| 37 | + def __init__(self, app, logger: logging.Logger = None): |
| 38 | + super().__init__(app) |
| 39 | + self.logger = logger or access_logger |
| 40 | + |
| 41 | + async def dispatch(self, request: Request, call_next: Callable) -> Response: |
| 42 | + """处理请求并记录访问日志""" |
| 43 | + # 记录请求开始时间 |
| 44 | + start_time = time.perf_counter() |
| 45 | + |
| 46 | + # 获取客户端IP |
| 47 | + client_ip = _extract_client_ip(request) |
| 48 | + |
| 49 | + # 处理请求 |
| 50 | + response = await call_next(request) |
| 51 | + |
| 52 | + # 计算处理时间 |
| 53 | + process_time = time.perf_counter() - start_time |
| 54 | + process_time_ms = int(process_time * 1000) # 转换为毫秒 |
| 55 | + |
| 56 | + # 格式化日志消息,添加处理时间 |
| 57 | + log_message = ( |
| 58 | + f"{client_ip}:{request.client.port if request.client else 'unknown'} - " |
| 59 | + f'"{request.method} {request.url.path}{"?" + request.url.query if request.url.query else ""} ' |
| 60 | + f'HTTP/{request.scope["http_version"]}" ' |
| 61 | + f"{response.status_code} - {process_time_ms}ms" |
| 62 | + ) |
| 63 | + |
| 64 | + # 记录日志 |
| 65 | + self.logger.info(log_message) |
| 66 | + |
| 67 | + return response |
0 commit comments