Skip to content

Commit 6cfecdb

Browse files
Merge remote-tracking branch 'upstream/master'
2 parents 0d508a1 + c4273cb commit 6cfecdb

30 files changed

+1754
-147
lines changed

.codecov.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
coverage:
2+
ignore:
3+
- "examples/**"
4+
- "internal/test/**"

.github/workflows/pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
go-version: '1.24'
1717
- name: Run vet
1818
run: |
19-
go vet .
19+
go vet -stdversion ./...
2020
- name: Run golangci-lint
2121
uses: golangci/golangci-lint-action@v7
2222
with:

chat.go

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"encoding/json"
66
"errors"
77
"net/http"
8+
9+
"github.com/meguminnnnnnnnn/go-openai/jsonschema"
810
)
911

1012
// Chat message role defined by the OpenAI API.
@@ -234,13 +236,49 @@ type ChatCompletionResponseFormatJSONSchema struct {
234236
Strict bool `json:"strict"`
235237
}
236238

239+
func (r *ChatCompletionResponseFormatJSONSchema) UnmarshalJSON(data []byte) error {
240+
type rawJSONSchema struct {
241+
Name string `json:"name"`
242+
Description string `json:"description,omitempty"`
243+
Schema json.RawMessage `json:"schema"`
244+
Strict bool `json:"strict"`
245+
}
246+
var raw rawJSONSchema
247+
if err := json.Unmarshal(data, &raw); err != nil {
248+
return err
249+
}
250+
r.Name = raw.Name
251+
r.Description = raw.Description
252+
r.Strict = raw.Strict
253+
if len(raw.Schema) > 0 && string(raw.Schema) != "null" {
254+
var d jsonschema.Definition
255+
err := json.Unmarshal(raw.Schema, &d)
256+
if err != nil {
257+
return err
258+
}
259+
r.Schema = &d
260+
}
261+
return nil
262+
}
263+
264+
// ChatCompletionRequestExtensions contains third-party OpenAI API extensions
265+
// (e.g., vendor-specific implementations like vLLM).
266+
type ChatCompletionRequestExtensions struct {
267+
// GuidedChoice is a vLLM-specific extension that restricts the model's output
268+
// to one of the predefined string choices provided in this field. This feature
269+
// is used to constrain the model's responses to a controlled set of options,
270+
// ensuring predictable and consistent outputs in scenarios where specific
271+
// choices are required.
272+
GuidedChoice []string `json:"guided_choice,omitempty"`
273+
}
274+
237275
// ChatCompletionRequest represents a request structure for chat completion API.
238276
type ChatCompletionRequest struct {
239277
Model string `json:"model"`
240278
Messages []ChatCompletionMessage `json:"messages"`
241279
// MaxTokens The maximum number of tokens that can be generated in the chat completion.
242280
// This value can be used to control costs for text generated via API.
243-
// This value is now deprecated in favor of max_completion_tokens, and is not compatible with o1 series models.
281+
// Deprecated: use MaxCompletionTokens. Not compatible with o1-series models.
244282
// refs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens
245283
MaxTokens int `json:"max_tokens,omitempty"`
246284
// MaxCompletionTokens An upper bound for the number of tokens that can be generated for a completion,
@@ -286,7 +324,15 @@ type ChatCompletionRequest struct {
286324
ReasoningEffort string `json:"reasoning_effort,omitempty"`
287325
// Metadata to store with the completion.
288326
Metadata map[string]string `json:"metadata,omitempty"`
289-
327+
// Configuration for a predicted output.
328+
Prediction *Prediction `json:"prediction,omitempty"`
329+
// ChatTemplateKwargs provides a way to add non-standard parameters to the request body.
330+
// Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
331+
// Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false}
332+
// https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes
333+
ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"`
334+
// Specifies the latency tier to use for processing the request.
335+
ServiceTier ServiceTier `json:"service_tier,omitempty"`
290336
// Extra fields to be sent in the request.
291337
// Useful for experimental features not yet officially supported.
292338
extraFields map[string]any
@@ -386,6 +432,15 @@ const (
386432
FinishReasonNull FinishReason = "null"
387433
)
388434

435+
type ServiceTier string
436+
437+
const (
438+
ServiceTierAuto ServiceTier = "auto"
439+
ServiceTierDefault ServiceTier = "default"
440+
ServiceTierFlex ServiceTier = "flex"
441+
ServiceTierPriority ServiceTier = "priority"
442+
)
443+
389444
func (r FinishReason) MarshalJSON() ([]byte, error) {
390445
if r == FinishReasonNull || r == "" {
391446
return []byte("null"), nil
@@ -418,6 +473,7 @@ type ChatCompletionResponse struct {
418473
Usage Usage `json:"usage"`
419474
SystemFingerprint string `json:"system_fingerprint"`
420475
PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"`
476+
ServiceTier ServiceTier `json:"service_tier,omitempty"`
421477

422478
httpHeader
423479
}

chat_test.go

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -999,3 +999,142 @@ func TestChatCompletionRequestExtraFields(t *testing.T) {
999999
t.Errorf("Expected overridden value 'new_value', got %v", gotFields["custom_field"])
10001000
}
10011001
}
1002+
1003+
func TestChatCompletionResponseFormatJSONSchema_UnmarshalJSON(t *testing.T) {
1004+
type args struct {
1005+
data []byte
1006+
}
1007+
tests := []struct {
1008+
name string
1009+
args args
1010+
wantErr bool
1011+
}{
1012+
{
1013+
"",
1014+
args{
1015+
data: []byte(`{
1016+
"name": "math_response",
1017+
"strict": true,
1018+
"schema": {
1019+
"type": "object",
1020+
"properties": {
1021+
"steps": {
1022+
"type": "array",
1023+
"items": {
1024+
"type": "object",
1025+
"properties": {
1026+
"explanation": { "type": "string" },
1027+
"output": { "type": "string" }
1028+
},
1029+
"required": ["explanation","output"],
1030+
"additionalProperties": false
1031+
}
1032+
},
1033+
"final_answer": { "type": "string" }
1034+
},
1035+
"required": ["steps","final_answer"],
1036+
"additionalProperties": false
1037+
}
1038+
}`),
1039+
},
1040+
false,
1041+
},
1042+
{
1043+
"",
1044+
args{
1045+
data: []byte(`{
1046+
"name": "math_response",
1047+
"strict": true,
1048+
"schema": null
1049+
}`),
1050+
},
1051+
false,
1052+
},
1053+
{
1054+
"",
1055+
args{
1056+
data: []byte(`[123,456]`),
1057+
},
1058+
true,
1059+
},
1060+
{
1061+
"",
1062+
args{
1063+
data: []byte(`{
1064+
"name": "math_response",
1065+
"strict": true,
1066+
"schema": 123456
1067+
}`),
1068+
},
1069+
true,
1070+
},
1071+
}
1072+
for _, tt := range tests {
1073+
t.Run(tt.name, func(t *testing.T) {
1074+
var r openai.ChatCompletionResponseFormatJSONSchema
1075+
err := r.UnmarshalJSON(tt.args.data)
1076+
if (err != nil) != tt.wantErr {
1077+
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
1078+
}
1079+
})
1080+
}
1081+
}
1082+
1083+
func TestChatCompletionRequest_UnmarshalJSON(t *testing.T) {
1084+
type args struct {
1085+
bs []byte
1086+
}
1087+
tests := []struct {
1088+
name string
1089+
args args
1090+
wantErr bool
1091+
}{
1092+
{
1093+
"",
1094+
args{bs: []byte(`{
1095+
"model": "llama3-1b",
1096+
"messages": [
1097+
{ "role": "system", "content": "You are a helpful math tutor." },
1098+
{ "role": "user", "content": "solve 8x + 31 = 2" }
1099+
],
1100+
"response_format": {
1101+
"type": "json_schema",
1102+
"json_schema": {
1103+
"name": "math_response",
1104+
"strict": true,
1105+
"schema": {
1106+
"type": "object",
1107+
"properties": {
1108+
"steps": {
1109+
"type": "array",
1110+
"items": {
1111+
"type": "object",
1112+
"properties": {
1113+
"explanation": { "type": "string" },
1114+
"output": { "type": "string" }
1115+
},
1116+
"required": ["explanation","output"],
1117+
"additionalProperties": false
1118+
}
1119+
},
1120+
"final_answer": { "type": "string" }
1121+
},
1122+
"required": ["steps","final_answer"],
1123+
"additionalProperties": false
1124+
}
1125+
}
1126+
}
1127+
}`)},
1128+
false,
1129+
},
1130+
}
1131+
for _, tt := range tests {
1132+
t.Run(tt.name, func(t *testing.T) {
1133+
var m openai.ChatCompletionRequest
1134+
err := json.Unmarshal(tt.args.bs, &m)
1135+
if err != nil {
1136+
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
1137+
}
1138+
})
1139+
}
1140+
}

client.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,20 @@ func withBody(body any) requestOption {
8484
}
8585
}
8686

87+
func withExtraBody(extraBody map[string]any) requestOption {
88+
return func(args *requestOptions) {
89+
// Assert that args.body is a map[string]any.
90+
bodyMap, ok := args.body.(map[string]any)
91+
if ok {
92+
// If it's a map[string]any then only add extraBody
93+
// fields to args.body otherwise keep only fields in request struct.
94+
for key, value := range extraBody {
95+
bodyMap[key] = value
96+
}
97+
}
98+
}
99+
}
100+
87101
func withContentType(contentType string) requestOption {
88102
return func(args *requestOptions) {
89103
args.header.Set("Content-Type", contentType)

completion.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ type CompletionResponse struct {
242242
Created int64 `json:"created"`
243243
Model string `json:"model"`
244244
Choices []CompletionChoice `json:"choices"`
245-
Usage Usage `json:"usage"`
245+
Usage *Usage `json:"usage,omitempty"`
246246

247247
httpHeader
248248
}

completion_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
192192
}
193193
inputTokens *= n
194194
completionTokens := completionReq.MaxTokens * len(prompts) * n
195-
res.Usage = openai.Usage{
195+
res.Usage = &openai.Usage{
196196
PromptTokens: inputTokens,
197197
CompletionTokens: completionTokens,
198198
TotalTokens: inputTokens + completionTokens,

embeddings.go

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/base64"
66
"encoding/binary"
7+
"encoding/json"
78
"errors"
89
"math"
910
"net/http"
@@ -160,6 +161,9 @@ type EmbeddingRequest struct {
160161
// Dimensions The number of dimensions the resulting output embeddings should have.
161162
// Only supported in text-embedding-3 and later models.
162163
Dimensions int `json:"dimensions,omitempty"`
164+
// The ExtraBody field allows for the inclusion of arbitrary key-value pairs
165+
// in the request body that may not be explicitly defined in this struct.
166+
ExtraBody map[string]any `json:"extra_body,omitempty"`
163167
}
164168

165169
func (r EmbeddingRequest) Convert() EmbeddingRequest {
@@ -187,6 +191,9 @@ type EmbeddingRequestStrings struct {
187191
// Dimensions The number of dimensions the resulting output embeddings should have.
188192
// Only supported in text-embedding-3 and later models.
189193
Dimensions int `json:"dimensions,omitempty"`
194+
// The ExtraBody field allows for the inclusion of arbitrary key-value pairs
195+
// in the request body that may not be explicitly defined in this struct.
196+
ExtraBody map[string]any `json:"extra_body,omitempty"`
190197
}
191198

192199
func (r EmbeddingRequestStrings) Convert() EmbeddingRequest {
@@ -196,6 +203,7 @@ func (r EmbeddingRequestStrings) Convert() EmbeddingRequest {
196203
User: r.User,
197204
EncodingFormat: r.EncodingFormat,
198205
Dimensions: r.Dimensions,
206+
ExtraBody: r.ExtraBody,
199207
}
200208
}
201209

@@ -219,6 +227,9 @@ type EmbeddingRequestTokens struct {
219227
// Dimensions The number of dimensions the resulting output embeddings should have.
220228
// Only supported in text-embedding-3 and later models.
221229
Dimensions int `json:"dimensions,omitempty"`
230+
// The ExtraBody field allows for the inclusion of arbitrary key-value pairs
231+
// in the request body that may not be explicitly defined in this struct.
232+
ExtraBody map[string]any `json:"extra_body,omitempty"`
222233
}
223234

224235
func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
@@ -228,6 +239,7 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
228239
User: r.User,
229240
EncodingFormat: r.EncodingFormat,
230241
Dimensions: r.Dimensions,
242+
ExtraBody: r.ExtraBody,
231243
}
232244
}
233245

@@ -241,11 +253,29 @@ func (c *Client) CreateEmbeddings(
241253
conv EmbeddingRequestConverter,
242254
) (res EmbeddingResponse, err error) {
243255
baseReq := conv.Convert()
256+
257+
// The body map is used to dynamically construct the request payload for the embedding API.
258+
// Instead of relying on a fixed struct, the body map allows for flexible inclusion of fields
259+
// based on their presence, avoiding unnecessary or empty fields in the request.
260+
extraBody := baseReq.ExtraBody
261+
baseReq.ExtraBody = nil
262+
263+
// Serialize baseReq to JSON
264+
jsonData, err := json.Marshal(baseReq)
265+
if err != nil {
266+
return
267+
}
268+
269+
// Deserialize JSON to map[string]any
270+
var body map[string]any
271+
_ = json.Unmarshal(jsonData, &body)
272+
244273
req, err := c.newRequest(
245274
ctx,
246275
http.MethodPost,
247276
c.fullURL("/embeddings", withModel(string(baseReq.Model))),
248-
withBody(baseReq),
277+
withBody(body), // Main request body.
278+
withExtraBody(extraBody), // Merge ExtraBody fields.
249279
)
250280
if err != nil {
251281
return

0 commit comments

Comments
 (0)