diff --git a/model/anthropic/anthropic.go b/model/anthropic/anthropic.go index 2bc5949a4..23819e914 100644 --- a/model/anthropic/anthropic.go +++ b/model/anthropic/anthropic.go @@ -848,7 +848,7 @@ func convertTools(tools map[string]tool.Tool) []anthropic.ToolUnionParam { result = append(result, anthropic.ToolUnionParam{ OfTool: &anthropic.ToolParam{ Name: declaration.Name, - Description: anthropic.String(declaration.Description), + Description: anthropic.String(buildToolDescription(declaration)), InputSchema: anthropic.ToolInputSchemaParam{ Type: constant.Object(declaration.InputSchema.Type), Properties: declaration.InputSchema.Properties, @@ -860,6 +860,22 @@ func convertTools(tools map[string]tool.Tool) []anthropic.ToolUnionParam { return result } +// buildToolDescription builds the description for a tool. +// It appends the output schema to the description. +func buildToolDescription(declaration *tool.Declaration) string { + desc := declaration.Description + if declaration.OutputSchema == nil { + return desc + } + schemaJSON, err := json.Marshal(declaration.OutputSchema) + if err != nil { + log.Debugf("marshal output schema for tool %s: %v", declaration.Name, err) + return desc + } + desc += "Output schema: " + string(schemaJSON) + return desc +} + // convertMessages builds Anthropic message parameters and system prompts from trpc-agent-go messages. // Merges consecutive tool results into a single user message and drops empty-content messages. func convertMessages(messages []model.Message) ([]anthropic.MessageParam, []anthropic.TextBlockParam, error) { diff --git a/model/anthropic/anthropic_test.go b/model/anthropic/anthropic_test.go index 02a20390c..512880298 100644 --- a/model/anthropic/anthropic_test.go +++ b/model/anthropic/anthropic_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + agentlog "trpc.group/trpc-go/trpc-agent-go/log" "trpc.group/trpc-go/trpc-agent-go/model" "trpc.group/trpc-go/trpc-agent-go/tool" ) @@ -37,6 +38,25 @@ func (s stubTool) Call(_ context.Context, _ []byte) (any, error) { return nil, n // Declaration returns the tool declaration. func (s stubTool) Declaration() *tool.Declaration { return s.decl } +type stubLogger struct { + debugfCalled bool + debugfMsg string +} + +func (stubLogger) Debug(args ...any) {} +func (l *stubLogger) Debugf(format string, args ...any) { + l.debugfCalled = true + l.debugfMsg = fmt.Sprintf(format, args...) +} +func (stubLogger) Info(args ...any) {} +func (stubLogger) Infof(format string, args ...any) {} +func (stubLogger) Warn(args ...any) {} +func (stubLogger) Warnf(format string, args ...any) {} +func (stubLogger) Error(args ...any) {} +func (stubLogger) Errorf(format string, args ...any) {} +func (stubLogger) Fatal(args ...any) {} +func (stubLogger) Fatalf(format string, args ...any) {} + func Test_Model_Info(t *testing.T) { m := New("claude-3-5-sonnet-latest") info := m.Info() @@ -147,6 +167,85 @@ func Test_convertTools(t *testing.T) { assert.Equal(t, "t1", params[0].OfTool.Name) } +func Test_buildToolDescription_AppendsOutputSchema(t *testing.T) { + schema := &tool.Schema{ + Type: "object", + Properties: map[string]*tool.Schema{ + "status": {Type: "string"}, + }, + } + decl := &tool.Declaration{ + Name: "foo", + Description: "desc", + OutputSchema: schema, + } + + desc := buildToolDescription(decl) + + assert.Contains(t, desc, "desc", "expected base description to remain") + assert.Contains(t, desc, "Output schema:", "expected output schema label to be present") + assert.Contains(t, desc, `"status"`, "expected output schema to be embedded in description") +} + +func Test_buildToolDescription_MarshalError(t *testing.T) { + logger := &stubLogger{} + original := agentlog.Default + agentlog.Default = logger + defer func() { agentlog.Default = original }() + + decl := &tool.Declaration{ + Name: "foo", + Description: "desc", + OutputSchema: &tool.Schema{ + Type: "object", + AdditionalProperties: func() {}, + }, + } + + desc := buildToolDescription(decl) + + assert.Equal(t, "desc", desc, "description should fall back when marshal fails") + assert.True(t, logger.debugfCalled, "expected marshal error to be logged") + assert.Contains(t, logger.debugfMsg, "marshal output schema", "expected marshal error message") +} + +func Test_buildToolDescription_NoOutputSchema(t *testing.T) { + decl := &tool.Declaration{ + Name: "foo", + Description: "bar", + } + + desc := buildToolDescription(decl) + + assert.Equal(t, "bar", desc, "description should stay unchanged when no output schema") +} + +func Test_convertTools_UsesOutputSchemaDescription(t *testing.T) { + outputSchema := &tool.Schema{ + Type: "object", + Properties: map[string]*tool.Schema{ + "count": {Type: "integer"}, + }, + } + decl := &tool.Declaration{ + Name: "tool_with_out", + Description: "tool desc", + InputSchema: &tool.Schema{Type: "object"}, + OutputSchema: outputSchema, + } + + params := convertTools(map[string]tool.Tool{ + decl.Name: stubTool{decl: decl}, + }) + + require.Len(t, params, 1) + require.NotNil(t, params[0].OfTool) + expected := buildToolDescription(decl) + assert.True(t, params[0].OfTool.Description.Valid(), "description should be set") + assert.Equal(t, expected, params[0].OfTool.Description.Value) + assert.Contains(t, params[0].OfTool.Description.Value, `"count"`, "output schema JSON should appear in description") +} + func Test_decodeToolArguments(t *testing.T) { // Empty -> empty map. v := decodeToolArguments(nil) diff --git a/model/openai/openai.go b/model/openai/openai.go index 17de14e6e..0779fd2f0 100644 --- a/model/openai/openai.go +++ b/model/openai/openai.go @@ -1213,7 +1213,7 @@ func (m *Model) convertTools(tools map[string]tool.Tool) []openai.ChatCompletion result = append(result, openai.ChatCompletionToolParam{ Function: openai.FunctionDefinitionParam{ Name: declaration.Name, - Description: openai.String(declaration.Description), + Description: openai.String(buildToolDescription(declaration)), Parameters: parameters, }, }) @@ -1221,6 +1221,22 @@ func (m *Model) convertTools(tools map[string]tool.Tool) []openai.ChatCompletion return result } +// buildToolDescription builds the description for a tool. +// It appends the output schema to the description. +func buildToolDescription(declaration *tool.Declaration) string { + desc := declaration.Description + if declaration.OutputSchema == nil { + return desc + } + schemaJSON, err := json.Marshal(declaration.OutputSchema) + if err != nil { + log.Errorf("marshal output schema for tool %s: %v", declaration.Name, err) + return desc + } + desc += "\nOutput schema: " + string(schemaJSON) + return desc +} + // handleStreamingResponse handles streaming chat completion responses. func (m *Model) handleStreamingResponse( ctx context.Context, diff --git a/model/openai/openai_test.go b/model/openai/openai_test.go index 5318a62df..f4afb16fe 100644 --- a/model/openai/openai_test.go +++ b/model/openai/openai_test.go @@ -27,6 +27,7 @@ import ( openaigo "github.com/openai/openai-go" openaiopt "github.com/openai/openai-go/option" "github.com/openai/openai-go/packages/respjson" + agentlog "trpc.group/trpc-go/trpc-agent-go/log" "trpc.group/trpc-go/trpc-agent-go/model" "trpc.group/trpc-go/trpc-agent-go/tool" @@ -241,6 +242,25 @@ type stubTool struct{ decl *tool.Declaration } func (s stubTool) Call(_ context.Context, _ []byte) (any, error) { return nil, nil } func (s stubTool) Declaration() *tool.Declaration { return s.decl } +type stubLogger struct { + errorfCalled bool + errorfMsg string +} + +func (stubLogger) Debug(args ...any) {} +func (stubLogger) Debugf(format string, args ...any) {} +func (stubLogger) Info(args ...any) {} +func (stubLogger) Infof(format string, args ...any) {} +func (stubLogger) Warn(args ...any) {} +func (stubLogger) Warnf(format string, args ...any) {} +func (stubLogger) Error(args ...any) {} +func (l *stubLogger) Errorf(format string, args ...any) { + l.errorfCalled = true + l.errorfMsg = fmt.Sprintf(format, args...) +} +func (stubLogger) Fatal(args ...any) {} +func (stubLogger) Fatalf(format string, args ...any) {} + // TestModel_convertMessages verifies that messages are converted to the // openai-go request format with the expected roles and fields. func TestModel_convertMessages(t *testing.T) { @@ -322,6 +342,85 @@ func TestModel_convertTools(t *testing.T) { require.False(t, reflect.ValueOf(fn.Parameters).IsZero(), "expected parameters to be populated from schema") } +func TestBuildToolDescription_AppendsOutputSchema(t *testing.T) { + schema := &tool.Schema{ + Type: "object", + Properties: map[string]*tool.Schema{ + "result": {Type: "string"}, + }, + } + decl := &tool.Declaration{ + Name: "example", + Description: "base", + OutputSchema: schema, + } + + desc := buildToolDescription(decl) + + assert.Contains(t, desc, "base", "expected base description to be preserved") + assert.Contains(t, desc, "Output schema:", "expected output schema label to be present") + assert.Contains(t, desc, `"result"`, "expected output schema to be present in description") +} + +func TestBuildToolDescription_MarshalError(t *testing.T) { + logger := &stubLogger{} + originalLogger := agentlog.Default + agentlog.Default = logger + defer func() { agentlog.Default = originalLogger }() + + decl := &tool.Declaration{ + Name: "invalid", + Description: "desc", + OutputSchema: &tool.Schema{ + Type: "object", + AdditionalProperties: func() {}, + }, + } + + desc := buildToolDescription(decl) + + assert.Equal(t, "desc", desc, "description should fall back when marshal fails") + assert.True(t, logger.errorfCalled, "expected marshal error to be logged") + assert.Contains(t, logger.errorfMsg, "marshal output schema", "expected marshal error message") +} + +func TestBuildToolDescription_NoOutputSchema(t *testing.T) { + decl := &tool.Declaration{ + Name: "example", + Description: "only desc", + } + + desc := buildToolDescription(decl) + + assert.Equal(t, "only desc", desc, "description should remain unchanged without output schema") +} + +func TestConvertTools_UsesOutputSchemaInDescription(t *testing.T) { + m := New("dummy") + outputSchema := &tool.Schema{ + Type: "object", + Properties: map[string]*tool.Schema{ + "value": {Type: "number"}, + }, + } + decl := &tool.Declaration{ + Name: "tool1", + Description: "desc", + InputSchema: &tool.Schema{Type: "object"}, + OutputSchema: outputSchema, + } + + params := m.convertTools(map[string]tool.Tool{ + decl.Name: stubTool{decl: decl}, + }) + + require.Len(t, params, 1) + expectedDesc := buildToolDescription(decl) + require.True(t, params[0].Function.Description.Valid(), "function description should be set") + assert.Equal(t, expectedDesc, params[0].Function.Description.Value) + assert.Contains(t, params[0].Function.Description.Value, `"value"`, "output schema JSON should be embedded") +} + // TestModel_Callbacks tests that callback functions are properly called with // the correct parameters including the request parameter. func TestModel_Callbacks(t *testing.T) {