Skip to content

Commit 3bce976

Browse files
authored
Merge pull request #7 from meguminnnnnnnnn/feat/option
feat: add WithExtraHeader option
2 parents 285a738 + 3ed488d commit 3bce976

File tree

5 files changed

+29
-12
lines changed

5 files changed

+29
-12
lines changed

chat.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -510,9 +510,9 @@ func (c *Client) CreateChatCompletion(
510510
}
511511

512512
body := any(request)
513-
if ccOpts.RequestBodySetter != nil {
513+
if ccOpts.RequestBodyModifier != nil {
514514
var newBody io.Reader
515-
newBody, err = c.getNewRequestBody(request, ccOpts.RequestBodySetter)
515+
newBody, err = c.getNewRequestBody(request, ccOpts.RequestBodyModifier)
516516
if err != nil {
517517
return response, err
518518
}
@@ -524,6 +524,7 @@ func (c *Client) CreateChatCompletion(
524524
http.MethodPost,
525525
c.fullURL(urlSuffix, withModel(request.Model)),
526526
withBody(body),
527+
withExtraHeader(ccOpts.ExtraHeader),
527528
)
528529
if err != nil {
529530
return
@@ -533,15 +534,15 @@ func (c *Client) CreateChatCompletion(
533534
return
534535
}
535536

536-
func (c *Client) getNewRequestBody(request ChatCompletionRequest, setter RequestBodySetter) (io.Reader, error) {
537+
func (c *Client) getNewRequestBody(request ChatCompletionRequest, modifier RequestBodyModifier) (io.Reader, error) {
537538
marshaller := openai.JSONMarshaller{}
538539

539540
body, err := marshaller.Marshal(request)
540541
if err != nil {
541542
return nil, err
542543
}
543544

544-
newBody, err := setter(body)
545+
newBody, err := modifier(body)
545546
if err != nil {
546547
return nil, err
547548
}

chat_stream.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ func (c *Client) CreateChatCompletionStream(
9999
}
100100

101101
body := any(request)
102-
if ccOpts.RequestBodySetter != nil {
102+
if ccOpts.RequestBodyModifier != nil {
103103
var newBody io.Reader
104-
newBody, err = c.getNewRequestBody(request, ccOpts.RequestBodySetter)
104+
newBody, err = c.getNewRequestBody(request, ccOpts.RequestBodyModifier)
105105
if err != nil {
106106
return stream, err
107107
}
@@ -113,6 +113,7 @@ func (c *Client) CreateChatCompletionStream(
113113
http.MethodPost,
114114
c.fullURL(urlSuffix, withModel(request.Model)),
115115
withBody(body),
116+
withExtraHeader(ccOpts.ExtraHeader),
116117
)
117118
if err != nil {
118119
return nil, err

chat_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,12 @@ func TestChatCompletionsWrongModel(t *testing.T) {
5252
checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg)
5353
}
5454

55-
func TestChatCompletionRequestWithRequestBodySetter(t *testing.T) {
55+
func TestChatCompletionRequestWithRequestBodyModifier(t *testing.T) {
5656
client, server, teardown := setupOpenAITestServer()
5757
defer teardown()
5858
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
5959

60-
opt := openai.WithRequestBodySetter(func(rawBody []byte) ([]byte, error) {
60+
opt := openai.WithRequestBodyModifier(func(rawBody []byte) ([]byte, error) {
6161
return rawBody, nil
6262
})
6363

client.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,14 @@ func withExtraBody(extraBody map[string]any) requestOption {
9898
}
9999
}
100100

101+
func withExtraHeader(header map[string]string) requestOption {
102+
return func(args *requestOptions) {
103+
for k, v := range header {
104+
args.header.Set(k, v)
105+
}
106+
}
107+
}
108+
101109
func withContentType(contentType string) requestOption {
102110
return func(args *requestOptions) {
103111
args.header.Set("Content-Type", contentType)

option.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
package openai
22

33
type chatCompletionRequestOptions struct {
4-
RequestBodySetter RequestBodySetter
4+
RequestBodyModifier RequestBodyModifier
5+
ExtraHeader map[string]string
56
}
67

78
type ChatCompletionRequestOption func(*chatCompletionRequestOptions)
89

9-
type RequestBodySetter func(rawBody []byte) ([]byte, error)
10+
type RequestBodyModifier func(rawBody []byte) ([]byte, error)
1011

11-
func WithRequestBodySetter(setter RequestBodySetter) ChatCompletionRequestOption {
12+
func WithRequestBodyModifier(modifier RequestBodyModifier) ChatCompletionRequestOption {
1213
return func(opts *chatCompletionRequestOptions) {
13-
opts.RequestBodySetter = setter
14+
opts.RequestBodyModifier = modifier
15+
}
16+
}
17+
18+
func WithExtraHeader(header map[string]string) ChatCompletionRequestOption {
19+
return func(opts *chatCompletionRequestOptions) {
20+
opts.ExtraHeader = header
1421
}
1522
}

0 commit comments

Comments
 (0)