Skip to content

Commit e6e29dc

Browse files
authored
Fix bugs found while testing file-based state tracking (#1347)
This PR was intended to enable file-based state tracking. However, I found some bugs during testing which involves touching many files, so this PR addresses these bugs without enabling file-based state tracking. 1. I discovered that I left out a call to change the state to "running" once the proxy-runner is ready. Simply adding the state manager type led to a circular dependency, which required a bunch of refactoring of code. This is responsible for most of the diff here. 2. The code which created the base directory for the status files was using the wrong path - this was fixed. 3. I originally had separate methods for creating the initial status, and updating. This led to problems with the detached process flow (which would try to create the status twice, leading to an error). This PR changes the code to use a single method which creates or update.
1 parent 2466969 commit e6e29dc

File tree

18 files changed

+576
-339
lines changed

18 files changed

+576
-339
lines changed

cmd/thv-proxyrunner/app/run.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,10 @@ func runCmdFunc(cmd *cobra.Command, args []string) error {
266266
return fmt.Errorf("failed to create RunConfig: %v", err)
267267
}
268268

269-
workloadManager := workloads.NewManagerFromRuntime(rt)
269+
workloadManager, err := workloads.NewManagerFromRuntime(rt)
270+
if err != nil {
271+
return fmt.Errorf("failed to create workload manager: %v", err)
272+
}
270273
return workloadManager.RunWorkload(ctx, runConfig)
271274
}
272275

