|
1 | 1 | package openai_test
|
2 | 2 |
|
3 | 3 | import (
|
4 |
| - . "github.com/sashabaranov/go-openai" |
5 |
| - "github.com/sashabaranov/go-openai/internal/test/checks" |
6 |
| - |
7 | 4 | "context"
|
8 | 5 | "encoding/json"
|
9 | 6 | "errors"
|
| 7 | + "fmt" |
10 | 8 | "io"
|
11 | 9 | "net/http"
|
| 10 | + "strconv" |
12 | 11 | "testing"
|
| 12 | + |
| 13 | + . "github.com/sashabaranov/go-openai" |
| 14 | + "github.com/sashabaranov/go-openai/internal/test/checks" |
13 | 15 | )
|
14 | 16 |
|
15 | 17 | func TestChatCompletionsStreamWrongModel(t *testing.T) {
|
@@ -178,6 +180,87 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
|
178 | 180 | t.Logf("%+v\n", apiErr)
|
179 | 181 | }
|
180 | 182 |
|
| 183 | +func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { |
| 184 | + client, server, teardown := setupOpenAITestServer() |
| 185 | + defer teardown() |
| 186 | + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { |
| 187 | + w.Header().Set("Content-Type", "text/event-stream") |
| 188 | + w.Header().Set(xCustomHeader, xCustomHeaderValue) |
| 189 | + |
| 190 | + // Send test responses |
| 191 | + //nolint:lll |
| 192 | + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) |
| 193 | + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) |
| 194 | + |
| 195 | + _, err := w.Write(dataBytes) |
| 196 | + checks.NoError(t, err, "Write error") |
| 197 | + }) |
| 198 | + |
| 199 | + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ |
| 200 | + MaxTokens: 5, |
| 201 | + Model: GPT3Dot5Turbo, |
| 202 | + Messages: []ChatCompletionMessage{ |
| 203 | + { |
| 204 | + Role: ChatMessageRoleUser, |
| 205 | + Content: "Hello!", |
| 206 | + }, |
| 207 | + }, |
| 208 | + Stream: true, |
| 209 | + }) |
| 210 | + checks.NoError(t, err, "CreateCompletionStream returned error") |
| 211 | + defer stream.Close() |
| 212 | + |
| 213 | + value := stream.Header().Get(xCustomHeader) |
| 214 | + if value != xCustomHeaderValue { |
| 215 | + t.Errorf("expected %s to be %s", xCustomHeaderValue, value) |
| 216 | + } |
| 217 | +} |
| 218 | + |
| 219 | +func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { |
| 220 | + client, server, teardown := setupOpenAITestServer() |
| 221 | + defer teardown() |
| 222 | + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { |
| 223 | + w.Header().Set("Content-Type", "text/event-stream") |
| 224 | + for k, v := range rateLimitHeaders { |
| 225 | + switch val := v.(type) { |
| 226 | + case int: |
| 227 | + w.Header().Set(k, strconv.Itoa(val)) |
| 228 | + default: |
| 229 | + w.Header().Set(k, fmt.Sprintf("%s", v)) |
| 230 | + } |
| 231 | + } |
| 232 | + |
| 233 | + // Send test responses |
| 234 | + //nolint:lll |
| 235 | + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) |
| 236 | + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) |
| 237 | + |
| 238 | + _, err := w.Write(dataBytes) |
| 239 | + checks.NoError(t, err, "Write error") |
| 240 | + }) |
| 241 | + |
| 242 | + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ |
| 243 | + MaxTokens: 5, |
| 244 | + Model: GPT3Dot5Turbo, |
| 245 | + Messages: []ChatCompletionMessage{ |
| 246 | + { |
| 247 | + Role: ChatMessageRoleUser, |
| 248 | + Content: "Hello!", |
| 249 | + }, |
| 250 | + }, |
| 251 | + Stream: true, |
| 252 | + }) |
| 253 | + checks.NoError(t, err, "CreateCompletionStream returned error") |
| 254 | + defer stream.Close() |
| 255 | + |
| 256 | + headers := stream.GetRateLimitHeaders() |
| 257 | + bs1, _ := json.Marshal(headers) |
| 258 | + bs2, _ := json.Marshal(rateLimitHeaders) |
| 259 | + if string(bs1) != string(bs2) { |
| 260 | + t.Errorf("expected rate limit header %s to be %s", bs2, bs1) |
| 261 | + } |
| 262 | +} |
| 263 | + |
181 | 264 | func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
|
182 | 265 | client, server, teardown := setupOpenAITestServer()
|
183 | 266 | defer teardown()
|
|
0 commit comments