Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions e2e/sampling_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
package e2e_test

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"

gomcp "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/stretchr/testify/require"
)

// TestExec_Gemini_SamplingWithTools exercises the MCP sampling-with-tools
// round-trip end-to-end. An in-process gomcp.NewServer is mounted on an
// httptest server via StreamableHTTPHandler. It exposes one tool,
// ask_with_calculator, whose handler runs a sampling loop: it sends
// sampling/createMessage with a tools array, gets a tool_use back from the
// host LLM, "executes" the calculator, sends a follow-up sampling request
// carrying the tool_result, and returns the final text.
func TestExec_Gemini_SamplingWithTools(t *testing.T) {
mcpURL := startSamplingToolsServer(t)
yamlPath := writeSamplingToolsAgent(t, mcpURL)

out := runCLI(t, "run", "--exec", "--yolo", yamlPath, "--model=google/gemini-2.5-flash", "What is 17 plus 25?")

require.Contains(t, out, "ask_with_calculator")
require.Contains(t, out, "42")
}

// startSamplingToolsServer mounts an MCP server on an httptest server and
// returns its URL. The server exposes a single tool whose handler drives a
// sampling-with-tools loop against the connecting client.
func startSamplingToolsServer(t *testing.T) string {
t.Helper()

server := gomcp.NewServer(&gomcp.Implementation{
Name: "sampling-tools-test",
Version: "0.0.1",
}, nil)

gomcp.AddTool(server, &gomcp.Tool{
Name: "ask_with_calculator",
Description: "Answer a math word problem by asking the host LLM for help, with access to a calculator tool the server provides.",
}, askWithCalculator)

handler := gomcp.NewStreamableHTTPHandler(
func(*http.Request) *gomcp.Server { return server },
nil,
)
httpSrv := httptest.NewServer(handler)
t.Cleanup(httpSrv.Close)

return httpSrv.URL
}

func writeSamplingToolsAgent(t *testing.T, mcpURL string) string {
t.Helper()
yamlPath := filepath.Join(t.TempDir(), "agent.yaml")
agentYAML := fmt.Appendf(nil, `agents:
root:
model: google/gemini-2.5-flash
description: "Test agent for MCP sampling-with-tools end-to-end verification"
instruction: |
You have access to one tool: ask_with_calculator. Whenever the user asks
a math word problem, call ask_with_calculator with the user's question.
Then report its answer verbatim to the user.
toolsets:
- type: mcp
allow_private_ips: true
remote:
url: %s
transport_type: streamable
`, mcpURL)
require.NoError(t, os.WriteFile(yamlPath, agentYAML, 0o644))
return yamlPath
}

type askInput struct {
Question string `json:"question" jsonschema:"the natural-language question to answer with help of the calculator"`
}

func askWithCalculator(ctx context.Context, req *gomcp.CallToolRequest, in askInput) (*gomcp.CallToolResult, any, error) {
calculator := &gomcp.Tool{
Name: "calculator",
Description: "Add two integers. Returns the sum.",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"x": map[string]any{"type": "integer"},
"y": map[string]any{"type": "integer"},
},
"required": []string{"x", "y"},
},
}

messages := []*gomcp.SamplingMessageV2{{
Role: "user",
Content: []gomcp.Content{&gomcp.TextContent{Text: in.Question}},
}}

for round := 1; round <= 4; round++ {
res, err := req.Session.CreateMessageWithTools(ctx, &gomcp.CreateMessageWithToolsParams{
MaxTokens: 1024,
Messages: messages,
Tools: []*gomcp.Tool{calculator},
SystemPrompt: "You are a careful assistant. Use the calculator tool whenever you need to add two integers. After you have the sum, answer the user's question in one short sentence.",
})
if err != nil {
return nil, nil, fmt.Errorf("sampling round %d: %w", round, err)
}

messages = append(messages, &gomcp.SamplingMessageV2{
Role: res.Role,
Content: res.Content,
})

var toolUses []*gomcp.ToolUseContent
var finalText strings.Builder
for _, c := range res.Content {
switch v := c.(type) {
case *gomcp.ToolUseContent:
toolUses = append(toolUses, v)
case *gomcp.TextContent:
finalText.WriteString(v.Text)
}
}

if len(toolUses) == 0 {
return &gomcp.CallToolResult{
Content: []gomcp.Content{&gomcp.TextContent{Text: finalText.String()}},
}, nil, nil
}

toolResults := make([]gomcp.Content, 0, len(toolUses))
for _, tu := range toolUses {
result, err := runCalculator(tu)
if err != nil {
toolResults = append(toolResults, &gomcp.ToolResultContent{
ToolUseID: tu.ID,
Content: []gomcp.Content{&gomcp.TextContent{Text: err.Error()}},
IsError: true,
})
continue
}
toolResults = append(toolResults, &gomcp.ToolResultContent{
ToolUseID: tu.ID,
Content: []gomcp.Content{&gomcp.TextContent{Text: result}},
})
}

messages = append(messages, &gomcp.SamplingMessageV2{
Role: "user",
Content: toolResults,
})
}

