Skip to content

Commit 370f947

Browse files
committed
Add dependency injection to ChatCompletionStream for improved testability
**Describe the change** This PR refactors the `ChatCompletionStream` to use dependency injection by introducing a `ChatStreamReader` interface. This allows for injecting custom stream readers, primarily for testing purposes, making the streaming functionality more testable and maintainable. **Provide OpenAI documentation link** https://platform.openai.com/docs/api-reference/chat/create **Describe your solution** The changes include: - Added a `ChatStreamReader` interface that defines the contract for reading chat completion streams - Refactored `ChatCompletionStream` to use composition with a `ChatStreamReader` instead of embedding `streamReader` - Added `NewChatCompletionStream()` constructor function to enable dependency injection - Implemented explicit delegation methods (`Recv()`, `Close()`, `Header()`, `GetRateLimitHeaders()`) on `ChatCompletionStream` - Added interface compliance check via `var _ ChatStreamReader = (*streamReader[ChatCompletionStreamResponse])(nil)` This approach maintains backward compatibility while enabling easier mocking and testing of streaming functionality. **Tests** Added comprehensive tests demonstrating the new functionality: - `TestChatCompletionStream_MockInjection`: Tests basic mock injection with the new constructor - `mock_streaming_demo_test.go`: A complete demonstration file showing how to create mock clients and stream readers for testing, including: - `MockOpenAIStreamClient`: Full mock client implementation - `mockStreamReader`: Custom stream reader for controlled test responses - `TestMockOpenAIStreamClient_Demo`: Demonstrates assembling multiple stream chunks - `TestMockOpenAIStreamClient_ErrorHandling`: Shows error handling patterns **Additional context** This refactoring improves the testability of code that depends on go-openai streaming without introducing breaking changes. The existing public API remains unchanged, but now supports dependency injection for testing scenarios. The new demo test file serves as documentation for users who want to mock streaming responses in their own tests.
1 parent ff9d83a commit 370f947

File tree

4 files changed

+270
-2
lines changed

4 files changed

+270
-2
lines changed

chat_stream.go

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,21 @@ type ChatCompletionStreamResponse struct {
6565
Usage *Usage `json:"usage,omitempty"`
6666
}
6767

68+
// ChatStreamReader is an interface for reading chat completion streams.
69+
type ChatStreamReader interface {
70+
Recv() (ChatCompletionStreamResponse, error)
71+
Close() error
72+
}
73+
6874
// ChatCompletionStream
6975
// Note: Perhaps it is more elegant to abstract Stream using generics.
7076
type ChatCompletionStream struct {
71-
*streamReader[ChatCompletionStreamResponse]
77+
reader ChatStreamReader
78+
}
79+
80+
// NewChatCompletionStream allows injecting a custom ChatStreamReader (for testing).
81+
func NewChatCompletionStream(reader ChatStreamReader) *ChatCompletionStream {
82+
return &ChatCompletionStream{reader: reader}
7283
}
7384

7485
// CreateChatCompletionStream — API call to create a chat completion w/ streaming
@@ -106,7 +117,37 @@ func (c *Client) CreateChatCompletionStream(
106117
return
107118
}
108119
stream = &ChatCompletionStream{
109-
streamReader: resp,
120+
reader: resp,
110121
}
111122
return
112123
}
124+
125+
func (s *ChatCompletionStream) Recv() (ChatCompletionStreamResponse, error) {
126+
return s.reader.Recv()
127+
}
128+
129+
func (s *ChatCompletionStream) Close() error {
130+
return s.reader.Close()
131+
}
132+
133+
func (s *ChatCompletionStream) Header() http.Header {
134+
if h, ok := s.reader.(interface{ Header() http.Header }); ok {
135+
return h.Header()
136+
}
137+
return http.Header{}
138+
}
139+
140+
func (s *ChatCompletionStream) GetRateLimitHeaders() map[string]interface{} {
141+
if h, ok := s.reader.(interface{ GetRateLimitHeaders() RateLimitHeaders }); ok {
142+
headers := h.GetRateLimitHeaders()
143+
return map[string]interface{}{
144+
"x-ratelimit-limit-requests": headers.LimitRequests,
145+
"x-ratelimit-limit-tokens": headers.LimitTokens,
146+
"x-ratelimit-remaining-requests": headers.RemainingRequests,
147+
"x-ratelimit-remaining-tokens": headers.RemainingTokens,
148+
"x-ratelimit-reset-requests": headers.ResetRequests.String(),
149+
"x-ratelimit-reset-tokens": headers.ResetTokens.String(),
150+
}
151+
}
152+
return map[string]interface{}{}
153+
}

