Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit d2c6069

Browse files
authored
Merge pull request #353 from jhrozek/fix_fim
A hotfix for the FIM pipeline
2 parents d9e2ddb + b6af423 commit d2c6069

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

src/codegate/pipeline/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def __init__(
267267
self.secret_manager = secret_manager
268268
self.is_fim = is_fim
269269
self.context = PipelineContext()
270+
self.context.metadata["is_fim"] = is_fim
270271

271272
async def process_request(
272273
self,

src/codegate/providers/copilot/provider.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,35 @@ def connection_made(self, transport: asyncio.Transport) -> None:
566566
self.transport = transport
567567
self.proxy.target_transport = transport
568568

569+
def _ensure_output_processor(self) -> None:
570+
if self.proxy.context_tracking is None:
571+
# No context tracking, no need to process pipeline
572+
return
573+
574+
if self.sse_processor is not None:
575+
# Already initialized, no need to reinitialize
576+
return
577+
578+
# this is a hotfix - we shortcut before selecting the output pipeline for FIM
579+
# because our FIM output pipeline is actually empty as of now. We should fix this
580+
# but don't have any immediate need.
581+
is_fim = self.proxy.context_tracking.metadata.get("is_fim", False)
582+
if is_fim:
583+
return
584+
585+
logger.debug("Tracking context for pipeline processing")
586+
self.sse_processor = SSEProcessor()
587+
is_fim = self.proxy.context_tracking.metadata.get("is_fim", False)
588+
if is_fim:
589+
out_pipeline_processor = self.proxy.pipeline_factory.create_fim_output_pipeline()
590+
else:
591+
out_pipeline_processor = self.proxy.pipeline_factory.create_output_pipeline()
592+
593+
self.output_pipeline_instance = OutputPipelineInstance(
594+
pipeline_steps=out_pipeline_processor.pipeline_steps,
595+
input_context=self.proxy.context_tracking,
596+
)
597+
569598
async def _process_stream(self):
570599
try:
571600

@@ -633,14 +662,7 @@ def _proxy_transport_write(self, data: bytes):
633662

634663
def data_received(self, data: bytes) -> None:
635664
"""Handle data received from target"""
636-
if self.proxy.context_tracking is not None and self.sse_processor is None:
637-
logger.debug("Tracking context for pipeline processing")
638-
self.sse_processor = SSEProcessor()
639-
out_pipeline_processor = self.proxy.pipeline_factory.create_output_pipeline()
640-
self.output_pipeline_instance = OutputPipelineInstance(
641-
pipeline_steps=out_pipeline_processor.pipeline_steps,
642-
input_context=self.proxy.context_tracking,
643-
)
665+
self._ensure_output_processor()
644666

645667
if self.proxy.transport and not self.proxy.transport.is_closing():
646668
if not self.sse_processor:

0 commit comments

Comments
 (0)