Skip to content

Commit 6d51fc3

Browse files
authored
Add group restart (#1198)
1 parent 3f5b033 commit 6d51fc3

File tree

2 files changed

+276
-13
lines changed

2 files changed

+276
-13
lines changed

cmd/thv/app/restart.go

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@ import (
77
"github.com/spf13/cobra"
88
"golang.org/x/sync/errgroup"
99

10+
"github.com/stacklok/toolhive/pkg/groups"
1011
"github.com/stacklok/toolhive/pkg/workloads"
1112
)
1213

1314
var (
14-
restartAll bool
15+
restartAll bool
16+
restartGroup string
1517
)
1618

1719
var restartCmd = &cobra.Command{
@@ -25,17 +27,24 @@ var restartCmd = &cobra.Command{
2527

2628
func init() {
2729
restartCmd.Flags().BoolVarP(&restartAll, "all", "a", false, "Restart all MCP servers")
30+
// TODO: Uncomment when groups are fully supported
31+
//restartCmd.Flags().StringVarP(&restartGroup, "group", "g", "", "Restart all MCP servers in a specific group")
32+
//
33+
//// Mark the flags as mutually exclusive
34+
//restartCmd.MarkFlagsMutuallyExclusive("all", "group")
2835
}
2936

3037
func restartCmdFunc(cmd *cobra.Command, args []string) error {
3138
ctx := cmd.Context()
3239

33-
// Validate arguments
34-
if restartAll && len(args) > 0 {
35-
return fmt.Errorf("cannot specify both --all flag and workload name")
40+
// Validate arguments - check mutual exclusivity with positional arguments
41+
// Cobra already handles mutual exclusivity between --all and --group
42+
if (restartAll || restartGroup != "") && len(args) > 0 {
43+
return fmt.Errorf("cannot specify both flags and workload name")
3644
}
37-
if !restartAll && len(args) == 0 {
38-
return fmt.Errorf("must specify either --all flag or workload name")
45+
46+
if !restartAll && restartGroup == "" && len(args) == 0 {
47+
return fmt.Errorf("must specify either --all flag, --group flag, or workload name")
3948
}
4049

4150
// Create workload managers.
@@ -48,6 +57,10 @@ func restartCmdFunc(cmd *cobra.Command, args []string) error {
4857
return restartAllContainers(ctx, workloadManager)
4958
}
5059

60+
if restartGroup != "" {
61+
return restartWorkloadsByGroup(ctx, workloadManager, restartGroup)
62+
}
63+
5164
// Restart single workload
5265
workloadName := args[0]
5366
restartGroup, err := workloadManager.RestartWorkloads(ctx, []string{workloadName})
@@ -76,16 +89,56 @@ func restartAllContainers(ctx context.Context, workloadManager workloads.Manager
7689
return nil
7790
}
7891

79-
var restartedCount int
80-
var failedCount int
92+
// Extract workload names
93+
workloadNames := make([]string, len(allWorkloads))
94+
for i, workload := range allWorkloads {
95+
workloadNames[i] = workload.Name
96+
}
97+
98+
return restartMultipleWorkloads(ctx, workloadManager, workloadNames)
99+
}
100+
101+
func restartWorkloadsByGroup(ctx context.Context, workloadManager workloads.Manager, groupName string) error {
102+
// Create a groups manager to list workloads in the group
103+
groupManager, err := groups.NewManager()
104+
if err != nil {
105+
return fmt.Errorf("failed to create group manager: %v", err)
106+
}
107+
108+
// Check if the group exists
109+
exists, err := groupManager.Exists(ctx, groupName)
110+
if err != nil {
111+
return fmt.Errorf("failed to check if group '%s' exists: %v", groupName, err)
112+
}
113+
if !exists {
114+
return fmt.Errorf("group '%s' does not exist", groupName)
115+
}
116+
117+
// Get all workload names in the group
118+
workloadNames, err := groupManager.ListWorkloadsInGroup(ctx, groupName)
119+
if err != nil {
120+
return fmt.Errorf("failed to list workloads in group '%s': %v", groupName, err)
121+
}
122+
123+
if len(workloadNames) == 0 {
124+
fmt.Printf("No MCP servers found in group '%s' to restart\n", groupName)
125+
return nil
126+
}
127+
128+
return restartMultipleWorkloads(ctx, workloadManager, workloadNames)
129+
}
130+
131+
// restartMultipleWorkloads handles restarting multiple workloads and reporting results
132+
func restartMultipleWorkloads(ctx context.Context, workloadManager workloads.Manager, workloadNames []string) error {
133+
restartedCount := 0
134+
failedCount := 0
81135
var errors []string
82136

83-
fmt.Printf("Restarting %d MCP server(s)...\n", len(allWorkloads))
137+
fmt.Printf("Restarting %d MCP server(s)...\n", len(workloadNames))
84138

85139
var restartRequests []*errgroup.Group
86140
// First, trigger the restarts concurrently.
87-
for _, workload := range allWorkloads {
88-
workloadName := workload.Name
141+
for _, workloadName := range workloadNames {
89142
fmt.Printf("Restarting %s...", workloadName)
90143
restart, err := workloadManager.RestartWorkloads(ctx, []string{workloadName})
91144
if err != nil {
@@ -101,14 +154,13 @@ func restartAllContainers(ctx context.Context, workloadManager workloads.Manager
101154

102155
// Wait for all restarts to complete.
103156
for _, restart := range restartRequests {
104-
err = restart.Wait()
157+
err := restart.Wait()
105158
if err != nil {
106159
fmt.Printf(" failed: %v\n", err)
107160
failedCount++
108161
// Unfortunately we don't have the workload name here, so we just log a generic error.
109162
errors = append(errors, fmt.Sprintf("Error restarting workload: %v", err))
110163
} else {
111-
fmt.Printf(" success\n")
112164
restartedCount++
113165
}
114166
}

test/e2e/restart_test.go

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
package e2e_test
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
"time"
7+
8+
. "github.com/onsi/ginkgo/v2"
9+
. "github.com/onsi/gomega"
10+
11+
"github.com/stacklok/toolhive/test/e2e"
12+
)
13+
14+
var _ = Describe("Server Restart", func() {
15+
var (
16+
config *e2e.TestConfig
17+
serverName string
18+
)
19+
20+
BeforeEach(func() {
21+
config = e2e.NewTestConfig()
22+
serverName = generateTestServerName("restart-test")
23+
24+
// Check if thv binary is available
25+
err := e2e.CheckTHVBinaryAvailable(config)
26+
Expect(err).ToNot(HaveOccurred(), "thv binary should be available")
27+
})
28+
29+
AfterEach(func() {
30+
if config.CleanupAfter {
31+
// Clean up the server if it exists
32+
err := e2e.StopAndRemoveMCPServer(config, serverName)
33+
Expect(err).ToNot(HaveOccurred(), "Should be able to stop and remove server")
34+
}
35+
})
36+
37+
Describe("Restarting MCP servers", func() {
38+
Context("when restarting a running server", func() {
39+
It("should successfully restart and remain accessible", func() {
40+
By("Starting an OSV MCP server")
41+
stdout, stderr := e2e.NewTHVCommand(config, "run", "--name", serverName, "osv").ExpectSuccess()
42+
43+
// The command should indicate success
44+
Expect(stdout+stderr).To(ContainSubstring("osv"), "Output should mention the osv server")
45+
46+
By("Waiting for the server to be running")
47+
err := e2e.WaitForMCPServer(config, serverName, 60*time.Second)
48+
Expect(err).ToNot(HaveOccurred(), "Server should be running within 60 seconds")
49+
50+
// Get the server URL before restart
51+
originalURL, err := e2e.GetMCPServerURL(config, serverName)
52+
Expect(err).ToNot(HaveOccurred(), "Should be able to get server URL")
53+
54+
By("Restarting the server")
55+
stdout, stderr = e2e.NewTHVCommand(config, "restart", serverName).ExpectSuccess()
56+
Expect(stdout+stderr).To(ContainSubstring("restart"), "Output should mention restart operation")
57+
58+
By("Waiting for the server to be running again")
59+
err = e2e.WaitForMCPServer(config, serverName, 60*time.Second)
60+
Expect(err).ToNot(HaveOccurred(), "Server should be running again within 60 seconds")
61+
62+
// Get the server URL after restart
63+
newURL, err := e2e.GetMCPServerURL(config, serverName)
64+
Expect(err).ToNot(HaveOccurred(), "Should be able to get server URL after restart")
65+
66+
// The URLs should be the same after restart
67+
Expect(newURL).To(Equal(originalURL), "Server URL should remain the same after restart")
68+
69+
By("Verifying the server is functional after restart")
70+
// List server to verify it's operational
71+
stdout, _ = e2e.NewTHVCommand(config, "list").ExpectSuccess()
72+
Expect(stdout).To(ContainSubstring(serverName), "Server should be listed")
73+
Expect(stdout).To(ContainSubstring("running"), "Server should be in running state")
74+
})
75+
})
76+
77+
Context("when restarting a stopped server", func() {
78+
It("should start the server if it was stopped", func() {
79+
By("Starting an OSV MCP server")
80+
stdout, stderr := e2e.NewTHVCommand(config, "run", "--name", serverName, "osv").ExpectSuccess()
81+
Expect(stdout+stderr).To(ContainSubstring("osv"), "Output should mention the osv server")
82+
83+
By("Waiting for the server to be running")
84+
err := e2e.WaitForMCPServer(config, serverName, 60*time.Second)
85+
Expect(err).ToNot(HaveOccurred(), "Server should be running within 60 seconds")
86+
87+
By("Stopping the server")
88+
stdout, _ = e2e.NewTHVCommand(config, "stop", serverName).ExpectSuccess()
89+
Expect(stdout).To(ContainSubstring("stop"), "Output should mention stop operation")
90+
91+
By("Verifying the server is stopped")
92+
Eventually(func() bool {
93+
stdout, _ := e2e.NewTHVCommand(config, "list", "--all").ExpectSuccess()
94+
lines := strings.Split(stdout, "\n")
95+
for _, line := range lines {
96+
if strings.Contains(line, serverName) {
97+
// Check if this specific server line contains "running"
98+
return !strings.Contains(line, "running")
99+
}
100+
}
101+
return false // Server not found in list
102+
}, 10*time.Second, 1*time.Second).Should(BeTrue(), "Server should be stopped")
103+
104+
By("Restarting the stopped server")
105+
stdout, stderr = e2e.NewTHVCommand(config, "restart", serverName).ExpectSuccess()
106+
Expect(stdout+stderr).To(ContainSubstring("restart"), "Output should mention restart operation")
107+
108+
By("Waiting for the server to be running again")
109+
err = e2e.WaitForMCPServer(config, serverName, 60*time.Second)
110+
Expect(err).ToNot(HaveOccurred(), "Server should be running again within 60 seconds")
111+
112+
By("Verifying the server is functional after restart")
113+
stdout, _ = e2e.NewTHVCommand(config, "list").ExpectSuccess()
114+
Expect(stdout).To(ContainSubstring(serverName), "Server should be listed")
115+
Expect(stdout).To(ContainSubstring("running"), "Server should be in running state")
116+
})
117+
})
118+
119+
// TODO: Uncomment when groups are fully supported
120+
//Context("when restarting servers with --groups flag", func() {
121+
// It("should restart servers belonging to the specified group", func() {
122+
// // Define group name
123+
// groupName := fmt.Sprintf("restart-group-%d", GinkgoRandomSeed())
124+
//
125+
// // Create two servers
126+
// serverName1 := generateTestServerName("restart-group-test1")
127+
// serverName2 := generateTestServerName("restart-group-test2")
128+
//
129+
// By("Creating a group first")
130+
// stdout, stderr := e2e.NewTHVCommand(config, "group", "create", groupName).ExpectSuccess()
131+
// Expect(stdout+stderr).To(ContainSubstring("group"), "Output should mention group creation")
132+
//
133+
// By("Starting the first server")
134+
// stdout, stderr = e2e.NewTHVCommand(config, "run", "--name", serverName1, "--group", groupName, "osv").ExpectSuccess()
135+
// Expect(stdout+stderr).To(ContainSubstring("osv"), "Output should mention the osv server")
136+
//
137+
// By("Starting the second server")
138+
// stdout, stderr = e2e.NewTHVCommand(config, "run", "--name", serverName2, "--group", groupName, "osv").ExpectSuccess()
139+
// Expect(stdout+stderr).To(ContainSubstring("osv"), "Output should mention the osv server")
140+
//
141+
// By("Waiting for both servers to be running")
142+
// err := e2e.WaitForMCPServer(config, serverName1, 60*time.Second)
143+
// Expect(err).ToNot(HaveOccurred(), "First server should be running within 60 seconds")
144+
//
145+
// err = e2e.WaitForMCPServer(config, serverName2, 60*time.Second)
146+
// Expect(err).ToNot(HaveOccurred(), "Second server should be running within 60 seconds")
147+
//
148+
// By("Stopping both servers")
149+
// stdout, _ = e2e.NewTHVCommand(config, "stop", serverName1).ExpectSuccess()
150+
// Expect(stdout).To(ContainSubstring("stop"), "Output should mention stop operation for first server")
151+
//
152+
// stdout, _ = e2e.NewTHVCommand(config, "stop", serverName2).ExpectSuccess()
153+
// Expect(stdout).To(ContainSubstring("stop"), "Output should mention stop operation for second server")
154+
//
155+
// By("Verifying the servers are stopped")
156+
// Eventually(func() bool {
157+
// stdout, _ := e2e.NewTHVCommand(config, "list", "--all").ExpectSuccess()
158+
// lines := strings.Split(stdout, "\n")
159+
// server1Found := false
160+
// server2Found := false
161+
// server1Running := false
162+
// server2Running := false
163+
//
164+
// for _, line := range lines {
165+
// if strings.Contains(line, serverName1) {
166+
// server1Found = true
167+
// server1Running = strings.Contains(line, "running")
168+
// }
169+
// if strings.Contains(line, serverName2) {
170+
// server2Found = true
171+
// server2Running = strings.Contains(line, "running")
172+
// }
173+
// }
174+
//
175+
// // Both servers should be found and neither should be running
176+
// return server1Found && server2Found && !server1Running && !server2Running
177+
// }, 10*time.Second, 1*time.Second).Should(BeTrue(), "Both servers should be stopped")
178+
//
179+
// By("Restarting all servers in the group")
180+
// stdout, stderr = e2e.NewTHVCommand(config, "restart", "--group", groupName).ExpectSuccess()
181+
// Expect(stdout+stderr).To(ContainSubstring("restart"), "Output should mention restart operation")
182+
//
183+
// By("Waiting for both servers to be running again")
184+
// err = e2e.WaitForMCPServer(config, serverName1, 60*time.Second)
185+
// Expect(err).ToNot(HaveOccurred(), "First server should be running again within 60 seconds")
186+
//
187+
// err = e2e.WaitForMCPServer(config, serverName2, 60*time.Second)
188+
// Expect(err).ToNot(HaveOccurred(), "Second server should be running again within 60 seconds")
189+
//
190+
// By("Verifying both servers are functional after restart")
191+
// stdout, _ = e2e.NewTHVCommand(config, "list").ExpectSuccess()
192+
// Expect(stdout).To(ContainSubstring(serverName1), "First server should be listed")
193+
// Expect(stdout).To(ContainSubstring(serverName2), "Second server should be listed")
194+
// Expect(stdout).To(ContainSubstring("running"), "Servers should be in running state")
195+
//
196+
// // Clean up these specific servers at the end of the test
197+
// defer func() {
198+
// if config.CleanupAfter {
199+
// _ = e2e.StopAndRemoveMCPServer(config, serverName1)
200+
// _ = e2e.StopAndRemoveMCPServer(config, serverName2)
201+
// }
202+
// }()
203+
// })
204+
//})
205+
})
206+
})
207+
208+
// generateTestServerName creates a unique server name for restart tests
209+
func generateTestServerName(prefix string) string {
210+
return fmt.Sprintf("%s-%d", prefix, GinkgoRandomSeed())
211+
}

0 commit comments

Comments
 (0)