Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions model/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,10 @@ func (m *Model) convertTools(tools map[string]tool.Tool) []openai.ChatCompletion
log.Errorf("failed to unmarshal tool schema for %s: %v", declaration.Name, err)
continue
}

// Ensure the schema is compatible with OpenAI's strict requirements:
// When type is "object", the "properties" field must be present (even if empty).
ensureObjectHasProperties(parameters)
result = append(result, openai.ChatCompletionToolParam{
Function: openai.FunctionDefinitionParam{
Name: declaration.Name,
Expand All @@ -960,6 +964,37 @@ func (m *Model) convertTools(tools map[string]tool.Tool) []openai.ChatCompletion
return result
}

// ensureObjectHasProperties recursively ensures all object-type schemas have a "properties" field.
// It directly modifies the schema map in-place for optimal performance.
func ensureObjectHasProperties(schema map[string]any) {
// Check if this is an object type schema without properties.
if typeVal, ok := schema["type"].(string); ok && typeVal == "object" {
if _, hasProps := schema["properties"]; !hasProps {
// Add empty properties object.
schema["properties"] = make(map[string]any)
}
}

// Recursively process nested properties.
if props, ok := schema["properties"].(map[string]any); ok {
for _, propSchema := range props {
if propMap, ok := propSchema.(map[string]any); ok {
ensureObjectHasProperties(propMap)
}
}
}

// Process array items.
if items, ok := schema["items"].(map[string]any); ok {
ensureObjectHasProperties(items)
}

// Process additionalProperties if it's a schema object.
if additionalProps, ok := schema["additionalProperties"].(map[string]any); ok {
ensureObjectHasProperties(additionalProps)
}
}

// handleStreamingResponse handles streaming chat completion responses.
func (m *Model) handleStreamingResponse(
ctx context.Context,
Expand Down
144 changes: 144 additions & 0 deletions model/openai/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package openai

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -307,6 +308,149 @@ func TestModel_convertTools(t *testing.T) {
require.False(t, reflect.ValueOf(fn.Parameters).IsZero(), "expected parameters to be populated from schema")
}

// TestModel_convertTools_ObjectSchemaWithoutProperties tests that tools with
// object type schemas but missing properties field are properly handled.
// This is required for compatibility with OpenAI o3 and other models that
// enforce strict schema validation.
func TestModel_convertTools_ObjectSchemaWithoutProperties(t *testing.T) {
m := New("dummy")

tests := []struct {
name string
inputSchema *tool.Schema
expectProps bool // whether properties should exist in the result.
}{
{
name: "object without properties",
inputSchema: &tool.Schema{Type: "object"},
expectProps: true,
},
{
name: "object with properties",
inputSchema: &tool.Schema{
Type: "object",
Properties: map[string]*tool.Schema{
"arg1": {Type: "string"},
},
},
expectProps: true,
},
{
name: "nested object without properties",
inputSchema: &tool.Schema{
Type: "object",
Properties: map[string]*tool.Schema{
"nested": {Type: "object"},
},
},
expectProps: true,
},
{
name: "non-object schema",
inputSchema: &tool.Schema{Type: "string"},
expectProps: false, // string types don't have properties.
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
toolsMap := map[string]tool.Tool{
"test": stubTool{decl: &tool.Declaration{
Name: "test",
Description: "test",
InputSchema: tt.inputSchema,
}},
}

params := m.convertTools(toolsMap)
require.Len(t, params, 1)

// Marshal the parameters to check the JSON structure.
paramBytes, err := json.Marshal(params[0].Function.Parameters)
require.NoError(t, err, "failed to marshal parameters")

var paramMap map[string]any
err = json.Unmarshal(paramBytes, &paramMap)
require.NoError(t, err, "failed to unmarshal parameters")

if tt.expectProps {
// For object types, properties must exist.
if typeVal, ok := paramMap["type"].(string); ok && typeVal == "object" {
_, hasProps := paramMap["properties"]
assert.True(t, hasProps, "object type schema must have properties field")

// If there are nested objects, check them too.
if props, ok := paramMap["properties"].(map[string]any); ok {
for propName, propVal := range props {
if propMap, ok := propVal.(map[string]any); ok {
if propType, ok := propMap["type"].(string); ok && propType == "object" {
_, nestedHasProps := propMap["properties"]
assert.True(t, nestedHasProps, "nested object %s must have properties field", propName)
}
}
}
}
}
}
})
}
}

// TestEnsureObjectHasProperties tests the helper function that ensures
// object schemas have properties fields.
func TestEnsureObjectHasProperties(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "add missing properties",
input: `{"type":"object"}`,
expected: `{"type":"object","properties":{}}`,
},
{
name: "preserve existing properties",
input: `{"type":"object","properties":{"foo":{"type":"string"}}}`,
expected: `{"type":"object","properties":{"foo":{"type":"string"}}}`,
},
{
name: "fix nested objects",
input: `{"type":"object","properties":{"nested":{"type":"object"}}}`,
expected: `{"type":"object","properties":{"nested":{"type":"object","properties":{}}}}`,
},
{
name: "non-object types unchanged",
input: `{"type":"string"}`,
expected: `{"type":"string"}`,
},
{
name: "array with object items",
input: `{"type":"array","items":{"type":"object"}}`,
expected: `{"type":"array","items":{"type":"object","properties":{}}}`,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Parse input
var schemaMap map[string]any
err := json.Unmarshal([]byte(tt.input), &schemaMap)
require.NoError(t, err, "failed to unmarshal input")

// Apply fix
ensureObjectHasProperties(schemaMap)

// Parse expected
var expectedMap map[string]any
err = json.Unmarshal([]byte(tt.expected), &expectedMap)
require.NoError(t, err, "failed to unmarshal expected")

assert.Equal(t, expectedMap, schemaMap, "schema mismatch")
})
}
}

