Skip to content

Commit 658e23e

Browse files
authored
Merge pull request temporal-sa#5 from robholland/rh-redis-refactor
Refactor redis architecture.
2 parents 29b71e7 + 74ebc06 commit 658e23e

File tree

17 files changed

+412
-286
lines changed

17 files changed

+412
-286
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ You can run through the scenarios with the Temporal version using a [Web Applica
2323
* [Poetry](https://python-poetry.org/docs/) - Python Dependency Management
2424
* [OpenAI API Key] (https://platform.openai.com/api-keys) - Your key to accessing OpenAI's LLM
2525
* [Temporal CLI](https://docs.temporal.io/cli#install) - Local Temporal service
26-
* [Redis](https://redis.io/downloads/) - Workflow writes conversation history, API reads from it
26+
* [Redis](https://redis.io/downloads/) - Stores conversation history and real-time status updates
2727

2828
## Set up Python Environment
2929
```bash
@@ -43,5 +43,7 @@ It should look something like this:
4343
export OPENAI_API_KEY=sk-proj-....
4444
```
4545

46+
## Getting Started
47+
4648
See the OpenAI Agents SDK Version [here](src/oai_supervisor/README.md)
4749
And the Temporal version of this example is located [here](src/temporal_supervisor/README.md)
-4.24 KB
Loading

poetry.lock

Lines changed: 10 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ dependencies = [
1616
"redis (>=6.2.0,<7.0.0)",
1717
"aiohttp (>=3.12.14,<4.0.0)",
1818
"openai-agents (>=0.2.3,<0.3.0)",
19-
"temporalio (>=1.15.0,<2.0.0)",
19+
"temporalio (>=1.17.0,<2.0.0)",
2020
]
2121

2222
[tool.poetry]

src/api/main.py

Lines changed: 9 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
1-
import json
2-
3-
import asyncio
41
from contextlib import asynccontextmanager
52
from typing import Optional, AsyncGenerator
63

7-
from fastapi import FastAPI, HTTPException, Request
4+
from fastapi import FastAPI, HTTPException, Request, Query
85
from fastapi.middleware.cors import CORSMiddleware
9-
from fastapi.responses import StreamingResponse
10-
from temporalio.client import Client, WithStartWorkflowOperation
11-
from temporalio.common import WorkflowIDReusePolicy, WorkflowIDConflictPolicy
6+
from temporalio.client import Client
7+
from temporalio.common import WorkflowIDReusePolicy
128
from temporalio.exceptions import TemporalError
139
from temporalio.contrib.openai_agents import OpenAIAgentsPlugin
1410
from temporalio.service import RPCError
1511

12+
from common.event_stream_manager import EventStreamManager
1613
from common.client_helper import ClientHelper
17-
from common.db_manager import DBManager
1814
from common.user_message import ProcessUserMessageInput
1915
from temporal_supervisor.claim_check.claim_check_plugin import ClaimCheckPlugin
2016
from temporal_supervisor.workflows.supervisor_workflow import WealthManagementWorkflow
@@ -59,10 +55,12 @@ def root():
5955
return {"message": "OpenAI Agent SDK + Temporal Agent!"}
6056

6157
@app.get("/get-chat-history")
62-
async def get_chat_history():
58+
async def get_chat_history(
59+
from_index: int = Query(0, description="Get events starting from this index")
60+
):
6361
""" Retrieves the chat history from Redis """
6462
try:
65-
history = await DBManager().read(WORKFLOW_ID)
63+
history = await EventStreamManager().get_events_from_index(WORKFLOW_ID, from_index)
6664
if history is None:
6765
return ""
6866

@@ -80,14 +78,6 @@ async def get_chat_history():
8078

8179
@app.post("/send-prompt")
8280
async def send_prompt(prompt: str):
83-
# Start or update the workflow
84-
start_op = WithStartWorkflowOperation(
85-
WealthManagementWorkflow.run,
86-
id=WORKFLOW_ID,
87-
task_queue=task_queue,
88-
id_conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING,
89-
)
90-
9181
print(f"Received prompt {prompt}")
9282

9383
message = ProcessUserMessageInput(
@@ -123,11 +113,10 @@ async def end_chat():
123113
@app.post("/start-workflow")
124114
async def start_workflow(request: Request):
125115
try:
126-
sse_url = str(request.url_for(UPDATE_STATUS_NAME))
127116
# start the workflow
128117
await temporal_client.start_workflow(
129118
WealthManagementWorkflow.run,
130-
args=[sse_url],
119+
args=[],
131120
id=WORKFLOW_ID,
132121
task_queue=task_queue,
133122
id_reuse_policy=WorkflowIDReusePolicy.ALLOW_DUPLICATE
@@ -141,33 +130,3 @@ async def start_workflow(request: Request):
141130
return {
142131
"message": f"An error occurred starting the workflow {e}"
143132
}
144-
145-
# In-memory list to hold active SSE client connections
146-
# Note that this does not scale past one instance of the API
147-
connected_clients = []
148-
149-
# SSE generator function
150-
async def event_generator(request: Request):
151-
client_queue = asyncio.Queue()
152-
connected_clients.append(client_queue)
153-
try:
154-
while True:
155-
# Wait for a new message to be put in the queue
156-
message = await client_queue.get()
157-
yield f"data: {message}\n\n"
158-
except asyncio.CancelledError:
159-
connected_clients.remove(client_queue)
160-
raise
161-
162-
# Endpoint for clients to connect and receive events
163-
@app.get("/sse/status/stream")
164-
async def sse_endpoint(request: Request):
165-
return StreamingResponse(event_generator(request), media_type="text/event-stream")
166-
167-
# Endpoint for the Temporal Workflow to send updates
168-
@app.post("/sse/status/update", name=UPDATE_STATUS_NAME)
169-
async def update_status(data: dict):
170-
message = json.dumps(data)
171-
for queue in connected_clients:
172-
await queue.put(message)
173-
return {"message": "Status updated"}

src/common/db_manager.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

src/common/event_stream_manager.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import json
2+
import time
3+
import os
4+
from typing import List, Dict, Any, Union
5+
from enum import Enum
6+
from dataclasses import asdict
7+
8+
import redis.asyncio as redis
9+
10+
from .user_message import ChatInteraction
11+
from .status_update import StatusUpdate
12+
13+
class EventType(str, Enum):
14+
"""Event types for the chat event stream"""
15+
CHAT_INTERACTION = "chat_interaction"
16+
STATUS_UPDATE = "status_update"
17+
18+
class EventStreamManager:
19+
"""
20+
Manages a Redis list-based event stream for chat conversations.
21+
22+
Uses LPUSH for O(1) appends and LRANGE for efficient range queries.
23+
Each event has an implicit sequence number based on its position in the list.
24+
"""
25+
26+
def __init__(self, redis_host: str = None, redis_port: int = None):
27+
self.redis_host = redis_host or os.getenv("REDIS_HOST", "localhost")
28+
self.redis_port = redis_port or int(os.getenv("REDIS_PORT", "6379"))
29+
self.redis_client = redis.Redis(
30+
host=self.redis_host,
31+
port=self.redis_port,
32+
decode_responses=True
33+
)
34+
35+
def _get_stream_key(self, workflow_id: str) -> str:
36+
"""Get the Redis key for the event stream"""
37+
return f"events:{workflow_id}"
38+
39+
def _get_meta_key(self, workflow_id: str) -> str:
40+
"""Get the Redis key for stream metadata"""
41+
return f"events:{workflow_id}:meta"
42+
43+
async def append_chat_interaction(
44+
self,
45+
workflow_id: str,
46+
chat_interaction: ChatInteraction
47+
) -> int:
48+
"""
49+
Append a chat interaction to the stream.
50+
51+
Returns the new total length of the event stream.
52+
"""
53+
return await self._append_domain_event(
54+
workflow_id,
55+
EventType.CHAT_INTERACTION,
56+
chat_interaction
57+
)
58+
59+
async def append_status_update(
60+
self,
61+
workflow_id: str,
62+
status_update: StatusUpdate
63+
) -> int:
64+
"""
65+
Append a status update to the stream.
66+
67+
Returns the new total length of the event stream.
68+
"""
69+
return await self._append_domain_event(
70+
workflow_id,
71+
EventType.STATUS_UPDATE,
72+
status_update
73+
)
74+
75+
async def _append_domain_event(
76+
self,
77+
workflow_id: str,
78+
event_type: EventType,
79+
domain_object: Union[ChatInteraction, StatusUpdate]
80+
) -> int:
81+
"""
82+
Internal method to append domain objects to the stream.
83+
84+
Returns the new total length of the event stream.
85+
"""
86+
stream_key = self._get_stream_key(workflow_id)
87+
88+
# Convert domain object to dict
89+
content_dict = asdict(domain_object)
90+
91+
# Build the event with structured content
92+
event = {
93+
"type": event_type.value,
94+
"content": content_dict
95+
}
96+
97+
event_json = json.dumps(event)
98+
99+
# Use RPUSH to add to the end (chronological order)
100+
# RPUSH returns the new length of the list after insertion
101+
new_length = await self.redis_client.rpush(stream_key, event_json)
102+
103+
return new_length
104+
105+
async def get_events_from_index(
106+
self,
107+
workflow_id: str,
108+
from_index: int = 0
109+
) -> List[Dict[str, Any]]:
110+
"""
111+
Get events starting from a specific index.
112+
113+
Args:
114+
workflow_id: The workflow ID
115+
from_index: Start from this index (0-based)
116+
117+
Returns:
118+
List of events in chronological order
119+
"""
120+
stream_key = self._get_stream_key(workflow_id)
121+
122+
# Get all events from the specified index to the end
123+
event_strings = await self.redis_client.lrange(stream_key, from_index, -1)
124+
125+
# Parse events
126+
events = []
127+
for event_str in event_strings:
128+
try:
129+
event = json.loads(event_str)
130+
events.append(event)
131+
except json.JSONDecodeError:
132+
continue # Skip malformed events
133+
134+
return events
135+
136+
async def get_all_events(self, workflow_id: str) -> List[Dict[str, Any]]:
137+
"""
138+
Get all events in the stream.
139+
140+
Returns events in chronological order.
141+
"""
142+
stream_key = self._get_stream_key(workflow_id)
143+
144+
# Get all events
145+
event_strings = await self.redis_client.lrange(stream_key, 0, -1)
146+
147+
# Parse events (already in chronological order due to RPUSH)
148+
events = []
149+
for event_str in event_strings:
150+
try:
151+
events.append(json.loads(event_str))
152+
except json.JSONDecodeError:
153+
continue # Skip malformed events
154+
155+
return events
156+
157+
async def get_total_events(self, workflow_id: str) -> int:
158+
"""Get the total number of events in the stream"""
159+
stream_key = self._get_stream_key(workflow_id)
160+
return await self.redis_client.llen(stream_key)
161+
162+
async def delete_stream(self, workflow_id: str) -> bool:
163+
"""
164+
Delete the entire event stream for a workflow.
165+
166+
Returns True if the stream was deleted, False if it didn't exist.
167+
"""
168+
stream_key = self._get_stream_key(workflow_id)
169+
meta_key = self._get_meta_key(workflow_id)
170+
171+
# Delete both the stream and metadata
172+
deleted = await self.redis_client.delete(stream_key, meta_key)
173+
174+
return deleted > 0
175+
176+
async def close(self):
177+
"""Close the Redis connection"""
178+
await self.redis_client.aclose()

src/common/status_update.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from dataclasses import dataclass
2+
3+
@dataclass
4+
class StatusUpdate:
5+
status: str
6+
7+
def __str__(self):
8+
return f"Status: {self.status}"

0 commit comments

Comments
 (0)