cmd/thv/app/run.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,10 @@ func runCmdFunc(cmd *cobra.Command, args []string) error {
144144
if err != nil {
145145
return fmt.Errorf("failed to create container runtime: %v", err)
146146
}
147-
workloadManager := workloads.NewManagerFromRuntime(rt)
147+
workloadManager, err := workloads.NewManagerFromRuntime(rt)
148+
if err != nil {
149+
return fmt.Errorf("failed to create workload manager: %v", err)
150+
}
148151

149152
err = validateGroup(ctx, workloadManager, serverOrImage)
150153
if err != nil {
@@ -298,7 +301,10 @@ func runFromConfigFile(ctx context.Context) error {
298301
runConfig.Deployer = rt
299302

300303
// Create workload manager
301-
workloadManager := workloads.NewManagerFromRuntime(rt)
304+
workloadManager, err := workloads.NewManagerFromRuntime(rt)
305+
if err != nil {
306+
return fmt.Errorf("failed to create workload manager: %v", err)
307+
}
302308

303309
// Run the workload based on foreground flag
304310
if runFlags.Foreground {

cmd/thv/app/stop.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
rt "github.com/stacklok/toolhive/pkg/container/runtime"
1111
"github.com/stacklok/toolhive/pkg/workloads"
12+
"github.com/stacklok/toolhive/pkg/workloads/types"
1213
)
1314

1415
var stopCmd = &cobra.Command{
@@ -100,7 +101,7 @@ func stopCmdFunc(cmd *cobra.Command, args []string) error {
100101
// If the workload is not found or not running, treat as a non-fatal error.
101102
if errors.Is(err, rt.ErrWorkloadNotFound) ||
102103
errors.Is(err, workloads.ErrWorkloadNotRunning) ||
103-
errors.Is(err, workloads.ErrInvalidWorkloadName) {
104+
errors.Is(err, types.ErrInvalidWorkloadName) {
104105
fmt.Printf("workload %s is not running\n", workloadName)
105106
return nil
106107
}

pkg/api/server.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,10 @@ func Serve(
171171
return fmt.Errorf("failed to create client manager: %v", err)
172172
}
173173

174-
workloadManager := workloads.NewManagerFromRuntime(containerRuntime)
174+
workloadManager, err := workloads.NewManagerFromRuntime(containerRuntime)
175+
if err != nil {
176+
return fmt.Errorf("failed to create workload manager: %v", err)
177+
}
175178

176179
// Create group manager
177180
groupManager, err := groups.NewManager()

pkg/api/v1/workloads.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"github.com/stacklok/toolhive/pkg/transport/types"
2323
"github.com/stacklok/toolhive/pkg/validation"
2424
"github.com/stacklok/toolhive/pkg/workloads"
25+
wt "github.com/stacklok/toolhive/pkg/workloads/types"
2526
)
2627

2728
// WorkloadRoutes defines the routes for workload management.
@@ -135,7 +136,7 @@ func (s *WorkloadRoutes) getWorkload(w http.ResponseWriter, r *http.Request) {
135136
if errors.Is(err, runtime.ErrWorkloadNotFound) {
136137
http.Error(w, "Workload not found", http.StatusNotFound)
137138
return
138-
} else if errors.Is(err, workloads.ErrInvalidWorkloadName) {
139+
} else if errors.Is(err, wt.ErrInvalidWorkloadName) {
139140
http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest)
140141
return
141142
}
@@ -169,7 +170,7 @@ func (s *WorkloadRoutes) stopWorkload(w http.ResponseWriter, r *http.Request) {
169170
// Use the bulk method with a single workload
170171
_, err := s.workloadManager.StopWorkloads(ctx, []string{name})
171172
if err != nil {
172-
if errors.Is(err, workloads.ErrInvalidWorkloadName) {
173+
if errors.Is(err, wt.ErrInvalidWorkloadName) {
173174
http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest)
174175
return
175176
}
@@ -198,7 +199,7 @@ func (s *WorkloadRoutes) restartWorkload(w http.ResponseWriter, r *http.Request)
198199
// Note: In the API, we always assume that the restart is a background operation
199200
_, err := s.workloadManager.RestartWorkloads(ctx, []string{name}, false)
200201
if err != nil {
201-
if errors.Is(err, workloads.ErrInvalidWorkloadName) {
202+
if errors.Is(err, wt.ErrInvalidWorkloadName) {
202203
http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest)
203204
return
204205
}
@@ -226,7 +227,7 @@ func (s *WorkloadRoutes) deleteWorkload(w http.ResponseWriter, r *http.Request)
226227
// Use the bulk method with a single workload
227228
_, err := s.workloadManager.DeleteWorkloads(ctx, []string{name})
228229
if err != nil {
229-
if errors.Is(err, workloads.ErrInvalidWorkloadName) {
230+
if errors.Is(err, wt.ErrInvalidWorkloadName) {
230231
http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest)
231232
return
232233
}
@@ -366,7 +367,7 @@ func (s *WorkloadRoutes) stopWorkloadsBulk(w http.ResponseWriter, r *http.Reques
366367
// The request is not blocked on completion.
367368
_, err = s.workloadManager.StopWorkloads(ctx, workloadNames)
368369
if err != nil {
369-
if errors.Is(err, workloads.ErrInvalidWorkloadName) {
370+
if errors.Is(err, wt.ErrInvalidWorkloadName) {
370371
http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest)
371372
return
372373
}
@@ -412,7 +413,7 @@ func (s *WorkloadRoutes) restartWorkloadsBulk(w http.ResponseWriter, r *http.Req
412413
// Note: In the API, we always assume that the restart is a background operation.
413414
_, err = s.workloadManager.RestartWorkloads(ctx, workloadNames, false)
414415
if err != nil {
415-
if errors.Is(err, workloads.ErrInvalidWorkloadName) {
416+
if errors.Is(err, wt.ErrInvalidWorkloadName) {
416417
http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest)
417418
return
418419
}
@@ -457,7 +458,7 @@ func (s *WorkloadRoutes) deleteWorkloadsBulk(w http.ResponseWriter, r *http.Requ
457458
// The request is not blocked on completion.
458459
_, err = s.workloadManager.DeleteWorkloads(ctx, workloadNames)
459460
if err != nil {
460-
if errors.Is(err, workloads.ErrInvalidWorkloadName) {
461+
if errors.Is(err, wt.ErrInvalidWorkloadName) {
461462
http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest)
462463
return
463464
}

pkg/runner/config.go

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,7 @@ func (c *RunConfig) WriteJSON(w io.Writer) error {
144144
// ReadJSON deserializes the RunConfig from JSON read from the provided reader
145145
func ReadJSON(r io.Reader) (*RunConfig, error) {
146146
var config RunConfig
147-
decoder := json.NewDecoder(r)
148-
if err := decoder.Decode(&config); err != nil {
147+
if err := state.ReadJSON(r, &config); err != nil {
149148
return nil, err
150149
}
151150
return &config, nil
@@ -313,38 +312,17 @@ func (c *RunConfig) WithStandardLabels() *RunConfig {
313312
return c
314313
}
315314

315+
// GetBaseName returns the base name for the run configuration
316+
func (c *RunConfig) GetBaseName() string {
317+
return c.BaseName
318+
}
319+
316320
// SaveState saves the run configuration to the state store
317321
func (c *RunConfig) SaveState(ctx context.Context) error {
318-
// Create a state store
319-
store, err := state.NewRunConfigStore(state.DefaultAppName)
320-
if err != nil {
321-
return fmt.Errorf("failed to create state store: %w", err)
322-
}
323-
324-
// Get a writer for the state
325-
writer, err := store.GetWriter(ctx, c.BaseName)
326-
if err != nil {
327-
return fmt.Errorf("failed to get writer for state: %w", err)
328-
}
329-
defer writer.Close()
330-
331-
// Serialize the configuration to JSON and write it directly to the state store
332-
if err := c.WriteJSON(writer); err != nil {
333-
return fmt.Errorf("failed to write run configuration: %w", err)
334-
}
335-
336-
logger.Infof("Saved run configuration for %s", c.BaseName)
337-
return nil
322+
return state.SaveRunConfig(ctx, c)
338323
}
339324

340325
// LoadState loads a run configuration from the state store
341326
func LoadState(ctx context.Context, name string) (*RunConfig, error) {
342-
reader, err := state.LoadRunConfigJSON(ctx, name)
343-
if err != nil {
344-
return nil, err
345-
}
346-
defer reader.Close()
347-
348-
// Deserialize the configuration
349-
return ReadJSON(reader)
327+
return state.LoadRunConfig(ctx, name, ReadJSON)
350328
}

pkg/runner/runner.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/stacklok/toolhive/pkg/auth"
1313
"github.com/stacklok/toolhive/pkg/client"
1414
"github.com/stacklok/toolhive/pkg/config"
15+
rt "github.com/stacklok/toolhive/pkg/container/runtime"
1516
"github.com/stacklok/toolhive/pkg/labels"
1617
"github.com/stacklok/toolhive/pkg/logger"
1718
"github.com/stacklok/toolhive/pkg/mcp"
@@ -20,6 +21,7 @@ import (
2021
"github.com/stacklok/toolhive/pkg/telemetry"
2122
"github.com/stacklok/toolhive/pkg/transport"
2223
"github.com/stacklok/toolhive/pkg/transport/types"
24+
"github.com/stacklok/toolhive/pkg/workloads/statuses"
2325
)
2426

2527
// Runner is responsible for running an MCP server with the provided configuration
@@ -32,12 +34,15 @@ type Runner struct {
3234

3335
// supportedMiddleware is a map of supported middleware types to their factory functions.
3436
supportedMiddleware map[string]types.MiddlewareFactory
37+
38+
statusManager statuses.StatusManager
3539
}
3640

3741
// NewRunner creates a new Runner with the provided configuration
38-
func NewRunner(runConfig *RunConfig) *Runner {
42+
func NewRunner(runConfig *RunConfig, statusManager statuses.StatusManager) *Runner {
3943
return &Runner{
40-
Config: runConfig,
44+
Config: runConfig,
45+
statusManager: statusManager,
4146
}
4247
}
4348

@@ -294,6 +299,12 @@ func (r *Runner) Run(ctx context.Context) error {
294299
}
295300
}()
296301

302+
// At this point, we can consider the workload started successfully.
303+
if err := r.statusManager.SetWorkloadStatus(ctx, r.Config.ContainerName, rt.WorkloadStatusRunning, ""); err != nil {
304+
// If we can't set the status to `running` - treat it as a fatal error.
305+
return fmt.Errorf("failed to set workload status: %v", err)
306+
}
307+
297308
// Wait for either a signal or the done channel to be closed
298309
select {
299310
case sig := <-sigCh:

pkg/state/runconfig.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package state
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
"io"
78

@@ -60,3 +61,89 @@ func DeleteSavedRunConfig(ctx context.Context, name string) error {
6061
logger.Infof("Deleted run configuration for %s", name)
6162
return nil
6263
}
64+
65+
// RunConfigPersister defines an interface for objects that can be persisted and loaded as JSON
66+
type RunConfigPersister interface {
67+
// WriteJSON serializes the object to JSON and writes it to the provided writer
68+
WriteJSON(w io.Writer) error
69+
// GetBaseName returns the base name used for persistence
70+
GetBaseName() string
71+
}
72+
73+
// ReadJSONFunc defines a function type for reading JSON into an object
74+
type ReadJSONFunc[T any] func(r io.Reader) (T, error)
75+
76+
// SaveRunConfig saves a run configuration to the state store
77+
func SaveRunConfig[T RunConfigPersister](ctx context.Context, config T) error {
78+
// Create a state store
79+
store, err := NewRunConfigStore(DefaultAppName)
80+
if err != nil {
81+
return fmt.Errorf("failed to create state store: %w", err)
82+
}
83+
84+
// Get a writer for the state
85+
writer, err := store.GetWriter(ctx, config.GetBaseName())
86+
if err != nil {
87+
return fmt.Errorf("failed to get writer for state: %w", err)
88+
}
89+
defer writer.Close()
90+
91+
// Serialize the configuration to JSON and write it directly to the state store
92+
if err := config.WriteJSON(writer); err != nil {
93+
return fmt.Errorf("failed to write run configuration: %w", err)
94+
}
95+
96+
logger.Infof("Saved run configuration for %s", config.GetBaseName())
97+
return nil
98+
}
99+
100+
// LoadRunConfig loads a run configuration from the state store using the provided reader function
101+
func LoadRunConfig[T any](ctx context.Context, name string, readJSONFunc ReadJSONFunc[T]) (T, error) {
102+
var zero T
103+
reader, err := LoadRunConfigJSON(ctx, name)
104+
if err != nil {
105+
return zero, err
106+
}
107+
defer reader.Close()
108+
109+
// Deserialize the configuration using the provided function
110+
return readJSONFunc(reader)
111+
}
112+
113+
// ReadRunConfigJSON deserializes a run configuration from JSON read from the provided reader
114+
// This is a generic JSON deserializer for any type that can be unmarshalled from JSON
115+
func ReadRunConfigJSON[T any](r io.Reader) (*T, error) {
116+
var config T
117+
decoder := json.NewDecoder(r)
118+
if err := decoder.Decode(&config); err != nil {
119+
return nil, err
120+
}
121+
return &config, nil
122+
}
123+
124+
// LoadRunConfigOfType loads a run configuration of a specific type T from the state store
125+
func LoadRunConfigOfType[T any](ctx context.Context, name string) (*T, error) {
126+
return LoadRunConfig(ctx, name, ReadRunConfigJSON[T])
127+
}
128+
129+
// RunConfigReadJSONFunc defines the function signature for reading a RunConfig from JSON
130+
// This allows us to accept the runner.ReadJSON function without creating a circular dependency
131+
type RunConfigReadJSONFunc func(r io.Reader) (interface{}, error)
132+
133+
// LoadRunConfigWithFunc loads a run configuration using a provided read function
134+
func LoadRunConfigWithFunc(ctx context.Context, name string, readFunc RunConfigReadJSONFunc) (interface{}, error) {
135+
reader, err := LoadRunConfigJSON(ctx, name)
136+
if err != nil {
137+
return nil, err
138+
}
139+
defer reader.Close()
140+
141+
return readFunc(reader)
142+
}
143+
144+
// ReadJSON deserializes JSON from the provided reader into a generic interface
145+
// This function is moved from the runner package to avoid circular dependencies
146+
func ReadJSON(r io.Reader, target interface{}) error {
147+
decoder := json.NewDecoder(r)
148+
return decoder.Decode(target)
149+
}

0 commit comments

Comments
 (0)