Skip to content

Commit 7b55692

Browse files
authored
Make restart an async operation (#781)
Mirrors changes to stop and delete. Also contains some refactoring which removes a redundant internal type from the workload manager.
1 parent ec6f6ed commit 7b55692

File tree

6 files changed

+77
-51
lines changed

6 files changed

+77
-51
lines changed

cmd/thv/app/restart.go

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66

77
"github.com/spf13/cobra"
8+
"golang.org/x/sync/errgroup"
89

910
"github.com/stacklok/toolhive/pkg/workloads"
1011
)
@@ -48,11 +49,16 @@ func restartCmdFunc(cmd *cobra.Command, args []string) error {
4849

4950
// Restart single container
5051
containerName := args[0]
51-
err = manager.RestartWorkload(ctx, containerName)
52+
restartGroup, err := manager.RestartWorkload(ctx, containerName)
5253
if err != nil {
5354
return err
5455
}
5556

57+
// Wait for the restart group to complete
58+
if err := restartGroup.Wait(); err != nil {
59+
return fmt.Errorf("failed to restart container %s: %v", containerName, err)
60+
}
61+
5662
fmt.Printf("Container %s restarted successfully\n", containerName)
5763
return nil
5864
}
@@ -75,14 +81,31 @@ func restartAllContainers(ctx context.Context, manager workloads.Manager) error
7581

7682
fmt.Printf("Restarting %d MCP server(s)...\n", len(containers))
7783

