Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 615e5cc

Browse files
committed
Automate URL processing from the proxy-ep value from auth header
1 parent 7f24d51 commit 615e5cc

File tree

2 files changed

+84
-59
lines changed

2 files changed

+84
-59
lines changed

src/codegate/providers/copilot/mapping.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,15 @@ class CoPilotMappings(BaseSettings):
1919
("/copilot/proxy", "https://copilot-proxy.githubusercontent.com"),
2020
("/origin-tracker", "https://origin-tracker.githubusercontent.com"),
2121
("/copilot/suggestions", "https://githubcopilot.com"),
22-
("/copilot/enterprise", "https://enterprise.githubcopilot.com"), # will need pr-proxy logic
23-
("/copilot/business", "https://business.githubcopilot.com"), # will need pr-proxy logic
24-
("/copilot/enterprise", "https://enterprise.githubcopilot.com"), # will need pr-proxy logic
25-
("/chat/completions", "https://api.enterprise.githubcopilot.com"), # will need pr-proxy logic
2622
("/copilot_internal/user", "https://api.github.com"),
2723
("/copilot_internal/v2/token", "https://api.github.com"),
28-
("/models", "https://api.enterprise.githubcopilot.com"),
29-
("/agents", "https://api.enterprise.githubcopilot.com"), # will need pr-proxy logic
30-
("/_ping", "https://api.enterprise.githubcopilot.com"), # will need pr-proxy logic
3124
("/telemety", "https://copilot-telemetry.githubusercontent.com"),
3225
("/", "https://github.com"),
3326
("/login/oauth/access_token", "https://github.com/login/oauth/access_token"),
3427
("/api/copilot", "https://api.github.com/copilot_internal"),
3528
("/api/copilot_internal", "https://api.github.com/copilot_internal"),
3629
("/v1/completions", "https://copilot-proxy.githubusercontent.com/v1/completions"),
3730
("/v1", "https://copilot-proxy.githubusercontent.com/v1"),
38-
("v1/engines/copilot-codex/completions", "https://proxy.enterprise.githubcopilot.com/v1/engines/copilot-codex/completions"), # will need pr-proxy logic
3931
]
4032

4133
# Headers configuration

src/codegate/providers/copilot/provider.py

Lines changed: 84 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import structlog
88

99
from codegate.config import Config
10-
11-
# from codegate.codegate_logging import log_error, log_proxy_forward, logger
1210
from codegate.ca.codegate_ca import CertificateAuthority
1311
from codegate.providers.copilot.mapping import VALIDATED_ROUTES
1412

@@ -39,6 +37,7 @@ def __init__(self, loop):
3937
self.target = None
4038
self.original_path = None
4139
self.ssl_context = None
40+
self.proxy_ep = None
4241
self.decrypted_data = bytearray()
4342
# Get the singleton instance of CertificateAuthority
4443
self.ca = CertificateAuthority.get_instance()
@@ -60,6 +59,30 @@ def extract_path(self, full_path: str) -> str:
6059
elif full_path.startswith('/'):
6160
return full_path.lstrip('/')
6261
return full_path
62+
63+
def get_headers(self) -> Dict[str, str]:
64+
"""Get request headers as a dictionary"""
65+
logger.debug("Getting headers as dictionary fn: get_headers")
66+
headers_dict = {}
67+
68+
try:
69+
if b'\r\n\r\n' not in self.buffer:
70+
return {}
71+
72+
headers_end = self.buffer.index(b'\r\n\r\n')
73+
headers = self.buffer[:headers_end].split(b'\r\n')[1:] # Skip request line
74+
75+
for header in headers:
76+
try:
77+
name, value = header.decode('utf-8').split(':', 1)
78+
headers_dict[name.strip().lower()] = value.strip()
79+
except ValueError:
80+
continue
81+
82+
return headers_dict
83+
except Exception as e:
84+
logger.error(f"Error getting headers: {e}")
85+
return {}
6386

6487
def parse_headers(self) -> bool:
6588
logger.debug("Parsing headers fn: parse_headers")
@@ -84,67 +107,69 @@ def parse_headers(self) -> bool:
84107
self.path = self.extract_path(full_path)
85108

86109
self.headers = [header.decode('utf-8') for header in headers[1:]]
87-
88-
logger.debug("=" * 40)
89-
logger.debug("=== Inbound Request ===")
90-
logger.debug(f"Method: {self.method}")
91-
logger.debug(f"Original Path: {self.original_path}")
92-
logger.debug(f"Extracted Path: {self.path}")
93-
logger.debug(f"Version: {self.version}")
94-
logger.debug("Headers:")
95-
96-
logger.debug("=" * 40)
97-
98-
logger.debug("Searching for proxy-ep header value")
99-
proxy_ep_value = None
100-
101-
for header in self.headers:
102-
logger.debug(f" {header}")
103-
if header.lower().startswith("authorization:"):
104-
match = re.search(r"proxy-ep=([^;]+)", header)
105-
if match:
106-
proxy_ep_value = match.group(1)
107-
108-
if proxy_ep_value:
109-
logger.debug(f"Extracted proxy-ep value: {proxy_ep_value}")
110-
else:
111-
logger.debug("proxy-ep value not found.")
112-
logger.debug("=" * 40)
113-
114110
return True
115111
except Exception as e:
116112
logger.error(f"Error parsing headers: {e}")
117113
return False
118114