return nil, nil, fmt.Errorf("sampling loop did not terminate within 4 rounds")

Check failure on line 162 in e2e/sampling_test.go

View workflow job for this annotation

GitHub Actions / lint

error-format: fmt.Errorf can be replaced with errors.New (perfsprint)
}

func runCalculator(tu *gomcp.ToolUseContent) (string, error) {
if tu.Name != "calculator" {
return "", fmt.Errorf("unknown tool: %s", tu.Name)
}
x, errX := toInt(tu.Input["x"])
y, errY := toInt(tu.Input["y"])
if errX != nil || errY != nil {
raw, _ := json.Marshal(tu.Input)
return "", fmt.Errorf("calculator expects integer x and y, got %s", raw)
}
return fmt.Sprintf("%d", x+y), nil

Check failure on line 175 in e2e/sampling_test.go

View workflow job for this annotation

GitHub Actions / lint

integer-format: fmt.Sprintf can be replaced with faster strconv.FormatInt (perfsprint)
}

func toInt(v any) (int64, error) {
switch n := v.(type) {
case float64:
return int64(n), nil
case int64:
return n, nil
case int:
return int64(n), nil
case json.Number:
return n.Int64()
default:
return 0, fmt.Errorf("not a number: %T", v)
}
}
99 changes: 99 additions & 0 deletions e2e/testdata/cassettes/TestExec_Gemini_SamplingWithTools.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
---
version: 2
interactions:
- id: 0
request:
proto: HTTP/1.1
proto_major: 1
proto_minor: 1
content_length: 0
host: generativelanguage.googleapis.com
body: |
{"contents":[{"parts":[{"text":"You have access to one tool: ask_with_calculator. Whenever the user asks\na math word problem, call ask_with_calculator with the user's question.\nThen report its answer verbatim to the user.\n"}],"role":"user"},{"parts":[{"text":"What is 17 plus 25?"}],"role":"user"}],"generationConfig":{},"toolConfig":{"functionCallingConfig":{"mode":"AUTO"}},"tools":[{"functionDeclarations":[{"description":"Answer a math word problem by asking the host LLM for help, with access to a calculator tool the server provides.","name":"ask_with_calculator","parameters":{"properties":{"question":{"description":"the natural-language question to answer with help of the calculator","type":"string"}},"required":["question"],"type":"object"}}]}]}
form:
alt:
- sse
url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:streamGenerateContent?alt=sse
method: POST
response:
proto: HTTP/2.0
proto_major: 2
proto_minor: 0
content_length: -1
body: "data: {\"candidates\": [{\"content\": {\"parts\": [{\"functionCall\": {\"name\": \"ask_with_calculator\",\"args\": {\"question\": \"What is 17 plus 25?\"}},\"thoughtSignature\": \"CiQBDDnWx0f32fnXCeUit6K5jCaDOjrvkvnvsHY5MIV4ozFP2DkKWQEMOdbH80G2+BEkAXkEaOtaQkTeMtam0g5v81QW4CbZq0EWzMpBvR4qorY0uY1UsLTkkoSC2Jrwtbx9HKcx2O5eg8e7glVpviskbssHnvs/Ym050fnXqUPVCswBAQw51sfCV/T8djp98wEvfztWe9iFUQ6D3oT94v4X9lE0WSuXqEhk/GceAChWa2DgtoR9Su4AzkZJjl1Yi4mR4DpAZp1/+jQBMgm2v8g+zUWqWt4beV/YmMmJREjaI56NpBq19U+t1WzlTk5aQyT7KcH6EoFdJaEyCiY+B1pdo4Nm4HhDb9J+C0evbqlU7jXGy9bn5GBjeAS7xMdVh/IJ2oLmabcwdA4htWMWKLYEyUb87kU+cvumMYbhTQvn40csP+j3IKVazFhsNcmQ\"}],\"role\": \"model\"},\"finishReason\": \"STOP\",\"index\": 0,\"finishMessage\": \"Model generated function call(s).\"}],\"usageMetadata\": {\"promptTokenCount\": 128,\"candidatesTokenCount\": 26,\"totalTokenCount\": 215,\"promptTokensDetails\": [{\"modality\": \"TEXT\",\"tokenCount\": 128}],\"thoughtsTokenCount\": 61,\"serviceTier\": \"standard\"},\"modelVersion\": \"gemini-2.5-flash\",\"responseId\": \"hL8kaue3J7CEz7IPw8H7kAY\"}\r\n\r\n"
headers: {}
status: 200 OK
code: 200
duration: 1.079262793s
- id: 1
request:
proto: HTTP/1.1
proto_major: 1
proto_minor: 1
content_length: 0
host: generativelanguage.googleapis.com
body: |
{"contents":[{"parts":[{"text":"You are a careful assistant. Use the calculator tool whenever you need to add two integers. After you have the sum, answer the user's question in one short sentence."}],"role":"user"},{"parts":[{"text":"What is 17 plus 25?"}],"role":"user"}],"generationConfig":{"maxOutputTokens":1024,"thinkingConfig":{"thinkingBudget":0}},"toolConfig":{"functionCallingConfig":{"mode":"AUTO"}},"tools":[{"functionDeclarations":[{"description":"Add two integers. Returns the sum.","name":"calculator","parameters":{"properties":{"x":{"type":"integer"},"y":{"type":"integer"}},"required":["x","y"],"type":"object"}}]}]}
form:
alt:
- sse
url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:streamGenerateContent?alt=sse
method: POST
response:
proto: HTTP/2.0
proto_major: 2
proto_minor: 0
content_length: -1
body: "data: {\"candidates\": [{\"content\": {\"parts\": [{\"functionCall\": {\"name\": \"calculator\",\"args\": {\"y\": 25,\"x\": 17}}}],\"role\": \"model\"},\"finishReason\": \"STOP\",\"index\": 0,\"finishMessage\": \"Model generated function call(s).\"}],\"usageMetadata\": {\"promptTokenCount\": 96,\"candidatesTokenCount\": 20,\"totalTokenCount\": 116,\"promptTokensDetails\": [{\"modality\": \"TEXT\",\"tokenCount\": 96}],\"serviceTier\": \"standard\"},\"modelVersion\": \"gemini-2.5-flash\",\"responseId\": \"hb8kas6hIuPqz7IP16znuQ8\"}\r\n\r\n"
headers: {}
status: 200 OK
code: 200
duration: 788.273412ms
- id: 2
request:
proto: HTTP/1.1
proto_major: 1
proto_minor: 1
content_length: 0
host: generativelanguage.googleapis.com
body: |
{"contents":[{"parts":[{"text":"You are a careful assistant. Use the calculator tool whenever you need to add two integers. After you have the sum, answer the user's question in one short sentence."}],"role":"user"},{"parts":[{"text":"What is 17 plus 25?"}],"role":"user"},{"parts":[{"functionCall":{"args":{"x":17,"y":25},"name":"calculator"},"thoughtSignature":"c2tpcF90aG91Z2h0X3NpZ25hdHVyZV92YWxpZGF0b3I="}],"role":"model"},{"parts":[{"functionResponse":{"name":"call_f15543e2-b952-4c7d-ab72-bcc56f00c6f6","response":{"result":"42"}}}],"role":"user"}],"generationConfig":{"maxOutputTokens":1024,"thinkingConfig":{"thinkingBudget":0}},"toolConfig":{"functionCallingConfig":{"mode":"AUTO"}},"tools":[{"functionDeclarations":[{"description":"Add two integers. Returns the sum.","name":"calculator","parameters":{"properties":{"x":{"type":"integer"},"y":{"type":"integer"}},"required":["x","y"],"type":"object"}}]}]}
form:
alt:
- sse
url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:streamGenerateContent?alt=sse
method: POST
response:
proto: HTTP/2.0
proto_major: 2
proto_minor: 0
content_length: -1
body: "data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"1\"}],\"role\": \"model\"},\"index\": 0}],\"usageMetadata\": {\"promptTokenCount\": 164,\"candidatesTokenCount\": 1,\"totalTokenCount\": 165,\"promptTokensDetails\": [{\"modality\": \"TEXT\",\"tokenCount\": 164}],\"serviceTier\": \"standard\"},\"modelVersion\": \"gemini-2.5-flash\",\"responseId\": \"hr8kapnPFPDrz7IPoInsiQQ\"}\r\n\r\ndata: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"7 plus 25 is 42.\"}],\"role\": \"model\"},\"finishReason\": \"STOP\",\"index\": 0}],\"usageMetadata\": {\"promptTokenCount\": 164,\"candidatesTokenCount\": 11,\"totalTokenCount\": 175,\"promptTokensDetails\": [{\"modality\": \"TEXT\",\"tokenCount\": 164}],\"serviceTier\": \"standard\"},\"modelVersion\": \"gemini-2.5-flash\",\"responseId\": \"hr8kapnPFPDrz7IPoInsiQQ\"}\r\n\r\n"
headers: {}
status: 200 OK
code: 200
duration: 635.846943ms
- id: 3
request:
proto: HTTP/1.1
proto_major: 1
proto_minor: 1
content_length: 0
host: generativelanguage.googleapis.com
body: |
{"contents":[{"parts":[{"text":"You have access to one tool: ask_with_calculator. Whenever the user asks\na math word problem, call ask_with_calculator with the user's question.\nThen report its answer verbatim to the user.\n"}],"role":"user"},{"parts":[{"text":"What is 17 plus 25?"}],"role":"user"},{"parts":[{"functionCall":{"args":{"question":"What is 17 plus 25?"},"name":"ask_with_calculator"},"thoughtSignature":"CiQBDDnWx0f32fnXCeUit6K5jCaDOjrvkvnvsHY5MIV4ozFP2DkKWQEMOdbH80G2+BEkAXkEaOtaQkTeMtam0g5v81QW4CbZq0EWzMpBvR4qorY0uY1UsLTkkoSC2Jrwtbx9HKcx2O5eg8e7glVpviskbssHnvs/Ym050fnXqUPVCswBAQw51sfCV/T8djp98wEvfztWe9iFUQ6D3oT94v4X9lE0WSuXqEhk/GceAChWa2DgtoR9Su4AzkZJjl1Yi4mR4DpAZp1/+jQBMgm2v8g+zUWqWt4beV/YmMmJREjaI56NpBq19U+t1WzlTk5aQyT7KcH6EoFdJaEyCiY+B1pdo4Nm4HhDb9J+C0evbqlU7jXGy9bn5GBjeAS7xMdVh/IJ2oLmabcwdA4htWMWKLYEyUb87kU+cvumMYbhTQvn40csP+j3IKVazFhsNcmQ"}],"role":"model"},{"parts":[{"functionResponse":{"name":"call_546b970f-9898-4b1f-94e0-e6b140fcbdea","response":{"result":"17 plus 25 is 42."}}}],"role":"user"}],"generationConfig":{},"toolConfig":{"functionCallingConfig":{"mode":"AUTO"}},"tools":[{"functionDeclarations":[{"description":"Answer a math word problem by asking the host LLM for help, with access to a calculator tool the server provides.","name":"ask_with_calculator","parameters":{"properties":{"question":{"description":"the natural-language question to answer with help of the calculator","type":"string"}},"required":["question"],"type":"object"}}]}]}
form:
alt:
- sse
url: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:streamGenerateContent?alt=sse
method: POST
response:
proto: HTTP/2.0
proto_major: 2
proto_minor: 0
content_length: -1
body: "data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"17 plus \"}],\"role\": \"model\"},\"index\": 0}],\"usageMetadata\": {\"promptTokenCount\": 272,\"candidatesTokenCount\": 3,\"totalTokenCount\": 275,\"promptTokensDetails\": [{\"modality\": \"TEXT\",\"tokenCount\": 272}],\"serviceTier\": \"standard\"},\"modelVersion\": \"gemini-2.5-flash\",\"responseId\": \"hr8kavfTPOTtz7IPmZDEwQE\"}\r\n\r\ndata: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"25 is 42.\"}],\"role\": \"model\"},\"finishReason\": \"STOP\",\"index\": 0}],\"usageMetadata\": {\"promptTokenCount\": 272,\"candidatesTokenCount\": 10,\"totalTokenCount\": 282,\"promptTokensDetails\": [{\"modality\": \"TEXT\",\"tokenCount\": 272}],\"serviceTier\": \"standard\"},\"modelVersion\": \"gemini-2.5-flash\",\"responseId\": \"hr8kavfTPOTtz7IPmZDEwQE\"}\r\n\r\n"
headers: {}
status: 200 OK
code: 200
duration: 537.623295ms
1 change: 1 addition & 0 deletions pkg/runtime/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ func (r *LocalRuntime) configureToolsetHandlers(a *agent.Agent, events EventSink
tools.ConfigureHandlers(toolset,
r.elicitationHandler,
r.samplingHandler,
r.samplingWithToolsHandler,
func() { events.Emit(Authorization(tools.ElicitationActionAccept, a.Name())) },
r.managedOAuth,
r.unmanagedOAuthRedirectURI,
Expand Down
Loading
Loading