Skip to content

Commit 93d10fb

Browse files
committed
Improve testing and 1. review
1 parent 51c75c4 commit 93d10fb

File tree

6 files changed

+249
-17
lines changed

6 files changed

+249
-17
lines changed

cmd/stackrox-mcp/main.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ func setupLogging() {
2121
zerolog.SetGlobalLevel(zerolog.InfoLevel)
2222
}
2323

24+
// getToolsets initializes and returns all available toolsets
25+
func getToolsets(cfg *config.Config) []toolsets.Toolset {
26+
return []toolsets.Toolset{
27+
configtools.NewToolset(cfg),
28+
vulnerability.NewToolset(cfg),
29+
}
30+
}
31+
2432
func main() {
2533
setupLogging()
2634

@@ -33,15 +41,8 @@ func main() {
3341
}
3442
log.Info().Interface("config", cfg).Msg("Configuration loaded successfully")
3543

36-
// Initialize toolsets
37-
configToolset := configtools.NewToolset(cfg)
38-
vulnToolset := vulnerability.NewToolset(cfg)
39-
4044
// Create registry with all toolsets
41-
registry := toolsets.NewRegistry(cfg, []toolsets.Toolset{
42-
configToolset,
43-
vulnToolset,
44-
})
45+
registry := toolsets.NewRegistry(cfg, getToolsets(cfg))
4546

4647
// Create MCP server
4748
srv := server.NewServer(cfg, registry)

cmd/stackrox-mcp/main_test.go

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,139 @@
11
package main
22

33
import (
4+
"context"
5+
"os"
6+
"syscall"
47
"testing"
8+
"time"
59

610
"github.com/rs/zerolog"
11+
"github.com/stackrox/stackrox-mcp/internal/config"
12+
"github.com/stackrox/stackrox-mcp/internal/server"
13+
"github.com/stackrox/stackrox-mcp/internal/toolsets"
714
"github.com/stretchr/testify/assert"
15+
"github.com/stretchr/testify/require"
816
)
917

1018
func TestSetupLogging(t *testing.T) {
1119
setupLogging()
1220
assert.Equal(t, zerolog.InfoLevel, zerolog.GlobalLevel())
1321
}
22+
23+
func TestGetToolsets(t *testing.T) {
24+
cfg := &config.Config{
25+
Central: config.CentralConfig{
26+
URL: "central.example.com:8443",
27+
},
28+
Tools: config.ToolsConfig{
29+
Vulnerability: config.VulnerabilityConfig{
30+
Enabled: true,
31+
},
32+
ConfigManager: config.ConfigManagerConfig{
33+
Enabled: true,
34+
},
35+
},
36+
}
37+
38+
toolsets := getToolsets(cfg)
39+
40+
require.NotNil(t, toolsets)
41+
assert.Len(t, toolsets, 2, "Should have 2 toolsets")
42+
assert.Equal(t, "config_manager", toolsets[0].GetName())
43+
assert.Equal(t, "vulnerability", toolsets[1].GetName())
44+
}
45+
46+
func TestGracefulShutdown(t *testing.T) {
47+
// Set up minimal valid config
48+
assert.NoError(t, os.Setenv("STACKROX_MCP__TOOLS__VULNERABILITY__ENABLED", "true"))
49+
defer func() { assert.NoError(t, os.Unsetenv("STACKROX_MCP__TOOLS__VULNERABILITY__ENABLED")) }()
50+
51+
cfg, err := config.LoadConfig("")
52+
require.NoError(t, err)
53+
require.NotNil(t, cfg)
54+
55+
// Use a different port to avoid conflicts
56+
cfg.Server.Port = 9999
57+
58+
// Create registry and server
59+
registry := toolsets.NewRegistry(cfg, getToolsets(cfg))
60+
srv := server.NewServer(cfg, registry)
61+
62+
// Set up context with cancellation
63+
ctx, cancel := context.WithCancel(context.Background())
64+
65+
// Start server in goroutine
66+
errChan := make(chan error, 1)
67+
go func() {
68+
errChan <- srv.Start(ctx)
69+
}()
70+
71+
// Give server time to start
72+
time.Sleep(100 * time.Millisecond)
73+
74+
// Simulate shutdown signal by canceling context
75+
cancel()
76+
77+
// Wait for server to shut down with timeout
78+
select {
79+
case err := <-errChan:
80+
// Server should shut down cleanly (either nil or context.Canceled)
81+
if err != nil && err != context.Canceled {
82+
t.Errorf("Server returned unexpected error: %v", err)
83+
}
84+
case <-time.After(5 * time.Second):
85+
t.Fatal("Server did not shut down within timeout period")
86+
}
87+
}
88+
89+
func TestGracefulShutdown_WithSignal(t *testing.T) {
90+
// Set up minimal valid config
91+
assert.NoError(t, os.Setenv("STACKROX_MCP__TOOLS__VULNERABILITY__ENABLED", "true"))
92+
defer func() { assert.NoError(t, os.Unsetenv("STACKROX_MCP__TOOLS__VULNERABILITY__ENABLED")) }()
93+
94+
cfg, err := config.LoadConfig("")
95+
require.NoError(t, err)
96+
require.NotNil(t, cfg)
97+
98+
// Use a different port to avoid conflicts
99+
cfg.Server.Port = 10000
100+
101+
// Create registry and server
102+
registry := toolsets.NewRegistry(cfg, getToolsets(cfg))
103+
srv := server.NewServer(cfg, registry)
104+
105+
// Set up context with cancellation and signal handling (like in main)
106+
ctx, cancel := context.WithCancel(context.Background())
107+
defer cancel()
108+
109+
sigChan := make(chan os.Signal, 1)
110+
// Note: We use a buffered channel and manually send to avoid OS signal complications in tests
111+
112+
go func() {
113+
<-sigChan
114+
cancel()
115+
}()
116+
117+
// Start server in goroutine
118+
errChan := make(chan error, 1)
119+
go func() {
120+
errChan <- srv.Start(ctx)
121+
}()
122+
123+
// Give server time to start
124+
time.Sleep(100 * time.Millisecond)
125+
126+
// Simulate signal by sending to channel
127+
sigChan <- syscall.SIGTERM
128+
129+
// Wait for server to shut down with timeout
130+
select {
131+
case err := <-errChan:
132+
// Server should shut down cleanly
133+
if err != nil && err != context.Canceled {
134+
t.Errorf("Server returned unexpected error: %v", err)
135+
}
136+
case <-time.After(5 * time.Second):
137+
t.Fatal("Server did not shut down within timeout period after signal")
138+
}
139+
}

examples/config-read-only.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,17 @@ global:
3636
# When false, both read and write tools may be available (if implemented)
3737
read_only_tools: true
3838

39+
# HTTP server configuration
40+
server:
41+
# Server listen address (optional, default: localhost)
42+
# The address on which the MCP HTTP server will listen
43+
address: localhost
44+
45+
# Server listen port (optional, default: 8080)
46+
# The port on which the MCP HTTP server will listen
47+
# Must be between 1 and 65535
48+
port: 8080
49+
3950
# Configuration of MCP tools
4051
# Each tool has an enable/disable flag. At least one tool has to be enabled.
4152
tools:

internal/config/config_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,51 @@ central:
153153
assert.Error(t, err)
154154
}
155155

