Skip to content

Commit 90dd205

Browse files
committed
Apply 2. review comments
1 parent 93d10fb commit 90dd205

File tree

3 files changed

+100
-28
lines changed

3 files changed

+100
-28
lines changed

cmd/stackrox-mcp/main_test.go

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package main
22

33
import (
44
"context"
5+
"fmt"
6+
"net/http"
57
"os"
68
"syscall"
79
"testing"
@@ -15,6 +17,23 @@ import (
1517
"github.com/stretchr/testify/require"
1618
)
1719

20+
// waitForServerReady polls the server until it's ready to accept connections
21+
func waitForServerReady(address string, timeout time.Duration) error {
22+
deadline := time.Now().Add(timeout)
23+
client := &http.Client{Timeout: 100 * time.Millisecond}
24+
25+
for time.Now().Before(deadline) {
26+
resp, err := client.Get(address)
27+
if err == nil {
28+
_ = resp.Body.Close()
29+
return nil
30+
}
31+
time.Sleep(100 * time.Millisecond)
32+
}
33+
34+
return fmt.Errorf("server did not become ready within %v", timeout)
35+
}
36+
1837
func TestSetupLogging(t *testing.T) {
1938
setupLogging()
2039
assert.Equal(t, zerolog.InfoLevel, zerolog.GlobalLevel())
@@ -68,8 +87,17 @@ func TestGracefulShutdown(t *testing.T) {
6887
errChan <- srv.Start(ctx)
6988
}()
7089

71-
// Give server time to start
72-
time.Sleep(100 * time.Millisecond)
90+
// Wait for server to be ready by polling
91+
serverURL := fmt.Sprintf("http://%s:%d", cfg.Server.Address, cfg.Server.Port)
92+
err = waitForServerReady(serverURL, 3*time.Second)
93+
require.NoError(t, err, "Server should start within timeout")
94+
95+
// Establish actual HTTP connection to verify server is responding
96+
resp, err := http.Get(serverURL)
97+
if err == nil {
98+
_ = resp.Body.Close()
99+
}
100+
assert.NoError(t, err, "Should be able to establish HTTP connection to server")
73101

74102
// Simulate shutdown signal by canceling context
75103
cancel()
@@ -120,8 +148,17 @@ func TestGracefulShutdown_WithSignal(t *testing.T) {
120148
errChan <- srv.Start(ctx)
121149
}()
122150

123-
// Give server time to start
124-
time.Sleep(100 * time.Millisecond)
151+
// Wait for server to be ready by polling
152+
serverURL := fmt.Sprintf("http://%s:%d", cfg.Server.Address, cfg.Server.Port)
153+
err = waitForServerReady(serverURL, 3*time.Second)
154+
require.NoError(t, err, "Server should start within timeout")
155+
156+
// Establish actual HTTP connection to verify server is responding
157+
resp, err := http.Get(serverURL)
158+
if err == nil {
159+
_ = resp.Body.Close()
160+
}
161+
assert.NoError(t, err, "Should be able to establish HTTP connection to server")
125162

126163
// Simulate signal by sending to channel
127164
sigChan <- syscall.SIGTERM

internal/server/server.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"net/http"
7+
"time"
78

89
"github.com/modelcontextprotocol/go-sdk/mcp"
910
"github.com/pkg/errors"
@@ -12,7 +13,9 @@ import (
1213
"github.com/stackrox/stackrox-mcp/internal/toolsets"
1314
)
1415

15-
// version is set at build time via ldflags
16+
const shutdownTimeout = 5 * time.Second
17+
18+
// version is set at build time via ldflags (ldflags can't modify constants)
1619
var version = "dev"
1720

1821
// Server represents the MCP HTTP server
@@ -109,7 +112,10 @@ func (s *Server) Start(ctx context.Context) error {
109112
select {
110113
case <-ctx.Done():
111114
log.Info().Msg("Shutting down HTTP server")
112-
return errors.Wrap(s.http.Shutdown(context.Background()), "server shutting down failed")
115+
// Create a context with timeout for graceful shutdown
116+
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout*time.Second)
117+
defer shutdownCancel()
118+
return errors.Wrap(s.http.Shutdown(shutdownCtx), "server shutting down failed")
113119
case err := <-errChan:
114120
return err
115121
}

internal/server/server_test.go

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package server
22

33
import (
44
"context"
5+
"fmt"
6+
"net/http"
57
"testing"
68
"time"
79

@@ -12,6 +14,23 @@ import (
1214
"github.com/stretchr/testify/require"
1315
)
1416

17+
// waitForServerReady polls the server until it's ready to accept connections
18+
func waitForServerReady(address string, timeout time.Duration) error {
19+
deadline := time.Now().Add(timeout)
20+
client := &http.Client{Timeout: 100 * time.Millisecond}
21+
22+
for time.Now().Before(deadline) {
23+
resp, err := client.Get(address)
24+
if err == nil {
25+
_ = resp.Body.Close()
26+
return nil
27+
}
28+
time.Sleep(100 * time.Millisecond)
29+
}
30+
31+
return fmt.Errorf("server did not become ready within %v", timeout)
32+
}
33+
1534
// mockTool implements the Tool interface for testing
1635
type mockTool struct {
1736
name string
@@ -97,20 +116,20 @@ func TestServer_registerTools_AllEnabled(t *testing.T) {
97116
},
98117
}
99118

100-
tool1 := &mockTool{name: "tool1", readOnly: true}
101-
tool2 := &mockTool{name: "tool2", readOnly: false}
119+
readOnlyTestTool := &mockTool{name: "test_read_only_tool", readOnly: true}
120+
readWriteTestTool := &mockTool{name: "test_read_write_tool", readOnly: false}
102121

103122
toolsetList := []toolsets.Toolset{
104-
newMockToolset("toolset1", true, []toolsets.Tool{tool1, tool2}),
123+
newMockToolset("test_toolset", true, []toolsets.Tool{readOnlyTestTool, readWriteTestTool}),
105124
}
106125

107126
registry := toolsets.NewRegistry(cfg, toolsetList)
108127
srv := NewServer(cfg, registry)
109128

110129
srv.registerTools()
111130

112-
assert.True(t, tool1.registerCall, "tool1 should be registered")
113-
assert.True(t, tool2.registerCall, "tool2 should be registered")
131+
assert.True(t, readOnlyTestTool.registerCall, "read-only test tool should be registered")
132+
assert.True(t, readWriteTestTool.registerCall, "read-write test tool should be registered")
114133
}
115134

116135
func TestServer_registerTools_ReadOnlyMode(t *testing.T) {
@@ -124,20 +143,20 @@ func TestServer_registerTools_ReadOnlyMode(t *testing.T) {
124143
},
125144
}
126145

127-
tool1 := &mockTool{name: "tool1", readOnly: true}
128-
tool2 := &mockTool{name: "tool2", readOnly: false}
146+
readOnlyTestTool := &mockTool{name: "test_read_only_tool", readOnly: true}
147+
readWriteTestTool := &mockTool{name: "test_read_write_tool", readOnly: false}
129148

130149
toolsetList := []toolsets.Toolset{
131-
newMockToolset("toolset1", true, []toolsets.Tool{tool1, tool2}),
150+
newMockToolset("test_toolset", true, []toolsets.Tool{readOnlyTestTool, readWriteTestTool}),
132151
}
133152

134153
registry := toolsets.NewRegistry(cfg, toolsetList)
135154
srv := NewServer(cfg, registry)
136155

137156
srv.registerTools()
138157

139-
assert.True(t, tool1.registerCall, "tool1 (read-only) should be registered")
140-
assert.False(t, tool2.registerCall, "tool2 (read-write) should not be registered in read-only mode")
158+
assert.True(t, readOnlyTestTool.registerCall, "read-only test tool should be registered")
159+
assert.False(t, readWriteTestTool.registerCall, "read-write test tool should not be registered in read-only mode")
141160
}
142161

143162
func TestServer_registerTools_DisabledToolset(t *testing.T) {
@@ -151,21 +170,21 @@ func TestServer_registerTools_DisabledToolset(t *testing.T) {
151170
},
152171
}
153172

154-
tool1 := &mockTool{name: "tool1", readOnly: true}
155-
tool2 := &mockTool{name: "tool2", readOnly: true}
173+
enabledTestTool := &mockTool{name: "test_enabled_tool", readOnly: true}
174+
disabledTestTool := &mockTool{name: "test_disabled_tool", readOnly: true}
156175

157176
toolsetList := []toolsets.Toolset{
158-
newMockToolset("enabled_toolset", true, []toolsets.Tool{tool1}),
159-
newMockToolset("disabled_toolset", false, []toolsets.Tool{tool2}),
177+
newMockToolset("enabled_toolset", true, []toolsets.Tool{enabledTestTool}),
178+
newMockToolset("disabled_toolset", false, []toolsets.Tool{disabledTestTool}),
160179
}
161180

162181
registry := toolsets.NewRegistry(cfg, toolsetList)
163182
srv := NewServer(cfg, registry)
164183

165184
srv.registerTools()
166185

167-
assert.True(t, tool1.registerCall, "tool1 from enabled toolset should be registered")
168-
assert.False(t, tool2.registerCall, "tool2 from disabled toolset should not be registered")
186+
assert.True(t, enabledTestTool.registerCall, "tool from enabled toolset should be registered")
187+
assert.False(t, disabledTestTool.registerCall, "tool from disabled toolset should not be registered")
169188
}
170189

