Skip to content

Commit 0ec1bab

Browse files
szyhfmeguminnnnnnnnn
authored andcommitted
chat: 新增ExtraFields接口
1 parent b107e16 commit 0ec1bab

File tree

8 files changed

+211
-13
lines changed

8 files changed

+211
-13
lines changed

README.md

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-openai)](https://goreportcard.com/report/github.com/sashabaranov/go-openai)
44
[![codecov](https://codecov.io/gh/sashabaranov/go-openai/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-openai)
55

6-
This library provides unofficial Go clients for [OpenAI API](https://platform.openai.com/). We support:
6+
This library provides unofficial Go clients for [OpenAI API](https://platform.openai.com/). We support:
77

88
* ChatGPT 4o, o1
99
* GPT-3, GPT-4
@@ -720,7 +720,7 @@ if errors.As(err, &e) {
720720
case 401:
721721
// invalid auth or key (do not retry)
722722
case 429:
723-
// rate limiting or engine overload (wait and retry)
723+
// rate limiting or engine overload (wait and retry)
724724
case 500:
725725
// openai server error (retry)
726726
default:
@@ -867,6 +867,58 @@ func main() {
867867
}
868868
```
869869
</details>
870+
871+
<details>
872+
<summary>Using ExtraFields</summary>
873+
874+
```go
875+
package main
876+
877+
import (
878+
"context"
879+
"fmt"
880+
openai "github.com/sashabaranov/go-openai"
881+
)
882+
883+
func main() {
884+
client := openai.NewClient("your token")
885+
ctx := context.Background()
886+
887+
// Create chat request
888+
req := openai.ChatCompletionRequest{
889+
Model: openai.GPT3Dot5Turbo,
890+
Messages: []openai.ChatCompletionMessage{
891+
{
892+
Role: openai.ChatMessageRoleUser,
893+
Content: "Hello!",
894+
},
895+
},
896+
}
897+
898+
// Add custom fields
899+
extraFields := map[string]any{
900+
"custom_field": "test_value",
901+
"numeric_field": 42,
902+
"bool_field": true,
903+
}
904+
req.SetExtraFields(extraFields)
905+
906+
// Get custom fields
907+
gotFields := req.GetExtraFields()
908+
fmt.Printf("Extra fields: %v\n", gotFields)
909+
910+
// Send request
911+
resp, err := client.CreateChatCompletion(ctx, req)
912+
if err != nil {
913+
fmt.Printf("ChatCompletion error: %v\n", err)
914+
return
915+
}
916+
917+
fmt.Println(resp.Choices[0].Message.Content)
918+
}
919+
```
920+
</details>
921+
870922
See the `examples/` folder for more.
871923

872924
## Frequently Asked Questions
@@ -887,18 +939,18 @@ Due to the factors mentioned above, different answers may be returned even for t
887939

888940
By adopting these strategies, you can expect more consistent results.
889941

890-
**Related Issues:**
942+
**Related Issues:**
891943
[omitempty option of request struct will generate incorrect request when parameter is 0.](https://github.com/sashabaranov/go-openai/issues/9)
892944

893945
### Does Go OpenAI provide a method to count tokens?
894946

895947
No, Go OpenAI does not offer a feature to count tokens, and there are no plans to provide such a feature in the future. However, if there's a way to implement a token counting feature with zero dependencies, it might be possible to merge that feature into Go OpenAI. Otherwise, it would be more appropriate to implement it in a dedicated library or repository.
896948

897-
For counting tokens, you might find the following links helpful:
949+
For counting tokens, you might find the following links helpful:
898950
- [Counting Tokens For Chat API Calls](https://github.com/pkoukk/tiktoken-go#counting-tokens-for-chat-api-calls)
899951
- [How to count tokens with tiktoken](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb)
900952

901-
**Related Issues:**
953+
**Related Issues:**
902954
[Is it possible to join the implementation of GPT3 Tokenizer](https://github.com/sashabaranov/go-openai/issues/62)
903955

904956
## Contributing

chat.go

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -286,13 +286,23 @@ type ChatCompletionRequest struct {
286286
ReasoningEffort string `json:"reasoning_effort,omitempty"`
287287
// Metadata to store with the completion.
288288
Metadata map[string]string `json:"metadata,omitempty"`
289-
// Configuration for a predicted output.
290-
Prediction *Prediction `json:"prediction,omitempty"`
291-
// ChatTemplateKwargs provides a way to add non-standard parameters to the request body.
292-
// Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
293-
// Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false}
294-
// https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes
295-
ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"`
289+
290+
// Extra fields to be sent in the request.
291+
// Useful for experimental features not yet officially supported.
292+
extraFields map[string]any
293+
}
294+
295+
// SetExtraFields adds extra fields to the JSON object.
296+
//
297+
// SetExtraFields will override any existing fields with the same key.
298+
// For security reasons, ensure this is only used with trusted input data.
299+
func (r *ChatCompletionRequest) SetExtraFields(extraFields map[string]any) {
300+
r.extraFields = extraFields
301+
}
302+
303+
// GetExtraFields returns the extra fields set in the request.
304+
func (r *ChatCompletionRequest) GetExtraFields() map[string]any {
305+
return r.extraFields
296306
}
297307

298308
type StreamOptions struct {

chat_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -950,3 +950,52 @@ func TestFinishReason(t *testing.T) {
950950
}
951951
}
952952
}
953+
954+
func TestChatCompletionRequestExtraFields(t *testing.T) {
955+
req := openai.ChatCompletionRequest{
956+
Model: "gpt-4",
957+
}
958+
959+
// 测试设置额外字段
960+
extraFields := map[string]any{
961+
"custom_field": "test_value",
962+
"numeric_field": 42,
963+
"bool_field": true,
964+
}
965+
req.SetExtraFields(extraFields)
966+
967+
// 测试获取额外字段
968+
gotFields := req.GetExtraFields()
969+
970+
// 验证字段数量
971+
if len(gotFields) != len(extraFields) {
972+
t.Errorf("Expected %d extra fields, got %d", len(extraFields), len(gotFields))
973+
}
974+
975+
// 验证字段值
976+
for key, expectedValue := range extraFields {
977+
gotValue, exists := gotFields[key]
978+
if !exists {
979+
t.Errorf("Expected field %s not found", key)
980+
continue
981+
}
982+
if gotValue != expectedValue {
983+
t.Errorf("Field %s: expected %v, got %v", key, expectedValue, gotValue)
984+
}
985+
}
986+
987+
// 测试覆盖已存在的字段
988+
newFields := map[string]any{
989+
"custom_field": "new_value",
990+
}
991+
req.SetExtraFields(newFields)
992+
gotFields = req.GetExtraFields()
993+
994+
if len(gotFields) != len(newFields) {
995+
t.Errorf("Expected %d extra fields after override, got %d", len(newFields), len(gotFields))
996+
}
997+
998+
if gotFields["custom_field"] != "new_value" {
999+
t.Errorf("Expected overridden value 'new_value', got %v", gotFields["custom_field"])
1000+
}
1001+
}

go.mod

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
module github.com/meguminnnnnnnnn/go-openai
22

33
go 1.18
4+
5+
require github.com/evanphx/json-patch v0.5.2
6+
7+
require github.com/pkg/errors v0.9.1 // indirect

go.sum

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
github.com/evanphx/json-patch v0.5.2 h1:xVCHIVMUu1wtM/VkR9jVZ45N3FhZfYMMYGorLCR8P3k=
2+
github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ=
3+
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
4+
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
5+
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=

internal/marshaller.go

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ package openai
22

33
import (
44
"encoding/json"
5+
"fmt"
6+
7+
jsonpatch "github.com/evanphx/json-patch"
58
)
69

710
type Marshaller interface {
@@ -11,5 +14,30 @@ type Marshaller interface {
1114
type JSONMarshaller struct{}
1215

1316
func (jm *JSONMarshaller) Marshal(value any) ([]byte, error) {
14-
return json.Marshal(value)
17+
originalBytes, err := json.Marshal(value)
18+
if err != nil {
19+
return nil, err
20+
}
21+
// Check if the value implements the GetExtraFields interface
22+
getExtraFieldsBody, ok := value.(interface {
23+
GetExtraFields() map[string]any
24+
})
25+
if !ok {
26+
// If not, return the original bytes
27+
return originalBytes, nil
28+
}
29+
extraFields := getExtraFieldsBody.GetExtraFields()
30+
if len(extraFields) == 0 {
31+
// If there are no extra fields, return the original bytes
32+
return originalBytes, nil
33+
}
34+
patchBytes, err := json.Marshal(extraFields)
35+
if err != nil {
36+
return nil, fmt.Errorf("Marshal extraFields(%+v) err: %w", extraFields, err)
37+
}
38+
finalBytes, err := jsonpatch.MergePatch(originalBytes, patchBytes)
39+
if err != nil {
40+
return nil, fmt.Errorf("MergePatch originalBytes(%s) patchBytes(%s) err: %w", originalBytes, patchBytes, err)
41+
}
42+
return finalBytes, nil
1543
}

internal/request_builder.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ func (b *HTTPRequestBuilder) Build(
3838
if err != nil {
3939
return
4040
}
41+
4142
bodyReader = bytes.NewBuffer(reqBytes)
4243
}
4344
}

internal/request_builder_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package openai //nolint:testpackage // testing private field
33
import (
44
"bytes"
55
"context"
6+
"encoding/json"
67
"errors"
78
"net/http"
89
"reflect"
@@ -59,3 +60,51 @@ func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) {
5960
t.Errorf("Build() got = %v, want %v", got, want)
6061
}
6162
}
63+
64+
type testExtraFieldsRequest struct {
65+
Model string `json:"model"`
66+
extraFields map[string]any
67+
}
68+
69+
func (r *testExtraFieldsRequest) GetExtraFields() map[string]any {
70+
return r.extraFields
71+
}
72+
73+
func TestRequestBuilderReturnsRequestWhenRequestHasExtraFields(t *testing.T) {
74+
b := NewRequestBuilder()
75+
var (
76+
ctx = context.Background()
77+
method = http.MethodPost
78+
url = "/foo"
79+
request = &testExtraFieldsRequest{
80+
Model: "test-model",
81+
}
82+
)
83+
request.extraFields = map[string]any{"extra_field": "extra_value"}
84+
85+
reqBytes, err := b.marshaller.Marshal(request)
86+
if err != nil {
87+
t.Fatalf("Marshal failed: %v", err)
88+
}
89+
90+
// 验证序列化结果包含原始字段和额外字段
91+
var result map[string]interface{}
92+
if err := json.Unmarshal(reqBytes, &result); err != nil {
93+
t.Fatalf("Unmarshal failed: %v", err)
94+
}
95+
96+
if result["model"] != "test-model" {
97+
t.Errorf("Expected model to be 'test-model', got %v", result["model"])
98+
}
99+
if result["extra_field"] != "extra_value" {
100+
t.Errorf("Expected extra_field to be 'extra_value', got %v", result["extra_field"])
101+
}
102+
103+
want, _ := http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(reqBytes))
104+
got, _ := b.Build(ctx, method, url, request, nil)
105+
if !reflect.DeepEqual(got.Body, want.Body) ||
106+
!reflect.DeepEqual(got.URL, want.URL) ||
107+
!reflect.DeepEqual(got.Method, want.Method) {
108+
t.Errorf("Build() got = %v, want %v", got, want)
109+
}
110+
}

0 commit comments

Comments
 (0)