diff --git a/stream_reader.go b/stream_reader.go index 6faefe0a7..deb3dad49 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -3,6 +3,8 @@ package openai import ( "bufio" "bytes" + "encoding/json" + "errors" "fmt" "io" "net/http" @@ -40,6 +42,14 @@ func (stream *streamReader[T]) Recv() (response T, err error) { err = stream.unmarshaler.Unmarshal(rawLine, &response) if err != nil { + // If we get a JSON parsing error, it might be because we got an error event + // Check if we have accumulated error data + var syntaxErr *json.SyntaxError + if errors.As(err, &syntaxErr) && len(stream.errAccumulator.Bytes()) > 0 { + // We have error data, return a more informative error + return response, fmt.Errorf("failed to parse response (error event received): %s", + string(stream.errAccumulator.Bytes())) + } return } return response, nil @@ -65,7 +75,18 @@ func (stream *streamReader[T]) processLines() ([]byte, error) { if readErr != nil || hasErrorPrefix { respErr := stream.unmarshalError() if respErr != nil { - return nil, fmt.Errorf("error, %w", respErr.Error) + return nil, respErr.Error + } + // If we detected an error event but couldn't parse it, and the stream ended, + // return a more informative error. This handles cases where providers send + // error events that don't match the expected format and immediately close. + if hasErrorPrefix && readErr == io.EOF { + // Check if we have error data that failed to parse + errBytes := stream.errAccumulator.Bytes() + if len(errBytes) > 0 { + return nil, fmt.Errorf("failed to parse error event: %s", string(errBytes)) + } + return nil, fmt.Errorf("stream ended after error event") } return nil, readErr } @@ -73,20 +94,24 @@ func (stream *streamReader[T]) processLines() ([]byte, error) { noSpaceLine := bytes.TrimSpace(rawLine) if errorPrefix.Match(noSpaceLine) { hasErrorPrefix = true - } - if !headerData.Match(noSpaceLine) || hasErrorPrefix { - if hasErrorPrefix { - noSpaceLine = headerData.ReplaceAll(noSpaceLine, nil) - } - writeErr := stream.errAccumulator.Write(noSpaceLine) + // Extract just the JSON part after "data: " prefix + // This handles both OpenAI format (data: {"error": ...}) and + // Groq format (event: error\ndata: {"error": ...}) + jsonData := headerData.ReplaceAll(noSpaceLine, nil) + writeErr := stream.errAccumulator.Write(jsonData) if writeErr != nil { return nil, writeErr } + continue + } + + // Skip non-data lines (e.g., "event: error" from Groq) + // This allows us to handle SSE streams that use explicit event types + if !headerData.Match(noSpaceLine) { emptyMessagesCount++ if emptyMessagesCount > stream.emptyMessagesLimit { return nil, ErrTooManyEmptyStreamMessages } - continue } @@ -111,6 +136,10 @@ func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) { errResp = nil } + // Reset the error accumulator for future error events + // A new accumulator is created to avoid potential interface issues + stream.errAccumulator = utils.NewErrorAccumulator() + return } diff --git a/stream_reader_test.go b/stream_reader_test.go index 449a14b43..49dec7ee5 100644 --- a/stream_reader_test.go +++ b/stream_reader_test.go @@ -54,11 +54,12 @@ func TestStreamReaderReturnsErrTooManyEmptyStreamMessages(t *testing.T) { func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) { stream := &streamReader[ChatCompletionStreamResponse]{ - reader: bufio.NewReader(bytes.NewReader([]byte("\n"))), + reader: bufio.NewReader(bytes.NewReader([]byte("data: {\"error\": {\"message\": \"test error\"}}\n"))), errAccumulator: &utils.DefaultErrorAccumulator{ Buffer: &test.FailingErrorBuffer{}, }, - unmarshaler: &utils.JSONUnmarshaler{}, + unmarshaler: &utils.JSONUnmarshaler{}, + emptyMessagesLimit: 5, } _, err := stream.Recv() checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error()) @@ -76,3 +77,126 @@ func TestStreamReaderRecvRaw(t *testing.T) { t.Fatalf("Did not return raw line: %v", string(rawLine)) } } + +func TestStreamReaderParsesErrorEvents(t *testing.T) { + // Test case simulating Groq's error event format + errorEvent := `event: error +data: {"error":{"message":"Invalid tool_call: tool \"name_unknown\" does not exist.",` + + `"type":"invalid_request_error","code":"invalid_tool_call"}} + +` + stream := &streamReader[ChatCompletionStreamResponse]{ + reader: bufio.NewReader(bytes.NewReader([]byte(errorEvent))), + errAccumulator: utils.NewErrorAccumulator(), + unmarshaler: &utils.JSONUnmarshaler{}, + emptyMessagesLimit: 5, + } + + // Process the error event + _, err := stream.Recv() + if err == nil { + t.Fatal("Expected error but got nil") + } + + // Verify it's an APIError + var apiErr *APIError + if !errors.As(err, &apiErr) { + t.Fatalf("Expected APIError type but got %T: %v", err, err) + } + + // Verify the error fields are correctly parsed + if apiErr.Message != "Invalid tool_call: tool \"name_unknown\" does not exist." { + t.Errorf("Unexpected error message: %s", apiErr.Message) + } + if apiErr.Type != "invalid_request_error" { + t.Errorf("Unexpected error type: %s", apiErr.Type) + } + if apiErr.Code != "invalid_tool_call" { + t.Errorf("Unexpected error code: %v", apiErr.Code) + } +} + +func TestStreamReaderHandlesErrorEventWithExtraData(t *testing.T) { + // Test case with error event followed by more data + errorEvent := `data: {"id":"chatcmpl-123","choices":[{"delta":{"content":"Hello"}}]} +event: error +data: {"error":{"message":"Stream interrupted","type":"server_error"}} +data: [DONE] +` + stream := &streamReader[ChatCompletionStreamResponse]{ + reader: bufio.NewReader(bytes.NewReader([]byte(errorEvent))), + errAccumulator: utils.NewErrorAccumulator(), + unmarshaler: &utils.JSONUnmarshaler{}, + emptyMessagesLimit: 5, + } + + // First recv should return the chat completion + resp, err := stream.Recv() + if err != nil { + t.Fatalf("First recv failed: %v", err) + } + if resp.ID != "chatcmpl-123" { + t.Errorf("Unexpected response ID: %s", resp.ID) + } + + // Second recv should return the error + _, err = stream.Recv() + if err == nil { + t.Fatal("Expected error but got nil") + } + + // Verify it's an APIError + var apiErr *APIError + if !errors.As(err, &apiErr) { + t.Fatalf("Expected APIError type but got %T: %v", err, err) + } + + if apiErr.Message != "Stream interrupted" { + t.Errorf("Unexpected error message: %s", apiErr.Message) + } +} + +func TestStreamReaderResetsErrorAccumulator(t *testing.T) { + // Test that error accumulator is reset after processing an error + multipleErrors := `event: error +data: {"error":{"message":"First error","type":"error_type_1"}} + +event: error +data: {"error":{"message":"Second error","type":"error_type_2"}} +` + stream := &streamReader[ChatCompletionStreamResponse]{ + reader: bufio.NewReader(bytes.NewReader([]byte(multipleErrors))), + errAccumulator: utils.NewErrorAccumulator(), + unmarshaler: &utils.JSONUnmarshaler{}, + emptyMessagesLimit: 5, + } + + // First recv should return the first error + _, err1 := stream.Recv() + if err1 == nil { + t.Fatal("Expected first error but got nil") + } + var apiErr1 *APIError + if !errors.As(err1, &apiErr1) { + t.Fatalf("Expected APIError type but got %T: %v", err1, err1) + } + if apiErr1.Message != "First error" { + t.Errorf("Unexpected first error message: %s", apiErr1.Message) + } + + // Second recv should return the second error (not a concatenation) + _, err2 := stream.Recv() + if err2 == nil { + t.Fatal("Expected second error but got nil") + } + var apiErr2 *APIError + if !errors.As(err2, &apiErr2) { + t.Fatalf("Expected APIError type but got %T: %v", err2, err2) + } + if apiErr2.Message != "Second error" { + t.Errorf("Unexpected second error message: %s", apiErr2.Message) + } + if apiErr2.Type != "error_type_2" { + t.Errorf("Unexpected second error type: %s", apiErr2.Type) + } +}