119115
def log_decrypted_data(self, data: bytes, direction: str):
116+
'''
117+
Uncomment to log data from payload
118+
'''
120119
try:
121-
decoded = data.decode('utf-8')
122-
logger.debug(f"=== Decrypted {direction} Data ===")
123-
logger.debug(decoded)
124-
logger.debug("=" * 40)
120+
# decoded = data.decode('utf-8')
121+
# logger.debug(f"=== Decrypted {direction} Data ===")
122+
# logger.debug(decoded)
123+
# logger.debug("=" * 40)
124+
pass
125125
except UnicodeDecodeError:
126-
logger.debug(f"=== Decrypted {direction} Data (hex) ===")
127-
logger.debug(data.hex())
128-
logger.debug("=" * 40)
126+
# pass
127+
# logger.debug(f"=== Decrypted {direction} Data (hex) ===")
128+
# logger.debug(data.hex())
129+
# logger.debug("=" * 40)
130+
pass
129131

130132
async def handle_http_request(self):
131133
logger.debug("Handling HTTP request fn: handle_http_request")
132134
logger.debug("=" * 40)
133135
logger.debug(f"Method: {self.method}")
134136
logger.debug(f"Searched Path: {self.path} in target URL")
135137
try:
136-
target_url = await self.get_target_url(self.path)
137-
logger.debug(f"Target URL: {target_url}")
138+
# Extract proxy endpoint from authorization header if present
139+
headers_dict = self.get_headers()
140+
auth_header = headers_dict.get('authorization', '')
141+
if auth_header:
142+
match = re.search(r"proxy-ep=([^;]+)", auth_header)
143+
if match:
144+
self.proxy_ep = match.group(1)
145+
logger.debug(f"Extracted proxy-ep value: {self.proxy_ep}")
146+
147+
# Check if the proxy_ep includes a scheme
148+
parsed_proxy_ep = urlparse(self.proxy_ep)
149+
if not parsed_proxy_ep.scheme:
150+
# Default to https if no scheme is provided
151+
self.proxy_ep = f"https://{self.proxy_ep}"
152+
logger.debug(f"Added default scheme to proxy-ep: {self.proxy_ep}")
153+
154+
target_url = f"{self.proxy_ep}/{self.path}"
155+
else:
156+
target_url = await self.get_target_url(self.path)
157+
else:
158+
target_url = await self.get_target_url(self.path)
159+
138160
if not target_url:
139161
self.send_error_response(404, b"Not Found")
140162
return
141-
logger.debug(f"target URL {target_url}")
163+
logger.debug(f"Target URL: {target_url}")
164+
165+
logger.debug(f"Target URL: {target_url}")
142166

143167
parsed_url = urlparse(target_url)
144168
logger.debug(f"Parsed URL {parsed_url}")
169+
145170
self.target_host = parsed_url.hostname
146171
self.target_port = parsed_url.port or (443 if parsed_url.scheme == 'https' else 80)
147-
logger.debug("=" * 40)
172+
148173
target_protocol = CopilotProxyTargetProtocol(self)
149174
logger.debug(f"Connecting to {self.target_host}:{self.target_port}")
150175
await self.loop.create_connection(
@@ -227,7 +252,16 @@ def data_received(self, data: bytes):
227252
logger.error(f"Error in data_received: {e}")
228253
self.send_error_response(502, str(e).encode())
229254

255+
230256
def handle_connect(self):
257+
'''
258+
This where requests are sent directly via the tunnel created during
259+
a CONNECT request. This is where the SSL context is created and the
260+
internal connection is made to the target host.
261+
262+
We do not need to do a URL to mapping, as this passes through the
263+
tunnel with a FQDN already set by the source (client) request.
264+
'''
231265
try:
232266
path = unquote(self.target)
233267
if ':' in path:
@@ -260,7 +294,7 @@ def handle_connect(self):
260294
except Exception as e:
261295
logger.error(f"Error handling CONNECT: {e}")
262296
self.send_error_response(502, str(e).encode())
263-
297+
264298
def send_error_response(self, status: int, message: bytes):
265299
logger.debug(f"Sending error response: {status} {message} fn: send_error_response")
266300
response = (
@@ -397,16 +431,15 @@ async def get_target_url(cls, path: str) -> Optional[str]:
397431

398432
# Then check for prefix match
399433
for route in VALIDATED_ROUTES:
400-
if path.startswith(route.path):
401-
# For prefix matches, keep the rest of the path
402-
remaining_path = path[len(route.path):]
403-
logger.debug(f"Remaining path: {remaining_path}")
404-
# Make sure we don't end up with double slashes
405-
if remaining_path and remaining_path.startswith('/'):
406-
remaining_path = remaining_path[1:]
407-
target = urljoin(str(route.target), remaining_path)
408-
logger.debug(f"Found prefix match: {path} -> {target} (using route {route.path} -> {route.target})")
409-
return target
434+
# For prefix matches, keep the rest of the path
435+
remaining_path = path[len(route.path):]
436+
logger.debug(f"Remaining path: {remaining_path}")
437+
# Make sure we don't end up with double slashes
438+
if remaining_path and remaining_path.startswith('/'):
439+
remaining_path = remaining_path[1:]
440+
target = urljoin(str(route.target), remaining_path)
441+
logger.debug(f"Found prefix match: {path} -> {target} (using route {route.path} -> {route.target})")
442+
return target
410443

411444
logger.warning(f"No route found for path: {path}")
412445
return None

0 commit comments

Comments
 (0)