Skip to content

Commit c3b2451

Browse files
authored
fix: invalid schema for function 'func_name': None is not of type 'object' (#429)(#432) (#434)
* fix: invalid schema for function 'func_name': None is not of type 'object' (#429)(#432) * test: add integration test for function call (#429)(#432) * style: remove duplicate import (#429)(#432)
1 parent f028c28 commit c3b2451

File tree

3 files changed

+63
-30
lines changed

3 files changed

+63
-30
lines changed

api_integration_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
. "github.com/sashabaranov/go-openai"
1313
"github.com/sashabaranov/go-openai/internal/test/checks"
14+
"github.com/sashabaranov/go-openai/jsonschema"
1415
)
1516

1617
func TestAPI(t *testing.T) {
@@ -100,6 +101,37 @@ func TestAPI(t *testing.T) {
100101
if counter == 0 {
101102
t.Error("Stream did not return any responses")
102103
}
104+
105+
_, err = c.CreateChatCompletion(
106+
context.Background(),
107+
ChatCompletionRequest{
108+
Model: GPT3Dot5Turbo,
109+
Messages: []ChatCompletionMessage{
110+
{
111+
Role: ChatMessageRoleUser,
112+
Content: "What is the weather like in Boston?",
113+
},
114+
},
115+
Functions: []FunctionDefinition{{
116+
Name: "get_current_weather",
117+
Parameters: jsonschema.Definition{
118+
Type: jsonschema.Object,
119+
Properties: map[string]jsonschema.Definition{
120+
"location": {
121+
Type: jsonschema.String,
122+
Description: "The city and state, e.g. San Francisco, CA",
123+
},
124+
"unit": {
125+
Type: jsonschema.String,
126+
Enum: []string{"celsius", "fahrenheit"},
127+
},
128+
},
129+
Required: []string{"location"},
130+
},
131+
}},
132+
},
133+
)
134+
checks.NoError(t, err, "CreateChatCompletion (with functions) returned error")
103135
}
104136

105137
func TestAPIError(t *testing.T) {

jsonschema/json.go

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,14 @@ type Definition struct {
3636
Items *Definition `json:"items,omitempty"`
3737
}
3838

39-
func (d *Definition) MarshalJSON() ([]byte, error) {
40-
d.initializeProperties()
41-
return json.Marshal(*d)
42-
}
43-
44-
func (d *Definition) initializeProperties() {
39+
func (d Definition) MarshalJSON() ([]byte, error) {
4540
if d.Properties == nil {
4641
d.Properties = make(map[string]Definition)
47-
return
48-
}
49-
50-
for k, v := range d.Properties {
51-
if v.Properties == nil {
52-
v.Properties = make(map[string]Definition)
53-
} else {
54-
v.initializeProperties()
55-
}
56-
d.Properties[k] = v
5742
}
43+
type Alias Definition
44+
return json.Marshal(struct {
45+
Alias
46+
}{
47+
Alias: (Alias)(d),
48+
})
5849
}

jsonschema/json_test.go

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -172,30 +172,40 @@ func TestDefinition_MarshalJSON(t *testing.T) {
172172

173173
for _, tt := range tests {
174174
t.Run(tt.name, func(t *testing.T) {
175-
gotBytes, err := json.Marshal(&tt.def)
176-
if err != nil {
177-
t.Errorf("Failed to Marshal JSON: error = %v", err)
178-
return
179-
}
180-
181-
var got map[string]interface{}
182-
err = json.Unmarshal(gotBytes, &got)
183-
if err != nil {
184-
t.Errorf("Failed to Unmarshal JSON: error = %v", err)
185-
return
186-
}
187-
188175
wantBytes := []byte(tt.want)
189176
var want map[string]interface{}
190-
err = json.Unmarshal(wantBytes, &want)
177+
err := json.Unmarshal(wantBytes, &want)
191178
if err != nil {
192179
t.Errorf("Failed to Unmarshal JSON: error = %v", err)
193180
return
194181
}
195182

183+
got := structToMap(t, tt.def)
184+
gotPtr := structToMap(t, &tt.def)
185+
196186
if !reflect.DeepEqual(got, want) {
197187
t.Errorf("MarshalJSON() got = %v, want %v", got, want)
198188
}
189+
if !reflect.DeepEqual(gotPtr, want) {
190+
t.Errorf("MarshalJSON() gotPtr = %v, want %v", gotPtr, want)
191+
}
199192
})
200193
}
201194
}
195+
196+
func structToMap(t *testing.T, v any) map[string]any {
197+
t.Helper()
198+
gotBytes, err := json.Marshal(v)
199+
if err != nil {
200+
t.Errorf("Failed to Marshal JSON: error = %v", err)
201+
return nil
202+
}
203+
204+
var got map[string]interface{}
205+
err = json.Unmarshal(gotBytes, &got)
206+
if err != nil {
207+
t.Errorf("Failed to Unmarshal JSON: error = %v", err)
208+
return nil
209+
}
210+
return got
211+
}

0 commit comments

Comments
 (0)