Skip to content

Commit f7e98b5

Browse files
authored
Revert "Refactor redis architecture."
1 parent 658e23e commit f7e98b5

File tree

17 files changed

+286
-412
lines changed

17 files changed

+286
-412
lines changed

README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ You can run through the scenarios with the Temporal version using a [Web Applica
2323
* [Poetry](https://python-poetry.org/docs/) - Python Dependency Management
2424
* [OpenAI API Key] (https://platform.openai.com/api-keys) - Your key to accessing OpenAI's LLM
2525
* [Temporal CLI](https://docs.temporal.io/cli#install) - Local Temporal service
26-
* [Redis](https://redis.io/downloads/) - Stores conversation history and real-time status updates
26+
* [Redis](https://redis.io/downloads/) - Workflow writes conversation history, API reads from it
2727

2828
## Set up Python Environment
2929
```bash
@@ -43,7 +43,5 @@ It should look something like this:
4343
export OPENAI_API_KEY=sk-proj-....
4444
```
4545

46-
## Getting Started
47-
4846
See the OpenAI Agents SDK Version [here](src/oai_supervisor/README.md)
4947
And the Temporal version of this example is located [here](src/temporal_supervisor/README.md)
4.24 KB
Loading

poetry.lock

Lines changed: 10 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ dependencies = [
1616
"redis (>=6.2.0,<7.0.0)",
1717
"aiohttp (>=3.12.14,<4.0.0)",
1818
"openai-agents (>=0.2.3,<0.3.0)",
19-
"temporalio (>=1.17.0,<2.0.0)",
19+
"temporalio (>=1.15.0,<2.0.0)",
2020
]
2121

2222
[tool.poetry]

src/api/main.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1+
import json
2+
3+
import asyncio
14
from contextlib import asynccontextmanager
25
from typing import Optional, AsyncGenerator
36

4-
from fastapi import FastAPI, HTTPException, Request, Query
7+
from fastapi import FastAPI, HTTPException, Request
58
from fastapi.middleware.cors import CORSMiddleware
6-
from temporalio.client import Client
7-
from temporalio.common import WorkflowIDReusePolicy
9+
from fastapi.responses import StreamingResponse
10+
from temporalio.client import Client, WithStartWorkflowOperation
11+
from temporalio.common import WorkflowIDReusePolicy, WorkflowIDConflictPolicy
812
from temporalio.exceptions import TemporalError
913
from temporalio.contrib.openai_agents import OpenAIAgentsPlugin
1014
from temporalio.service import RPCError
1115

12-
from common.event_stream_manager import EventStreamManager
1316
from common.client_helper import ClientHelper
17+
from common.db_manager import DBManager
1418
from common.user_message import ProcessUserMessageInput
1519
from temporal_supervisor.claim_check.claim_check_plugin import ClaimCheckPlugin
1620
from temporal_supervisor.workflows.supervisor_workflow import WealthManagementWorkflow
@@ -55,12 +59,10 @@ def root():
5559
return {"message": "OpenAI Agent SDK + Temporal Agent!"}
5660

5761
@app.get("/get-chat-history")
58-
async def get_chat_history(
59-
from_index: int = Query(0, description="Get events starting from this index")
60-
):
62+
async def get_chat_history():
6163
""" Retrieves the chat history from Redis """
6264
try:
63-
history = await EventStreamManager().get_events_from_index(WORKFLOW_ID, from_index)
65+
history = await DBManager().read(WORKFLOW_ID)
6466
if history is None:
6567
return ""
6668

@@ -78,6 +80,14 @@ async def get_chat_history(
7880

7981
@app.post("/send-prompt")
8082
async def send_prompt(prompt: str):
83+
# Start or update the workflow
84+
start_op = WithStartWorkflowOperation(
85+
WealthManagementWorkflow.run,
86+
id=WORKFLOW_ID,
87+
task_queue=task_queue,
88+
id_conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING,
89+
)
90+
8191
print(f"Received prompt {prompt}")
8292

8393
message = ProcessUserMessageInput(
@@ -113,10 +123,11 @@ async def end_chat():
113123
@app.post("/start-workflow")
114124
async def start_workflow(request: Request):
115125
try:
126+
sse_url = str(request.url_for(UPDATE_STATUS_NAME))
116127
# start the workflow
117128
await temporal_client.start_workflow(
118129
WealthManagementWorkflow.run,
119-
args=[],
130+
args=[sse_url],
120131
id=WORKFLOW_ID,
121132
task_queue=task_queue,
122133
id_reuse_policy=WorkflowIDReusePolicy.ALLOW_DUPLICATE
@@ -130,3 +141,33 @@ async def start_workflow(request: Request):
130141
return {
131142
"message": f"An error occurred starting the workflow {e}"
132143
}
144+
145+
# In-memory list to hold active SSE client connections
146+
# Note that this does not scale past one instance of the API
147+
connected_clients = []
148+
149+
# SSE generator function
150+
async def event_generator(request: Request):
151+
client_queue = asyncio.Queue()
152+
connected_clients.append(client_queue)
153+
try:
154+
while True:
155+
# Wait for a new message to be put in the queue
156+
message = await client_queue.get()
157+
yield f"data: {message}\n\n"
158+
except asyncio.CancelledError:
159+
connected_clients.remove(client_queue)
160+
raise
161+
162+
# Endpoint for clients to connect and receive events
163+
@app.get("/sse/status/stream")
164+
async def sse_endpoint(request: Request):
165+
return StreamingResponse(event_generator(request), media_type="text/event-stream")
166+
167+
# Endpoint for the Temporal Workflow to send updates
168+
@app.post("/sse/status/update", name=UPDATE_STATUS_NAME)
169+
async def update_status(data: dict):
170+
message = json.dumps(data)
171+
for queue in connected_clients:
172+
await queue.put(message)
173+
return {"message": "Status updated"}

src/common/db_manager.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import json
2+
3+
import redis.asyncio as redis
4+
5+
class DBManager:
6+
def __init__(self, redis_host: str = "localhost", redis_port: int = 6379):
7+
self.redis_client = redis.Redis(host=redis_host, port=redis_port)
8+
9+
# may need to revisit this later
10+
async def save(self, key: str, value):
11+
value_as_json = json.dumps(value)
12+
await self.redis_client.set(key, value_as_json)
13+
14+
# may need to revisit this later
15+
async def read(self, key: str) -> any:
16+
print(f"Getting ready to retrieve the value for key {key}")
17+
value = await self.redis_client.get(key)
18+
print(f"The value read is {value}")
19+
if value is not None:
20+
return_value = json.loads(value)
21+
print(f"The return value after reading Redis and parsing json is {return_value}")
22+
return return_value
23+
# TODO: Validate this is okay
24+
return None
25+
26+
async def delete(self, key: str):
27+
print(f"Deleting key {key}")
28+
await self.redis_client.delete(key)

src/common/event_stream_manager.py

Lines changed: 0 additions & 178 deletions
This file was deleted.

src/common/status_update.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

0 commit comments

Comments
 (0)