Skip to content

fix: handle SSE error events from providers like Groq #1057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions stream_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package openai
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -40,6 +42,14 @@ func (stream *streamReader[T]) Recv() (response T, err error) {

err = stream.unmarshaler.Unmarshal(rawLine, &response)
if err != nil {
// If we get a JSON parsing error, it might be because we got an error event
// Check if we have accumulated error data
var syntaxErr *json.SyntaxError
if errors.As(err, &syntaxErr) && len(stream.errAccumulator.Bytes()) > 0 {
// We have error data, return a more informative error
return response, fmt.Errorf("failed to parse response (error event received): %s",
string(stream.errAccumulator.Bytes()))
}
return
}
return response, nil
Expand All @@ -65,28 +75,43 @@ func (stream *streamReader[T]) processLines() ([]byte, error) {
if readErr != nil || hasErrorPrefix {
respErr := stream.unmarshalError()
if respErr != nil {
return nil, fmt.Errorf("error, %w", respErr.Error)
return nil, respErr.Error
}
// If we detected an error event but couldn't parse it, and the stream ended,
// return a more informative error. This handles cases where providers send
// error events that don't match the expected format and immediately close.
if hasErrorPrefix && readErr == io.EOF {
// Check if we have error data that failed to parse
errBytes := stream.errAccumulator.Bytes()
if len(errBytes) > 0 {
return nil, fmt.Errorf("failed to parse error event: %s", string(errBytes))
}
return nil, fmt.Errorf("stream ended after error event")
}
return nil, readErr
}

noSpaceLine := bytes.TrimSpace(rawLine)
if errorPrefix.Match(noSpaceLine) {
hasErrorPrefix = true
}
if !headerData.Match(noSpaceLine) || hasErrorPrefix {
if hasErrorPrefix {
noSpaceLine = headerData.ReplaceAll(noSpaceLine, nil)
}
writeErr := stream.errAccumulator.Write(noSpaceLine)
// Extract just the JSON part after "data: " prefix
// This handles both OpenAI format (data: {"error": ...}) and
// Groq format (event: error\ndata: {"error": ...})
jsonData := headerData.ReplaceAll(noSpaceLine, nil)
writeErr := stream.errAccumulator.Write(jsonData)
if writeErr != nil {
return nil, writeErr
}
continue
}

// Skip non-data lines (e.g., "event: error" from Groq)
// This allows us to handle SSE streams that use explicit event types
if !headerData.Match(noSpaceLine) {
emptyMessagesCount++
if emptyMessagesCount > stream.emptyMessagesLimit {
return nil, ErrTooManyEmptyStreamMessages
}

continue
}

Expand All @@ -111,6 +136,10 @@ func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) {
errResp = nil
}

// Reset the error accumulator for future error events
// A new accumulator is created to avoid potential interface issues
stream.errAccumulator = utils.NewErrorAccumulator()

return
}

Expand Down
128 changes: 126 additions & 2 deletions stream_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,12 @@ func TestStreamReaderReturnsErrTooManyEmptyStreamMessages(t *testing.T) {

func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) {
stream := &streamReader[ChatCompletionStreamResponse]{
reader: bufio.NewReader(bytes.NewReader([]byte("\n"))),
reader: bufio.NewReader(bytes.NewReader([]byte("data: {\"error\": {\"message\": \"test error\"}}\n"))),
errAccumulator: &utils.DefaultErrorAccumulator{
Buffer: &test.FailingErrorBuffer{},
},
unmarshaler: &utils.JSONUnmarshaler{},
unmarshaler: &utils.JSONUnmarshaler{},
emptyMessagesLimit: 5,
}
_, err := stream.Recv()
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
Expand All @@ -76,3 +77,126 @@ func TestStreamReaderRecvRaw(t *testing.T) {
t.Fatalf("Did not return raw line: %v", string(rawLine))
}
}

