|
4 | 4 | from fastapi.responses import JSONResponse, Response |
5 | 5 | from httpx import AsyncClient |
6 | 6 |
|
| 7 | +from app.logger import logger |
| 8 | + |
| 9 | +HOP_BY_HOP = { |
| 10 | + "connection", |
| 11 | + "keep-alive", |
| 12 | + "proxy-authenticate", |
| 13 | + "proxy-authorization", |
| 14 | + "te", |
| 15 | + "trailers", |
| 16 | + "transfer-encoding", |
| 17 | + "upgrade", |
| 18 | +} |
| 19 | + |
| 20 | +def _filtered_request_headers(items: list[tuple[str, str]]) -> dict: |
| 21 | + skip = HOP_BY_HOP | {"host", "content-length"} |
| 22 | + return {k: v for k, v in items if k.lower() not in skip} |
| 23 | + |
7 | 24 |
|
8 | 25 | async def proxy_request(url: str, request: Request) -> Response: |
9 | | - async with AsyncClient() as client: |
10 | | - method = request.method |
11 | | - headers = dict(request.headers) |
12 | | - body = await request.body() |
13 | | - try: |
14 | | - proxied_response = await client.request( |
15 | | - method, url, headers=headers, content=body, params=request.query_params |
16 | | - ) |
17 | | - return Response( |
18 | | - content=proxied_response.content, |
19 | | - status_code=proxied_response.status_code, |
20 | | - headers=dict(proxied_response.headers), |
21 | | - media_type=proxied_response.headers.get("content-type"), |
| 26 | + try: |
| 27 | + async with AsyncClient(follow_redirects=False) as client: |
| 28 | + upstream = await client.request( |
| 29 | + request.method, |
| 30 | + url, |
| 31 | + headers=_filtered_request_headers(request.headers.items()), |
| 32 | + params=request.query_params, |
| 33 | + content=await request.body(), |
22 | 34 | ) |
23 | | - except Exception as e: |
24 | | - return JSONResponse(status_code=502, content={"error": str(e)}) |
| 35 | + resp = Response( |
| 36 | + content=upstream.content, |
| 37 | + status_code=upstream.status_code, |
| 38 | + media_type=upstream.headers.get("content-type"), |
| 39 | + ) |
| 40 | + skip_out = HOP_BY_HOP | {"content-length", "date", "server", "set-cookie"} |
| 41 | + for k, v in upstream.headers.items(): |
| 42 | + if k.lower() in skip_out: |
| 43 | + continue |
| 44 | + resp.headers[k] = v |
| 45 | + set_cookies: list[str] = [] |
| 46 | + if hasattr(upstream.headers, "get_list"): |
| 47 | + set_cookies = upstream.headers.get_list("set-cookie") or [] |
| 48 | + if not set_cookies: |
| 49 | + raw = getattr(upstream.headers, "raw", None) |
| 50 | + if raw is not None: |
| 51 | + for k, v in raw: |
| 52 | + name = k.decode("latin1") if isinstance(k, (bytes, bytearray)) else str(k) |
| 53 | + if name.lower() == "set-cookie": |
| 54 | + value = v.decode("latin1") if isinstance(v, (bytes, bytearray)) else str(v) |
| 55 | + set_cookies.append(value) |
| 56 | + if set_cookies: |
| 57 | + raw_headers = list(resp.raw_headers) |
| 58 | + for sc in set_cookies: |
| 59 | + raw_headers.append((b"set-cookie", sc.encode("latin1"))) |
| 60 | + resp.raw_headers = tuple(raw_headers) |
| 61 | + return resp |
| 62 | + except Exception as e: |
| 63 | + logger.error(f"Proxy request failed: {e}") |
| 64 | + return JSONResponse(status_code=502, content={"code":"internal_error"}) |
25 | 65 |
|
26 | 66 |
|
27 | 67 | def prefix_and_tag_paths( |
|
0 commit comments