Skip to content

Commit 277a4ab

Browse files
mrh997mrh
authored andcommitted
feat: support usage extra unmarshaler
1 parent 3bce976 commit 277a4ab

File tree

6 files changed

+195
-7
lines changed

6 files changed

+195
-7
lines changed

common.go

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,47 @@
11
package openai
22

3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"reflect"
7+
8+
openai "github.com/meguminnnnnnnnn/go-openai/internal"
9+
)
10+
311
// common.go defines common types used throughout the OpenAI API.
412

513
// Usage Represents the total token usage per request to OpenAI.
614
type Usage struct {
7-
PromptTokens int `json:"prompt_tokens"`
8-
CompletionTokens int `json:"completion_tokens"`
9-
TotalTokens int `json:"total_tokens"`
10-
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details"`
11-
CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details"`
15+
PromptTokens int `json:"prompt_tokens"`
16+
CompletionTokens int `json:"completion_tokens"`
17+
TotalTokens int `json:"total_tokens"`
18+
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details"`
19+
CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details"`
20+
ExtraFields map[string]json.RawMessage `json:"-"`
21+
}
22+
23+
func (u *Usage) UnmarshalJSON(data []byte) error {
24+
if u == nil {
25+
return fmt.Errorf("usage is nil")
26+
}
27+
28+
type Alias Usage
29+
alias := &Alias{}
30+
err := json.Unmarshal(data, alias)
31+
if err != nil {
32+
return err
33+
}
34+
35+
*u = Usage(*alias)
36+
37+
extra, err := openai.UnmarshalExtraFields(reflect.TypeOf(u), data)
38+
if err != nil {
39+
return err
40+
}
41+
42+
u.ExtraFields = extra
43+
44+
return nil
1245
}
1346

1447
// CompletionTokensDetails Breakdown of tokens used in a completion.

common_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package openai_test
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
7+
"github.com/meguminnnnnnnnn/go-openai"
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func TestUsageUnmarshalJSON(t *testing.T) {
12+
data := []byte(`{
13+
"prompt_tokens": 10,
14+
"completion_tokens": 20,
15+
"total_tokens": 30,
16+
"prompt_tokens_details": {
17+
"cached_tokens": 15
18+
},
19+
"completion_tokens_details": {
20+
"audio_tokens": 10
21+
},
22+
"extra_field": "extra_value"
23+
}`)
24+
25+
usage := &openai.Usage{}
26+
err := json.Unmarshal(data, usage)
27+
assert.NoError(t, err)
28+
assert.Equal(t, 10, usage.PromptTokens)
29+
assert.Equal(t, 20, usage.CompletionTokens)
30+
assert.Equal(t, 30, usage.TotalTokens)
31+
assert.NotNil(t, usage.PromptTokensDetails)
32+
assert.Equal(t, 15, usage.PromptTokensDetails.CachedTokens)
33+
assert.NotNil(t, usage.CompletionTokensDetails)
34+
assert.Equal(t, 10, usage.CompletionTokensDetails.AudioTokens)
35+
assert.Len(t, usage.ExtraFields, 1)
36+
37+
var extraValue string
38+
err = json.Unmarshal(usage.ExtraFields["extra_field"], &extraValue)
39+
assert.NoError(t, err)
40+
assert.Equal(t, "extra_value", extraValue)
41+
}

go.mod

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,20 @@ module github.com/meguminnnnnnnnn/go-openai
22

33
go 1.18
44

5-
require github.com/evanphx/json-patch v0.5.2
5+
require (
6+
github.com/bytedance/sonic v1.14.0
7+
github.com/evanphx/json-patch v0.5.2
8+
github.com/stretchr/testify v1.10.0
9+
)
610

7-
require github.com/pkg/errors v0.9.1 // indirect
11+
require (
12+
github.com/bytedance/sonic/loader v0.3.0 // indirect
13+
github.com/cloudwego/base64x v0.1.5 // indirect
14+
github.com/davecgh/go-spew v1.1.1 // indirect
15+
github.com/klauspost/cpuid/v2 v2.0.9 // indirect
16+
github.com/pkg/errors v0.9.1 // indirect
17+
github.com/pmezard/go-difflib v1.0.0 // indirect
18+
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
19+
golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect
20+
gopkg.in/yaml.v3 v3.0.1 // indirect
21+
)

