Skip to content

Commit ae59712

Browse files
authored
Define types and structure for generic middleware creation (#1298)
This is the first step towards generalizing the handling of middleware in the code. This is intended to allow new middleware to be added without altering the runner code or the runconfig schema. This changes the runconfig to contain a list of generic MiddlewareConfig structs, and modifies the runner to create middleware instances based on this list, assuming a set of factory functions are supplied to the runner. Right now, this list is never populated, and none of the existing middleware have been converted to use this scheme. This will be done in a future PR.
1 parent 5bd5a98 commit ae59712

File tree

18 files changed

+127
-25
lines changed

18 files changed

+127
-25
lines changed

cmd/thv/app/proxy.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error {
195195
}
196196

197197
// Create middlewares slice for incoming request authentication
198-
var middlewares []types.Middleware
198+
var middlewares []types.MiddlewareFunction
199199

200200
// Get OIDC configuration if enabled (for protecting the proxy endpoint)
201201
var oidcConfig *auth.TokenValidatorConfig
@@ -514,7 +514,7 @@ func resolveClientSecret() (string, error) {
514514
}
515515

516516
// createTokenInjectionMiddleware creates a middleware that injects the OAuth token into requests
517-
func createTokenInjectionMiddleware(tokenSource *oauth2.TokenSource) types.Middleware {
517+
func createTokenInjectionMiddleware(tokenSource *oauth2.TokenSource) types.MiddlewareFunction {
518518
return func(next http.Handler) http.Handler {
519519
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
520520
token, err := (*tokenSource).Token()

docs/server/docs.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/server/swagger.json

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/server/swagger.yaml

Lines changed: 19 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/audit/config.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"net/http"
99
"os"
1010
"path/filepath"
11+
12+
"github.com/stacklok/toolhive/pkg/transport/types"
1113
)
1214

1315
// Config represents the audit logging configuration.
@@ -103,7 +105,7 @@ func (c *Config) ShouldAuditEvent(eventType string) bool {
103105
}
104106

105107
// CreateMiddleware creates an HTTP middleware from the audit configuration.
106-
func (c *Config) CreateMiddleware() (func(http.Handler) http.Handler, error) {
108+
func (c *Config) CreateMiddleware() (types.MiddlewareFunction, error) {
107109
auditor, err := NewAuditor(c)
108110
if err != nil {
109111
return nil, fmt.Errorf("failed to create auditor: %w", err)

pkg/authz/config.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"strings"
1111

1212
"sigs.k8s.io/yaml"
13+
14+
"github.com/stacklok/toolhive/pkg/transport/types"
1315
)
1416

1517
// ConfigType represents the type of authorization configuration.
@@ -117,7 +119,7 @@ func (c *Config) Validate() error {
117119
}
118120

119121
// CreateMiddleware creates an HTTP middleware from the configuration.
120-
func (c *Config) CreateMiddleware() (func(http.Handler) http.Handler, error) {
122+
func (c *Config) CreateMiddleware() (types.MiddlewareFunction, error) {
121123
// Create the appropriate middleware based on the configuration type
122124
switch c.Type {
123125
case ConfigTypeCedarV1:

pkg/mcp/tool_filter.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ var errBug = errors.New("there's a bug")
3232
// This middleware is designed to be used ONLY when tool filtering is enabled,
3333
// and expects the list of tools to be "correct" (i.e. not empty and not
3434
// containing nonexisting tools).
35-
func NewToolFilterMiddleware(filterTools []string) (types.Middleware, error) {
35+
func NewToolFilterMiddleware(filterTools []string) (types.MiddlewareFunction, error) {
3636
if len(filterTools) == 0 {
3737
return nil, fmt.Errorf("tools list for filtering is empty")
3838
}
@@ -73,7 +73,7 @@ func NewToolFilterMiddleware(filterTools []string) (types.Middleware, error) {
7373
// This middleware is designed to be used ONLY when tool filtering is enabled,
7474
// and expects the list of tools to be "correct" (i.e. not empty and not
7575
// containing nonexisting tools).
76-
func NewToolCallFilterMiddleware(filterTools []string) (types.Middleware, error) {
76+
func NewToolCallFilterMiddleware(filterTools []string) (types.MiddlewareFunction, error) {
7777
if len(filterTools) == 0 {
7878
return nil, fmt.Errorf("tools list for filtering is empty")
7979
}

pkg/runner/config.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@ type RunConfig struct {
128128

129129
// IgnoreConfig contains configuration for ignore processing
130130
IgnoreConfig *ignore.Config `json:"ignore_config,omitempty" yaml:"ignore_config,omitempty"`
131+
132+
// MiddlewareConfigs contains the list of middleware to apply to the transport
133+
// and the configuration for each middleware.
134+
MiddlewareConfigs []types.MiddlewareConfig `json:"middleware_configs,omitempty" yaml:"middleware_configs,omitempty"`
131135
}
132136

133137
// WriteJSON serializes the RunConfig to JSON and writes it to the provided writer
@@ -167,6 +171,12 @@ func (c *RunConfig) WithAudit(config *audit.Config) *RunConfig {
167171
return c
168172
}
169173

174+
// WithMiddlewareConfig adds middleware configuration to the RunConfig
175+
func (c *RunConfig) WithMiddlewareConfig(middlewareConfig []types.MiddlewareConfig) *RunConfig {
176+
c.MiddlewareConfigs = middlewareConfig
177+
return c
178+
}
179+
170180
// WithTransport parses and sets the transport type
171181
func (c *RunConfig) WithTransport(t string) (*RunConfig, error) {
172182
transportType, err := types.ParseTransportType(t)

pkg/runner/config_builder.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ func (b *RunConfigBuilder) WithName(name string) *RunConfigBuilder {
5858
return b
5959
}
6060

61+
// WithMiddlewareConfig sets the middleware configuration
62+
func (b *RunConfigBuilder) WithMiddlewareConfig(middlewareConfig []types.MiddlewareConfig) *RunConfigBuilder {
63+
b.config.MiddlewareConfigs = middlewareConfig
64+
return b
65+
}
66+
6167
// WithCmdArgs sets the command arguments
6268
func (b *RunConfigBuilder) WithCmdArgs(args []string) *RunConfigBuilder {
6369
b.config.CmdArgs = args

pkg/runner/runner.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ type Runner struct {
2929

3030
// telemetryProvider is the OpenTelemetry provider for cleanup
3131
telemetryProvider *telemetry.Provider
32+
33+
// supportedMiddleware is a map of supported middleware types to their factory functions.
34+
supportedMiddleware map[string]types.MiddlewareFactory
3235
}
3336

3437
// NewRunner creates a new Runner with the provided configuration
@@ -53,6 +56,29 @@ func (r *Runner) Run(ctx context.Context) error {
5356
Debug: r.Config.Debug,
5457
}
5558

59+
// Create middleware from the MiddlewareConfigs instances in the RunConfig.
60+
for _, middlewareConfig := range r.Config.MiddlewareConfigs {
61+
// First, get the correct factory function for the middleware type.
62+
factory, ok := r.supportedMiddleware[middlewareConfig.Type]
63+
if !ok {
64+
return fmt.Errorf("unsupported middleware type: %s", middlewareConfig.Type)
65+
}
66+
67+
// Create the middleware instance using the factory function.
68+
middleware, err := factory(&middlewareConfig)
69+
if err != nil {
70+
return fmt.Errorf("failed to create middleware of type %s: %v", middlewareConfig.Type, err)
71+
}
72+
73+
// Ensure middleware is cleaned up on shutdown.
74+
defer func() {
75+
if err := middleware.Close(); err != nil {
76+
logger.Warnf("Failed to close middleware of type %s: %v", middlewareConfig.Type, err)
77+
}
78+
}()
79+
transportConfig.Middlewares = append(transportConfig.Middlewares, middleware.Handler())
80+
}
81+
5682
if len(r.Config.ToolsFilter) > 0 {
5783
toolsFilterMiddleware, err := mcp.NewToolFilterMiddleware(r.Config.ToolsFilter)
5884
if err != nil {

0 commit comments

Comments
 (0)