// TestModel_Callbacks tests that callback functions are properly called with
// the correct parameters including the request parameter.
func TestModel_Callbacks(t *testing.T) {
Expand Down
18 changes: 18 additions & 0 deletions tool/mcp/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,17 @@ func convertMCPSchemaToSchema(mcpSchema any) *tool.Schema {
}

schema := &tool.Schema{}
// Handle type field - can be string or array of strings (for nullable/union types).
if typeVal, ok := schemaMap["type"].(string); ok {
schema.Type = typeVal
} else if typeArr, ok := schemaMap["type"].([]any); ok {
// For nullable types like ["integer", "null"], extract the non-null type.
for _, t := range typeArr {
if typeStr, ok := t.(string); ok && typeStr != "null" {
schema.Type = typeStr
break
}
}
}
if descVal, ok := schemaMap["description"].(string); ok {
schema.Description = descVal
Expand Down Expand Up @@ -64,8 +73,17 @@ func convertProperties(props map[string]any) map[string]*tool.Schema {
for name, prop := range props {
if propMap, ok := prop.(map[string]any); ok {
propSchema := &tool.Schema{}
// Handle type field - can be string or array of strings (for nullable/union types).
if typeVal, ok := propMap["type"].(string); ok {
propSchema.Type = typeVal
} else if typeArr, ok := propMap["type"].([]any); ok {
// For nullable types like ["integer", "null"], extract the non-null type.
for _, t := range typeArr {
if typeStr, ok := t.(string); ok && typeStr != "null" {
propSchema.Type = typeStr
break
}
}
}
if descVal, ok := propMap["description"].(string); ok {
propSchema.Description = descVal
Expand Down
93 changes: 93 additions & 0 deletions tool/mcp/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,96 @@ func TestNewMCPTool_WithOutputSchema(t *testing.T) {
require.Equal(t, "Success flag", convertedSchema.Properties["success"].Description)
require.ElementsMatch(t, []string{"result", "success"}, convertedSchema.Required)
}

func TestConvertMCPSchema_UnionTypes(t *testing.T) {
// Test handling of union types (nullable types represented as ["type", "null"]).
mcpSchema := map[string]any{
"type": "object",
"description": "schema with optional parameters",
"required": []any{"text"}, // text is required, size is optional
"properties": map[string]any{
"text": map[string]any{
"type": "string",
"description": "Required text parameter",
},
"size": map[string]any{
"type": []any{"integer", "null"}, // Nullable integer.
"description": "Size of the variable in bytes",
},
"count": map[string]any{
"type": []any{"number", "null"}, // Nullable number.
"description": "Optional count parameter",
},
"flag": map[string]any{
"type": []any{"boolean", "null"}, // Nullable boolean.
"description": "Optional flag parameter",
},
},
}

schema := convertMCPSchemaToSchema(mcpSchema)

// Verify top-level schema.
require.Equal(t, "object", schema.Type)
require.Equal(t, "schema with optional parameters", schema.Description)
require.ElementsMatch(t, []string{"text"}, schema.Required)

// Verify required string parameter.
textSchema := schema.Properties["text"]
require.NotNil(t, textSchema)
require.Equal(t, "string", textSchema.Type)
require.Equal(t, "Required text parameter", textSchema.Description)

// Verify optional integer parameter (should extract "integer" from ["integer", "null"]).
sizeSchema := schema.Properties["size"]
require.NotNil(t, sizeSchema)
require.Equal(t, "integer", sizeSchema.Type, "should extract 'integer' from union type")
require.Equal(t, "Size of the variable in bytes", sizeSchema.Description)

// Verify optional number parameter.
countSchema := schema.Properties["count"]
require.NotNil(t, countSchema)
require.Equal(t, "number", countSchema.Type, "should extract 'number' from union type")
require.Equal(t, "Optional count parameter", countSchema.Description)

// Verify optional boolean parameter.
flagSchema := schema.Properties["flag"]
require.NotNil(t, flagSchema)
require.Equal(t, "boolean", flagSchema.Type, "should extract 'boolean' from union type")
require.Equal(t, "Optional flag parameter", flagSchema.Description)
}

func TestConvertMCPSchema_UnionType_NullOnly(t *testing.T) {
// Edge case: what if type is ["null"] only?
mcpSchema := map[string]any{
"type": "object",
"properties": map[string]any{
"nullable_field": map[string]any{
"type": []any{"null"},
"description": "A field that can only be null",
},
},
}

schema := convertMCPSchemaToSchema(mcpSchema)

// When only "null" is in the array, Type should remain empty.
nullableSchema := schema.Properties["nullable_field"]
require.NotNil(t, nullableSchema)
require.Equal(t, "", nullableSchema.Type, "should not set type when only 'null' is present")
require.Equal(t, "A field that can only be null", nullableSchema.Description)
}

func TestConvertMCPSchema_TopLevelUnionType(t *testing.T) {
// Test union type at top level (rare but possible).
mcpSchema := map[string]any{
"type": []any{"string", "null"},
"description": "A top-level union type",
}

schema := convertMCPSchemaToSchema(mcpSchema)

// Should extract "string" from the union type.
require.Equal(t, "string", schema.Type, "should extract 'string' from top-level union type")
require.Equal(t, "A top-level union type", schema.Description)
}
Loading