171190
func TestServer_Start(t *testing.T) {
@@ -176,10 +195,10 @@ func TestServer_Start(t *testing.T) {
176195
},
177196
}
178197

179-
tool1 := &mockTool{name: "tool1", readOnly: true}
198+
testTool := &mockTool{name: "test_tool", readOnly: true}
180199

181200
toolsetList := []toolsets.Toolset{
182-
newMockToolset("toolset1", true, []toolsets.Tool{tool1}),
201+
newMockToolset("test_toolset", true, []toolsets.Tool{testTool}),
183202
}
184203

185204
registry := toolsets.NewRegistry(cfg, toolsetList)
@@ -194,11 +213,21 @@ func TestServer_Start(t *testing.T) {
194213
errChan <- srv.Start(ctx)
195214
}()
196215

197-
// Give server time to start
198-
time.Sleep(100 * time.Millisecond)
216+
// Wait for server to be ready by polling
217+
serverURL := fmt.Sprintf("http://%s:%d", cfg.Server.Address, cfg.Server.Port)
218+
err := waitForServerReady(serverURL, 3*time.Second)
219+
require.NoError(t, err, "Server should start within timeout")
199220

200221
// Verify tools were registered
201-
assert.True(t, tool1.registerCall, "tool1 should be registered when server starts")
222+
assert.True(t, testTool.registerCall, "test tool should be registered when server starts")
223+
224+
// Establish actual HTTP connection to verify server is responding
225+
resp, err := http.Get(serverURL)
226+
if err == nil {
227+
_ = resp.Body.Close()
228+
}
229+
// We don't require a successful response, just that we can connect
230+
assert.NoError(t, err, "Should be able to establish HTTP connection to server")
202231

203232
// Trigger graceful shutdown
204233
cancel()
@@ -210,7 +239,7 @@ func TestServer_Start(t *testing.T) {
210239
if err != nil && err != context.Canceled {
211240
t.Errorf("Server returned unexpected error: %v", err)
212241
}
213-
case <-time.After(5 * time.Second):
242+
case <-time.After(shutdownTimeout):
214243
t.Fatal("Server did not shut down within timeout period")
215244
}
216245
}

0 commit comments

Comments
 (0)