1+ import json
2+
3+ import asyncio
14from contextlib import asynccontextmanager
25from typing import Optional , AsyncGenerator
36
4- from fastapi import FastAPI , HTTPException , Request , Query
7+ from fastapi import FastAPI , HTTPException , Request
58from fastapi .middleware .cors import CORSMiddleware
6- from temporalio .client import Client
7- from temporalio .common import WorkflowIDReusePolicy
9+ from fastapi .responses import StreamingResponse
10+ from temporalio .client import Client , WithStartWorkflowOperation
11+ from temporalio .common import WorkflowIDReusePolicy , WorkflowIDConflictPolicy
812from temporalio .exceptions import TemporalError
913from temporalio .contrib .openai_agents import OpenAIAgentsPlugin
1014from temporalio .service import RPCError
1115
12- from common .event_stream_manager import EventStreamManager
1316from common .client_helper import ClientHelper
17+ from common .db_manager import DBManager
1418from common .user_message import ProcessUserMessageInput
1519from temporal_supervisor .claim_check .claim_check_plugin import ClaimCheckPlugin
1620from temporal_supervisor .workflows .supervisor_workflow import WealthManagementWorkflow
@@ -55,12 +59,10 @@ def root():
5559 return {"message" : "OpenAI Agent SDK + Temporal Agent!" }
5660
5761@app .get ("/get-chat-history" )
58- async def get_chat_history (
59- from_index : int = Query (0 , description = "Get events starting from this index" )
60- ):
62+ async def get_chat_history ():
6163 """ Retrieves the chat history from Redis """
6264 try :
63- history = await EventStreamManager ().get_events_from_index (WORKFLOW_ID , from_index )
65+ history = await DBManager ().read (WORKFLOW_ID )
6466 if history is None :
6567 return ""
6668
@@ -78,6 +80,14 @@ async def get_chat_history(
7880
7981@app .post ("/send-prompt" )
8082async def send_prompt (prompt : str ):
83+ # Start or update the workflow
84+ start_op = WithStartWorkflowOperation (
85+ WealthManagementWorkflow .run ,
86+ id = WORKFLOW_ID ,
87+ task_queue = task_queue ,
88+ id_conflict_policy = WorkflowIDConflictPolicy .USE_EXISTING ,
89+ )
90+
8191 print (f"Received prompt { prompt } " )
8292
8393 message = ProcessUserMessageInput (
@@ -113,10 +123,11 @@ async def end_chat():
113123@app .post ("/start-workflow" )
114124async def start_workflow (request : Request ):
115125 try :
126+ sse_url = str (request .url_for (UPDATE_STATUS_NAME ))
116127 # start the workflow
117128 await temporal_client .start_workflow (
118129 WealthManagementWorkflow .run ,
119- args = [],
130+ args = [sse_url ],
120131 id = WORKFLOW_ID ,
121132 task_queue = task_queue ,
122133 id_reuse_policy = WorkflowIDReusePolicy .ALLOW_DUPLICATE
@@ -130,3 +141,33 @@ async def start_workflow(request: Request):
130141 return {
131142 "message" : f"An error occurred starting the workflow { e } "
132143 }
144+
145+ # In-memory list to hold active SSE client connections
146+ # Note that this does not scale past one instance of the API
147+ connected_clients = []
148+
149+ # SSE generator function
150+ async def event_generator (request : Request ):
151+ client_queue = asyncio .Queue ()
152+ connected_clients .append (client_queue )
153+ try :
154+ while True :
155+ # Wait for a new message to be put in the queue
156+ message = await client_queue .get ()
157+ yield f"data: { message } \n \n "
158+ except asyncio .CancelledError :
159+ connected_clients .remove (client_queue )
160+ raise
161+
162+ # Endpoint for clients to connect and receive events
163+ @app .get ("/sse/status/stream" )
164+ async def sse_endpoint (request : Request ):
165+ return StreamingResponse (event_generator (request ), media_type = "text/event-stream" )
166+
167+ # Endpoint for the Temporal Workflow to send updates
168+ @app .post ("/sse/status/update" , name = UPDATE_STATUS_NAME )
169+ async def update_status (data : dict ):
170+ message = json .dumps (data )
171+ for queue in connected_clients :
172+ await queue .put (message )
173+ return {"message" : "Status updated" }
0 commit comments