Skip to content

Commit 3211aa9

Browse files
OriNachumclaude
andauthored
fix: add content_index to streamed text deltas and improve stream reliability (#52)
* fix: add content_index to streamed text delta events and improve stream reliability (#44) Add the missing `content_index` field to `response.output_text.delta` SSE events so clients that validate against the OpenAI Responses API spec (e.g. ChatKit SDK) no longer fail with a Pydantic validation error. Also introduces SSE heartbeat keepalives, configurable stream timeout, and structured stream timing logs to improve reliability with slow backends. Closes #44 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: address PR review - heartbeat task cleanup and trailing whitespace - Add try/finally to _with_heartbeat() to cancel in-flight tasks and close the underlying async iterator on cancellation/exit - Guard against interval <= 0 to prevent tight heartbeat loops - Fix trailing whitespace and missing newline in config.py Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: resolve SonarCloud quality gate failures in _with_heartbeat - S7497: Re-raise asyncio.CancelledError after cleanup instead of swallowing it - B110: Replace bare except/pass with logger.debug for aclose errors - S5806: Rename `aiter` to `inner` to avoid shadowing the builtin - S3776: Extract cleanup logic to _cleanup_heartbeat() to reduce cognitive complexity below threshold Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3a62e62 commit 3211aa9

File tree

9 files changed

+177
-19
lines changed

9 files changed

+177
-19
lines changed

.env.example

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ OPENAI_API_KEY=sk-mockapikey123456789abcdefghijklmnopqrstuvwxyz
77
API_ADAPTER_HOST=0.0.0.0
88
API_ADAPTER_PORT=8080
99

10+
# Streaming Configuration
11+
STREAM_TIMEOUT=120.0
12+
HEARTBEAT_INTERVAL=15.0
13+
1014
# Logging Configuration (optional)
1115
LOG_LEVEL=INFO
1216
LOG_FILE_PATH=./log/api_adapter.log

src/open_responses_server/api_controller.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,61 @@
11
import json
2+
import asyncio
23
from fastapi import FastAPI, Request, HTTPException
34
from fastapi.responses import StreamingResponse, Response
45
from fastapi.middleware.cors import CORSMiddleware
56

6-
from open_responses_server.common.config import logger
7+
from open_responses_server.common.config import logger, HEARTBEAT_INTERVAL, STREAM_TIMEOUT
78
from open_responses_server.common.llm_client import startup_llm_client, shutdown_llm_client, LLMClient
89
from open_responses_server.common.mcp_manager import mcp_manager
910
from open_responses_server.responses_service import convert_responses_to_chat_completions, process_chat_completions_stream
1011
from open_responses_server.chat_completions_service import handle_chat_completions
1112

13+
_HEARTBEAT = object()
14+
15+
16+
async def _with_heartbeat(async_gen, interval):
17+
"""Wrap an async generator to yield _HEARTBEAT sentinels during idle periods.
18+
19+
Uses asyncio.wait with timeout so the underlying task is never cancelled.
20+
This keeps SSE connections alive when the backend LLM is slow to respond.
21+
"""
22+
if not interval or interval <= 0:
23+
interval = 1.0
24+
25+
inner = async_gen.__aiter__()
26+
task = None
27+
try:
28+
while True:
29+
task = asyncio.ensure_future(inner.__anext__())
30+
while not task.done():
31+
done, _ = await asyncio.wait({task}, timeout=interval)
32+
if not done:
33+
yield _HEARTBEAT
34+
try:
35+
yield task.result()
36+
except StopAsyncIteration:
37+
return
38+
finally:
39+
task = None
40+
finally:
41+
await _cleanup_heartbeat(task, inner)
42+
43+
44+
async def _cleanup_heartbeat(task, inner):
45+
"""Cancel in-flight task and close the underlying async iterator."""
46+
if task is not None and not task.done():
47+
task.cancel()
48+
try:
49+
await task
50+
except asyncio.CancelledError:
51+
raise
52+
if hasattr(inner, "aclose"):
53+
try:
54+
await inner.aclose()
55+
except Exception:
56+
logger.debug("Error closing heartbeat inner iterator", exc_info=True)
57+
58+
1259
app = FastAPI(
1360
title="Open Responses Server",
1461
description="A proxy server that converts between different OpenAI-compatible API formats.",
@@ -249,7 +296,7 @@ async def stream_response():
249296
"POST",
250297
"/v1/chat/completions",
251298
json=chat_request,
252-
timeout=120.0
299+
timeout=STREAM_TIMEOUT
253300
) as response:
254301
logger.info(f"Stream request status: {response.status_code}")
255302

@@ -259,8 +306,15 @@ async def stream_response():
259306
yield f"data: {json.dumps({'type': 'error', 'error': {'message': f'Error from LLM API: {response.status_code}'}})}\n\n"
260307
return
261308

262-
async for event in process_chat_completions_stream(response, chat_request):
263-
yield event
309+
async for event in _with_heartbeat(
310+
process_chat_completions_stream(response, chat_request),
311+
HEARTBEAT_INTERVAL
312+
):
313+
if event is _HEARTBEAT:
314+
logger.debug("[STREAM-HEARTBEAT] Sending SSE keepalive")
315+
yield ": heartbeat\n\n"
316+
else:
317+
yield event
264318
except Exception as e:
265319
logger.error(f"Error in stream_response: {str(e)}")
266320
yield f"data: {json.dumps({'type': 'error', 'error': {'message': str(e)}})}\n\n"
@@ -346,7 +400,7 @@ async def stream_response():
346400

347401
# async def stream_response():
348402
# try:
349-
# async with client.stream("POST", "/v1/chat/completions", json=chat_request, timeout=120.0) as response:
403+
# async with client.stream("POST", "/v1/chat/completions", json=chat_request, timeout=STREAM_TIMEOUT) as response:
350404
# if response.status_code != 200:
351405
# error_content = await response.aread()
352406
# logger.error(f"Error from LLM API: {error_content.decode()}")
@@ -411,12 +465,12 @@ async def proxy_endpoint(request: Request, path_name: str):
411465

412466
if is_stream:
413467
async def stream_proxy():
414-
async with client.stream(request.method, url, headers=headers, content=body, timeout=120.0) as response:
468+
async with client.stream(request.method, url, headers=headers, content=body, timeout=STREAM_TIMEOUT) as response:
415469
async for chunk in response.aiter_bytes():
416470
yield chunk
417471
return StreamingResponse(stream_proxy(), media_type=request.headers.get('accept', 'application/json'))
418472
else:
419-
response = await client.request(request.method, url, headers=headers, content=body, timeout=120.0)
473+
response = await client.request(request.method, url, headers=headers, content=body, timeout=STREAM_TIMEOUT)
420474
return Response(content=response.content, status_code=response.status_code, headers=response.headers)
421475

422476
except Exception as e:

src/open_responses_server/chat_completions_service.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from fastapi import Request
33
from fastapi.responses import StreamingResponse, Response, JSONResponse
44
from open_responses_server.common.llm_client import LLMClient
5-
from open_responses_server.common.config import logger, OPENAI_BASE_URL_INTERNAL, OPENAI_API_KEY, MAX_TOOL_CALL_ITERATIONS
5+
from open_responses_server.common.config import logger, OPENAI_BASE_URL_INTERNAL, OPENAI_API_KEY, MAX_TOOL_CALL_ITERATIONS, STREAM_TIMEOUT
66
from open_responses_server.common.mcp_manager import mcp_manager, serialize_tool_result
77

88
async def _handle_non_streaming_request(client: LLMClient, request_data: dict):
@@ -25,7 +25,7 @@ async def _handle_non_streaming_request(client: LLMClient, request_data: dict):
2525
response = await client.post(
2626
"/v1/chat/completions",
2727
json=current_request_data,
28-
timeout=120.0
28+
timeout=STREAM_TIMEOUT
2929
)
3030
response.raise_for_status()
3131
response_data = response.json()
@@ -102,7 +102,7 @@ async def _handle_streaming_request(client: LLMClient, request_data: dict) -> St
102102
for _ in range(MAX_TOOL_CALL_ITERATIONS):
103103
try:
104104
# Make a non-streaming request first to check for tool calls
105-
response = await client.post("/v1/chat/completions", json={**non_stream_request_data, "messages": messages}, timeout=120.0)
105+
response = await client.post("/v1/chat/completions", json={**non_stream_request_data, "messages": messages}, timeout=STREAM_TIMEOUT)
106106
response.raise_for_status()
107107
response_data = response.json()
108108

@@ -170,7 +170,7 @@ async def stream_proxy():
170170
"POST",
171171
"/v1/chat/completions",
172172
json=stream_request_data,
173-
timeout=120.0
173+
timeout=STREAM_TIMEOUT
174174
) as stream_response:
175175
async for chunk in stream_response.aiter_bytes():
176176
yield chunk

src/open_responses_server/common/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
MAX_CONVERSATION_HISTORY = int(os.environ.get("MAX_CONVERSATION_HISTORY", "100"))
2323
MAX_TOOL_CALL_ITERATIONS = int(os.environ.get("MAX_TOOL_CALL_ITERATIONS", "25"))
2424

25+
# Streaming Configuration
26+
STREAM_TIMEOUT = float(os.environ.get("STREAM_TIMEOUT", "120.0"))
27+
HEARTBEAT_INTERVAL = float(os.environ.get("HEARTBEAT_INTERVAL", "15.0"))
28+
2529

2630
# --- Logging Configuration ---
2731

@@ -54,4 +58,6 @@ def setup_logging():
5458
logger.info(f" MCP_TOOL_REFRESH_INTERVAL: {MCP_TOOL_REFRESH_INTERVAL}")
5559
logger.info(f" MCP_SERVERS_CONFIG_PATH: {MCP_SERVERS_CONFIG_PATH}")
5660
logger.info(f" MAX_CONVERSATION_HISTORY: {MAX_CONVERSATION_HISTORY}")
57-
logger.info(f" MAX_TOOL_CALL_ITERATIONS: {MAX_TOOL_CALL_ITERATIONS}")
61+
logger.info(f" MAX_TOOL_CALL_ITERATIONS: {MAX_TOOL_CALL_ITERATIONS}")
62+
logger.info(f" STREAM_TIMEOUT: {STREAM_TIMEOUT}")
63+
logger.info(f" HEARTBEAT_INTERVAL: {HEARTBEAT_INTERVAL}")

src/open_responses_server/common/llm_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import httpx
2-
from .config import OPENAI_BASE_URL_INTERNAL, OPENAI_API_KEY, logger
2+
from .config import OPENAI_BASE_URL_INTERNAL, OPENAI_API_KEY, STREAM_TIMEOUT, logger
33

44
class LLMClient:
55
"""
@@ -18,7 +18,7 @@ async def get_client(cls) -> httpx.AsyncClient:
1818
cls._client = httpx.AsyncClient(
1919
base_url=OPENAI_BASE_URL_INTERNAL,
2020
headers={"Authorization": f"Bearer {OPENAI_API_KEY}"},
21-
timeout=httpx.Timeout(120.0)
21+
timeout=httpx.Timeout(STREAM_TIMEOUT)
2222
)
2323
return cls._client
2424

src/open_responses_server/models/responses_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class OutputTextDelta(BaseModel):
9191
type: str = "response.output_text.delta"
9292
item_id: str
9393
output_index: int
94+
content_index: int
9495
delta: str
9596

9697
class ResponseCreated(BaseModel):

src/open_responses_server/responses_service.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,9 @@ async def process_chat_completions_stream(response, chat_request=None):
324324
tool_call_counter = 0
325325
message_id = f"msg_{uuid.uuid4().hex}"
326326
output_text_content = "" # Track the full text content for logging
327-
logger.info(f"Processing streaming response from chat.completions API response_id {response_id}; message_id {message_id}")
327+
request_start_time = time.time()
328+
last_chunk_time = request_start_time
329+
logger.info(f"[STREAM-START] response_id={response_id} message_id={message_id}")
328330

329331
# Create and yield the initial response.created event
330332
response_obj = ResponseModel(
@@ -354,12 +356,25 @@ async def process_chat_completions_stream(response, chat_request=None):
354356
try:
355357
async for chunk in response.aiter_lines():
356358
chunk_counter += 1
359+
now = time.time()
360+
chunk_gap = now - last_chunk_time
361+
last_chunk_time = now
362+
if chunk_gap > 2.0:
363+
logger.info(
364+
f"[STREAM-TIMING] response_id={response_id} "
365+
f"chunk_gap={chunk_gap:.1f}s chunk={chunk_counter}"
366+
)
357367
if not chunk.strip():
358368
continue
359369

360370
# Handle [DONE] message
361371
if chunk.strip() == "data: [DONE]" or chunk.strip() == "[DONE]":
362-
logger.info(f"Received [DONE] message after {chunk_counter} chunks (status: {response_obj.status})")
372+
total_time = time.time() - request_start_time
373+
logger.info(
374+
f"[STREAM-DONE] response_id={response_id} "
375+
f"chunks={chunk_counter} total_time={total_time:.1f}s "
376+
f"status={response_obj.status}"
377+
)
363378

364379
# If we haven't already completed the response, do it now
365380
if response_obj.status != "completed":
@@ -544,6 +559,7 @@ async def process_chat_completions_stream(response, chat_request=None):
544559
type="response.output_text.delta",
545560
item_id=message_id,
546561
output_index=0,
562+
content_index=0,
547563
delta=content_delta
548564
)
549565

@@ -595,6 +611,7 @@ async def process_chat_completions_stream(response, chat_request=None):
595611
type="response.output_text.delta",
596612
item_id=tool_call["id"],
597613
output_index=0,
614+
content_index=0,
598615
delta=text
599616
)
600617
yield f"data: {json.dumps(text_event.dict())}\n\n"
@@ -725,10 +742,11 @@ async def process_chat_completions_stream(response, chat_request=None):
725742
type="response.output_text.delta",
726743
item_id=tool_call["id"],
727744
output_index=0,
745+
content_index=0,
728746
delta=text
729747
)
730748
yield f"data: {json.dumps(text_event.dict())}\n\n"
731-
749+
732750
logger.info(f"[TOOL-CALLS-FINISH] Added function_call_output for MCP tool '{tool_call['function']['name']}'")
733751

734752
else:
@@ -885,7 +903,12 @@ async def process_chat_completions_stream(response, chat_request=None):
885903
continue
886904

887905
except Exception as e:
888-
logger.error(f"Error processing streaming response: {str(e)}")
906+
total_time = time.time() - request_start_time
907+
logger.error(
908+
f"[STREAM-ERROR] response_id={response_id} "
909+
f"error={str(e)} total_time={total_time:.1f}s "
910+
f"chunks={chunk_counter}"
911+
)
889912
# Emit a completion event if we haven't already
890913
if response_obj.status != "completed":
891914
response_obj.status = "completed"

tests/test_api_controller_endpoints.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""
22
Tests for api_controller.py endpoints.
33
"""
4+
import asyncio
45
import json
56
import pytest
67
from unittest.mock import patch, MagicMock, AsyncMock
78
from fastapi.testclient import TestClient
89
from fastapi.responses import StreamingResponse
910

10-
from open_responses_server.api_controller import app
11+
from open_responses_server.api_controller import app, _with_heartbeat, _HEARTBEAT
1112

1213

1314
class TestResponsesEndpoint:
@@ -274,3 +275,66 @@ def test_proxy_invalid_json_body(self, client, mock_llm_client_fixture):
274275
headers={"content-type": "text/plain"},
275276
)
276277
assert response.status_code == 200
278+
279+
280+
@pytest.mark.asyncio
281+
class TestWithHeartbeat:
282+
"""Tests for the _with_heartbeat async generator wrapper."""
283+
284+
async def test_fast_generator_no_heartbeats(self):
285+
"""Fast generators produce no heartbeat sentinels."""
286+
async def fast_gen():
287+
yield "a"
288+
yield "b"
289+
yield "c"
290+
291+
results = [item async for item in _with_heartbeat(fast_gen(), interval=10.0)]
292+
assert results == ["a", "b", "c"]
293+
assert _HEARTBEAT not in results
294+
295+
async def test_slow_generator_emits_heartbeats(self):
296+
"""Slow generators trigger heartbeat sentinels between items."""
297+
async def slow_gen():
298+
yield "first"
299+
await asyncio.sleep(0.6)
300+
yield "second"
301+
302+
results = [item async for item in _with_heartbeat(slow_gen(), interval=0.2)]
303+
# Should have at least one heartbeat between "first" and "second"
304+
heartbeats = [r for r in results if r is _HEARTBEAT]
305+
data = [r for r in results if r is not _HEARTBEAT]
306+
assert len(heartbeats) >= 1
307+
assert data == ["first", "second"]
308+
309+
async def test_empty_generator(self):
310+
"""Empty generator produces no output."""
311+
async def empty_gen():
312+
return
313+
yield # noqa: unreachable - makes this an async generator
314+
315+
results = [item async for item in _with_heartbeat(empty_gen(), interval=1.0)]
316+
assert results == []
317+
318+
async def test_generator_exception_propagates(self):
319+
"""Exceptions from the wrapped generator propagate through."""
320+
async def error_gen():
321+
yield "ok"
322+
raise ValueError("test error")
323+
324+
results = []
325+
with pytest.raises(ValueError, match="test error"):
326+
async for item in _with_heartbeat(error_gen(), interval=1.0):
327+
results.append(item)
328+
assert results == ["ok"]
329+
330+
async def test_heartbeat_count_scales_with_delay(self):
331+
"""Longer delays produce more heartbeats."""
332+
async def very_slow_gen():
333+
yield "start"
334+
await asyncio.sleep(1.0)
335+
yield "end"
336+
337+
results = [item async for item in _with_heartbeat(very_slow_gen(), interval=0.2)]
338+
heartbeats = [r for r in results if r is _HEARTBEAT]
339+
# ~1.0s delay / 0.2s interval = ~5 heartbeats (allow some variance)
340+
assert len(heartbeats) >= 3

0 commit comments

Comments
 (0)