Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit eaf9dbd

Browse files
committed
Fix issue causing copilot to hang after creating multiple sessions
1 parent 246c9cd commit eaf9dbd

File tree

2 files changed

+48
-23
lines changed

2 files changed

+48
-23
lines changed

src/codegate/pipeline/output.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ async def _record_to_db(self):
115115
await self._db_recorder.record_context(self._input_context)
116116

117117
async def process_stream(
118-
self, stream: AsyncIterator[ModelResponse]
118+
self, stream: AsyncIterator[ModelResponse], cleanup_sensitive: bool = True
119119
) -> AsyncIterator[ModelResponse]:
120120
"""
121121
Process a stream through all pipeline steps
@@ -182,7 +182,7 @@ async def process_stream(
182182
self._context.buffer.clear()
183183

184184
# Cleanup sensitive data through the input context
185-
if self._input_context and self._input_context.sensitive:
185+
if cleanup_sensitive and self._input_context and self._input_context.sensitive:
186186
self._input_context.sensitive.secure_cleanup()
187187

188188

src/codegate/providers/copilot/provider.py

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -656,10 +656,15 @@ def __init__(self, proxy: CopilotProvider):
656656
self.stream_queue: Optional[asyncio.Queue] = None
657657
self.processing_task: Optional[asyncio.Task] = None
658658

659+
self.finish_stream = False
660+
661+
# For debugging only
662+
# self.data_sent = []
663+
659664
def connection_made(self, transport: asyncio.Transport) -> None:
660665
"""Handle successful connection to target"""
661666
self.transport = transport
662-
logger.debug(f"Target transport peer: {transport.get_extra_info('peername')}")
667+
logger.debug(f"Connection established to target: {transport.get_extra_info('peername')}")
663668
self.proxy.target_transport = transport
664669

665670
def _ensure_output_processor(self) -> None:
@@ -688,7 +693,7 @@ async def _process_stream(self):
688693
try:
689694

690695
async def stream_iterator():
691-
while True:
696+
while not self.stream_queue.empty():
692697
incoming_record = await self.stream_queue.get()
693698

694699
record_content = incoming_record.get("content", {})
@@ -701,6 +706,9 @@ async def stream_iterator():
701706
else:
702707
content = choice.get("delta", {}).get("content")
703708

709+
if choice.get("finish_reason", None) == "stop":
710+
self.finish_stream = True
711+
704712
streaming_choices.append(
705713
StreamingChoices(
706714
finish_reason=choice.get("finish_reason", None),
@@ -722,22 +730,18 @@ async def stream_iterator():
722730
)
723731
yield mr
724732

725-
async for record in self.output_pipeline_instance.process_stream(stream_iterator()):
733+
async for record in self.output_pipeline_instance.process_stream(
734+
stream_iterator(), cleanup_sensitive=False
735+
):
726736
chunk = record.model_dump_json(exclude_none=True, exclude_unset=True)
727737
sse_data = f"data: {chunk}\n\n".encode("utf-8")
728738
chunk_size = hex(len(sse_data))[2:] + "\r\n"
729739
self._proxy_transport_write(chunk_size.encode())
730740
self._proxy_transport_write(sse_data)
731741
self._proxy_transport_write(b"\r\n")
732742

733-
sse_data = b"data: [DONE]\n\n"
734-
# Add chunk size for DONE message too
735-
chunk_size = hex(len(sse_data))[2:] + "\r\n"
736-
self._proxy_transport_write(chunk_size.encode())
737-
self._proxy_transport_write(sse_data)
738-
self._proxy_transport_write(b"\r\n")
739-
# Now send the final zero chunk
740-
self._proxy_transport_write(b"0\r\n\r\n")
743+
if self.finish_stream:
744+
self.finish_data()
741745

742746
except asyncio.CancelledError:
743747
logger.debug("Stream processing cancelled")
@@ -746,12 +750,37 @@ async def stream_iterator():
746750
logger.error(f"Error processing stream: {e}")
747751
finally:
748752
# Clean up
753+
self.stream_queue = None
749754
if self.processing_task and not self.processing_task.done():
750755
self.processing_task.cancel()
751-
if self.proxy.context_tracking and self.proxy.context_tracking.sensitive:
752-
self.proxy.context_tracking.sensitive.secure_cleanup()
756+
757+
def finish_data(self):
758+
logger.debug("Finishing data stream")
759+
sse_data = b"data: [DONE]\n\n"
760+
# Add chunk size for DONE message too
761+
chunk_size = hex(len(sse_data))[2:] + "\r\n"
762+
self._proxy_transport_write(chunk_size.encode())
763+
self._proxy_transport_write(sse_data)
764+
self._proxy_transport_write(b"\r\n")
765+
# Now send the final zero chunk
766+
self._proxy_transport_write(b"0\r\n\r\n")
767+
768+
# For debugging only
769+
# print("===========START DATA SENT====================")
770+
# for data in self.data_sent:
771+
# print(data)
772+
# self.data_sent = []
773+
# print("===========START DATA SENT====================")
774+
775+
self.finish_stream = False
776+
self.headers_sent = False
753777

754778
def _process_chunk(self, chunk: bytes):
779+
# For debugging only
780+
# print("===========START DATA RECVD====================")
781+
# print(chunk)
782+
# print("===========END DATA RECVD======================")
783+
755784
records = self.sse_processor.process_chunk(chunk)
756785

757786
for record in records:
@@ -763,14 +792,12 @@ def _process_chunk(self, chunk: bytes):
763792
self.stream_queue.put_nowait(record)
764793

765794
def _proxy_transport_write(self, data: bytes):
795+
# For debugging only
796+
# self.data_sent.append(data)
766797
if not self.proxy.transport or self.proxy.transport.is_closing():
767798
logger.error("Proxy transport not available")
768799
return
769800
self.proxy.transport.write(data)
770-
# print("DEBUG =================================")
771-
# print(data)
772-
# print("DEBUG =================================")
773-
774801

775802
def data_received(self, data: bytes) -> None:
776803
"""Handle data received from target"""
@@ -788,7 +815,7 @@ def data_received(self, data: bytes) -> None:
788815
if header_end != -1:
789816
self.headers_sent = True
790817
# Send headers first
791-
headers = data[: header_end]
818+
headers = data[:header_end]
792819

793820
# If Transfer-Encoding is not present, add it
794821
if b"Transfer-Encoding:" not in headers:
@@ -800,15 +827,13 @@ def data_received(self, data: bytes) -> None:
800827
logger.debug(f"Headers sent: {headers}")
801828

802829
data = data[header_end + 4 :]
803-
# print("DEBUG =================================")
804-
# print(data)
805-
# print("DEBUG =================================")
806830

807831
self._process_chunk(data)
808832

809833
def connection_lost(self, exc: Optional[Exception]) -> None:
810834
"""Handle connection loss to target"""
811835

836+
logger.debug("Lost connection to target")
812837
if (
813838
not self.proxy._closing
814839
and self.proxy.transport

0 commit comments

Comments
 (0)