go.sum

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,40 @@
1+
github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ=
2+
github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA=
3+
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
4+
github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA=
5+
github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI=
6+
github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4=
7+
github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
8+
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
9+
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
10+
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
11+
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
112
github.com/evanphx/json-patch v0.5.2 h1:xVCHIVMUu1wtM/VkR9jVZ45N3FhZfYMMYGorLCR8P3k=
213
github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ=
314
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
15+
github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4=
16+
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
17+
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
418
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
519
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
20+
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
21+
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
22+
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
23+
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
24+
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
25+
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
26+
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
27+
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
28+
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
29+
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
30+
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
31+
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
32+
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
33+
golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU=
34+
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
35+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
36+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
37+
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
38+
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
39+
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
40+
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=

internal/unmarshaler.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ package openai
22

33
import (
44
"encoding/json"
5+
"fmt"
6+
"reflect"
7+
8+
"github.com/bytedance/sonic"
59
)
610

711
type Unmarshaler interface {
@@ -13,3 +17,39 @@ type JSONUnmarshaler struct{}
1317
func (jm *JSONUnmarshaler) Unmarshal(data []byte, v any) error {
1418
return json.Unmarshal(data, v)
1519
}
20+
21+
func UnmarshalExtraFields(typ reflect.Type, data []byte) (map[string]json.RawMessage, error) {
22+
m := make(map[string]json.RawMessage)
23+
if err := sonic.Unmarshal(data, &m); err != nil {
24+
return nil, err
25+
}
26+
27+
for typ.Kind() == reflect.Ptr {
28+
typ = typ.Elem()
29+
}
30+
31+
if typ.Kind() != reflect.Struct {
32+
return nil, fmt.Errorf("type is not a struct")
33+
}
34+
35+
for i := 0; i < typ.NumField(); i++ {
36+
field := typ.Field(i)
37+
38+
jsonTag := field.Tag.Get("json")
39+
if jsonTag != "" {
40+
delete(m, jsonTag)
41+
} else {
42+
if !field.IsExported() {
43+
continue
44+
}
45+
delete(m, field.Name)
46+
}
47+
}
48+
49+
extra := make(map[string]json.RawMessage, len(m))
50+
for k, v := range m {
51+
extra[k] = v
52+
}
53+
54+
return extra, nil
55+
}

internal/unmarshaler_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package openai_test
22

33
import (
4+
"encoding/json"
5+
"reflect"
46
"testing"
57

68
openai "github.com/meguminnnnnnnnn/go-openai/internal"
79
"github.com/meguminnnnnnnnn/go-openai/internal/test/checks"
10+
"github.com/stretchr/testify/assert"
811
)
912

1013
func TestJSONUnmarshaler_Normal(t *testing.T) {
@@ -35,3 +38,25 @@ func TestJSONUnmarshaler_EmptyInput(t *testing.T) {
3538
err := jm.Unmarshal(nil, &v)
3639
checks.HasError(t, err, "should return error for nil input")
3740
}
41+
42+
func TestUnmarshalExtraFields(t *testing.T) {
43+
type TestStruct struct {
44+
Field1 string `json:"field1"`
45+
Field2 int
46+
Field3 struct {
47+
Field4 string `json:"field4"`
48+
} `json:"field3"`
49+
}
50+
51+
testData := []byte(`{"field1":"value1","Field2":2,"field3":{"field4":"value4"},"extraField1":"extraValue1"}`)
52+
testStruct := &TestStruct{}
53+
extra, err := openai.UnmarshalExtraFields(reflect.TypeOf(testStruct), testData)
54+
assert.NoError(t, err)
55+
assert.Len(t, extra, 1)
56+
57+
var extraValue1 string
58+
err = json.Unmarshal(extra["extraField1"], &extraValue1)
59+
assert.NoError(t, err)
60+
61+
assert.Equal(t, "extraValue1", extraValue1)
62+
}

0 commit comments

Comments
 (0)