Skip to content

Commit d45e73a

Browse files
authored
Refactor run command so logic is shared with API. (#796)
Previously, a lot of the logic relating to running a workload was only implemented for the CLI. This PR addresses refactors the code so that most of the logic from the run command now lives inside the constructor for the RunConfig struct.
1 parent 12a0b66 commit d45e73a

File tree

6 files changed

+286
-244
lines changed

6 files changed

+286
-244
lines changed

cmd/thv/app/run.go

Lines changed: 33 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package app
22

33
import (
4-
"context"
54
"fmt"
65
"net"
76
"os"
@@ -12,11 +11,9 @@ import (
1211
"github.com/stacklok/toolhive/pkg/logger"
1312
"github.com/stacklok/toolhive/pkg/permissions"
1413
"github.com/stacklok/toolhive/pkg/process"
15-
"github.com/stacklok/toolhive/pkg/registry"
1614
"github.com/stacklok/toolhive/pkg/runner"
1715
"github.com/stacklok/toolhive/pkg/runner/retriever"
1816
"github.com/stacklok/toolhive/pkg/transport"
19-
"github.com/stacklok/toolhive/pkg/transport/types"
2017
"github.com/stacklok/toolhive/pkg/workloads"
2118
)
2219

@@ -208,7 +205,8 @@ func runCmdFunc(cmd *cobra.Command, args []string) error {
208205
}
209206
runHost = validatedHost
210207

211-
// Get the server name or image
208+
// Get the name of the MCP server to run.
209+
// This may be a server name from the registry, a container image, or a protocol scheme.
212210
serverOrImage := args[0]
213211

214212
// Process command arguments using os.Args to find everything after --
@@ -233,11 +231,34 @@ func runCmdFunc(cmd *cobra.Command, args []string) error {
233231
}
234232
workloadManager := workloads.NewManagerFromRuntime(rt)
235233

234+
// Select an env var validation strategy depending on how the CLI is run:
235+
// If we have called the CLI directly, we use the CLIEnvVarValidator.
236+
// If we are running in detached mode, or the CLI is wrapped by the K8s operator,
237+
// we use the DetachedEnvVarValidator.
238+
var envVarValidator runner.EnvVarValidator
239+
if process.IsDetached() || container.IsKubernetesRuntime() {
240+
envVarValidator = &runner.DetachedEnvVarValidator{}
241+
} else {
242+
envVarValidator = &runner.CLIEnvVarValidator{}
243+
}
244+
245+
// Take the MCP server we were supplied and either fetch the image, or
246+
// build it from a protocol scheme. If the server URI refers to an image
247+
// in our trusted registry, we will also fetch the image metadata.
248+
imageURL, imageMetadata, err := retriever.GetMCPServer(ctx, serverOrImage, runCACertPath, runVerifyImage)
249+
if err != nil {
250+
return fmt.Errorf("failed to find or create the MCP server %s: %v", serverOrImage, err)
251+
}
252+
236253
// Initialize a new RunConfig with values from command-line flags
237-
runConfig := runner.NewRunConfigFromFlags(
254+
// TODO: As noted elsewhere, we should use the builder pattern here to make it more readable.
255+
runConfig, err := runner.NewRunConfigFromFlags(
256+
ctx,
238257
rt,
239258
cmdArgs,
240259
runName,
260+
imageURL,
261+
imageMetadata,
241262
runHost,
242263
debugMode,
243264
runVolumes,
@@ -247,6 +268,10 @@ func runCmdFunc(cmd *cobra.Command, args []string) error {
247268
runEnableAudit,
248269
runPermissionProfile,
249270
runTargetHost,
271+
runTransport,
272+
runPort,
273+
runTargetPort,
274+
runEnv,
250275
oidcIssuer,
251276
oidcAudience,
252277
oidcJwksURL,
@@ -259,30 +284,11 @@ func runCmdFunc(cmd *cobra.Command, args []string) error {
259284
runOtelEnablePrometheusMetricsPath,
260285
runOtelEnvironmentVariables,
261286
runIsolateNetwork,
287+
runK8sPodPatch,
288+
envVarValidator,
262289
)
263-
264-
// Set the Kubernetes pod template patch if provided
265-
if runK8sPodPatch != "" {
266-
runConfig.K8sPodTemplatePatch = runK8sPodPatch
267-
}
268-
269-
imageURL, imageMetadata, err := retriever.GetMCPServer(ctx, serverOrImage, runCACertPath, runVerifyImage)
270290
if err != nil {
271-
return fmt.Errorf("failed to retrieve MCP server: %v", err)
272-
}
273-
runConfig.Image = imageURL
274-
275-
if imageMetadata != nil {
276-
// If the image came from our registry, apply settings from the registry metadata.
277-
err = applyRegistrySettings(ctx, cmd, serverOrImage, imageMetadata, runConfig, debugMode)
278-
if err != nil {
279-
return fmt.Errorf("failed to apply registry settings: %v", err)
280-
}
281-
}
282-
283-
// Configure the RunConfig with transport, ports, permissions, etc.
284-
if err := configureRunConfig(runConfig, runTransport, runPort, runTargetPort, runEnv); err != nil {
285-
return err
291+
return fmt.Errorf("failed to create RunConfig: %v", err)
286292
}
287293

288294
// Once we have built the RunConfig, start the MCP workload.
@@ -293,85 +299,6 @@ func runCmdFunc(cmd *cobra.Command, args []string) error {
293299
return workloadManager.RunWorkloadDetached(runConfig)
294300
}
295301

296-
// applyRegistrySettings applies settings from a registry server to the run config
297-
func applyRegistrySettings(
298-
ctx context.Context,
299-
cmd *cobra.Command,
300-
serverName string,
301-
metadata *registry.ImageMetadata,
302-
runConfig *runner.RunConfig,
303-
debugMode bool,
304-
) error {
305-
// Use the image from the registry
306-
runConfig.Image = metadata.Image
307-
308-
// If name is not provided, use the metadata name from registry
309-
if runConfig.Name == "" {
310-
runConfig.Name = serverName
311-
}
312-
313-
// Use registry transport if not overridden
314-
if !cmd.Flags().Changed("transport") {
315-
logDebug(debugMode, "Using registry transport: %s", metadata.Transport)
316-
runTransport = metadata.Transport
317-
} else {
318-
logDebug(debugMode, "Using provided transport: %s (overriding registry default: %s)",
319-
runTransport, metadata.Transport)
320-
}
321-
322-
// Use registry target port if not overridden and transport is SSE or Streamable HTTP
323-
if !cmd.Flags().Changed("target-port") && (metadata.Transport == types.TransportTypeSSE.String() ||
324-
metadata.Transport == types.TransportTypeStreamableHTTP.String()) && metadata.TargetPort > 0 {
325-
logDebug(debugMode, "Using registry target port: %d", metadata.TargetPort)
326-
runTargetPort = metadata.TargetPort
327-
}
328-
329-
// Prepend registry args to command-line args if available
330-
if len(metadata.Args) > 0 {
331-
logDebug(debugMode, "Prepending registry args: %v", metadata.Args)
332-
runConfig.CmdArgs = append(metadata.Args, runConfig.CmdArgs...)
333-
}
334-
335-
// Note this logic will be moved elsewhere in a future PR.
336-
// Select an env var validation strategy depending on how the CLI is run:
337-
// If we have called the CLI directly, we use the CLIEnvVarValidator.
338-
// If we are running in detached mode, or the CLI is wrapped by the K8s operator,
339-
// we use the DetachedEnvVarValidator.
340-
var envVarValidator runner.EnvVarValidator
341-
if process.IsDetached() || container.IsKubernetesRuntime() {
342-
envVarValidator = &runner.DetachedEnvVarValidator{}
343-
} else {
344-
envVarValidator = &runner.CLIEnvVarValidator{}
345-
}
346-
347-
// Process environment variables from registry.
348-
// This will be merged with command-line env vars in configureRunConfig
349-
if err := envVarValidator.Validate(ctx, metadata, runConfig, runEnv); err != nil {
350-
return fmt.Errorf("failed to validate required configuration values: %v", err)
351-
}
352-
353-
// Create a temporary file for the permission profile if not explicitly provided
354-
if !cmd.Flags().Changed("permission-profile") {
355-
permProfilePath, err := runner.CreatePermissionProfileFile(serverName, metadata.Permissions)
356-
if err != nil {
357-
// Just log the error and continue with the default permission profile
358-
logger.Warnf("Warning: Failed to create permission profile file: %v", err)
359-
} else {
360-
// Update the permission profile path
361-
runConfig.PermissionProfileNameOrPath = permProfilePath
362-
}
363-
}
364-
365-
return nil
366-
}
367-
368-
// logDebug logs a message if debug mode is enabled
369-
func logDebug(debugMode bool, format string, args ...interface{}) {
370-
if debugMode {
371-
logger.Infof(format+"", args...)
372-
}
373-
}
374-
375302
// parseCommandArguments processes command-line arguments to find everything after the -- separator
376303
// which are the arguments to be passed to the MCP server
377304
func parseCommandArguments(args []string) []string {

cmd/thv/app/run_common.go

Lines changed: 0 additions & 72 deletions
This file was deleted.

0 commit comments

Comments
 (0)