@@ -656,10 +656,15 @@ def __init__(self, proxy: CopilotProvider):
656
656
self .stream_queue : Optional [asyncio .Queue ] = None
657
657
self .processing_task : Optional [asyncio .Task ] = None
658
658
659
+ self .finish_stream = False
660
+
661
+ # For debugging only
662
+ # self.data_sent = []
663
+
659
664
def connection_made (self , transport : asyncio .Transport ) -> None :
660
665
"""Handle successful connection to target"""
661
666
self .transport = transport
662
- logger .debug (f"Target transport peer : { transport .get_extra_info ('peername' )} " )
667
+ logger .debug (f"Connection established to target : { transport .get_extra_info ('peername' )} " )
663
668
self .proxy .target_transport = transport
664
669
665
670
def _ensure_output_processor (self ) -> None :
@@ -688,7 +693,7 @@ async def _process_stream(self):
688
693
try :
689
694
690
695
async def stream_iterator ():
691
- while True :
696
+ while not self . stream_queue . empty () :
692
697
incoming_record = await self .stream_queue .get ()
693
698
694
699
record_content = incoming_record .get ("content" , {})
@@ -701,6 +706,9 @@ async def stream_iterator():
701
706
else :
702
707
content = choice .get ("delta" , {}).get ("content" )
703
708
709
+ if choice .get ("finish_reason" , None ) == "stop" :
710
+ self .finish_stream = True
711
+
704
712
streaming_choices .append (
705
713
StreamingChoices (
706
714
finish_reason = choice .get ("finish_reason" , None ),
@@ -722,22 +730,18 @@ async def stream_iterator():
722
730
)
723
731
yield mr
724
732
725
- async for record in self .output_pipeline_instance .process_stream (stream_iterator ()):
733
+ async for record in self .output_pipeline_instance .process_stream (
734
+ stream_iterator (), cleanup_sensitive = False
735
+ ):
726
736
chunk = record .model_dump_json (exclude_none = True , exclude_unset = True )
727
737
sse_data = f"data: { chunk } \n \n " .encode ("utf-8" )
728
738
chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
729
739
self ._proxy_transport_write (chunk_size .encode ())
730
740
self ._proxy_transport_write (sse_data )
731
741
self ._proxy_transport_write (b"\r \n " )
732
742
733
- sse_data = b"data: [DONE]\n \n "
734
- # Add chunk size for DONE message too
735
- chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
736
- self ._proxy_transport_write (chunk_size .encode ())
737
- self ._proxy_transport_write (sse_data )
738
- self ._proxy_transport_write (b"\r \n " )
739
- # Now send the final zero chunk
740
- self ._proxy_transport_write (b"0\r \n \r \n " )
743
+ if self .finish_stream :
744
+ self .finish_data ()
741
745
742
746
except asyncio .CancelledError :
743
747
logger .debug ("Stream processing cancelled" )
@@ -746,12 +750,37 @@ async def stream_iterator():
746
750
logger .error (f"Error processing stream: { e } " )
747
751
finally :
748
752
# Clean up
753
+ self .stream_queue = None
749
754
if self .processing_task and not self .processing_task .done ():
750
755
self .processing_task .cancel ()
751
- if self .proxy .context_tracking and self .proxy .context_tracking .sensitive :
752
- self .proxy .context_tracking .sensitive .secure_cleanup ()
756
+
757
+ def finish_data (self ):
758
+ logger .debug ("Finishing data stream" )
759
+ sse_data = b"data: [DONE]\n \n "
760
+ # Add chunk size for DONE message too
761
+ chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
762
+ self ._proxy_transport_write (chunk_size .encode ())
763
+ self ._proxy_transport_write (sse_data )
764
+ self ._proxy_transport_write (b"\r \n " )
765
+ # Now send the final zero chunk
766
+ self ._proxy_transport_write (b"0\r \n \r \n " )
767
+
768
+ # For debugging only
769
+ # print("===========START DATA SENT====================")
770
+ # for data in self.data_sent:
771
+ # print(data)
772
+ # self.data_sent = []
773
+ # print("===========START DATA SENT====================")
774
+
775
+ self .finish_stream = False
776
+ self .headers_sent = False
753
777
754
778
def _process_chunk (self , chunk : bytes ):
779
+ # For debugging only
780
+ # print("===========START DATA RECVD====================")
781
+ # print(chunk)
782
+ # print("===========END DATA RECVD======================")
783
+
755
784
records = self .sse_processor .process_chunk (chunk )
756
785
757
786
for record in records :
@@ -763,14 +792,12 @@ def _process_chunk(self, chunk: bytes):
763
792
self .stream_queue .put_nowait (record )
764
793
765
794
def _proxy_transport_write (self , data : bytes ):
795
+ # For debugging only
796
+ # self.data_sent.append(data)
766
797
if not self .proxy .transport or self .proxy .transport .is_closing ():
767
798
logger .error ("Proxy transport not available" )
768
799
return
769
800
self .proxy .transport .write (data )
770
- # print("DEBUG =================================")
771
- # print(data)
772
- # print("DEBUG =================================")
773
-
774
801
775
802
def data_received (self , data : bytes ) -> None :
776
803
"""Handle data received from target"""
@@ -788,7 +815,7 @@ def data_received(self, data: bytes) -> None:
788
815
if header_end != - 1 :
789
816
self .headers_sent = True
790
817
# Send headers first
791
- headers = data [: header_end ]
818
+ headers = data [:header_end ]
792
819
793
820
# If Transfer-Encoding is not present, add it
794
821
if b"Transfer-Encoding:" not in headers :
@@ -800,15 +827,13 @@ def data_received(self, data: bytes) -> None:
800
827
logger .debug (f"Headers sent: { headers } " )
801
828
802
829
data = data [header_end + 4 :]
803
- # print("DEBUG =================================")
804
- # print(data)
805
- # print("DEBUG =================================")
806
830
807
831
self ._process_chunk (data )
808
832
809
833
def connection_lost (self , exc : Optional [Exception ]) -> None :
810
834
"""Handle connection loss to target"""
811
835
836
+ logger .debug ("Lost connection to target" )
812
837
if (
813
838
not self .proxy ._closing
814
839
and self .proxy .transport
0 commit comments