chat_stream_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,34 @@ func TestCreateChatCompletionStreamStreamOptions(t *testing.T) {
767767
}
768768
}
769769

770+
type mockStream struct {
771+
calls int
772+
}
773+
774+
// Implement ChatStreamReader
775+
func (m *mockStream) Recv() (openai.ChatCompletionStreamResponse, error) {
776+
m.calls++
777+
if m.calls == 1 {
778+
return openai.ChatCompletionStreamResponse{ID: "mock1"}, nil
779+
}
780+
return openai.ChatCompletionStreamResponse{}, io.EOF
781+
}
782+
func (m *mockStream) Close() error { return nil }
783+
784+
func TestChatCompletionStream_MockInjection(t *testing.T) {
785+
mock := &mockStream{}
786+
stream := openai.NewChatCompletionStream(mock)
787+
788+
resp, err := stream.Recv()
789+
if err != nil || resp.ID != "mock1" {
790+
t.Errorf("expected mock1, got %v, err %v", resp.ID, err)
791+
}
792+
_, err = stream.Recv()
793+
if err != io.EOF {
794+
t.Errorf("expected EOF, got %v", err)
795+
}
796+
}
797+
770798
// Helper funcs.
771799
func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool {
772800
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {

mock_streaming_demo_test.go

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
package openai
2+
3+
import (
4+
"context"
5+
"errors"
6+
"io"
7+
"testing"
8+
)
9+
10+
// This file demonstrates how to create mock clients for go-openai streaming
11+
// functionality. This pattern is useful when testing code that depends on
12+
// go-openai streaming but you want to control the responses for testing.
13+
14+
// MockOpenAIStreamClient demonstrates how to create a full mock client for go-openai
15+
type MockOpenAIStreamClient struct {
16+
// Configure canned responses
17+
ChatCompletionResponse ChatCompletionResponse
18+
ChatCompletionStreamErr error
19+
20+
// Allow function overrides for more complex scenarios
21+
CreateChatCompletionStreamFn func(
22+
ctx context.Context, req ChatCompletionRequest) (*ChatCompletionStream, error)
23+
}
24+
25+
func (m *MockOpenAIStreamClient) CreateChatCompletionStream(
26+
ctx context.Context,
27+
req ChatCompletionRequest,
28+
) (*ChatCompletionStream, error) {
29+
if m.CreateChatCompletionStreamFn != nil {
30+
return m.CreateChatCompletionStreamFn(ctx, req)
31+
}
32+
return nil, m.ChatCompletionStreamErr
33+
}
34+
35+
// mockStreamReader creates specific responses for testing
36+
type mockStreamReader struct {
37+
responses []ChatCompletionStreamResponse
38+
index int
39+
}
40+
41+
func (m *mockStreamReader) Recv() (ChatCompletionStreamResponse, error) {
42+
if m.index >= len(m.responses) {
43+
return ChatCompletionStreamResponse{}, io.EOF
44+
}
45+
resp := m.responses[m.index]
46+
m.index++
47+
return resp, nil
48+
}
49+
50+
func (m *mockStreamReader) Close() error {
51+
return nil
52+
}
53+
54+
func TestMockOpenAIStreamClient_Demo(t *testing.T) {
55+
// Create expected responses that our mock stream will return
56+
expectedResponses := []ChatCompletionStreamResponse{
57+
{
58+
ID: "test-1",
59+
Object: "chat.completion.chunk",
60+
Model: "gpt-3.5-turbo",
61+
Choices: []ChatCompletionStreamChoice{
62+
{
63+
Index: 0,
64+
Delta: ChatCompletionStreamChoiceDelta{
65+
Role: "assistant",
66+
Content: "Hello",
67+
},
68+
},
69+
},
70+
},
71+
{
72+
ID: "test-2",
73+
Object: "chat.completion.chunk",
74+
Model: "gpt-3.5-turbo",
75+
Choices: []ChatCompletionStreamChoice{
76+
{
77+
Index: 0,
78+
Delta: ChatCompletionStreamChoiceDelta{
79+
Content: " World",
80+
},
81+
},
82+
},
83+
},
84+
{
85+
ID: "test-3",
86+
Object: "chat.completion.chunk",
87+
Model: "gpt-3.5-turbo",
88+
Choices: []ChatCompletionStreamChoice{
89+
{
90+
Index: 0,
91+
Delta: ChatCompletionStreamChoiceDelta{},
92+
FinishReason: "stop",
93+
},
94+
},
95+
},
96+
}
97+
98+
// Create mock client with custom stream function
99+
mockClient := &MockOpenAIStreamClient{
100+
CreateChatCompletionStreamFn: func(
101+
ctx context.Context, req ChatCompletionRequest,
102+
) (*ChatCompletionStream, error) {
103+
// Create a mock stream reader with our expected responses
104+
mockStreamReader := &mockStreamReader{
105+
responses: expectedResponses,
106+
index: 0,
107+
}
108+
// Return a new ChatCompletionStream with our mock reader
109+
return NewChatCompletionStream(mockStreamReader), nil
110+
},
111+
}
112+
113+
// Test the mock client
114+
stream, err := mockClient.CreateChatCompletionStream(
115+
context.Background(),
116+
ChatCompletionRequest{
117+
Model: GPT3Dot5Turbo,
118+
Messages: []ChatCompletionMessage{
119+
{
120+
Role: ChatMessageRoleUser,
121+
Content: "Hello!",
122+
},
123+
},
124+
},
125+
)
126+
if err != nil {
127+
t.Fatalf("CreateChatCompletionStream returned error: %v", err)
128+
}
129+
defer stream.Close()
130+
131+
// Verify we get back exactly the responses we configured
132+
fullResponse := ""
133+
for i, expectedResponse := range expectedResponses {
134+
receivedResponse, streamErr := stream.Recv()
135+
if streamErr != nil {
136+
t.Fatalf("stream.Recv() failed at index %d: %v", i, streamErr)
137+
}
138+
139+
// Additional specific checks
140+
if receivedResponse.ID != expectedResponse.ID {
141+
t.Errorf("Response %d ID mismatch. Expected: %s, Got: %s",
142+
i, expectedResponse.ID, receivedResponse.ID)
143+
}
144+
if len(receivedResponse.Choices) > 0 && len(expectedResponse.Choices) > 0 {
145+
expectedContent := expectedResponse.Choices[0].Delta.Content
146+
receivedContent := receivedResponse.Choices[0].Delta.Content
147+
if receivedContent != expectedContent {
148+
t.Errorf("Response %d content mismatch. Expected: %s, Got: %s",
149+
i, expectedContent, receivedContent)
150+
}
151+
fullResponse += receivedContent
152+
}
153+
}
154+
155+
// Verify EOF at the end
156+
_, streamErr := stream.Recv()
157+
if streamErr != io.EOF {
158+
t.Errorf("Expected EOF at end of stream, got: %v", streamErr)
159+
}
160+
161+
// Verify the full assembled response
162+
expectedFullResponse := "Hello World"
163+
if fullResponse != expectedFullResponse {
164+
t.Errorf("Full response mismatch. Expected: %s, Got: %s", expectedFullResponse, fullResponse)
165+
}
166+
167+
t.Log("✅ Successfully demonstrated mock OpenAI client with streaming responses!")
168+
t.Logf(" Full response assembled: %q", fullResponse)
169+
}
170+
171+
// TestMockOpenAIStreamClient_ErrorHandling demonstrates error handling
172+
func TestMockOpenAIStreamClient_ErrorHandling(t *testing.T) {
173+
expectedError := errors.New("mock stream error")
174+
175+
mockClient := &MockOpenAIStreamClient{
176+
ChatCompletionStreamErr: expectedError,
177+
}
178+
179+
_, err := mockClient.CreateChatCompletionStream(
180+
context.Background(),
181+
ChatCompletionRequest{
182+
Model: GPT3Dot5Turbo,
183+
Messages: []ChatCompletionMessage{
184+
{
185+
Role: ChatMessageRoleUser,
186+
Content: "Hello!",
187+
},
188+
},
189+
},
190+
)
191+
192+
if err != expectedError {
193+
t.Errorf("Expected error %v, got %v", expectedError, err)
194+
}
195+
196+
t.Log("✅ Successfully demonstrated mock OpenAI client error handling!")
197+
}

stream_reader.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ var (
1616
errorPrefix = regexp.MustCompile(`^data:\s*{"error":`)
1717
)
1818

19+
var _ ChatStreamReader = (*streamReader[ChatCompletionStreamResponse])(nil)
20+
1921
type streamable interface {
2022
ChatCompletionStreamResponse | CompletionResponse
2123
}

0 commit comments

Comments
 (0)