@@ -559,31 +559,52 @@ def __init__(self, proxy: CopilotProvider):
559
559
self .headers_sent = False
560
560
self .sse_processor : Optional [SSEProcessor ] = None
561
561
self .output_pipeline_instance : Optional [OutputPipelineInstance ] = None
562
+ self .stream_queue : Optional [asyncio .Queue ] = None
562
563
563
564
def connection_made (self , transport : asyncio .Transport ) -> None :
564
565
"""Handle successful connection to target"""
565
566
self .transport = transport
566
567
self .proxy .target_transport = transport
567
568
569
+ async def _process_stream (self ):
570
+ try :
571
+ async def stream_iterator ():
572
+ while True :
573
+ incoming_record = await self .stream_queue .get ()
574
+ yield incoming_record
575
+
576
+ async for record in stream_iterator ():
577
+ print ("received from stream" )
578
+ print (record )
579
+ if record ["type" ] == "done" :
580
+ sse_data = b"data: [DONE]\n \n "
581
+ # Add chunk size for DONE message too
582
+ chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
583
+ self ._proxy_transport_write (chunk_size .encode ())
584
+ self ._proxy_transport_write (sse_data )
585
+ self ._proxy_transport_write (b"\r \n " )
586
+ # Now send the final zero chunk
587
+ self ._proxy_transport_write (b"0\r \n \r \n " )
588
+ else :
589
+ sse_data = f"data: { json .dumps (record ['content' ])} \n \n " .encode ("utf-8" )
590
+ chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
591
+ self ._proxy_transport_write (chunk_size .encode ())
592
+ self ._proxy_transport_write (sse_data )
593
+ self ._proxy_transport_write (b"\r \n " )
594
+
595
+ except Exception as e :
596
+ logger .error (f"Error processing stream: { e } " )
597
+
568
598
def _process_chunk (self , chunk : bytes ):
569
599
records = self .sse_processor .process_chunk (chunk )
570
600
571
601
for record in records :
572
- if record ["type" ] == "done" :
573
- sse_data = b"data: [DONE]\n \n "
574
- # Add chunk size for DONE message too
575
- chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
576
- self ._proxy_transport_write (chunk_size .encode ())
577
- self ._proxy_transport_write (sse_data )
578
- self ._proxy_transport_write (b"\r \n " )
579
- # Now send the final zero chunk
580
- self ._proxy_transport_write (b"0\r \n \r \n " )
581
- else :
582
- sse_data = f"data: { json .dumps (record ['content' ])} \n \n " .encode ("utf-8" )
583
- chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
584
- self ._proxy_transport_write (chunk_size .encode ())
585
- self ._proxy_transport_write (sse_data )
586
- self ._proxy_transport_write (b"\r \n " )
602
+ if self .stream_queue is None :
603
+ # Initialize queue and start processing task on first record
604
+ self .stream_queue = asyncio .Queue ()
605
+ self .processing_task = asyncio .create_task (self ._process_stream ())
606
+
607
+ self .stream_queue .put_nowait (record )
587
608
588
609
def _proxy_transport_write (self , data : bytes ):
589
610
self .proxy .transport .write (data )
0 commit comments