Skip to content

Commit b6b155d

Browse files
authored
fix(anthropic): handle partial JSON chunks in streaming responses (BerriAI#17493)
Fixes BerriAI#17473 - Anthropic streaming fails with JSONDecodeError when network fragmentation causes SSE data to arrive in partial chunks. Changes: - Add accumulated_json buffer and chunk_type to ModelResponseIterator - Add _handle_accumulated_json_chunk() to accumulate partial JSON - Add _parse_sse_data() to handle both complete and partial chunks - Modify __next__ and __anext__ to use accumulation logic - Add unit tests for partial chunk handling
1 parent 0f5694c commit b6b155d

File tree

2 files changed

+214
-62
lines changed

2 files changed

+214
-62
lines changed

litellm/llms/anthropic/chat/handler.py

Lines changed: 142 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Callable,
1111
Dict,
1212
List,
13+
Literal,
1314
Optional,
1415
Tuple,
1516
Union,
@@ -498,6 +499,11 @@ def __init__(
498499
# Track if we've converted any response_format tools (affects finish_reason)
499500
self.converted_response_format_tool: bool = False
500501

502+
# For handling partial JSON chunks from fragmentation
503+
# See: https://github.com/BerriAI/litellm/issues/17473
504+
self.accumulated_json: str = ""
505+
self.chunk_type: Literal["valid_json", "accumulated_json"] = "valid_json"
506+
501507
def check_empty_tool_call_args(self) -> bool:
502508
"""
503509
Check if the tool call block so far has been an empty string
@@ -866,80 +872,154 @@ def _handle_message_delta(self, chunk: dict) -> Tuple[str, Optional[Usage]]:
866872
usage = self._handle_usage(anthropic_usage_chunk=message_delta["usage"])
867873
return finish_reason, usage
868874

875+
def _handle_accumulated_json_chunk(
876+
self, data_str: str
877+
) -> Optional[GenericStreamingChunk]:
878+
"""
879+
Handle partial JSON chunks by accumulating them until valid JSON is received.
880+
881+
This fixes network fragmentation issues where SSE data chunks may be split
882+
across TCP packets. See: https://github.com/BerriAI/litellm/issues/17473
883+
884+
Args:
885+
data_str: The JSON string to parse (without "data:" prefix)
886+
887+
Returns:
888+
GenericStreamingChunk if JSON is complete, None if still accumulating
889+
"""
890+
# Accumulate JSON data
891+
self.accumulated_json += data_str
892+
893+
# Try to parse the accumulated JSON
894+
try:
895+
data_json = json.loads(self.accumulated_json)
896+
self.accumulated_json = "" # Reset after successful parsing
897+
return self.chunk_parser(chunk=data_json)
898+
except json.JSONDecodeError:
899+
# If it's not valid JSON yet, continue to the next chunk
900+
return None
901+
902+
def _parse_sse_data(self, str_line: str) -> Optional[GenericStreamingChunk]:
903+
"""
904+
Parse SSE data line, handling both complete and partial JSON chunks.
905+
906+
Args:
907+
str_line: The SSE line starting with "data:"
908+
909+
Returns:
910+
GenericStreamingChunk if parsing succeeded, None if accumulating partial JSON
911+
"""
912+
data_str = str_line[5:] # Remove "data:" prefix
913+
914+
if self.chunk_type == "accumulated_json":
915+
# Already in accumulation mode, keep accumulating
916+
return self._handle_accumulated_json_chunk(data_str)
917+
918+
# Try to parse as valid JSON first
919+
try:
920+
data_json = json.loads(data_str)
921+
return self.chunk_parser(chunk=data_json)
922+
except json.JSONDecodeError:
923+
# Switch to accumulation mode and start accumulating
924+
self.chunk_type = "accumulated_json"
925+
return self._handle_accumulated_json_chunk(data_str)
926+
869927
# Sync iterator
870928
def __iter__(self):
871929
return self
872930

873931
def __next__(self):
874-
try:
875-
chunk = self.response_iterator.__next__()
876-
except StopIteration:
877-
raise StopIteration
878-
except ValueError as e:
879-
raise RuntimeError(f"Error receiving chunk from stream: {e}")
880-
881-
try:
882-
str_line = chunk
883-
if isinstance(chunk, bytes): # Handle binary data
884-
str_line = chunk.decode("utf-8") # Convert bytes to string
885-
index = str_line.find("data:")
886-
if index != -1:
887-
str_line = str_line[index:]
888-
889-
if str_line.startswith("data:"):
890-
data_json = json.loads(str_line[5:])
891-
return self.chunk_parser(chunk=data_json)
892-
else:
893-
return GenericStreamingChunk(
894-
text="",
895-
is_finished=False,
896-
finish_reason="",
897-
usage=None,
898-
index=0,
899-
tool_use=None,
900-
)
901-
except StopIteration:
902-
raise StopIteration
903-
except ValueError as e:
904-
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
932+
while True:
933+
try:
934+
chunk = self.response_iterator.__next__()
935+
except StopIteration:
936+
# If we have accumulated JSON when stream ends, try to parse it
937+
if self.accumulated_json:
938+
try:
939+
data_json = json.loads(self.accumulated_json)
940+
self.accumulated_json = ""
941+
return self.chunk_parser(chunk=data_json)
942+
except json.JSONDecodeError:
943+
pass
944+
raise StopIteration
945+
except ValueError as e:
946+
raise RuntimeError(f"Error receiving chunk from stream: {e}")
947+
948+
try:
949+
str_line = chunk
950+
if isinstance(chunk, bytes): # Handle binary data
951+
str_line = chunk.decode("utf-8") # Convert bytes to string
952+
index = str_line.find("data:")
953+
if index != -1:
954+
str_line = str_line[index:]
955+
956+
if str_line.startswith("data:"):
957+
result = self._parse_sse_data(str_line)
958+
if result is not None:
959+
return result
960+
# If None, continue loop to get more chunks for accumulation
961+
else:
962+
return GenericStreamingChunk(
963+
text="",
964+
is_finished=False,
965+
finish_reason="",
966+
usage=None,
967+
index=0,
968+
tool_use=None,
969+
)
970+
except StopIteration:
971+
raise StopIteration
972+
except ValueError as e:
973+
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
905974

906975
# Async iterator
907976
def __aiter__(self):
908977
self.async_response_iterator = self.streaming_response.__aiter__()
909978
return self
910979

911980
async def __anext__(self):
912-
try:
913-
chunk = await self.async_response_iterator.__anext__()
914-
except StopAsyncIteration:
915-
raise StopAsyncIteration
916-
except ValueError as e:
917-
raise RuntimeError(f"Error receiving chunk from stream: {e}")
918-
919-
try:
920-
str_line = chunk
921-
if isinstance(chunk, bytes): # Handle binary data
922-
str_line = chunk.decode("utf-8") # Convert bytes to string
923-
index = str_line.find("data:")
924-
if index != -1:
925-
str_line = str_line[index:]
926-
927-
if str_line.startswith("data:"):
928-
data_json = json.loads(str_line[5:])
929-
return self.chunk_parser(chunk=data_json)
930-
else:
931-
return GenericStreamingChunk(
932-
text="",
933-
is_finished=False,
934-
finish_reason="",
935-
usage=None,
936-
index=0,
937-
tool_use=None,
938-
)
939-
except StopAsyncIteration:
940-
raise StopAsyncIteration
941-
except ValueError as e:
942-
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
981+
while True:
982+
try:
983+
chunk = await self.async_response_iterator.__anext__()
984+
except StopAsyncIteration:
985+
# If we have accumulated JSON when stream ends, try to parse it
986+
if self.accumulated_json:
987+
try:
988+
data_json = json.loads(self.accumulated_json)
989+
self.accumulated_json = ""
990+
return self.chunk_parser(chunk=data_json)
991+
except json.JSONDecodeError:
992+
pass
993+
raise StopAsyncIteration
994+
except ValueError as e:
995+
raise RuntimeError(f"Error receiving chunk from stream: {e}")
996+
997+
try:
998+
str_line = chunk
999+
if isinstance(chunk, bytes): # Handle binary data
1000+
str_line = chunk.decode("utf-8") # Convert bytes to string
1001+
index = str_line.find("data:")
1002+
if index != -1:
1003+
str_line = str_line[index:]
1004+
1005+
if str_line.startswith("data:"):
1006+
result = self._parse_sse_data(str_line)
1007+
if result is not None:
1008+
return result
1009+
# If None, continue loop to get more chunks for accumulation
1010+
else:
1011+
return GenericStreamingChunk(
1012+
text="",
1013+
is_finished=False,
1014+
finish_reason="",
1015+
usage=None,
1016+
index=0,
1017+
tool_use=None,
1018+
)
1019+
except StopAsyncIteration:
1020+
raise StopAsyncIteration
1021+
except ValueError as e:
1022+
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
9431023

9441024
def convert_str_chunk_to_generic_chunk(self, chunk: str) -> ModelResponseStream:
9451025
"""

tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,3 +460,75 @@ def test_streaming_chunks_have_stable_ids():
460460
response_two = iterator.chunk_parser(chunk=second_chunk)
461461

462462
assert response_one.id == response_two.id == iterator.response_id
463+
464+
465+
def test_partial_json_chunk_accumulation():
466+
"""
467+
Test that partial JSON chunks are accumulated correctly.
468+
469+
This tests the fix for https://github.com/BerriAI/litellm/issues/17473
470+
where network fragmentation can cause SSE data to arrive in partial chunks.
471+
"""
472+
iterator = ModelResponseIterator(
473+
streaming_response=MagicMock(), sync_stream=True, json_mode=False
474+
)
475+
476+
# Simulate a complete JSON chunk being split into two parts
477+
partial_chunk_1 = '{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hel'
478+
partial_chunk_2 = 'lo"}}'
479+
480+
# First partial chunk should return None (still accumulating)
481+
result1 = iterator._parse_sse_data(f"data:{partial_chunk_1}")
482+
assert result1 is None, "First partial chunk should return None while accumulating"
483+
assert iterator.chunk_type == "accumulated_json", "Should switch to accumulated_json mode"
484+
assert iterator.accumulated_json == partial_chunk_1, "Should have accumulated first part"
485+
486+
# Second partial chunk should complete the JSON and return a parsed result
487+
result2 = iterator._parse_sse_data(f"data:{partial_chunk_2}")
488+
assert result2 is not None, "Second chunk should return parsed result"
489+
assert iterator.accumulated_json == "", "Buffer should be cleared after successful parse"
490+
assert result2.choices[0].delta.content == "Hello", f"Expected 'Hello', got '{result2.choices[0].delta.content}'"
491+
492+
493+
def test_complete_json_chunk_no_accumulation():
494+
"""
495+
Test that complete JSON chunks are parsed immediately without accumulation.
496+
"""
497+
iterator = ModelResponseIterator(
498+
streaming_response=MagicMock(), sync_stream=True, json_mode=False
499+
)
500+
501+
complete_chunk = '{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}'
502+
503+
result = iterator._parse_sse_data(f"data:{complete_chunk}")
504+
assert result is not None, "Complete chunk should return parsed result immediately"
505+
assert iterator.chunk_type == "valid_json", "Should remain in valid_json mode"
506+
assert iterator.accumulated_json == "", "Buffer should remain empty"
507+
assert result.choices[0].delta.content == "Hello", f"Expected 'Hello', got '{result.choices[0].delta.content}'"
508+
509+
510+
def test_multiple_partial_chunks_accumulation():
511+
"""
512+
Test that multiple partial chunks can be accumulated across several iterations.
513+
"""
514+
iterator = ModelResponseIterator(
515+
streaming_response=MagicMock(), sync_stream=True, json_mode=False
516+
)
517+
518+
# Split a JSON chunk into three parts
519+
part1 = '{"type":"content_block_del'
520+
part2 = 'ta","index":0,"delta":{"type":"text_del'
521+
part3 = 'ta","text":"Hello"}}'
522+
523+
result1 = iterator._parse_sse_data(f"data:{part1}")
524+
assert result1 is None
525+
assert iterator.accumulated_json == part1
526+
527+
result2 = iterator._parse_sse_data(f"data:{part2}")
528+
assert result2 is None
529+
assert iterator.accumulated_json == part1 + part2
530+
531+
result3 = iterator._parse_sse_data(f"data:{part3}")
532+
assert result3 is not None
533+
assert iterator.accumulated_json == ""
534+
assert result3.choices[0].delta.content == "Hello"

0 commit comments

Comments
 (0)