|
4 | 4 | from typing import Dict, Optional, Tuple, Union
|
5 | 5 | from urllib.parse import unquote, urljoin, urlparse
|
6 | 6 |
|
7 |
| -import httpx |
8 | 7 | import structlog
|
9 |
| -from fastapi import Request, Response, WebSocket |
10 | 8 |
|
11 | 9 | from codegate.config import Config
|
12 | 10 |
|
13 | 11 | # from codegate.codegate_logging import log_error, log_proxy_forward, logger
|
14 | 12 | from codegate.ca.codegate_ca import CertificateAuthority
|
15 |
| -from codegate.providers.copilot.mapping import VALIDATED_ROUTES |
| 13 | +from codegate.providers.copilot.mapping import VALIDATED_ROUTES |
16 | 14 |
|
17 | 15 | logger = structlog.get_logger("codegate")
|
18 | 16 |
|
@@ -413,177 +411,6 @@ async def get_target_url(cls, path: str) -> Optional[str]:
|
413 | 411 | logger.warning(f"No route found for path: {path}")
|
414 | 412 | return None
|
415 | 413 |
|
416 |
| - @classmethod |
417 |
| - def prepare_headers(cls, request: Union[Request, WebSocket], target_url: str) -> Dict[str, str]: |
418 |
| - """Prepare headers for the proxy request""" |
419 |
| - logger.debug(f"Preparing headers for {target_url}") |
420 |
| - headers = {} |
421 |
| - |
422 |
| - # Get headers from request |
423 |
| - logger.debug("Request headers:") |
424 |
| - if isinstance(request, Request): |
425 |
| - request_headers = request.headers |
426 |
| - else: # WebSocket |
427 |
| - request_headers = request.headers |
428 |
| - |
429 |
| - # Copy preserved headers from the original request |
430 |
| - logger.debug("=" * 40) |
431 |
| - logger.debug("Preserved headers:") |
432 |
| - for header, value in request_headers.items(): |
433 |
| - if header.lower() in [h.lower() for h in settings.PRESERVED_HEADERS]: |
434 |
| - headers[header] = value |
435 |
| - logger.debug("=" * 40) |
436 |
| - |
437 |
| - # Add endpoint-specific headers |
438 |
| - logger.debug("=" * 40) |
439 |
| - logger.debug("Endpoint headers:") |
440 |
| - if isinstance(request, Request): |
441 |
| - path = urlparse(str(request.url)).path |
442 |
| - if path in settings.ENDPOINT_HEADERS: |
443 |
| - headers.update(settings.ENDPOINT_HEADERS[path]) |
444 |
| - |
445 |
| - # Set the Host header to match the target |
446 |
| - target_parsed = urlparse(target_url) |
447 |
| - headers['Host'] = target_parsed.netloc |
448 |
| - |
449 |
| - # Remove any headers that shouldn't be forwarded |
450 |
| - for header in settings.REMOVED_HEADERS: |
451 |
| - headers.pop(header.lower(), None) |
452 |
| - |
453 |
| - # Log headers for debugging |
454 |
| - logger.debug(f"Prepared headers for {target_url}: {headers}") |
455 |
| - logger.debug("=" * 40) |
456 |
| - |
457 |
| - return headers |
458 |
| - |
459 |
| - @classmethod |
460 |
| - async def forward_request( |
461 |
| - cls, |
462 |
| - request: Request, |
463 |
| - target_url: str, |
464 |
| - client: httpx.AsyncClient |
465 |
| - ) -> Tuple[Response, int]: |
466 |
| - """Forward the request to the target URL""" |
467 |
| - logger.debug(f"Forwarding request to {target_url} fn: forward_request") |
468 |
| - try: |
469 |
| - # Prepare headers |
470 |
| - headers = cls.prepare_headers(request, target_url) |
471 |
| - |
472 |
| - # Get request body |
473 |
| - body = await request.body() |
474 |
| - |
475 |
| - logger.debug(f"Forwarding {request.method} request to {target_url}") |
476 |
| - logger.debug(f"Request headers: {headers}") |
477 |
| - if body: |
478 |
| - logger.debug(f"Request body length: {len(body)} bytes") |
479 |
| - |
480 |
| - # Forward the request |
481 |
| - logger.debug(f"Sending request to {target_url}") |
482 |
| - response = await client.request( |
483 |
| - method=request.method, |
484 |
| - url=target_url, |
485 |
| - headers=headers, |
486 |
| - content=body, |
487 |
| - follow_redirects=True |
488 |
| - ) |
489 |
| - |
490 |
| - logger.debug(f"Received response from {target_url}: status={response.status_code}") |
491 |
| - logger.debug(f"Response headers: {dict(response.headers)}") |
492 |
| - |
493 |
| - # Log the forwarded request |
494 |
| - logger.debug(f"Forwarded request to {target_url}: {response.status_code}") |
495 |
| - log_proxy_forward(target_url, request.method, response.status_code) |
496 |
| - |
497 |
| - # Create FastAPI response |
498 |
| - logger.debug(f"Creating FastAPI response for {target_url}") |
499 |
| - return Response( |
500 |
| - content=response.content, |
501 |
| - status_code=response.status_code, |
502 |
| - headers=dict(response.headers) |
503 |
| - ), response.status_code |
504 |
| - |
505 |
| - except httpx.RequestError as e: |
506 |
| - log_error("request_error", str(e), {"target_url": target_url}) |
507 |
| - logger.error(f"Request error for {target_url}: {str(e)}") |
508 |
| - return Response( |
509 |
| - content=str(e).encode(), |
510 |
| - status_code=502, |
511 |
| - media_type="text/plain" |
512 |
| - ), 502 |
513 |
| - |
514 |
| - except Exception as e: |
515 |
| - log_error("proxy_error", str(e), {"target_url": target_url}) |
516 |
| - logger.error(f"Proxy error for {target_url}: {str(e)}") |
517 |
| - return Response( |
518 |
| - content=str(e).encode(), |
519 |
| - status_code=500, |
520 |
| - media_type="text/plain" |
521 |
| - ), 500 |
522 |
| - |
523 |
| - @classmethod |
524 |
| - async def tunnel_websocket(cls, websocket: WebSocket, target_host: str, target_port: int): |
525 |
| - """Create a tunnel between WebSocket and target server""" |
526 |
| - logger.debug(f"Creating WebSocket tunnel to {target_host}:{target_port}") |
527 |
| - try: |
528 |
| - # Connect to target server |
529 |
| - reader, writer = await asyncio.open_connection(target_host, target_port) |
530 |
| - |
531 |
| - # Create bidirectional tunnel |
532 |
| - logger.debug("Creating bidirectional tunnel") |
533 |
| - |
534 |
| - async def forward_ws_to_target(): |
535 |
| - logger.debug("Forwarding WS to target") |
536 |
| - try: |
537 |
| - while True: |
538 |
| - data = await websocket.receive_bytes() |
539 |
| - writer.write(data) |
540 |
| - await writer.drain() |
541 |
| - except Exception as e: |
542 |
| - logger.error(f"WS to target error: {e}") |
543 |
| - |
544 |
| - async def forward_target_to_ws(): |
545 |
| - try: |
546 |
| - while True: |
547 |
| - data = await reader.read(8192) |
548 |
| - if not data: |
549 |
| - break |
550 |
| - await websocket.send_bytes(data) |
551 |
| - except Exception as e: |
552 |
| - logger.error(f"Target to WS error: {e}") |
553 |
| - |
554 |
| - # Run both forwarding tasks |
555 |
| - await asyncio.gather( |
556 |
| - forward_ws_to_target(), |
557 |
| - forward_target_to_ws(), |
558 |
| - return_exceptions=True |
559 |
| - ) |
560 |
| - |
561 |
| - except Exception as e: |
562 |
| - log_error("tunnel_error", str(e)) |
563 |
| - await websocket.close(code=1011, reason=str(e)) |
564 |
| - finally: |
565 |
| - writer.close() |
566 |
| - try: |
567 |
| - await writer.wait_closed() |
568 |
| - except Exception: |
569 |
| - pass |
570 |
| - |
571 |
| - @classmethod |
572 |
| - def create_error_response(cls, status_code: int, message: str) -> Response: |
573 |
| - """Create an error response""" |
574 |
| - content = { |
575 |
| - "error": { |
576 |
| - "message": message, |
577 |
| - "type": "proxy_error", |
578 |
| - "code": status_code |
579 |
| - } |
580 |
| - } |
581 |
| - return Response( |
582 |
| - content=str(content).encode(), |
583 |
| - status_code=status_code, |
584 |
| - media_type="application/json" |
585 |
| - ) |
586 |
| - |
587 | 414 | class CopilotProxyTargetProtocol(asyncio.Protocol):
|
588 | 415 | def __init__(self, proxy: CopilotProvider):
|
589 | 416 | logger.debug("Initializing CopilotProxyTargetProtocol class: CopilotProxyTargetProtocol")
|
|
0 commit comments