84+
var restartRequests []*errgroup.Group
85+
// First, trigger the restarts concurrently.
7886
for _, container := range containers {
7987
containerName := container.Name
8088
fmt.Printf("Restarting %s...", containerName)
81-
err := manager.RestartWorkload(ctx, containerName)
89+
restart, err := manager.RestartWorkload(ctx, containerName)
8290
if err != nil {
8391
fmt.Printf(" failed: %v\n", err)
8492
failedCount++
8593
errors = append(errors, fmt.Sprintf("%s: %v", containerName, err))
94+
} else {
95+
// If it didn't fail during the synchronous part of the operation,
96+
// append to the list of restart requests in flight.
97+
restartRequests = append(restartRequests, restart)
98+
}
99+
}
100+
101+
// Wait for all restarts to complete.
102+
for _, restart := range restartRequests {
103+
err = restart.Wait()
104+
if err != nil {
105+
fmt.Printf(" failed: %v\n", err)
106+
failedCount++
107+
// Unfortunately we don't have the container name here, so we just log a generic error.
108+
errors = append(errors, fmt.Sprintf("Error restarting container: %v", err))
86109
} else {
87110
fmt.Printf(" success\n")
88111
restartedCount++

docs/server/docs.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/server/swagger.json

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

docs/server/swagger.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,12 +1131,12 @@ paths:
11311131
schema:
11321132
type: string
11331133
responses:
1134-
"204":
1134+
"202":
11351135
content:
11361136
application/json:
11371137
schema:
11381138
type: string
1139-
description: No Content
1139+
description: Accepted
11401140
"404":
11411141
content:
11421142
application/json:
@@ -1182,12 +1182,12 @@ paths:
11821182
schema:
11831183
type: string
11841184
responses:
1185-
"204":
1185+
"202":
11861186
content:
11871187
application/json:
11881188
schema:
11891189
type: string
1190-
description: No Content
1190+
description: Accepted
11911191
"404":
11921192
content:
11931193
application/json:

pkg/api/v1/workloads.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ func (s *WorkloadRoutes) stopAllWorkloads(w http.ResponseWriter, r *http.Request
168168
// @Description Delete a workload
169169
// @Tags workloads
170170
// @Param name path string true "Workload name"
171-
// @Success 204 {string} string "No Content"
171+
// @Success 202 {string} string "Accepted"
172172
// @Failure 404 {string} string "Not Found"
173173
// @Router /api/v1beta/workloads/{name} [delete]
174174
func (s *WorkloadRoutes) deleteWorkload(w http.ResponseWriter, r *http.Request) {
@@ -186,7 +186,7 @@ func (s *WorkloadRoutes) deleteWorkload(w http.ResponseWriter, r *http.Request)
186186
http.Error(w, "Failed to delete workload", http.StatusInternalServerError)
187187
return
188188
}
189-
w.WriteHeader(http.StatusNoContent)
189+
w.WriteHeader(http.StatusAccepted)
190190
}
191191

192192
// restartWorkload
@@ -195,13 +195,15 @@ func (s *WorkloadRoutes) deleteWorkload(w http.ResponseWriter, r *http.Request)
195195
// @Description Restart a running workload
196196
// @Tags workloads
197197
// @Param name path string true "Workload name"
198-
// @Success 204 {string} string "No Content"
198+
// @Success 202 {string} string "Accepted"
199199
// @Failure 404 {string} string "Not Found"
200200
// @Router /api/v1beta/workloads/{name}/restart [post]
201201
func (s *WorkloadRoutes) restartWorkload(w http.ResponseWriter, r *http.Request) {
202202
ctx := r.Context()
203203
name := chi.URLParam(r, "name")
204-
err := s.manager.RestartWorkload(ctx, name)
204+
// Note that this is an asynchronous operation.
205+
// In the API, we do not wait for the operation to complete.
206+
_, err := s.manager.RestartWorkload(ctx, name)
205207
if err != nil {
206208
if errors.Is(err, workloads.ErrContainerNotFound) {
207209
http.Error(w, "Workload not found", http.StatusNotFound)
@@ -211,7 +213,7 @@ func (s *WorkloadRoutes) restartWorkload(w http.ResponseWriter, r *http.Request)
211213
http.Error(w, "Failed to restart workload", http.StatusInternalServerError)
212214
return
213215
}
214-
w.WriteHeader(http.StatusNoContent)
216+
w.WriteHeader(http.StatusAccepted)
215217
}
216218

217219
// createWorkload

pkg/workloads/manager.go

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ type Manager interface {
4343
// RunWorkloadDetached runs a container in the background.
4444
RunWorkloadDetached(runConfig *runner.RunConfig) error
4545
// RestartWorkload restarts a previously stopped container.
46-
RestartWorkload(ctx context.Context, name string) error
46+
// It is implemented as an asynchronous operation which returns an errgroup.Group
47+
RestartWorkload(ctx context.Context, name string) (*errgroup.Group, error)
4748
}
4849

4950
type defaultManager struct {
@@ -159,17 +160,12 @@ func (d *defaultManager) StopWorkload(ctx context.Context, name string) (*errgro
159160
return nil, err
160161
}
161162

162-
containerID := container.ID
163-
containerBaseName := labels.GetContainerBaseName(container.Labels)
164163
running := isContainerRunning(container)
165-
166164
if !running {
167165
return nil, fmt.Errorf("%w: %s", ErrContainerNotRunning, name)
168166
}
169167

170-
workload := stopWorkloadRequest{Name: containerBaseName, ID: containerID}
171-
// Do the actual stop operation in the background, and return an error group.
172-
return d.stopWorkloads(ctx, []stopWorkloadRequest{workload}), nil
168+
return d.stopWorkloads(ctx, []*rt.ContainerInfo{container}), nil
173169
}
174170

175171
func (d *defaultManager) StopAllWorkloads(ctx context.Context) (*errgroup.Group, error) {
@@ -181,16 +177,15 @@ func (d *defaultManager) StopAllWorkloads(ctx context.Context) (*errgroup.Group,
181177

182178
// Duplicates the logic of GetWorkloads, but is simple enough that it's not
183179
// worth duplicating.
184-
stopRequests := make([]stopWorkloadRequest, 0, len(containers))
180+
var containersToStop []*rt.ContainerInfo
185181
for _, c := range containers {
186182
// If the caller did not set `listAll` to true, only include running containers.
187183
if labels.IsToolHiveContainer(c.Labels) && isContainerRunning(&c) {
188-
req := stopWorkloadRequest{Name: labels.GetContainerBaseName(c.Labels), ID: c.ID}
189-
stopRequests = append(stopRequests, req)
184+
containersToStop = append(containersToStop, &c)
190185
}
191186
}
192187

193-
return d.stopWorkloads(ctx, stopRequests), nil
188+
return d.stopWorkloads(ctx, containersToStop), nil
194189
}
195190

196191
func (*defaultManager) RunWorkload(ctx context.Context, runConfig *runner.RunConfig) error {
@@ -395,7 +390,7 @@ func (*defaultManager) RunWorkloadDetached(runConfig *runner.RunConfig) error {
395390
return nil
396391
}
397392

398-
func (d *defaultManager) RestartWorkload(ctx context.Context, name string) error {
393+
func (d *defaultManager) RestartWorkload(ctx context.Context, name string) (*errgroup.Group, error) {
399394
var containerBaseName string
400395
var running bool
401396
// Try to find the container.
@@ -418,30 +413,41 @@ func (d *defaultManager) RestartWorkload(ctx context.Context, name string) error
418413

419414
if running && proxyRunning {
420415
logger.Infof("Container %s and proxy are already running", name)
421-
return nil
422-
}
423-
424-
containerID := container.ID
425-
// If the container is running but the proxy is not, stop the container first
426-
if container.ID != "" && running { // && !proxyRunning was previously here but is implied by previous if statement.
427-
logger.Infof("Container %s is running but proxy is not. Stopping container...", name)
428-
if err := d.runtime.StopWorkload(ctx, containerID); err != nil {
429-
return fmt.Errorf("failed to stop container: %v", err)
430-
}
431-
logger.Infof("Container %s stopped", name)
416+
// Return empty error group so that client does not need to check for nil.
417+
return &errgroup.Group{}, nil
432418
}
433419

434420
// Load the configuration from the state store
421+
// This is done synchronously since it is relatively inexpensive operation
422+
// and it allows for better error handling.
435423
mcpRunner, err := d.loadRunnerFromState(ctx, containerBaseName)
436424
if err != nil {
437-
return fmt.Errorf("failed to load state for %s: %v", containerBaseName, err)
425+
return nil, fmt.Errorf("failed to load state for %s: %v", containerBaseName, err)
438426
}
439-
440427
logger.Infof("Loaded configuration from state for %s", containerBaseName)
441428

442429
// Run the tooling server inside a detached process.
430+
// TODO: This will need to be changed when RunWorkloadDetached is converted
431+
// to be async.
443432
logger.Infof("Starting tooling server %s...", name)
444-
return d.RunWorkloadDetached(mcpRunner.Config)
433+
runGroup := &errgroup.Group{}
434+
runGroup.Go(func() error {
435+
containerID := container.ID
436+
// If the container is running but the proxy is not, stop the container first
437+
if container.ID != "" && running { // && !proxyRunning was previously here but is implied by previous if statement.
438+
logger.Infof("Container %s is running but proxy is not. Stopping container...", name)
439+
// n.b. - we do not reuse the `StopWorkload` method here because it
440+
// does some extra things which are not appropriate for resuming a workload.
441+
if err = d.runtime.StopWorkload(ctx, containerID); err != nil {
442+
return fmt.Errorf("failed to stop container %s: %v", name, err)
443+
}
444+
logger.Infof("Container %s stopped", name)
445+
}
446+
447+
return d.RunWorkloadDetached(mcpRunner.Config)
448+
})
449+
450+
return runGroup, nil
445451
}
446452

447453
func (d *defaultManager) findContainerByName(ctx context.Context, name string) (*rt.ContainerInfo, error) {
@@ -555,36 +561,31 @@ func (*defaultManager) cleanupTempPermissionProfile(ctx context.Context, baseNam
555561
return nil
556562
}
557563

558-
// Internal type used when stopping workloads.
559-
type stopWorkloadRequest struct {
560-
Name string
561-
ID string
562-
}
563-
564564
// stopWorkloads stops the named workloads concurrently.
565565
// It assumes that the workloads exist in the running state.
566-
func (d *defaultManager) stopWorkloads(ctx context.Context, workloads []stopWorkloadRequest) *errgroup.Group {
566+
func (d *defaultManager) stopWorkloads(ctx context.Context, workloads []*rt.ContainerInfo) *errgroup.Group {
567567
group := errgroup.Group{}
568568
for _, workload := range workloads {
569569
group.Go(func() error {
570+
name := labels.GetContainerBaseName(workload.Labels)
570571
// Stop the proxy process
571-
proxy.StopProcess(workload.Name)
572+
proxy.StopProcess(name)
572573

573-
logger.Infof("Stopping containers for %s...", workload.Name)
574+
logger.Infof("Stopping containers for %s...", name)
574575
// Stop the container
575576
if err := d.runtime.StopWorkload(ctx, workload.ID); err != nil {
576577
return fmt.Errorf("failed to stop container: %w", err)
577578
}
578579

579580
if shouldRemoveClientConfig() {
580-
if err := removeClientConfigurations(workload.Name); err != nil {
581+
if err := removeClientConfigurations(name); err != nil {
581582
logger.Warnf("Warning: Failed to remove client configurations: %v", err)
582583
} else {
583-
logger.Infof("Client configurations for %s removed", workload.Name)
584+
logger.Infof("Client configurations for %s removed", name)
584585
}
585586
}
586587

587-
logger.Infof("Successfully stopped %s...", workload.Name)
588+
logger.Infof("Successfully stopped %s...", name)
588589
return nil
589590
})
590591
}

0 commit comments

Comments
 (0)