Skip to content

Commit aa0ac23

Browse files
authored
Wireup middleware with new interfaces (#1394)
This PR does the following: 1. Add implementations of the Middleware config and middleware factory interfaces for each type of middleware. 2. Convert the CLI code to instantiate the middleware using the configs. 3. Provides a backwards compatibility mechanism for legacy run configs. Not addressed: Changing how the API works. Middleware is not really supported in the API. I would rather switch the API entirely to this new structure in a future PR.
1 parent ea9081d commit aa0ac23

File tree

5 files changed

+559
-130
lines changed

5 files changed

+559
-130
lines changed

cmd/thv/app/run_flags.go

Lines changed: 210 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@ package app
33
import (
44
"context"
55
"fmt"
6+
"strings"
67

78
"github.com/spf13/cobra"
89

10+
"github.com/stacklok/toolhive/pkg/auth"
11+
"github.com/stacklok/toolhive/pkg/authz"
912
cfg "github.com/stacklok/toolhive/pkg/config"
1013
"github.com/stacklok/toolhive/pkg/container"
1114
"github.com/stacklok/toolhive/pkg/container/runtime"
@@ -16,6 +19,7 @@ import (
1619
"github.com/stacklok/toolhive/pkg/registry"
1720
"github.com/stacklok/toolhive/pkg/runner"
1821
"github.com/stacklok/toolhive/pkg/runner/retriever"
22+
"github.com/stacklok/toolhive/pkg/telemetry"
1923
"github.com/stacklok/toolhive/pkg/transport"
2024
"github.com/stacklok/toolhive/pkg/transport/types"
2125
)
@@ -191,14 +195,53 @@ func BuildRunnerConfig(
191195
debugMode bool,
192196
cmd *cobra.Command,
193197
) (*runner.RunConfig, error) {
194-
// Validate the host flag
198+
// Validate and setup basic configuration
195199
validatedHost, err := ValidateAndNormaliseHostFlag(runFlags.Host)
196200
if err != nil {
197201
return nil, fmt.Errorf("invalid host: %s", runFlags.Host)
198202
}
199203

200-
// Get OIDC flags
204+
// Setup OIDC configuration
205+
oidcConfig, err := setupOIDCConfiguration(cmd, runFlags)
206+
if err != nil {
207+
return nil, err
208+
}
209+
210+
// Setup telemetry configuration
211+
telemetryConfig := setupTelemetryConfiguration(cmd, runFlags)
212+
213+
// Setup runtime and validation
214+
rt, envVarValidator, err := setupRuntimeAndValidation(ctx)
215+
if err != nil {
216+
return nil, err
217+
}
218+
219+
// Handle image retrieval
220+
imageURL, imageMetadata, err := handleImageRetrieval(ctx, serverOrImage, runFlags)
221+
if err != nil {
222+
return nil, err
223+
}
224+
225+
// Validate and setup proxy mode
226+
if err := validateAndSetupProxyMode(runFlags); err != nil {
227+
return nil, err
228+
}
229+
230+
// Parse environment variables
231+
envVars, err := environment.ParseEnvironmentVariables(runFlags.Env)
232+
if err != nil {
233+
return nil, fmt.Errorf("failed to parse environment variables: %v", err)
234+
}
235+
236+
// Build the runner config
237+
return buildRunnerConfig(ctx, runFlags, cmdArgs, debugMode, validatedHost, rt, imageURL, imageMetadata,
238+
envVars, envVarValidator, oidcConfig, telemetryConfig)
239+
}
240+
241+
// setupOIDCConfiguration sets up OIDC configuration and validates URLs
242+
func setupOIDCConfiguration(cmd *cobra.Command, runFlags *RunFlags) (*auth.TokenValidatorConfig, error) {
201243
oidcIssuer, oidcAudience, oidcJwksURL, oidcIntrospectionURL, oidcClientID, oidcClientSecret := getOidcFromFlags(cmd)
244+
202245
if oidcJwksURL != "" {
203246
if err := networking.ValidateEndpointURL(oidcJwksURL); err != nil {
204247
return nil, fmt.Errorf("invalid %s: %w", oidcJwksURL, err)
@@ -210,61 +253,85 @@ func BuildRunnerConfig(
210253
}
211254
}
212255

213-
// Get OTEL flag values with config fallbacks
256+
return createOIDCConfig(oidcIssuer, oidcAudience, oidcJwksURL, oidcIntrospectionURL,
257+
oidcClientID, oidcClientSecret, runFlags.ResourceURL), nil
258+
}
259+
260+
// setupTelemetryConfiguration sets up telemetry configuration with config fallbacks
261+
func setupTelemetryConfiguration(cmd *cobra.Command, runFlags *RunFlags) *telemetry.Config {
214262
config := cfg.GetConfig()
215263
finalOtelEndpoint, finalOtelSamplingRate, finalOtelEnvironmentVariables := getTelemetryFromFlags(cmd, config,
216264
runFlags.OtelEndpoint, runFlags.OtelSamplingRate, runFlags.OtelEnvironmentVariables)
217265

218-
// Create container runtime
266+
return createTelemetryConfig(finalOtelEndpoint, runFlags.OtelEnablePrometheusMetricsPath,
267+
runFlags.OtelServiceName, finalOtelSamplingRate, runFlags.OtelHeaders, runFlags.OtelInsecure,
268+
finalOtelEnvironmentVariables)
269+
}
270+
271+
// setupRuntimeAndValidation creates container runtime and selects environment variable validator
272+
func setupRuntimeAndValidation(ctx context.Context) (runtime.Deployer, runner.EnvVarValidator, error) {
219273
rt, err := container.NewFactory().Create(ctx)
220274
if err != nil {
221-
return nil, fmt.Errorf("failed to create container runtime: %v", err)
275+
return nil, nil, fmt.Errorf("failed to create container runtime: %v", err)
222276
}
223277

224-
// Select an envVars var validation strategy depending on how the CLI is run:
225-
// If we have called the CLI directly, we use the CLIEnvVarValidator.
226-
// If we are running in detached mode, or the CLI is wrapped by the K8s operator,
227-
// we use the DetachedEnvVarValidator.
228278
var envVarValidator runner.EnvVarValidator
229279
if process.IsDetached() || runtime.IsKubernetesRuntime() {
230280
envVarValidator = &runner.DetachedEnvVarValidator{}
231281
} else {
232282
envVarValidator = &runner.CLIEnvVarValidator{}
233283
}
234284

235-
// Image retrieval
285+
return rt, envVarValidator, nil
286+
}
287+
288+
// handleImageRetrieval retrieves and processes the MCP server image
289+
func handleImageRetrieval(
290+
ctx context.Context, serverOrImage string, runFlags *RunFlags,
291+
) (string, *registry.ImageMetadata, error) {
236292
var imageMetadata *registry.ImageMetadata
237293
imageURL := serverOrImage
238-
// Only pull image if we are not running in Kubernetes mode.
239-
// This split will go away if we implement a separate command or binary
240-
// for running MCP servers in Kubernetes.
294+
241295
if !runtime.IsKubernetesRuntime() {
242-
// Take the MCP server we were supplied and either fetch the image, or
243-
// build it from a protocol scheme. If the server URI refers to an image
244-
// in our trusted registry, we will also fetch the image metadata.
296+
var err error
245297
imageURL, imageMetadata, err = retriever.GetMCPServer(ctx, serverOrImage, runFlags.CACertPath, runFlags.VerifyImage)
246298
if err != nil {
247-
return nil, fmt.Errorf("failed to find or create the MCP server %s: %v", serverOrImage, err)
299+
return "", nil, fmt.Errorf("failed to find or create the MCP server %s: %v", serverOrImage, err)
248300
}
249301
}
250302

251-
// Validate proxy mode early
303+
return imageURL, imageMetadata, nil
304+
}
305+
306+
// validateAndSetupProxyMode validates and sets default proxy mode if needed
307+
func validateAndSetupProxyMode(runFlags *RunFlags) error {
252308
if !types.IsValidProxyMode(runFlags.ProxyMode) {
253309
if runFlags.ProxyMode == "" {
254310
runFlags.ProxyMode = types.ProxyModeSSE.String() // default to SSE for backward compatibility
255311
} else {
256-
return nil, fmt.Errorf("invalid value for --proxy-mode: %s", runFlags.ProxyMode)
312+
return fmt.Errorf("invalid value for --proxy-mode: %s", runFlags.ProxyMode)
257313
}
258314
}
315+
return nil
316+
}
259317

260-
// Parse the environment variables from a list of strings to a map.
261-
envVars, err := environment.ParseEnvironmentVariables(runFlags.Env)
262-
if err != nil {
263-
return nil, fmt.Errorf("failed to parse environment variables: %v", err)
264-
}
265-
266-
// Initialize a new RunConfig with values from command-line flags
267-
return runner.NewRunConfigBuilder().
318+
// buildRunnerConfig creates the final RunnerConfig using the builder pattern
319+
func buildRunnerConfig(
320+
ctx context.Context,
321+
runFlags *RunFlags,
322+
cmdArgs []string,
323+
debugMode bool,
324+
validatedHost string,
325+
rt runtime.Deployer,
326+
imageURL string,
327+
imageMetadata *registry.ImageMetadata,
328+
envVars map[string]string,
329+
envVarValidator runner.EnvVarValidator,
330+
oidcConfig *auth.TokenValidatorConfig,
331+
telemetryConfig *telemetry.Config,
332+
) (*runner.RunConfig, error) {
333+
// Create a builder for the RunConfig
334+
builder := runner.NewRunConfigBuilder().
268335
WithRuntime(rt).
269336
WithCmdArgs(cmdArgs).
270337
WithName(runFlags.Name).
@@ -284,16 +351,59 @@ func BuildRunnerConfig(
284351
WithAuditEnabled(runFlags.EnableAudit, runFlags.AuditConfig).
285352
WithLabels(runFlags.Labels).
286353
WithGroup(runFlags.Group).
287-
WithOIDCConfig(oidcIssuer, oidcAudience, oidcJwksURL, oidcIntrospectionURL, oidcClientID, oidcClientSecret,
288-
runFlags.ThvCABundle, runFlags.JWKSAuthTokenFile, runFlags.ResourceURL, runFlags.JWKSAllowPrivateIP).
289-
WithTelemetryConfig(finalOtelEndpoint, runFlags.OtelEnablePrometheusMetricsPath, runFlags.OtelServiceName,
290-
finalOtelSamplingRate, runFlags.OtelHeaders, runFlags.OtelInsecure, finalOtelEnvironmentVariables).
291-
WithToolsFilter(runFlags.ToolsFilter).
292354
WithIgnoreConfig(&ignore.Config{
293355
LoadGlobal: runFlags.IgnoreGlobally,
294356
PrintOverlays: runFlags.PrintOverlays,
295-
}).
296-
Build(ctx, imageMetadata, envVars, envVarValidator)
357+
})
358+
359+
// Configure middleware from flags
360+
builder = builder.WithMiddlewareFromFlags(
361+
oidcConfig,
362+
runFlags.ToolsFilter,
363+
telemetryConfig,
364+
runFlags.AuthzConfig,
365+
runFlags.EnableAudit,
366+
runFlags.AuditConfig,
367+
runFlags.Name,
368+
runFlags.Transport,
369+
)
370+
371+
// Load authz config if path is provided
372+
if runFlags.AuthzConfig != "" {
373+
if authzConfigData, err := authz.LoadConfig(runFlags.AuthzConfig); err == nil {
374+
builder = builder.WithAuthzConfig(authzConfigData)
375+
}
376+
// Note: Path is already set via WithAuthzConfigPath above
377+
}
378+
379+
// Get OIDC and telemetry values for legacy configuration
380+
oidcIssuer, oidcAudience, oidcJwksURL, oidcIntrospectionURL, oidcClientID, oidcClientSecret := extractOIDCValues(oidcConfig)
381+
finalOtelEndpoint, finalOtelSamplingRate, finalOtelEnvironmentVariables := extractTelemetryValues(telemetryConfig)
382+
383+
// Set additional configurations that are still needed in old format for other parts of the system
384+
builder = builder.WithOIDCConfig(oidcIssuer, oidcAudience, oidcJwksURL, oidcIntrospectionURL, oidcClientID, oidcClientSecret,
385+
runFlags.ThvCABundle, runFlags.JWKSAuthTokenFile, runFlags.ResourceURL, runFlags.JWKSAllowPrivateIP).
386+
WithTelemetryConfig(finalOtelEndpoint, runFlags.OtelEnablePrometheusMetricsPath, runFlags.OtelServiceName,
387+
finalOtelSamplingRate, runFlags.OtelHeaders, runFlags.OtelInsecure, finalOtelEnvironmentVariables).
388+
WithToolsFilter(runFlags.ToolsFilter)
389+
390+
return builder.Build(ctx, imageMetadata, envVars, envVarValidator)
391+
}
392+
393+
// extractOIDCValues extracts OIDC values from the OIDC config for legacy configuration
394+
func extractOIDCValues(config *auth.TokenValidatorConfig) (string, string, string, string, string, string) {
395+
if config == nil {
396+
return "", "", "", "", "", ""
397+
}
398+
return config.Issuer, config.Audience, config.JWKSURL, config.IntrospectionURL, config.ClientID, config.ClientSecret
399+
}
400+
401+
// extractTelemetryValues extracts telemetry values from the telemetry config for legacy configuration
402+
func extractTelemetryValues(config *telemetry.Config) (string, float64, []string) {
403+
if config == nil {
404+
return "", 0.0, nil
405+
}
406+
return config.Endpoint, config.SamplingRate, config.EnvironmentVariables
297407
}
298408

299409
// getOidcFromFlags extracts OIDC configuration from command flags
@@ -329,3 +439,69 @@ func getTelemetryFromFlags(cmd *cobra.Command, config *cfg.Config, otelEndpoint
329439

330440
return finalOtelEndpoint, finalOtelSamplingRate, finalOtelEnvironmentVariables
331441
}
442+
443+
// createOIDCConfig creates an OIDC configuration if any OIDC parameters are provided
444+
func createOIDCConfig(oidcIssuer, oidcAudience, oidcJwksURL, oidcIntrospectionURL,
445+
oidcClientID, oidcClientSecret, resourceURL string) *auth.TokenValidatorConfig {
446+
if oidcIssuer != "" || oidcAudience != "" || oidcJwksURL != "" || oidcIntrospectionURL != "" ||
447+
oidcClientID != "" || oidcClientSecret != "" || resourceURL != "" {
448+
return &auth.TokenValidatorConfig{
449+
Issuer: oidcIssuer,
450+
Audience: oidcAudience,
451+
JWKSURL: oidcJwksURL,
452+
IntrospectionURL: oidcIntrospectionURL,
453+
ClientID: oidcClientID,
454+
ClientSecret: oidcClientSecret,
455+
ResourceURL: resourceURL,
456+
}
457+
}
458+
return nil
459+
}
460+
461+
// createTelemetryConfig creates a telemetry configuration if any telemetry parameters are provided
462+
func createTelemetryConfig(otelEndpoint string, otelEnablePrometheusMetricsPath bool,
463+
otelServiceName string, otelSamplingRate float64, otelHeaders []string,
464+
otelInsecure bool, otelEnvironmentVariables []string) *telemetry.Config {
465+
if otelEndpoint == "" && !otelEnablePrometheusMetricsPath {
466+
return nil
467+
}
468+
469+
// Parse headers from key=value format
470+
headers := make(map[string]string)
471+
for _, header := range otelHeaders {
472+
parts := strings.SplitN(header, "=", 2)
473+
if len(parts) == 2 {
474+
headers[parts[0]] = parts[1]
475+
}
476+
}
477+
478+
// Use provided service name or default
479+
serviceName := otelServiceName
480+
if serviceName == "" {
481+
serviceName = telemetry.DefaultConfig().ServiceName
482+
}
483+
484+
// Process environment variables - split comma-separated values
485+
var processedEnvVars []string
486+
for _, envVarEntry := range otelEnvironmentVariables {
487+
// Split by comma and trim whitespace
488+
envVars := strings.Split(envVarEntry, ",")
489+
for _, envVar := range envVars {
490+
trimmed := strings.TrimSpace(envVar)
491+
if trimmed != "" {
492+
processedEnvVars = append(processedEnvVars, trimmed)
493+
}
494+
}
495+
}
496+
497+
return &telemetry.Config{
498+
Endpoint: otelEndpoint,
499+
ServiceName: serviceName,
500+
ServiceVersion: telemetry.DefaultConfig().ServiceVersion,
501+
SamplingRate: otelSamplingRate,
502+
Headers: headers,
503+
Insecure: otelInsecure,
504+
EnablePrometheusMetricsPath: otelEnablePrometheusMetricsPath,
505+
EnvironmentVariables: processedEnvVars,
506+
}
507+
}

pkg/audit/middleware.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun
6060
auditConfig = DefaultConfig()
6161
}
6262

63-
// Set component name if provided
64-
if params.Component != "" {
63+
// Set component name if provided and config doesn't already have one
64+
if params.Component != "" && auditConfig.Component == "" {
6565
auditConfig.Component = params.Component
6666
}
6767

0 commit comments

Comments
 (0)