|
1 | 1 | package openai_test
|
2 | 2 |
|
3 | 3 | import (
|
4 |
| - . "github.com/sashabaranov/go-openai" |
5 |
| - "github.com/sashabaranov/go-openai/internal/test" |
6 |
| - "github.com/sashabaranov/go-openai/internal/test/checks" |
7 |
| - |
8 | 4 | "context"
|
9 | 5 | "errors"
|
10 | 6 | "io"
|
11 | 7 | "net/http"
|
12 | 8 | "net/http/httptest"
|
13 | 9 | "testing"
|
| 10 | + |
| 11 | + . "github.com/sashabaranov/go-openai" |
| 12 | + "github.com/sashabaranov/go-openai/internal/test" |
| 13 | + "github.com/sashabaranov/go-openai/internal/test/checks" |
14 | 14 | )
|
15 | 15 |
|
16 | 16 | func TestCompletionsStreamWrongModel(t *testing.T) {
|
@@ -171,6 +171,52 @@ func TestCreateCompletionStreamError(t *testing.T) {
|
171 | 171 | t.Logf("%+v\n", apiErr)
|
172 | 172 | }
|
173 | 173 |
|
| 174 | +func TestCreateCompletionStreamRateLimitError(t *testing.T) { |
| 175 | + server := test.NewTestServer() |
| 176 | + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { |
| 177 | + w.Header().Set("Content-Type", "application/json") |
| 178 | + w.WriteHeader(429) |
| 179 | + |
| 180 | + // Send test responses |
| 181 | + dataBytes := []byte(`{"error":{` + |
| 182 | + `"message": "You are sending requests too quickly.",` + |
| 183 | + `"type":"rate_limit_reached",` + |
| 184 | + `"param":null,` + |
| 185 | + `"code":"rate_limit_reached"}}`) |
| 186 | + |
| 187 | + _, err := w.Write(dataBytes) |
| 188 | + checks.NoError(t, err, "Write error") |
| 189 | + }) |
| 190 | + ts := server.OpenAITestServer() |
| 191 | + ts.Start() |
| 192 | + defer ts.Close() |
| 193 | + |
| 194 | + // Client portion of the test |
| 195 | + config := DefaultConfig(test.GetTestToken()) |
| 196 | + config.BaseURL = ts.URL + "/v1" |
| 197 | + config.HTTPClient.Transport = &tokenRoundTripper{ |
| 198 | + test.GetTestToken(), |
| 199 | + http.DefaultTransport, |
| 200 | + } |
| 201 | + |
| 202 | + client := NewClientWithConfig(config) |
| 203 | + ctx := context.Background() |
| 204 | + |
| 205 | + request := CompletionRequest{ |
| 206 | + MaxTokens: 5, |
| 207 | + Model: GPT3Ada, |
| 208 | + Prompt: "Hello!", |
| 209 | + Stream: true, |
| 210 | + } |
| 211 | + |
| 212 | + var apiErr *APIError |
| 213 | + _, err := client.CreateCompletionStream(ctx, request) |
| 214 | + if !errors.As(err, &apiErr) { |
| 215 | + t.Errorf("TestCreateCompletionStreamRateLimitError did not return APIError") |
| 216 | + } |
| 217 | + t.Logf("%+v\n", apiErr) |
| 218 | +} |
| 219 | + |
174 | 220 | // A "tokenRoundTripper" is a struct that implements the RoundTripper
|
175 | 221 | // interface, specifically to handle the authentication token by adding a token
|
176 | 222 | // to the request header. We need this because the API requires that each
|
|
0 commit comments