7
7
import structlog
8
8
9
9
from codegate .config import Config
10
-
11
- # from codegate.codegate_logging import log_error, log_proxy_forward, logger
12
10
from codegate .ca .codegate_ca import CertificateAuthority
13
11
from codegate .providers .copilot .mapping import VALIDATED_ROUTES
14
12
@@ -39,6 +37,7 @@ def __init__(self, loop):
39
37
self .target = None
40
38
self .original_path = None
41
39
self .ssl_context = None
40
+ self .proxy_ep = None
42
41
self .decrypted_data = bytearray ()
43
42
# Get the singleton instance of CertificateAuthority
44
43
self .ca = CertificateAuthority .get_instance ()
@@ -60,6 +59,30 @@ def extract_path(self, full_path: str) -> str:
60
59
elif full_path .startswith ('/' ):
61
60
return full_path .lstrip ('/' )
62
61
return full_path
62
+
63
+ def get_headers (self ) -> Dict [str , str ]:
64
+ """Get request headers as a dictionary"""
65
+ logger .debug ("Getting headers as dictionary fn: get_headers" )
66
+ headers_dict = {}
67
+
68
+ try :
69
+ if b'\r \n \r \n ' not in self .buffer :
70
+ return {}
71
+
72
+ headers_end = self .buffer .index (b'\r \n \r \n ' )
73
+ headers = self .buffer [:headers_end ].split (b'\r \n ' )[1 :] # Skip request line
74
+
75
+ for header in headers :
76
+ try :
77
+ name , value = header .decode ('utf-8' ).split (':' , 1 )
78
+ headers_dict [name .strip ().lower ()] = value .strip ()
79
+ except ValueError :
80
+ continue
81
+
82
+ return headers_dict
83
+ except Exception as e :
84
+ logger .error (f"Error getting headers: { e } " )
85
+ return {}
63
86
64
87
def parse_headers (self ) -> bool :
65
88
logger .debug ("Parsing headers fn: parse_headers" )
@@ -84,67 +107,69 @@ def parse_headers(self) -> bool:
84
107
self .path = self .extract_path (full_path )
85
108
86
109
self .headers = [header .decode ('utf-8' ) for header in headers [1 :]]
87
-
88
- logger .debug ("=" * 40 )
89
- logger .debug ("=== Inbound Request ===" )
90
- logger .debug (f"Method: { self .method } " )
91
- logger .debug (f"Original Path: { self .original_path } " )
92
- logger .debug (f"Extracted Path: { self .path } " )
93
- logger .debug (f"Version: { self .version } " )
94
- logger .debug ("Headers:" )
95
-
96
- logger .debug ("=" * 40 )
97
-
98
- logger .debug ("Searching for proxy-ep header value" )
99
- proxy_ep_value = None
100
-
101
- for header in self .headers :
102
- logger .debug (f" { header } " )
103
- if header .lower ().startswith ("authorization:" ):
104
- match = re .search (r"proxy-ep=([^;]+)" , header )
105
- if match :
106
- proxy_ep_value = match .group (1 )
107
-
108
- if proxy_ep_value :
109
- logger .debug (f"Extracted proxy-ep value: { proxy_ep_value } " )
110
- else :
111
- logger .debug ("proxy-ep value not found." )
112
- logger .debug ("=" * 40 )
113
-
114
110
return True
115
111
except Exception as e :
116
112
logger .error (f"Error parsing headers: { e } " )
117
113
return False
118
114
119
115
def log_decrypted_data (self , data : bytes , direction : str ):
116
+ '''
117
+ Uncomment to log data from payload
118
+ '''
120
119
try :
121
- decoded = data .decode ('utf-8' )
122
- logger .debug (f"=== Decrypted { direction } Data ===" )
123
- logger .debug (decoded )
124
- logger .debug ("=" * 40 )
120
+ # decoded = data.decode('utf-8')
121
+ # logger.debug(f"=== Decrypted {direction} Data ===")
122
+ # logger.debug(decoded)
123
+ # logger.debug("=" * 40)
124
+ pass
125
125
except UnicodeDecodeError :
126
- logger .debug (f"=== Decrypted { direction } Data (hex) ===" )
127
- logger .debug (data .hex ())
128
- logger .debug ("=" * 40 )
126
+ # pass
127
+ # logger.debug(f"=== Decrypted {direction} Data (hex) ===")
128
+ # logger.debug(data.hex())
129
+ # logger.debug("=" * 40)
130
+ pass
129
131
130
132
async def handle_http_request (self ):
131
133
logger .debug ("Handling HTTP request fn: handle_http_request" )
132
134
logger .debug ("=" * 40 )
133
135
logger .debug (f"Method: { self .method } " )
134
136
logger .debug (f"Searched Path: { self .path } in target URL" )
135
137
try :
136
- target_url = await self .get_target_url (self .path )
137
- logger .debug (f"Target URL: { target_url } " )
138
+ # Extract proxy endpoint from authorization header if present
139
+ headers_dict = self .get_headers ()
140
+ auth_header = headers_dict .get ('authorization' , '' )
141
+ if auth_header :
142
+ match = re .search (r"proxy-ep=([^;]+)" , auth_header )
143
+ if match :
144
+ self .proxy_ep = match .group (1 )
145
+ logger .debug (f"Extracted proxy-ep value: { self .proxy_ep } " )
146
+
147
+ # Check if the proxy_ep includes a scheme
148
+ parsed_proxy_ep = urlparse (self .proxy_ep )
149
+ if not parsed_proxy_ep .scheme :
150
+ # Default to https if no scheme is provided
151
+ self .proxy_ep = f"https://{ self .proxy_ep } "
152
+ logger .debug (f"Added default scheme to proxy-ep: { self .proxy_ep } " )
153
+
154
+ target_url = f"{ self .proxy_ep } /{ self .path } "
155
+ else :
156
+ target_url = await self .get_target_url (self .path )
157
+ else :
158
+ target_url = await self .get_target_url (self .path )
159
+
138
160
if not target_url :
139
161
self .send_error_response (404 , b"Not Found" )
140
162
return
141
- logger .debug (f"target URL { target_url } " )
163
+ logger .debug (f"Target URL: { target_url } " )
164
+
165
+ logger .debug (f"Target URL: { target_url } " )
142
166
143
167
parsed_url = urlparse (target_url )
144
168
logger .debug (f"Parsed URL { parsed_url } " )
169
+
145
170
self .target_host = parsed_url .hostname
146
171
self .target_port = parsed_url .port or (443 if parsed_url .scheme == 'https' else 80 )
147
- logger . debug ( "=" * 40 )
172
+
148
173
target_protocol = CopilotProxyTargetProtocol (self )
149
174
logger .debug (f"Connecting to { self .target_host } :{ self .target_port } " )
150
175
await self .loop .create_connection (
@@ -227,7 +252,16 @@ def data_received(self, data: bytes):
227
252
logger .error (f"Error in data_received: { e } " )
228
253
self .send_error_response (502 , str (e ).encode ())
229
254
255
+
230
256
def handle_connect (self ):
257
+ '''
258
+ This where requests are sent directly via the tunnel created during
259
+ a CONNECT request. This is where the SSL context is created and the
260
+ internal connection is made to the target host.
261
+
262
+ We do not need to do a URL to mapping, as this passes through the
263
+ tunnel with a FQDN already set by the source (client) request.
264
+ '''
231
265
try :
232
266
path = unquote (self .target )
233
267
if ':' in path :
@@ -260,7 +294,7 @@ def handle_connect(self):
260
294
except Exception as e :
261
295
logger .error (f"Error handling CONNECT: { e } " )
262
296
self .send_error_response (502 , str (e ).encode ())
263
-
297
+
264
298
def send_error_response (self , status : int , message : bytes ):
265
299
logger .debug (f"Sending error response: { status } { message } fn: send_error_response" )
266
300
response = (
@@ -397,16 +431,15 @@ async def get_target_url(cls, path: str) -> Optional[str]:
397
431
398
432
# Then check for prefix match
399
433
for route in VALIDATED_ROUTES :
400
- if path .startswith (route .path ):
401
- # For prefix matches, keep the rest of the path
402
- remaining_path = path [len (route .path ):]
403
- logger .debug (f"Remaining path: { remaining_path } " )
404
- # Make sure we don't end up with double slashes
405
- if remaining_path and remaining_path .startswith ('/' ):
406
- remaining_path = remaining_path [1 :]
407
- target = urljoin (str (route .target ), remaining_path )
408
- logger .debug (f"Found prefix match: { path } -> { target } (using route { route .path } -> { route .target } )" )
409
- return target
434
+ # For prefix matches, keep the rest of the path
435
+ remaining_path = path [len (route .path ):]
436
+ logger .debug (f"Remaining path: { remaining_path } " )
437
+ # Make sure we don't end up with double slashes
438
+ if remaining_path and remaining_path .startswith ('/' ):
439
+ remaining_path = remaining_path [1 :]
440
+ target = urljoin (str (route .target ), remaining_path )
441
+ logger .debug (f"Found prefix match: { path } -> { target } (using route { route .path } -> { route .target } )" )
442
+ return target
410
443
411
444
logger .warning (f"No route found for path: { path } " )
412
445
return None
0 commit comments