Skip to content

Commit 37dffa7

Browse files
committed
move output settings to own sub-struct to avoid the requirement to set those values
1 parent 80157eb commit 37dffa7

File tree

5 files changed

+59
-25
lines changed

5 files changed

+59
-25
lines changed

internal/config/model/llm/tools/command/expression_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,29 @@ const cmdDescriptor = {
171171
assert.Equal(t, `Echo: Hello World`, strings.TrimSpace(string(result)))
172172
}
173173

174+
func TestCommandExpression_CommandFn_RunCommand_WithLimit(t *testing.T) {
175+
toTest := Expression(`
176+
const pa = JSON.parse(` + expression.VarNameContext + `.args)
177+
const cmdDescriptor = {
178+
"command": "echo",
179+
"arguments": ["Echo:", pa.message],
180+
"output": {
181+
"firstNBytes": 1,
182+
}
183+
}
184+
185+
` + expression.FuncNameRun + `(cmdDescriptor)
186+
`)
187+
require.NoError(t, toTest.Validate())
188+
189+
llmArgs := `{"message": "Hello World"}`
190+
191+
ctx, _ := context.WithTimeout(context.Background(), 5*time.Second)
192+
result, err := toTest.CommandFn(FunctionDefinition{})(ctx, llmArgs)
193+
assert.NoError(t, err)
194+
assert.Equal(t, "E\n{{ 17 bytes skipped }}", strings.TrimSpace(string(result)))
195+
}
196+
174197
func TestCommandExpression_CommandFn_RunCommand_WithEnv(t *testing.T) {
175198
toTest := Expression(`
176199
const pa = JSON.parse(` + expression.VarNameContext + `.args)

internal/controller/llm_tool_approval.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package controller
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"github.com/tmc/langchaingo/llms"
78
"log/slog"
@@ -41,7 +42,7 @@ func (c *Controller) waitForApproval(ctx context.Context, call llms.ToolCall) er
4142
if a.Message != "" {
4243
errMsg = a.Message
4344
}
44-
return fmt.Errorf(errMsg)
45+
return errors.New(errMsg)
4546
}
4647
case <-ctx.Done():
4748
return fmt.Errorf("Approval for tool '%s' timed out!", call.FunctionCall.Name)

internal/mcp/server/builtin/tools/command/command.go

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@ type CommandDescriptor struct {
1616
Environment map[string]string `json:"env"`
1717
AdditionalEnvironment map[string]string `json:"additionalEnv"`
1818
WorkingDirectory string `json:"workingDir"`
19-
DisableStdOut bool `json:"disableStdOut"`
20-
DisableStdErr bool `json:"disableStdErr"`
21-
FirstNBytes int `json:"firstNBytes"`
22-
LastNBytes int `json:"lastNBytes"`
19+
Output *OutputSettings `json:"output,omitempty"`
20+
}
21+
22+
type OutputSettings struct {
23+
DisableStdOut bool `json:"disableStdOut"`
24+
DisableStdErr bool `json:"disableStdErr"`
25+
FirstNBytes int `json:"firstNBytes"`
26+
LastNBytes int `json:"lastNBytes"`
2327
}
2428

2529
func (c CommandDescriptor) Run(ctx context.Context) ([]byte, error) {
@@ -45,10 +49,10 @@ func (c CommandDescriptor) Run(ctx context.Context) ([]byte, error) {
4549
}()
4650

4751
cmd := cmdBuild.WithErrorChecker(cmdchain.IgnoreExitErrors()).Finalize()
48-
if !c.DisableStdOut {
52+
if c.Output == nil || !c.Output.DisableStdOut {
4953
cmd = cmd.WithOutput(oFile)
5054
}
51-
if !c.DisableStdErr {
55+
if c.Output == nil || !c.Output.DisableStdErr {
5256
cmd = cmd.WithError(oFile)
5357
}
5458

@@ -57,7 +61,11 @@ func (c CommandDescriptor) Run(ctx context.Context) ([]byte, error) {
5761
}
5862

5963
func (c CommandDescriptor) getOutput(f *os.File) []byte {
60-
if c.FirstNBytes < 0 || c.LastNBytes < 0 {
64+
if c.Output == nil {
65+
return readFile(f)
66+
}
67+
cfg := c.Output
68+
if cfg.FirstNBytes < 0 || cfg.LastNBytes < 0 {
6169
return readFile(f)
6270
}
6371

@@ -69,7 +77,7 @@ func (c CommandDescriptor) getOutput(f *os.File) []byte {
6977
)
7078
return nil
7179
}
72-
if c.FirstNBytes+c.LastNBytes > int(fs.Size()) {
80+
if cfg.FirstNBytes+cfg.LastNBytes > int(fs.Size()) {
7381
return readFile(f)
7482
}
7583

@@ -83,34 +91,34 @@ func (c CommandDescriptor) getOutput(f *os.File) []byte {
8391
return nil
8492
}
8593

86-
if c.FirstNBytes > 0 {
87-
_, err = io.CopyN(buf, f, int64(c.FirstNBytes))
94+
if cfg.FirstNBytes > 0 {
95+
_, err = io.CopyN(buf, f, int64(cfg.FirstNBytes))
8896
if err != nil && err != io.EOF {
8997
slog.Error("Could not read first bytes from command output file.",
90-
"bytes", c.FirstNBytes,
98+
"bytes", cfg.FirstNBytes,
9199
"path", f.Name(),
92100
"error", err,
93101
)
94102
return nil
95103
}
96104
} else {
97105
// Indicate that there were bytes skipped
98-
buf.WriteString(skippedBytesIndicator(fs.Size() - int64(c.LastNBytes)))
106+
buf.WriteString(skippedBytesIndicator(fs.Size() - int64(cfg.LastNBytes)))
99107
buf.WriteRune('\n')
100108
}
101109

102-
if c.FirstNBytes > 0 && c.LastNBytes > 0 {
110+
if cfg.FirstNBytes > 0 && cfg.LastNBytes > 0 {
103111
// Indicate that there were bytes skipped
104112
buf.WriteRune('\n')
105-
buf.WriteString(skippedBytesIndicator(fs.Size() - int64(c.FirstNBytes+c.LastNBytes)))
113+
buf.WriteString(skippedBytesIndicator(fs.Size() - int64(cfg.FirstNBytes+cfg.LastNBytes)))
106114
buf.WriteRune('\n')
107115
}
108116

109-
if c.LastNBytes > 0 {
110-
_, err = f.Seek(-int64(c.LastNBytes), io.SeekEnd) // Seek to the last N bytes
117+
if cfg.LastNBytes > 0 {
118+
_, err = f.Seek(-int64(cfg.LastNBytes), io.SeekEnd) // Seek to the last N bytes
111119
if err != nil {
112120
slog.Error("Could not seek to the last bytes of command output file.",
113-
"bytes", c.LastNBytes,
121+
"bytes", cfg.LastNBytes,
114122
"path", f.Name(),
115123
"error", err,
116124
)
@@ -119,7 +127,7 @@ func (c CommandDescriptor) getOutput(f *os.File) []byte {
119127
_, err = io.Copy(buf, f)
120128
if err != nil && err != io.EOF {
121129
slog.Error("Could not read last bytes from command output file.",
122-
"bytes", c.LastNBytes,
130+
"bytes", cfg.LastNBytes,
123131
"path", f.Name(),
124132
"error", err,
125133
)
@@ -128,7 +136,7 @@ func (c CommandDescriptor) getOutput(f *os.File) []byte {
128136
} else {
129137
// Indicate that there were bytes skipped
130138
buf.WriteRune('\n')
131-
buf.WriteString(skippedBytesIndicator(fs.Size() - int64(c.FirstNBytes)))
139+
buf.WriteString(skippedBytesIndicator(fs.Size() - int64(cfg.FirstNBytes)))
132140
}
133141

134142
return buf.Bytes()

internal/mcp/server/builtin/tools/command/command_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func Test_getOutput(t *testing.T) {
7171
_, err = f.Seek(0, 0)
7272
require.NoError(t, err)
7373

74-
toTest := CommandDescriptor{LastNBytes: tc.lastNBytes, FirstNBytes: tc.firstNBytes}
74+
toTest := CommandDescriptor{Output: &OutputSettings{LastNBytes: tc.lastNBytes, FirstNBytes: tc.firstNBytes}}
7575
result := toTest.getOutput(f)
7676

7777
require.Equal(t, tc.expectedOutput, string(result))

internal/mcp/server/builtin/tools/command/exec.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,12 @@ var CommandExecutionToolHandler = func(ctx context.Context, request mcp.CallTool
8484
Arguments: pArgs.Arguments,
8585
AdditionalEnvironment: pArgs.Environment,
8686
WorkingDirectory: pArgs.WorkingDirectory,
87-
DisableStdOut: pArgs.DisableOut,
88-
DisableStdErr: pArgs.DisableErr,
89-
FirstNBytes: pArgs.FirstBytes,
90-
LastNBytes: pArgs.LastBytes,
87+
Output: &OutputSettings{
88+
DisableStdOut: pArgs.DisableOut,
89+
DisableStdErr: pArgs.DisableErr,
90+
FirstNBytes: pArgs.FirstBytes,
91+
LastNBytes: pArgs.LastBytes,
92+
},
9193
}
9294

9395
raw, err := cmdDesc.Run(ctx)

0 commit comments

Comments
 (0)