156+
func TestLoadConfig_UnmarshalFailure(t *testing.T) {
157+
tmpDir := t.TempDir()
158+
configPath := filepath.Join(tmpDir, "config.yaml")
159+
160+
// YAML with type mismatch - port should be int
161+
invalidTypeYAML := `
162+
server:
163+
port: "not-a-number"
164+
`
165+
166+
err := os.WriteFile(configPath, []byte(invalidTypeYAML), 0644)
167+
require.NoError(t, err)
168+
defer func() { assert.NoError(t, os.Remove(configPath)) }()
169+
170+
_, err = LoadConfig(configPath)
171+
require.Error(t, err)
172+
assert.Contains(t, err.Error(), "failed to unmarshal config")
173+
}
174+
175+
func TestLoadConfig_ValidationFailure(t *testing.T) {
176+
tmpDir := t.TempDir()
177+
configPath := filepath.Join(tmpDir, "config.yaml")
178+
179+
// Valid YAML but fails on central URL validation (no URL)
180+
validYAMLInvalidConfig := `
181+
central:
182+
url: ""
183+
server:
184+
address: localhost
185+
port: 8080
186+
tools:
187+
vulnerability:
188+
enabled: true
189+
`
190+
191+
err := os.WriteFile(configPath, []byte(validYAMLInvalidConfig), 0644)
192+
require.NoError(t, err)
193+
defer func() { assert.NoError(t, os.Remove(configPath)) }()
194+
195+
_, err = LoadConfig(configPath)
196+
require.Error(t, err)
197+
assert.Contains(t, err.Error(), "invalid configuration")
198+
assert.Contains(t, err.Error(), "central.url is required")
199+
}
200+
156201
func TestValidate_MissingURL(t *testing.T) {
157202
cfg := &Config{
158203
Central: CentralConfig{

internal/server/server.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ func NewServer(cfg *config.Config, registry *toolsets.Registry) *Server {
4040
}
4141
}
4242

43-
// RegisterTools registers all tools from the registry with the MCP server
44-
func (s *Server) RegisterTools() {
43+
// registerTools registers all tools from the registry with the MCP server
44+
func (s *Server) registerTools() {
4545
log.Info().Msg("Registering MCP tools")
4646

4747
for _, toolset := range s.registry.GetToolsets() {
@@ -75,7 +75,7 @@ func (s *Server) RegisterTools() {
7575

7676
// Start starts the HTTP server with Streamable HTTP transport
7777
func (s *Server) Start(ctx context.Context) error {
78-
s.RegisterTools()
78+
s.registerTools()
7979

8080
// Create Streamable HTTP handler
8181
handler := mcp.NewStreamableHTTPHandler(

internal/server/server_test.go

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package server
22

33
import (
4+
"context"
45
"testing"
6+
"time"
57

68
"github.com/modelcontextprotocol/go-sdk/mcp"
79
"github.com/stackrox/stackrox-mcp/internal/config"
@@ -84,7 +86,7 @@ func TestNewServer(t *testing.T) {
8486
assert.NotNil(t, srv.mcp)
8587
}
8688

87-
func TestServer_RegisterTools_AllEnabled(t *testing.T) {
89+
func TestServer_registerTools_AllEnabled(t *testing.T) {
8890
cfg := &config.Config{
8991
Global: config.GlobalConfig{
9092
ReadOnlyTools: false,
@@ -105,13 +107,13 @@ func TestServer_RegisterTools_AllEnabled(t *testing.T) {
105107
registry := toolsets.NewRegistry(cfg, toolsetList)
106108
srv := NewServer(cfg, registry)
107109

108-
srv.RegisterTools()
110+
srv.registerTools()
109111

110112
assert.True(t, tool1.registerCall, "tool1 should be registered")
111113
assert.True(t, tool2.registerCall, "tool2 should be registered")
112114
}
113115

114-
func TestServer_RegisterTools_ReadOnlyMode(t *testing.T) {
116+
func TestServer_registerTools_ReadOnlyMode(t *testing.T) {
115117
cfg := &config.Config{
116118
Global: config.GlobalConfig{
117119
ReadOnlyTools: true,
@@ -132,13 +134,13 @@ func TestServer_RegisterTools_ReadOnlyMode(t *testing.T) {
132134
registry := toolsets.NewRegistry(cfg, toolsetList)
133135
srv := NewServer(cfg, registry)
134136

135-
srv.RegisterTools()
137+
srv.registerTools()
136138

137139
assert.True(t, tool1.registerCall, "tool1 (read-only) should be registered")
138140
assert.False(t, tool2.registerCall, "tool2 (read-write) should not be registered in read-only mode")
139141
}
140142

141-
func TestServer_RegisterTools_DisabledToolset(t *testing.T) {
143+
func TestServer_registerTools_DisabledToolset(t *testing.T) {
142144
cfg := &config.Config{
143145
Global: config.GlobalConfig{
144146
ReadOnlyTools: false,
@@ -160,8 +162,55 @@ func TestServer_RegisterTools_DisabledToolset(t *testing.T) {
160162
registry := toolsets.NewRegistry(cfg, toolsetList)
161163
srv := NewServer(cfg, registry)
162164

163-
srv.RegisterTools()
165+
srv.registerTools()
164166

165167
assert.True(t, tool1.registerCall, "tool1 from enabled toolset should be registered")
166168
assert.False(t, tool2.registerCall, "tool2 from disabled toolset should not be registered")
167169
}
170+
171+
func TestServer_Start(t *testing.T) {
172+
cfg := &config.Config{
173+
Server: config.ServerConfig{
174+
Address: "localhost",
175+
Port: 9091, // Use different port to avoid conflicts
176+
},
177+
}
178+
179+
tool1 := &mockTool{name: "tool1", readOnly: true}
180+
181+
toolsetList := []toolsets.Toolset{
182+
newMockToolset("toolset1", true, []toolsets.Tool{tool1}),
183+
}
184+
185+
registry := toolsets.NewRegistry(cfg, toolsetList)
186+
srv := NewServer(cfg, registry)
187+
188+
// Set up context with cancellation
189+
ctx, cancel := context.WithCancel(context.Background())
190+
191+
// Start server in goroutine
192+
errChan := make(chan error, 1)
193+
go func() {
194+
errChan <- srv.Start(ctx)
195+
}()
196+
197+
// Give server time to start
198+
time.Sleep(100 * time.Millisecond)
199+
200+
// Verify tools were registered
201+
assert.True(t, tool1.registerCall, "tool1 should be registered when server starts")
202+
203+
// Trigger graceful shutdown
204+
cancel()
205+
206+
// Wait for server to shut down
207+
select {
208+
case err := <-errChan:
209+
// Server should shut down cleanly
210+
if err != nil && err != context.Canceled {
211+
t.Errorf("Server returned unexpected error: %v", err)
212+
}
213+
case <-time.After(5 * time.Second):
214+
t.Fatal("Server did not shut down within timeout period")
215+
}
216+
}

0 commit comments

Comments
 (0)