Skip to content

Commit 6febbdd

Browse files
authored
model: append tool outputschema to tool description (#756)
The openai/anthropic SDK don’t expose an output-schema field for tools, so we surface the schema in descriptions to give the LLM stronger guidance on tool return shapes.
1 parent 2fe4460 commit 6febbdd

File tree

4 files changed

+232
-2
lines changed

4 files changed

+232
-2
lines changed

model/anthropic/anthropic.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,7 @@ func convertTools(tools map[string]tool.Tool) []anthropic.ToolUnionParam {
848848
result = append(result, anthropic.ToolUnionParam{
849849
OfTool: &anthropic.ToolParam{
850850
Name: declaration.Name,
851-
Description: anthropic.String(declaration.Description),
851+
Description: anthropic.String(buildToolDescription(declaration)),
852852
InputSchema: anthropic.ToolInputSchemaParam{
853853
Type: constant.Object(declaration.InputSchema.Type),
854854
Properties: declaration.InputSchema.Properties,
@@ -860,6 +860,22 @@ func convertTools(tools map[string]tool.Tool) []anthropic.ToolUnionParam {
860860
return result
861861
}
862862

863+
// buildToolDescription builds the description for a tool.
864+
// It appends the output schema to the description.
865+
func buildToolDescription(declaration *tool.Declaration) string {
866+
desc := declaration.Description
867+
if declaration.OutputSchema == nil {
868+
return desc
869+
}
870+
schemaJSON, err := json.Marshal(declaration.OutputSchema)
871+
if err != nil {
872+
log.Debugf("marshal output schema for tool %s: %v", declaration.Name, err)
873+
return desc
874+
}
875+
desc += "Output schema: " + string(schemaJSON)
876+
return desc
877+
}
878+
863879
// convertMessages builds Anthropic message parameters and system prompts from trpc-agent-go messages.
864880
// Merges consecutive tool results into a single user message and drops empty-content messages.
865881
func convertMessages(messages []model.Message) ([]anthropic.MessageParam, []anthropic.TextBlockParam, error) {

model/anthropic/anthropic_test.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/stretchr/testify/assert"
2525
"github.com/stretchr/testify/require"
2626

27+
agentlog "trpc.group/trpc-go/trpc-agent-go/log"
2728
"trpc.group/trpc-go/trpc-agent-go/model"
2829
"trpc.group/trpc-go/trpc-agent-go/tool"
2930
)
@@ -37,6 +38,25 @@ func (s stubTool) Call(_ context.Context, _ []byte) (any, error) { return nil, n
3738
// Declaration returns the tool declaration.
3839
func (s stubTool) Declaration() *tool.Declaration { return s.decl }
3940

41+
type stubLogger struct {
42+
debugfCalled bool
43+
debugfMsg string
44+
}
45+
46+
func (stubLogger) Debug(args ...any) {}
47+
func (l *stubLogger) Debugf(format string, args ...any) {
48+
l.debugfCalled = true
49+
l.debugfMsg = fmt.Sprintf(format, args...)
50+
}
51+
func (stubLogger) Info(args ...any) {}
52+
func (stubLogger) Infof(format string, args ...any) {}
53+
func (stubLogger) Warn(args ...any) {}
54+
func (stubLogger) Warnf(format string, args ...any) {}
55+
func (stubLogger) Error(args ...any) {}
56+
func (stubLogger) Errorf(format string, args ...any) {}
57+
func (stubLogger) Fatal(args ...any) {}
58+
func (stubLogger) Fatalf(format string, args ...any) {}
59+
4060
func Test_Model_Info(t *testing.T) {
4161
m := New("claude-3-5-sonnet-latest")
4262
info := m.Info()
@@ -147,6 +167,85 @@ func Test_convertTools(t *testing.T) {
147167
assert.Equal(t, "t1", params[0].OfTool.Name)
148168
}
149169

170+
func Test_buildToolDescription_AppendsOutputSchema(t *testing.T) {
171+
schema := &tool.Schema{
172+
Type: "object",
173+
Properties: map[string]*tool.Schema{
174+
"status": {Type: "string"},
175+
},
176+
}
177+
decl := &tool.Declaration{
178+
Name: "foo",
179+
Description: "desc",
180+
OutputSchema: schema,
181+
}
182+
183+
desc := buildToolDescription(decl)
184+
185+
assert.Contains(t, desc, "desc", "expected base description to remain")
186+
assert.Contains(t, desc, "Output schema:", "expected output schema label to be present")
187+
assert.Contains(t, desc, `"status"`, "expected output schema to be embedded in description")
188+
}
189+
190+
func Test_buildToolDescription_MarshalError(t *testing.T) {
191+
logger := &stubLogger{}
192+
original := agentlog.Default
193+
agentlog.Default = logger
194+
defer func() { agentlog.Default = original }()
195+
196+
decl := &tool.Declaration{
197+
Name: "foo",
198+
Description: "desc",
199+
OutputSchema: &tool.Schema{
200+
Type: "object",
201+
AdditionalProperties: func() {},
202+
},
203+
}
204+
205+
desc := buildToolDescription(decl)
206+
207+
assert.Equal(t, "desc", desc, "description should fall back when marshal fails")
208+
assert.True(t, logger.debugfCalled, "expected marshal error to be logged")
209+
assert.Contains(t, logger.debugfMsg, "marshal output schema", "expected marshal error message")
210+
}
211+
212+
func Test_buildToolDescription_NoOutputSchema(t *testing.T) {
213+
decl := &tool.Declaration{
214+
Name: "foo",
215+
Description: "bar",
216+
}
217+
218+
desc := buildToolDescription(decl)
219+
220+
assert.Equal(t, "bar", desc, "description should stay unchanged when no output schema")
221+
}
222+
223+
func Test_convertTools_UsesOutputSchemaDescription(t *testing.T) {
224+
outputSchema := &tool.Schema{
225+
Type: "object",
226+
Properties: map[string]*tool.Schema{
227+
"count": {Type: "integer"},
228+
},
229+
}
230+
decl := &tool.Declaration{
231+
Name: "tool_with_out",
232+
Description: "tool desc",
233+
InputSchema: &tool.Schema{Type: "object"},
234+
OutputSchema: outputSchema,
235+
}
236+
237+
params := convertTools(map[string]tool.Tool{
238+
decl.Name: stubTool{decl: decl},
239+
})
240+
241+
require.Len(t, params, 1)
242+
require.NotNil(t, params[0].OfTool)
243+
expected := buildToolDescription(decl)
244+
assert.True(t, params[0].OfTool.Description.Valid(), "description should be set")
245+
assert.Equal(t, expected, params[0].OfTool.Description.Value)
246+
assert.Contains(t, params[0].OfTool.Description.Value, `"count"`, "output schema JSON should appear in description")
247+
}
248+
150249
func Test_decodeToolArguments(t *testing.T) {
151250
// Empty -> empty map.
152251
v := decodeToolArguments(nil)

model/openai/openai.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1213,14 +1213,30 @@ func (m *Model) convertTools(tools map[string]tool.Tool) []openai.ChatCompletion
12131213
result = append(result, openai.ChatCompletionToolParam{
12141214
Function: openai.FunctionDefinitionParam{
12151215
Name: declaration.Name,
1216-
Description: openai.String(declaration.Description),
1216+
Description: openai.String(buildToolDescription(declaration)),
12171217
Parameters: parameters,
12181218
},
12191219
})
12201220
}
12211221
return result
12221222
}
12231223

1224+
// buildToolDescription builds the description for a tool.
1225+
// It appends the output schema to the description.
1226+
func buildToolDescription(declaration *tool.Declaration) string {
1227+
desc := declaration.Description
1228+
if declaration.OutputSchema == nil {
1229+
return desc
1230+
}
1231+
schemaJSON, err := json.Marshal(declaration.OutputSchema)
1232+
if err != nil {
1233+
log.Errorf("marshal output schema for tool %s: %v", declaration.Name, err)
1234+
return desc
1235+
}
1236+
desc += "\nOutput schema: " + string(schemaJSON)
1237+
return desc
1238+
}
1239+
12241240
// handleStreamingResponse handles streaming chat completion responses.
12251241
func (m *Model) handleStreamingResponse(
12261242
ctx context.Context,

model/openai/openai_test.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
openaigo "github.com/openai/openai-go"
2828
openaiopt "github.com/openai/openai-go/option"
2929
"github.com/openai/openai-go/packages/respjson"
30+
agentlog "trpc.group/trpc-go/trpc-agent-go/log"
3031
"trpc.group/trpc-go/trpc-agent-go/model"
3132
"trpc.group/trpc-go/trpc-agent-go/tool"
3233

@@ -241,6 +242,25 @@ type stubTool struct{ decl *tool.Declaration }
241242
func (s stubTool) Call(_ context.Context, _ []byte) (any, error) { return nil, nil }
242243
func (s stubTool) Declaration() *tool.Declaration { return s.decl }
243244

245+
type stubLogger struct {
246+
errorfCalled bool
247+
errorfMsg string
248+
}
249+
250+
func (stubLogger) Debug(args ...any) {}
251+
func (stubLogger) Debugf(format string, args ...any) {}
252+
func (stubLogger) Info(args ...any) {}
253+
func (stubLogger) Infof(format string, args ...any) {}
254+
func (stubLogger) Warn(args ...any) {}
255+
func (stubLogger) Warnf(format string, args ...any) {}
256+
func (stubLogger) Error(args ...any) {}
257+
func (l *stubLogger) Errorf(format string, args ...any) {
258+
l.errorfCalled = true
259+
l.errorfMsg = fmt.Sprintf(format, args...)
260+
}
261+
func (stubLogger) Fatal(args ...any) {}
262+
func (stubLogger) Fatalf(format string, args ...any) {}
263+
244264
// TestModel_convertMessages verifies that messages are converted to the
245265
// openai-go request format with the expected roles and fields.
246266
func TestModel_convertMessages(t *testing.T) {
@@ -322,6 +342,85 @@ func TestModel_convertTools(t *testing.T) {
322342
require.False(t, reflect.ValueOf(fn.Parameters).IsZero(), "expected parameters to be populated from schema")
323343
}
324344

345+
func TestBuildToolDescription_AppendsOutputSchema(t *testing.T) {
346+
schema := &tool.Schema{
347+
Type: "object",
348+
Properties: map[string]*tool.Schema{
349+
"result": {Type: "string"},
350+
},
351+
}
352+
decl := &tool.Declaration{
353+
Name: "example",
354+
Description: "base",
355+
OutputSchema: schema,
356+
}
357+
358+
desc := buildToolDescription(decl)
359+
360+
assert.Contains(t, desc, "base", "expected base description to be preserved")
361+
assert.Contains(t, desc, "Output schema:", "expected output schema label to be present")
362+
assert.Contains(t, desc, `"result"`, "expected output schema to be present in description")
363+
}
364+
365+
func TestBuildToolDescription_MarshalError(t *testing.T) {
366+
logger := &stubLogger{}
367+
originalLogger := agentlog.Default
368+
agentlog.Default = logger
369+
defer func() { agentlog.Default = originalLogger }()
370+
371+
decl := &tool.Declaration{
372+
Name: "invalid",
373+
Description: "desc",
374+
OutputSchema: &tool.Schema{
375+
Type: "object",
376+
AdditionalProperties: func() {},
377+
},
378+
}
379+
380+
desc := buildToolDescription(decl)
381+
382+
assert.Equal(t, "desc", desc, "description should fall back when marshal fails")
383+
assert.True(t, logger.errorfCalled, "expected marshal error to be logged")
384+
assert.Contains(t, logger.errorfMsg, "marshal output schema", "expected marshal error message")
385+
}
386+
387+
func TestBuildToolDescription_NoOutputSchema(t *testing.T) {
388+
decl := &tool.Declaration{
389+
Name: "example",
390+
Description: "only desc",
391+
}
392+
393+
desc := buildToolDescription(decl)
394+
395+
assert.Equal(t, "only desc", desc, "description should remain unchanged without output schema")
396+
}
397+
398+
func TestConvertTools_UsesOutputSchemaInDescription(t *testing.T) {
399+
m := New("dummy")
400+
outputSchema := &tool.Schema{
401+
Type: "object",
402+
Properties: map[string]*tool.Schema{
403+
"value": {Type: "number"},
404+
},
405+
}
406+
decl := &tool.Declaration{
407+
Name: "tool1",
408+
Description: "desc",
409+
InputSchema: &tool.Schema{Type: "object"},
410+
OutputSchema: outputSchema,
411+
}
412+
413+
params := m.convertTools(map[string]tool.Tool{
414+
decl.Name: stubTool{decl: decl},
415+
})
416+
417+
require.Len(t, params, 1)
418+
expectedDesc := buildToolDescription(decl)
419+
require.True(t, params[0].Function.Description.Valid(), "function description should be set")
420+
assert.Equal(t, expectedDesc, params[0].Function.Description.Value)
421+
assert.Contains(t, params[0].Function.Description.Value, `"value"`, "output schema JSON should be embedded")
422+
}
423+
325424
// TestModel_Callbacks tests that callback functions are properly called with
326425
// the correct parameters including the request parameter.
327426
func TestModel_Callbacks(t *testing.T) {

0 commit comments

Comments
 (0)