8
8
9
9
from codegate .config import Config
10
10
from codegate .ca .codegate_ca import CertificateAuthority
11
- from codegate .providers .copilot .mapping import VALIDATED_ROUTES
11
+ from codegate .providers .copilot .mapping import VALIDATED_ROUTES
12
12
13
13
logger = structlog .get_logger ("codegate" )
14
14
@@ -59,7 +59,7 @@ def extract_path(self, full_path: str) -> str:
59
59
elif full_path .startswith ('/' ):
60
60
return full_path .lstrip ('/' )
61
61
return full_path
62
-
62
+
63
63
def get_headers (self ) -> Dict [str , str ]:
64
64
"""Get request headers as a dictionary"""
65
65
logger .debug ("Getting headers as dictionary fn: get_headers" )
@@ -78,7 +78,7 @@ def get_headers(self) -> Dict[str, str]:
78
78
headers_dict [name .strip ().lower ()] = value .strip ()
79
79
except ValueError :
80
80
continue
81
-
81
+
82
82
return headers_dict
83
83
except Exception as e :
84
84
logger .error (f"Error getting headers: { e } " )
@@ -222,37 +222,58 @@ async def handle_http_request(self):
222
222
logger .error (f"Error handling HTTP request: { e } " )
223
223
self .send_error_response (502 , str (e ).encode ())
224
224
225
- def data_received (self , data : bytes ):
225
+ def _check_buffer_size (self , new_data : bytes ) -> bool :
226
+ """Check if adding new data would exceed the maximum buffer size"""
227
+ return len (self .buffer ) + len (new_data ) <= MAX_BUFFER_SIZE
228
+
229
+ def _handle_parsed_headers (self ) -> None :
230
+ """Handle the request based on parsed headers"""
231
+ if self .method == 'CONNECT' :
232
+ logger .debug ("Handling CONNECT request" )
233
+ self .handle_connect ()
234
+ else :
235
+ logger .debug ("Handling HTTP request" )
236
+ asyncio .create_task (self .handle_http_request ())
237
+
238
+ def _forward_data_to_target (self , data : bytes ) -> None :
239
+ """Forward data to target if connection is established"""
240
+ if self .target_transport and not self .target_transport .is_closing ():
241
+ self .log_decrypted_data (data , "Client to Server" )
242
+ self .target_transport .write (data )
243
+
244
+ def data_received (self , data : bytes ) -> None :
245
+ """Handle received data from the client"""
226
246
logger .debug (f"Data received from { self .peername } fn: data_received" )
227
247
228
248
try :
229
- if len (self .buffer ) + len (data ) > MAX_BUFFER_SIZE :
230
- logger .error ("Request too large" )
249
+ # Check buffer size limit
250
+ if not self ._check_buffer_size (data ):
251
+ logger .error ("Request exceeds maximum buffer size" )
231
252
self .send_error_response (413 , b"Request body too large" )
232
253
return
233
254
255
+ # Append new data to buffer
234
256
self .buffer .extend (data )
235
257
236
258
if not self .headers_parsed :
259
+ # Try to parse headers
237
260
self .headers_parsed = self .parse_headers ()
238
261
if not self .headers_parsed :
239
262
return
240
263
241
- if self .method == 'CONNECT' :
242
- logger .debug ("Handling CONNECT request" )
243
- self .handle_connect ()
244
- else :
245
- logger .debug ("Handling HTTP request" )
246
- asyncio .create_task (self .handle_http_request ())
247
- elif self .target_transport and not self .target_transport .is_closing ():
248
- self .log_decrypted_data (data , "Client to Server" )
249
- self .target_transport .write (data )
264
+ # Handle the request based on parsed headers
265
+ self ._handle_parsed_headers ()
266
+ else :
267
+ # Forward data to target if headers are already parsed
268
+ self ._forward_data_to_target (data )
250
269
270
+ except asyncio .CancelledError :
271
+ logger .warning ("Operation cancelled" )
272
+ raise
251
273
except Exception as e :
252
- logger .error (f"Error in data_received : { e } " )
274
+ logger .error (f"Error processing received data : { e } " )
253
275
self .send_error_response (502 , str (e ).encode ())
254
276
255
-
256
277
def handle_connect (self ):
257
278
'''
258
279
This where requests are sent directly via the tunnel created during
@@ -294,7 +315,7 @@ def handle_connect(self):
294
315
except Exception as e :
295
316
logger .error (f"Error handling CONNECT: { e } " )
296
317
self .send_error_response (502 , str (e ).encode ())
297
-
318
+
298
319
def send_error_response (self , status : int , message : bytes ):
299
320
logger .debug (f"Sending error response: { status } { message } fn: send_error_response" )
300
321
response = (
0 commit comments