Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion model/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,7 @@ func convertTools(tools map[string]tool.Tool) []anthropic.ToolUnionParam {
result = append(result, anthropic.ToolUnionParam{
OfTool: &anthropic.ToolParam{
Name: declaration.Name,
Description: anthropic.String(declaration.Description),
Description: anthropic.String(buildToolDescription(declaration)),
InputSchema: anthropic.ToolInputSchemaParam{
Type: constant.Object(declaration.InputSchema.Type),
Properties: declaration.InputSchema.Properties,
Expand All @@ -860,6 +860,22 @@ func convertTools(tools map[string]tool.Tool) []anthropic.ToolUnionParam {
return result
}

// buildToolDescription builds the description for a tool.
// It appends the output schema to the description.
func buildToolDescription(declaration *tool.Declaration) string {
desc := declaration.Description
if declaration.OutputSchema == nil {
return desc
}
schemaJSON, err := json.Marshal(declaration.OutputSchema)
if err != nil {
log.Debugf("marshal output schema for tool %s: %v", declaration.Name, err)
return desc
}
desc += "Output schema: " + string(schemaJSON)
return desc
}

// convertMessages builds Anthropic message parameters and system prompts from trpc-agent-go messages.
// Merges consecutive tool results into a single user message and drops empty-content messages.
func convertMessages(messages []model.Message) ([]anthropic.MessageParam, []anthropic.TextBlockParam, error) {
Expand Down
99 changes: 99 additions & 0 deletions model/anthropic/anthropic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

agentlog "trpc.group/trpc-go/trpc-agent-go/log"
"trpc.group/trpc-go/trpc-agent-go/model"
"trpc.group/trpc-go/trpc-agent-go/tool"
)
Expand All @@ -37,6 +38,25 @@ func (s stubTool) Call(_ context.Context, _ []byte) (any, error) { return nil, n
// Declaration returns the tool declaration.
func (s stubTool) Declaration() *tool.Declaration { return s.decl }

type stubLogger struct {
debugfCalled bool
debugfMsg string
}

func (stubLogger) Debug(args ...any) {}
func (l *stubLogger) Debugf(format string, args ...any) {
l.debugfCalled = true
l.debugfMsg = fmt.Sprintf(format, args...)
}
func (stubLogger) Info(args ...any) {}
func (stubLogger) Infof(format string, args ...any) {}
func (stubLogger) Warn(args ...any) {}
func (stubLogger) Warnf(format string, args ...any) {}
func (stubLogger) Error(args ...any) {}
func (stubLogger) Errorf(format string, args ...any) {}
func (stubLogger) Fatal(args ...any) {}
func (stubLogger) Fatalf(format string, args ...any) {}

func Test_Model_Info(t *testing.T) {
m := New("claude-3-5-sonnet-latest")
info := m.Info()
Expand Down Expand Up @@ -147,6 +167,85 @@ func Test_convertTools(t *testing.T) {
assert.Equal(t, "t1", params[0].OfTool.Name)
}

func Test_buildToolDescription_AppendsOutputSchema(t *testing.T) {
schema := &tool.Schema{
Type: "object",
Properties: map[string]*tool.Schema{
"status": {Type: "string"},
},
}
decl := &tool.Declaration{
Name: "foo",
Description: "desc",
OutputSchema: schema,
}

desc := buildToolDescription(decl)

assert.Contains(t, desc, "desc", "expected base description to remain")
assert.Contains(t, desc, "Output schema:", "expected output schema label to be present")
assert.Contains(t, desc, `"status"`, "expected output schema to be embedded in description")
}

func Test_buildToolDescription_MarshalError(t *testing.T) {
logger := &stubLogger{}
original := agentlog.Default
agentlog.Default = logger
defer func() { agentlog.Default = original }()

decl := &tool.Declaration{
Name: "foo",
Description: "desc",
OutputSchema: &tool.Schema{
Type: "object",
AdditionalProperties: func() {},
},
}

desc := buildToolDescription(decl)

assert.Equal(t, "desc", desc, "description should fall back when marshal fails")
assert.True(t, logger.debugfCalled, "expected marshal error to be logged")
assert.Contains(t, logger.debugfMsg, "marshal output schema", "expected marshal error message")
}

func Test_buildToolDescription_NoOutputSchema(t *testing.T) {
decl := &tool.Declaration{
Name: "foo",
Description: "bar",
}

desc := buildToolDescription(decl)

assert.Equal(t, "bar", desc, "description should stay unchanged when no output schema")
}

func Test_convertTools_UsesOutputSchemaDescription(t *testing.T) {
outputSchema := &tool.Schema{
Type: "object",
Properties: map[string]*tool.Schema{
"count": {Type: "integer"},
},
}
decl := &tool.Declaration{
Name: "tool_with_out",
Description: "tool desc",
InputSchema: &tool.Schema{Type: "object"},
OutputSchema: outputSchema,
}

params := convertTools(map[string]tool.Tool{
decl.Name: stubTool{decl: decl},
})

require.Len(t, params, 1)
require.NotNil(t, params[0].OfTool)
expected := buildToolDescription(decl)
assert.True(t, params[0].OfTool.Description.Valid(), "description should be set")
assert.Equal(t, expected, params[0].OfTool.Description.Value)
assert.Contains(t, params[0].OfTool.Description.Value, `"count"`, "output schema JSON should appear in description")
}

func Test_decodeToolArguments(t *testing.T) {
// Empty -> empty map.
v := decodeToolArguments(nil)
Expand Down
18 changes: 17 additions & 1 deletion model/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -1213,14 +1213,30 @@ func (m *Model) convertTools(tools map[string]tool.Tool) []openai.ChatCompletion
result = append(result, openai.ChatCompletionToolParam{
Function: openai.FunctionDefinitionParam{
Name: declaration.Name,
Description: openai.String(declaration.Description),
Description: openai.String(buildToolDescription(declaration)),
Parameters: parameters,
},
})
}
return result
}

// buildToolDescription builds the description for a tool.
// It appends the output schema to the description.
func buildToolDescription(declaration *tool.Declaration) string {
desc := declaration.Description
if declaration.OutputSchema == nil {
return desc
}
schemaJSON, err := json.Marshal(declaration.OutputSchema)
if err != nil {
log.Errorf("marshal output schema for tool %s: %v", declaration.Name, err)
return desc
}
desc += "\nOutput schema: " + string(schemaJSON)
return desc
}

// handleStreamingResponse handles streaming chat completion responses.
func (m *Model) handleStreamingResponse(
ctx context.Context,
Expand Down
99 changes: 99 additions & 0 deletions model/openai/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
openaigo "github.com/openai/openai-go"
openaiopt "github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/respjson"
agentlog "trpc.group/trpc-go/trpc-agent-go/log"
"trpc.group/trpc-go/trpc-agent-go/model"
"trpc.group/trpc-go/trpc-agent-go/tool"

Expand Down Expand Up @@ -241,6 +242,25 @@ type stubTool struct{ decl *tool.Declaration }
func (s stubTool) Call(_ context.Context, _ []byte) (any, error) { return nil, nil }
func (s stubTool) Declaration() *tool.Declaration { return s.decl }

type stubLogger struct {
errorfCalled bool
errorfMsg string
}

func (stubLogger) Debug(args ...any) {}
func (stubLogger) Debugf(format string, args ...any) {}
func (stubLogger) Info(args ...any) {}
func (stubLogger) Infof(format string, args ...any) {}
func (stubLogger) Warn(args ...any) {}
func (stubLogger) Warnf(format string, args ...any) {}
func (stubLogger) Error(args ...any) {}
func (l *stubLogger) Errorf(format string, args ...any) {
l.errorfCalled = true
l.errorfMsg = fmt.Sprintf(format, args...)
}
func (stubLogger) Fatal(args ...any) {}
func (stubLogger) Fatalf(format string, args ...any) {}

// TestModel_convertMessages verifies that messages are converted to the
// openai-go request format with the expected roles and fields.
func TestModel_convertMessages(t *testing.T) {
Expand Down Expand Up @@ -322,6 +342,85 @@ func TestModel_convertTools(t *testing.T) {
require.False(t, reflect.ValueOf(fn.Parameters).IsZero(), "expected parameters to be populated from schema")
}

func TestBuildToolDescription_AppendsOutputSchema(t *testing.T) {
schema := &tool.Schema{
Type: "object",
Properties: map[string]*tool.Schema{
"result": {Type: "string"},
},
}
decl := &tool.Declaration{
Name: "example",
Description: "base",
OutputSchema: schema,
}

desc := buildToolDescription(decl)

assert.Contains(t, desc, "base", "expected base description to be preserved")
assert.Contains(t, desc, "Output schema:", "expected output schema label to be present")
assert.Contains(t, desc, `"result"`, "expected output schema to be present in description")
}

func TestBuildToolDescription_MarshalError(t *testing.T) {
logger := &stubLogger{}
originalLogger := agentlog.Default
agentlog.Default = logger
defer func() { agentlog.Default = originalLogger }()

decl := &tool.Declaration{
Name: "invalid",
Description: "desc",
OutputSchema: &tool.Schema{
Type: "object",
AdditionalProperties: func() {},
},
}

desc := buildToolDescription(decl)

assert.Equal(t, "desc", desc, "description should fall back when marshal fails")
assert.True(t, logger.errorfCalled, "expected marshal error to be logged")
assert.Contains(t, logger.errorfMsg, "marshal output schema", "expected marshal error message")
}

func TestBuildToolDescription_NoOutputSchema(t *testing.T) {
decl := &tool.Declaration{
Name: "example",
Description: "only desc",
}

desc := buildToolDescription(decl)

assert.Equal(t, "only desc", desc, "description should remain unchanged without output schema")
}

func TestConvertTools_UsesOutputSchemaInDescription(t *testing.T) {
m := New("dummy")
outputSchema := &tool.Schema{
Type: "object",
Properties: map[string]*tool.Schema{
"value": {Type: "number"},
},
}
decl := &tool.Declaration{
Name: "tool1",
Description: "desc",
InputSchema: &tool.Schema{Type: "object"},
OutputSchema: outputSchema,
}

params := m.convertTools(map[string]tool.Tool{
decl.Name: stubTool{decl: decl},
})

require.Len(t, params, 1)
expectedDesc := buildToolDescription(decl)
require.True(t, params[0].Function.Description.Valid(), "function description should be set")
assert.Equal(t, expectedDesc, params[0].Function.Description.Value)
assert.Contains(t, params[0].Function.Description.Value, `"value"`, "output schema JSON should be embedded")
}

// TestModel_Callbacks tests that callback functions are properly called with
// the correct parameters including the request parameter.
func TestModel_Callbacks(t *testing.T) {
Expand Down
Loading