func TestStreamReaderParsesErrorEvents(t *testing.T) {
// Test case simulating Groq's error event format
errorEvent := `event: error
data: {"error":{"message":"Invalid tool_call: tool \"name_unknown\" does not exist.",` +
`"type":"invalid_request_error","code":"invalid_tool_call"}}

`
stream := &streamReader[ChatCompletionStreamResponse]{
reader: bufio.NewReader(bytes.NewReader([]byte(errorEvent))),
errAccumulator: utils.NewErrorAccumulator(),
unmarshaler: &utils.JSONUnmarshaler{},
emptyMessagesLimit: 5,
}

// Process the error event
_, err := stream.Recv()
if err == nil {
t.Fatal("Expected error but got nil")
}

// Verify it's an APIError
var apiErr *APIError
if !errors.As(err, &apiErr) {
t.Fatalf("Expected APIError type but got %T: %v", err, err)
}

// Verify the error fields are correctly parsed
if apiErr.Message != "Invalid tool_call: tool \"name_unknown\" does not exist." {
t.Errorf("Unexpected error message: %s", apiErr.Message)
}
if apiErr.Type != "invalid_request_error" {
t.Errorf("Unexpected error type: %s", apiErr.Type)
}
if apiErr.Code != "invalid_tool_call" {
t.Errorf("Unexpected error code: %v", apiErr.Code)
}
}

func TestStreamReaderHandlesErrorEventWithExtraData(t *testing.T) {
// Test case with error event followed by more data
errorEvent := `data: {"id":"chatcmpl-123","choices":[{"delta":{"content":"Hello"}}]}
event: error
data: {"error":{"message":"Stream interrupted","type":"server_error"}}
data: [DONE]
`
stream := &streamReader[ChatCompletionStreamResponse]{
reader: bufio.NewReader(bytes.NewReader([]byte(errorEvent))),
errAccumulator: utils.NewErrorAccumulator(),
unmarshaler: &utils.JSONUnmarshaler{},
emptyMessagesLimit: 5,
}

// First recv should return the chat completion
resp, err := stream.Recv()
if err != nil {
t.Fatalf("First recv failed: %v", err)
}
if resp.ID != "chatcmpl-123" {
t.Errorf("Unexpected response ID: %s", resp.ID)
}

// Second recv should return the error
_, err = stream.Recv()
if err == nil {
t.Fatal("Expected error but got nil")
}

// Verify it's an APIError
var apiErr *APIError
if !errors.As(err, &apiErr) {
t.Fatalf("Expected APIError type but got %T: %v", err, err)
}

if apiErr.Message != "Stream interrupted" {
t.Errorf("Unexpected error message: %s", apiErr.Message)
}
}

func TestStreamReaderResetsErrorAccumulator(t *testing.T) {
// Test that error accumulator is reset after processing an error
multipleErrors := `event: error
data: {"error":{"message":"First error","type":"error_type_1"}}

event: error
data: {"error":{"message":"Second error","type":"error_type_2"}}
`
stream := &streamReader[ChatCompletionStreamResponse]{
reader: bufio.NewReader(bytes.NewReader([]byte(multipleErrors))),
errAccumulator: utils.NewErrorAccumulator(),
unmarshaler: &utils.JSONUnmarshaler{},
emptyMessagesLimit: 5,
}

// First recv should return the first error
_, err1 := stream.Recv()
if err1 == nil {
t.Fatal("Expected first error but got nil")
}
var apiErr1 *APIError
if !errors.As(err1, &apiErr1) {
t.Fatalf("Expected APIError type but got %T: %v", err1, err1)
}
if apiErr1.Message != "First error" {
t.Errorf("Unexpected first error message: %s", apiErr1.Message)
}

// Second recv should return the second error (not a concatenation)
_, err2 := stream.Recv()
if err2 == nil {
t.Fatal("Expected second error but got nil")
}
var apiErr2 *APIError
if !errors.As(err2, &apiErr2) {
t.Fatalf("Expected APIError type but got %T: %v", err2, err2)
}
if apiErr2.Message != "Second error" {
t.Errorf("Unexpected second error message: %s", apiErr2.Message)
}
if apiErr2.Type != "error_type_2" {
t.Errorf("Unexpected second error type: %s", apiErr2.Type)
}
}
Loading