Skip to content

Commit a24581d

Browse files
authored
maintain underlying error structs to allow for type conversion (#293)
* maintain underlying error structs to allow for type conversion and defensive error checking * allow Error.Is for Azure responses * update readme, add tests to ensure type conversion * fix whitespacing * read me * add import to readme example
1 parent 24aa200 commit a24581d

File tree

4 files changed

+46
-10
lines changed

4 files changed

+46
-10
lines changed

README.md

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ This library provides Go clients for [OpenAI API](https://platform.openai.com/).
1010
* DALL·E 2
1111
* Whisper
1212

13-
Installation:
13+
### Installation:
1414
```
1515
go get github.com/sashabaranov/go-openai
1616
```
1717

1818

19-
ChatGPT example usage:
19+
### ChatGPT example usage:
2020

2121
```go
2222
package main
@@ -52,9 +52,7 @@ func main() {
5252

5353
```
5454

55-
56-
57-
Other examples:
55+
### Other examples:
5856

5957
<details>
6058
<summary>ChatGPT streaming completion</summary>
@@ -462,3 +460,29 @@ func main() {
462460
}
463461
```
464462
</details>
463+
464+
<details>
465+
<summary>Error handling</summary>
466+
467+
Open-AI maintains clear documentation on how to [handle API errors](https://platform.openai.com/docs/guides/error-codes/api-errors)
468+
469+
example:
470+
```
471+
e := &openai.APIError{}
472+
if errors.As(err, &e) {
473+
switch e.HTTPStatusCode {
474+
case 401:
475+
// invalid auth or key (do not retry)
476+
case 429:
477+
// rate limiting or engine overload (wait and retry)
478+
case 500:
479+
// openai server error (retry)
480+
default:
481+
// unhandled
482+
}
483+
}
484+
485+
```
486+
</details>
487+
488+

client.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,16 @@ func (c *Client) handleErrorResp(resp *http.Response) error {
149149
var errRes ErrorResponse
150150
err := json.NewDecoder(resp.Body).Decode(&errRes)
151151
if err != nil || errRes.Error == nil {
152-
reqErr := RequestError{
152+
reqErr := &RequestError{
153153
HTTPStatusCode: resp.StatusCode,
154154
Err: err,
155155
}
156156
if errRes.Error != nil {
157157
reqErr.Err = errRes.Error
158158
}
159-
return fmt.Errorf("error, %w", &reqErr)
159+
return reqErr
160160
}
161+
161162
errRes.Error.HTTPStatusCode = resp.StatusCode
162-
return fmt.Errorf("error, status code: %d, message: %w", resp.StatusCode, errRes.Error)
163+
return errRes.Error
163164
}

client_test.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package openai //nolint:testpackage // testing private field
22

33
import (
44
"bytes"
5+
"errors"
56
"fmt"
67
"io"
78
"net/http"
@@ -106,7 +107,7 @@ func TestHandleErrorResp(t *testing.T) {
106107
}
107108
}`,
108109
)),
109-
expected: "error, status code 401, message: Access denied due to Virtual Network/Firewall rules.",
110+
expected: "error, status code: 401, message: Access denied due to Virtual Network/Firewall rules.",
110111
},
111112
{
112113
name: "503 Model Overloaded",
@@ -135,6 +136,12 @@ func TestHandleErrorResp(t *testing.T) {
135136
t.Errorf("Unexpected error: %v , expected: %s", err, tc.expected)
136137
t.Fail()
137138
}
139+
140+
e := &APIError{}
141+
if !errors.As(err, &e) {
142+
t.Errorf("(%s) Expected error to be of type APIError", tc.name)
143+
t.Fail()
144+
}
138145
})
139146
}
140147
}

error.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ type ErrorResponse struct {
2525
}
2626

2727
func (e *APIError) Error() string {
28+
if e.HTTPStatusCode > 0 {
29+
return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Message)
30+
}
31+
2832
return e.Message
2933
}
3034

@@ -70,7 +74,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) {
7074
}
7175

7276
func (e *RequestError) Error() string {
73-
return fmt.Sprintf("status code %d, message: %s", e.HTTPStatusCode, e.Err)
77+
return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Err)
7478
}
7579

7680
func (e *RequestError) Unwrap() error {

0 commit comments

Comments
 (0)