Skip to content

Commit 5178926

Browse files
authored
Merge pull request #5 from meguminnnnnnnnn/feat/option
feat: add WithRequestBodySetter Option
2 parents 5668862 + 4a7937d commit 5178926

File tree

3 files changed

+74
-1
lines changed

3 files changed

+74
-1
lines changed

chat.go

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
package openai
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/json"
67
"errors"
8+
"io"
79
"net/http"
810

11+
openai "github.com/meguminnnnnnnnn/go-openai/internal"
12+
913
"github.com/meguminnnnnnnnn/go-openai/jsonschema"
1014
)
1115

@@ -482,6 +486,7 @@ type ChatCompletionResponse struct {
482486
func (c *Client) CreateChatCompletion(
483487
ctx context.Context,
484488
request ChatCompletionRequest,
489+
opts ...ChatCompletionRequestOption,
485490
) (response ChatCompletionResponse, err error) {
486491
if request.Stream {
487492
err = ErrChatCompletionStreamNotSupported
@@ -499,11 +504,26 @@ func (c *Client) CreateChatCompletion(
499504
return
500505
}
501506

507+
ccOpts := &chatCompletionRequestOptions{}
508+
for _, opt := range opts {
509+
opt(ccOpts)
510+
}
511+
512+
body := any(request)
513+
if ccOpts.RequestBodySetter != nil {
514+
var newBody io.Reader
515+
newBody, err = c.getNewRequestBody(request, ccOpts.RequestBodySetter)
516+
if err != nil {
517+
return response, err
518+
}
519+
body = newBody
520+
}
521+
502522
req, err := c.newRequest(
503523
ctx,
504524
http.MethodPost,
505525
c.fullURL(urlSuffix, withModel(request.Model)),
506-
withBody(request),
526+
withBody(body),
507527
)
508528
if err != nil {
509529
return
@@ -512,3 +532,19 @@ func (c *Client) CreateChatCompletion(
512532
err = c.sendRequest(req, &response)
513533
return
514534
}
535+
536+
func (c *Client) getNewRequestBody(request ChatCompletionRequest, setter RequestBodySetter) (io.Reader, error) {
537+
marshaller := openai.JSONMarshaller{}
538+
539+
body, err := marshaller.Marshal(request)
540+
if err != nil {
541+
return nil, err
542+
}
543+
544+
newBody, err := setter(body)
545+
if err != nil {
546+
return nil, err
547+
}
548+
549+
return bytes.NewBuffer(newBody), nil
550+
}

chat_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,28 @@ func TestChatCompletionsWrongModel(t *testing.T) {
5252
checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg)
5353
}
5454

55+
func TestChatCompletionRequestWithRequestBodySetter(t *testing.T) {
56+
client, server, teardown := setupOpenAITestServer()
57+
defer teardown()
58+
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
59+
60+
opt := openai.WithRequestBodySetter(func(rawBody []byte) ([]byte, error) {
61+
return rawBody, nil
62+
})
63+
64+
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
65+
Model: openai.O1Preview,
66+
MaxCompletionTokens: 1000,
67+
Messages: []openai.ChatCompletionMessage{
68+
{
69+
Role: openai.ChatMessageRoleUser,
70+
Content: "Hello!",
71+
},
72+
},
73+
}, opt)
74+
checks.NoError(t, err)
75+
}
76+
5577
func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) {
5678
tests := []struct {
5779
name string

option.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package openai
2+
3+
type chatCompletionRequestOptions struct {
4+
RequestBodySetter RequestBodySetter
5+
}
6+
7+
type ChatCompletionRequestOption func(*chatCompletionRequestOptions)
8+
9+
type RequestBodySetter func(rawBody []byte) ([]byte, error)
10+
11+
func WithRequestBodySetter(setter RequestBodySetter) ChatCompletionRequestOption {
12+
return func(opts *chatCompletionRequestOptions) {
13+
opts.RequestBodySetter = setter
14+
}
15+
}

0 commit comments

Comments
 (0)