@@ -150,8 +150,16 @@ def __init__(self, loop: asyncio.AbstractEventLoop):
150
150
self .cert_manager = TLSCertDomainManager (self .ca )
151
151
self ._closing = False
152
152
self .pipeline_factory = PipelineFactory (SecretsManager ())
153
+ self .input_pipeline : Optional [CopilotPipeline ] = None
154
+ self .fim_pipeline : Optional [CopilotPipeline ] = None
155
+ # the context as provided by the pipeline
153
156
self .context_tracking : Optional [PipelineContext ] = None
154
157
158
+ def _ensure_pipelines (self ):
159
+ if not self .input_pipeline or not self .fim_pipeline :
160
+ self .input_pipeline = CopilotChatPipeline (self .pipeline_factory )
161
+ self .fim_pipeline = CopilotFimPipeline (self .pipeline_factory )
162
+
155
163
def _select_pipeline (self , method : str , path : str ) -> Optional [CopilotPipeline ]:
156
164
if method != "POST" :
157
165
logger .debug ("Not a POST request, no pipeline selected" )
@@ -161,10 +169,10 @@ def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]:
161
169
if path == route .path :
162
170
if route .pipeline_type == PipelineType .FIM :
163
171
logger .debug ("Selected FIM pipeline" )
164
- return CopilotFimPipeline ( self .pipeline_factory )
172
+ return self .fim_pipeline
165
173
elif route .pipeline_type == PipelineType .CHAT :
166
174
logger .debug ("Selected CHAT pipeline" )
167
- return CopilotChatPipeline ( self .pipeline_factory )
175
+ return self .input_pipeline
168
176
169
177
logger .debug ("No pipeline selected" )
170
178
return None
@@ -181,7 +189,6 @@ async def _body_through_pipeline(
181
189
# if we didn't select any strategy that would change the request
182
190
# let's just pass through the body as-is
183
191
return body , None
184
- logger .debug (f"Processing body through pipeline: { len (body )} bytes" )
185
192
return await strategy .process_body (headers , body )
186
193
187
194
async def _request_to_target (self , headers : list [str ], body : bytes ):
@@ -288,6 +295,9 @@ async def _forward_data_through_pipeline(self, data: bytes) -> Union[HttpRequest
288
295
http_request .headers ,
289
296
http_request .body ,
290
297
)
298
+ # TODO: it's weird that we're overwriting the context.
299
+ # Should we set the context once? Maybe when
300
+ # creating the pipeline instance?
291
301
self .context_tracking = context
292
302
293
303
if context and context .shortcut_response :
@@ -442,6 +452,7 @@ def data_received(self, data: bytes) -> None:
442
452
if not self .headers_parsed :
443
453
self .headers_parsed = self .parse_headers ()
444
454
if self .headers_parsed :
455
+ self ._ensure_pipelines ()
445
456
if self .request .method == "CONNECT" :
446
457
self .handle_connect ()
447
458
self .buffer .clear ()
@@ -756,10 +767,12 @@ def connection_made(self, transport: asyncio.Transport) -> None:
756
767
757
768
def _ensure_output_processor (self ) -> None :
758
769
if self .proxy .context_tracking is None :
770
+ logger .debug ("No context tracking, no need to process pipeline" )
759
771
# No context tracking, no need to process pipeline
760
772
return
761
773
762
774
if self .sse_processor is not None :
775
+ logger .debug ("Already initialized, no need to reinitialize" )
763
776
# Already initialized, no need to reinitialize
764
777
return
765
778
0 commit comments