|
10 | 10 | Callable, |
11 | 11 | Dict, |
12 | 12 | List, |
| 13 | + Literal, |
13 | 14 | Optional, |
14 | 15 | Tuple, |
15 | 16 | Union, |
@@ -498,6 +499,11 @@ def __init__( |
498 | 499 | # Track if we've converted any response_format tools (affects finish_reason) |
499 | 500 | self.converted_response_format_tool: bool = False |
500 | 501 |
|
| 502 | + # For handling partial JSON chunks from fragmentation |
| 503 | + # See: https://github.com/BerriAI/litellm/issues/17473 |
| 504 | + self.accumulated_json: str = "" |
| 505 | + self.chunk_type: Literal["valid_json", "accumulated_json"] = "valid_json" |
| 506 | + |
501 | 507 | def check_empty_tool_call_args(self) -> bool: |
502 | 508 | """ |
503 | 509 | Check if the tool call block so far has been an empty string |
@@ -866,80 +872,154 @@ def _handle_message_delta(self, chunk: dict) -> Tuple[str, Optional[Usage]]: |
866 | 872 | usage = self._handle_usage(anthropic_usage_chunk=message_delta["usage"]) |
867 | 873 | return finish_reason, usage |
868 | 874 |
|
| 875 | + def _handle_accumulated_json_chunk( |
| 876 | + self, data_str: str |
| 877 | + ) -> Optional[GenericStreamingChunk]: |
| 878 | + """ |
| 879 | + Handle partial JSON chunks by accumulating them until valid JSON is received. |
| 880 | +
|
| 881 | + This fixes network fragmentation issues where SSE data chunks may be split |
| 882 | + across TCP packets. See: https://github.com/BerriAI/litellm/issues/17473 |
| 883 | +
|
| 884 | + Args: |
| 885 | + data_str: The JSON string to parse (without "data:" prefix) |
| 886 | +
|
| 887 | + Returns: |
| 888 | + GenericStreamingChunk if JSON is complete, None if still accumulating |
| 889 | + """ |
| 890 | + # Accumulate JSON data |
| 891 | + self.accumulated_json += data_str |
| 892 | + |
| 893 | + # Try to parse the accumulated JSON |
| 894 | + try: |
| 895 | + data_json = json.loads(self.accumulated_json) |
| 896 | + self.accumulated_json = "" # Reset after successful parsing |
| 897 | + return self.chunk_parser(chunk=data_json) |
| 898 | + except json.JSONDecodeError: |
| 899 | + # If it's not valid JSON yet, continue to the next chunk |
| 900 | + return None |
| 901 | + |
| 902 | + def _parse_sse_data(self, str_line: str) -> Optional[GenericStreamingChunk]: |
| 903 | + """ |
| 904 | + Parse SSE data line, handling both complete and partial JSON chunks. |
| 905 | +
|
| 906 | + Args: |
| 907 | + str_line: The SSE line starting with "data:" |
| 908 | +
|
| 909 | + Returns: |
| 910 | + GenericStreamingChunk if parsing succeeded, None if accumulating partial JSON |
| 911 | + """ |
| 912 | + data_str = str_line[5:] # Remove "data:" prefix |
| 913 | + |
| 914 | + if self.chunk_type == "accumulated_json": |
| 915 | + # Already in accumulation mode, keep accumulating |
| 916 | + return self._handle_accumulated_json_chunk(data_str) |
| 917 | + |
| 918 | + # Try to parse as valid JSON first |
| 919 | + try: |
| 920 | + data_json = json.loads(data_str) |
| 921 | + return self.chunk_parser(chunk=data_json) |
| 922 | + except json.JSONDecodeError: |
| 923 | + # Switch to accumulation mode and start accumulating |
| 924 | + self.chunk_type = "accumulated_json" |
| 925 | + return self._handle_accumulated_json_chunk(data_str) |
| 926 | + |
869 | 927 | # Sync iterator |
870 | 928 | def __iter__(self): |
871 | 929 | return self |
872 | 930 |
|
873 | 931 | def __next__(self): |
874 | | - try: |
875 | | - chunk = self.response_iterator.__next__() |
876 | | - except StopIteration: |
877 | | - raise StopIteration |
878 | | - except ValueError as e: |
879 | | - raise RuntimeError(f"Error receiving chunk from stream: {e}") |
880 | | - |
881 | | - try: |
882 | | - str_line = chunk |
883 | | - if isinstance(chunk, bytes): # Handle binary data |
884 | | - str_line = chunk.decode("utf-8") # Convert bytes to string |
885 | | - index = str_line.find("data:") |
886 | | - if index != -1: |
887 | | - str_line = str_line[index:] |
888 | | - |
889 | | - if str_line.startswith("data:"): |
890 | | - data_json = json.loads(str_line[5:]) |
891 | | - return self.chunk_parser(chunk=data_json) |
892 | | - else: |
893 | | - return GenericStreamingChunk( |
894 | | - text="", |
895 | | - is_finished=False, |
896 | | - finish_reason="", |
897 | | - usage=None, |
898 | | - index=0, |
899 | | - tool_use=None, |
900 | | - ) |
901 | | - except StopIteration: |
902 | | - raise StopIteration |
903 | | - except ValueError as e: |
904 | | - raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") |
| 932 | + while True: |
| 933 | + try: |
| 934 | + chunk = self.response_iterator.__next__() |
| 935 | + except StopIteration: |
| 936 | + # If we have accumulated JSON when stream ends, try to parse it |
| 937 | + if self.accumulated_json: |
| 938 | + try: |
| 939 | + data_json = json.loads(self.accumulated_json) |
| 940 | + self.accumulated_json = "" |
| 941 | + return self.chunk_parser(chunk=data_json) |
| 942 | + except json.JSONDecodeError: |
| 943 | + pass |
| 944 | + raise StopIteration |
| 945 | + except ValueError as e: |
| 946 | + raise RuntimeError(f"Error receiving chunk from stream: {e}") |
| 947 | + |
| 948 | + try: |
| 949 | + str_line = chunk |
| 950 | + if isinstance(chunk, bytes): # Handle binary data |
| 951 | + str_line = chunk.decode("utf-8") # Convert bytes to string |
| 952 | + index = str_line.find("data:") |
| 953 | + if index != -1: |
| 954 | + str_line = str_line[index:] |
| 955 | + |
| 956 | + if str_line.startswith("data:"): |
| 957 | + result = self._parse_sse_data(str_line) |
| 958 | + if result is not None: |
| 959 | + return result |
| 960 | + # If None, continue loop to get more chunks for accumulation |
| 961 | + else: |
| 962 | + return GenericStreamingChunk( |
| 963 | + text="", |
| 964 | + is_finished=False, |
| 965 | + finish_reason="", |
| 966 | + usage=None, |
| 967 | + index=0, |
| 968 | + tool_use=None, |
| 969 | + ) |
| 970 | + except StopIteration: |
| 971 | + raise StopIteration |
| 972 | + except ValueError as e: |
| 973 | + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") |
905 | 974 |
|
906 | 975 | # Async iterator |
907 | 976 | def __aiter__(self): |
908 | 977 | self.async_response_iterator = self.streaming_response.__aiter__() |
909 | 978 | return self |
910 | 979 |
|
911 | 980 | async def __anext__(self): |
912 | | - try: |
913 | | - chunk = await self.async_response_iterator.__anext__() |
914 | | - except StopAsyncIteration: |
915 | | - raise StopAsyncIteration |
916 | | - except ValueError as e: |
917 | | - raise RuntimeError(f"Error receiving chunk from stream: {e}") |
918 | | - |
919 | | - try: |
920 | | - str_line = chunk |
921 | | - if isinstance(chunk, bytes): # Handle binary data |
922 | | - str_line = chunk.decode("utf-8") # Convert bytes to string |
923 | | - index = str_line.find("data:") |
924 | | - if index != -1: |
925 | | - str_line = str_line[index:] |
926 | | - |
927 | | - if str_line.startswith("data:"): |
928 | | - data_json = json.loads(str_line[5:]) |
929 | | - return self.chunk_parser(chunk=data_json) |
930 | | - else: |
931 | | - return GenericStreamingChunk( |
932 | | - text="", |
933 | | - is_finished=False, |
934 | | - finish_reason="", |
935 | | - usage=None, |
936 | | - index=0, |
937 | | - tool_use=None, |
938 | | - ) |
939 | | - except StopAsyncIteration: |
940 | | - raise StopAsyncIteration |
941 | | - except ValueError as e: |
942 | | - raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") |
| 981 | + while True: |
| 982 | + try: |
| 983 | + chunk = await self.async_response_iterator.__anext__() |
| 984 | + except StopAsyncIteration: |
| 985 | + # If we have accumulated JSON when stream ends, try to parse it |
| 986 | + if self.accumulated_json: |
| 987 | + try: |
| 988 | + data_json = json.loads(self.accumulated_json) |
| 989 | + self.accumulated_json = "" |
| 990 | + return self.chunk_parser(chunk=data_json) |
| 991 | + except json.JSONDecodeError: |
| 992 | + pass |
| 993 | + raise StopAsyncIteration |
| 994 | + except ValueError as e: |
| 995 | + raise RuntimeError(f"Error receiving chunk from stream: {e}") |
| 996 | + |
| 997 | + try: |
| 998 | + str_line = chunk |
| 999 | + if isinstance(chunk, bytes): # Handle binary data |
| 1000 | + str_line = chunk.decode("utf-8") # Convert bytes to string |
| 1001 | + index = str_line.find("data:") |
| 1002 | + if index != -1: |
| 1003 | + str_line = str_line[index:] |
| 1004 | + |
| 1005 | + if str_line.startswith("data:"): |
| 1006 | + result = self._parse_sse_data(str_line) |
| 1007 | + if result is not None: |
| 1008 | + return result |
| 1009 | + # If None, continue loop to get more chunks for accumulation |
| 1010 | + else: |
| 1011 | + return GenericStreamingChunk( |
| 1012 | + text="", |
| 1013 | + is_finished=False, |
| 1014 | + finish_reason="", |
| 1015 | + usage=None, |
| 1016 | + index=0, |
| 1017 | + tool_use=None, |
| 1018 | + ) |
| 1019 | + except StopAsyncIteration: |
| 1020 | + raise StopAsyncIteration |
| 1021 | + except ValueError as e: |
| 1022 | + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") |
943 | 1023 |
|
944 | 1024 | def convert_str_chunk_to_generic_chunk(self, chunk: str) -> ModelResponseStream: |
945 | 1025 | """ |
|
0 commit comments