diff --git a/cmd/regup/app/root.go b/cmd/regup/app/root.go index 28e33cfba..13c92df4b 100644 --- a/cmd/regup/app/root.go +++ b/cmd/regup/app/root.go @@ -3,8 +3,12 @@ package app import ( "github.com/spf13/cobra" + + log "github.com/stacklok/toolhive/pkg/logger" ) +var logger = log.NewLogger() + var rootCmd = &cobra.Command{ Use: "regup", DisableAutoGenTag: true, diff --git a/cmd/regup/app/update.go b/cmd/regup/app/update.go index 394ef7bce..773064ee8 100644 --- a/cmd/regup/app/update.go +++ b/cmd/regup/app/update.go @@ -15,7 +15,6 @@ import ( "github.com/spf13/cobra" "github.com/stacklok/toolhive/pkg/container/verifier" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/registry" ) @@ -304,7 +303,7 @@ func verifyServerProvenance(name string, server *registry.ImageMetadata) error { logger.Infof("Verifying provenance for server %s with image %s", name, server.Image) // Create verifier - v, err := verifier.New(server) + v, err := verifier.New(server, logger) if err != nil { return fmt.Errorf("failed to create verifier: %w", err) } diff --git a/cmd/regup/app/update_test.go b/cmd/regup/app/update_test.go index 606be5550..c3f49b890 100644 --- a/cmd/regup/app/update_test.go +++ b/cmd/regup/app/update_test.go @@ -9,14 +9,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/registry" ) //nolint:paralleltest // This test manages temporary directories and cannot run in parallel func TestUpdateCmdFunc(t *testing.T) { - // Initialize logger for tests - logger.Initialize() tests := []struct { name string diff --git a/cmd/regup/main.go b/cmd/regup/main.go index b6f9f7ba6..eef0341bc 100644 --- a/cmd/regup/main.go +++ b/cmd/regup/main.go @@ -5,12 +5,11 @@ import ( "os" "github.com/stacklok/toolhive/cmd/regup/app" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" ) func main() { - // Initialize the logger system - logger.Initialize() + logger := log.NewLogger() if err := app.NewRootCmd().Execute(); err != nil { logger.Errorf("%v", err) diff --git a/cmd/thv-operator/controllers/mcpserver_controller.go b/cmd/thv-operator/controllers/mcpserver_controller.go index c3533fc8d..6139e9d2f 100644 --- a/cmd/thv-operator/controllers/mcpserver_controller.go +++ b/cmd/thv-operator/controllers/mcpserver_controller.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "go.uber.org/zap" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" rbacv1 "k8s.io/api/rbac/v1" @@ -28,13 +29,13 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" - "github.com/stacklok/toolhive/pkg/logger" ) // MCPServerReconciler reconciles a MCPServer object type MCPServerReconciler struct { client.Client Scheme *runtime.Scheme + logger *zap.SugaredLogger } // defaultRBACRules are the default RBAC rules that the @@ -222,7 +223,7 @@ func (r *MCPServerReconciler) Reconcile(ctx context.Context, req ctrl.Request) ( } // Check if the deployment spec changed - if deploymentNeedsUpdate(deployment, mcpServer) { + if deploymentNeedsUpdate(deployment, mcpServer, r.logger) { // Update the deployment newDeployment := r.deploymentForMCPServer(mcpServer) deployment.Spec = newDeployment.Spec @@ -283,7 +284,7 @@ func (r *MCPServerReconciler) createRBACResource( ) error { desired := createResource() if err := controllerutil.SetControllerReference(mcpServer, desired, r.Scheme); err != nil { - logger.Error(fmt.Sprintf("Failed to set controller reference for %s", resourceType), err) + r.logger.Error(fmt.Sprintf("Failed to set controller reference for %s", resourceType), err) return nil } @@ -309,7 +310,7 @@ func (r *MCPServerReconciler) updateRBACResourceIfNeeded( ) error { desired := createResource() if err := controllerutil.SetControllerReference(mcpServer, desired, r.Scheme); err != nil { - logger.Error(fmt.Sprintf("Failed to set controller reference for %s", resourceType), err) + r.logger.Error(fmt.Sprintf("Failed to set controller reference for %s", resourceType), err) return nil } @@ -404,7 +405,7 @@ func (r *MCPServerReconciler) deploymentForMCPServer(m *mcpv1alpha1.MCPServer) * if finalPodTemplateSpec != nil { podTemplatePatch, err := json.Marshal(finalPodTemplateSpec) if err != nil { - logger.Errorf("Failed to marshal pod template spec: %v", err) + r.logger.Errorf("Failed to marshal pod template spec: %v", err) } else { args = append(args, fmt.Sprintf("--k8s-pod-patch=%s", string(podTemplatePatch))) } @@ -619,7 +620,7 @@ func (r *MCPServerReconciler) deploymentForMCPServer(m *mcpv1alpha1.MCPServer) * // Set MCPServer instance as the owner and controller if err := controllerutil.SetControllerReference(m, dep, r.Scheme); err != nil { - logger.Error("Failed to set controller reference for Deployment", err) + r.logger.Error("Failed to set controller reference for Deployment", err) return nil } return dep @@ -676,7 +677,7 @@ func (r *MCPServerReconciler) serviceForMCPServer(m *mcpv1alpha1.MCPServer) *cor // Set MCPServer instance as the owner and controller if err := controllerutil.SetControllerReference(m, svc, r.Scheme); err != nil { - logger.Error("Failed to set controller reference for Service", err) + r.logger.Error("Failed to set controller reference for Service", err) return nil } return svc @@ -773,7 +774,7 @@ func (r *MCPServerReconciler) finalizeMCPServer(ctx context.Context, m *mcpv1alp // deploymentNeedsUpdate checks if the deployment needs to be updated // //nolint:gocyclo -func deploymentNeedsUpdate(deployment *appsv1.Deployment, mcpServer *mcpv1alpha1.MCPServer) bool { +func deploymentNeedsUpdate(deployment *appsv1.Deployment, mcpServer *mcpv1alpha1.MCPServer, logger *zap.SugaredLogger) bool { // Check if the container args have changed if len(deployment.Spec.Template.Spec.Containers) > 0 { container := deployment.Spec.Template.Spec.Containers[0] @@ -1198,13 +1199,13 @@ func (r *MCPServerReconciler) generateOIDCArgs(ctx context.Context, m *mcpv1alph } // generateKubernetesOIDCArgs generates OIDC args for Kubernetes service account token validation -func (*MCPServerReconciler) generateKubernetesOIDCArgs(m *mcpv1alpha1.MCPServer) []string { +func (r *MCPServerReconciler) generateKubernetesOIDCArgs(m *mcpv1alpha1.MCPServer) []string { var args []string config := m.Spec.OIDCConfig.Kubernetes // Set defaults if config is nil if config == nil { - logger.Infof("Kubernetes OIDCConfig is nil for MCPServer %s, using default configuration", m.Name) + r.logger.Infof("Kubernetes OIDCConfig is nil for MCPServer %s, using default configuration", m.Name) defaultUseClusterAuth := true config = &mcpv1alpha1.KubernetesOIDCConfig{ UseClusterAuth: &defaultUseClusterAuth, // Default to true @@ -1272,7 +1273,7 @@ func (r *MCPServerReconciler) generateConfigMapOIDCArgs( // nolint:gocyclo Namespace: m.Namespace, }, configMap) if err != nil { - logger.Errorf("Failed to get ConfigMap %s: %v", config.Name, err) + r.logger.Errorf("Failed to get ConfigMap %s: %v", config.Name, err) return args } diff --git a/cmd/thv-operator/controllers/mcpserver_oidc_test.go b/cmd/thv-operator/controllers/mcpserver_oidc_test.go index 5cfb40b8c..073575b9c 100644 --- a/cmd/thv-operator/controllers/mcpserver_oidc_test.go +++ b/cmd/thv-operator/controllers/mcpserver_oidc_test.go @@ -27,14 +27,9 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client/fake" mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" ) -func init() { - // Initialize logger for tests - logger.Initialize() -} - func TestGenerateOIDCArgs(t *testing.T) { t.Parallel() @@ -230,6 +225,7 @@ func TestGenerateOIDCArgs(t *testing.T) { reconciler := &MCPServerReconciler{ Client: fakeClient, Scheme: scheme, + logger: log.NewLogger(), } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -245,7 +241,9 @@ func TestGenerateOIDCArgs(t *testing.T) { func TestGenerateKubernetesOIDCArgs(t *testing.T) { t.Parallel() - reconciler := &MCPServerReconciler{} + reconciler := &MCPServerReconciler{ + logger: log.NewLogger(), + } tests := []struct { name string @@ -357,7 +355,9 @@ func TestGenerateKubernetesOIDCArgs(t *testing.T) { func TestGenerateInlineOIDCArgs(t *testing.T) { t.Parallel() - reconciler := &MCPServerReconciler{} + reconciler := &MCPServerReconciler{ + logger: log.NewLogger(), + } tests := []struct { name string diff --git a/cmd/thv-operator/controllers/mcpserver_resource_overrides_test.go b/cmd/thv-operator/controllers/mcpserver_resource_overrides_test.go index 6a4703694..fc9d8a042 100644 --- a/cmd/thv-operator/controllers/mcpserver_resource_overrides_test.go +++ b/cmd/thv-operator/controllers/mcpserver_resource_overrides_test.go @@ -25,6 +25,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client/fake" mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" + log "github.com/stacklok/toolhive/pkg/logger" ) func TestResourceOverrides(t *testing.T) { @@ -263,6 +264,7 @@ func TestResourceOverrides(t *testing.T) { r := &MCPServerReconciler{ Client: client, Scheme: scheme, + logger: log.NewLogger(), } // Test deployment creation @@ -366,6 +368,7 @@ func TestDeploymentNeedsUpdateServiceAccount(t *testing.T) { r := &MCPServerReconciler{ Client: client, Scheme: scheme, + logger: log.NewLogger(), } mcpServer := &mcpv1alpha1.MCPServer{ @@ -384,7 +387,7 @@ func TestDeploymentNeedsUpdateServiceAccount(t *testing.T) { require.NotNil(t, deployment) // Test with the current deployment - this should NOT need update - needsUpdate := deploymentNeedsUpdate(deployment, mcpServer) + needsUpdate := deploymentNeedsUpdate(deployment, mcpServer, r.logger) // With the service account bug fixed, this should now return false assert.False(t, needsUpdate, "deploymentNeedsUpdate should return false when deployment matches MCPServer spec") @@ -400,6 +403,7 @@ func TestDeploymentNeedsUpdateProxyEnv(t *testing.T) { r := &MCPServerReconciler{ Client: client, Scheme: scheme, + logger: log.NewLogger(), } tests := []struct { @@ -563,7 +567,7 @@ func TestDeploymentNeedsUpdateProxyEnv(t *testing.T) { deployment.Spec.Template.Spec.Containers[0].Image = getToolhiveRunnerImage() // Test if deployment needs update - should correlate with env change expectation - needsUpdate := deploymentNeedsUpdate(deployment, tt.mcpServer) + needsUpdate := deploymentNeedsUpdate(deployment, tt.mcpServer, r.logger) if tt.expectEnvChange { assert.True(t, needsUpdate, "Expected deployment update due to proxy env change") @@ -588,6 +592,7 @@ func TestDeploymentNeedsUpdateToolsFilter(t *testing.T) { r := &MCPServerReconciler{ Client: client, Scheme: scheme, + logger: log.NewLogger(), } tests := []struct { @@ -643,7 +648,7 @@ func TestDeploymentNeedsUpdateToolsFilter(t *testing.T) { mcpServer.Spec.ToolsFilter = tt.newToolsFilter - needsUpdate := deploymentNeedsUpdate(deployment, mcpServer) + needsUpdate := deploymentNeedsUpdate(deployment, mcpServer, r.logger) assert.Equal(t, tt.expectEnvChange, needsUpdate) }) } diff --git a/cmd/thv-operator/main.go b/cmd/thv-operator/main.go index d84d18a4d..2b3e3e0a1 100644 --- a/cmd/thv-operator/main.go +++ b/cmd/thv-operator/main.go @@ -48,9 +48,6 @@ func main() { "Enabling this will ensure there is only one active controller manager.") flag.Parse() - // Initialize the structured logger - logger.Initialize() - // Set the controller-runtime logger to use our structured logger ctrl.SetLogger(logger.NewLogr()) diff --git a/cmd/thv-proxyrunner/app/commands.go b/cmd/thv-proxyrunner/app/commands.go index cfbc6f129..c957ab48b 100644 --- a/cmd/thv-proxyrunner/app/commands.go +++ b/cmd/thv-proxyrunner/app/commands.go @@ -4,10 +4,13 @@ package app import ( "github.com/spf13/cobra" "github.com/spf13/viper" + "go.uber.org/zap" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" ) +var logger *zap.SugaredLogger = log.NewLogger() + var rootCmd = &cobra.Command{ Use: "thv-proxyrunner", DisableAutoGenTag: true, @@ -20,9 +23,6 @@ It is written in Go and has extensive test coverage—including input validation logger.Errorf("Error displaying help: %v", err) } }, - PersistentPreRun: func(_ *cobra.Command, _ []string) { - logger.Initialize() - }, } // NewRootCmd creates a new root command for the ToolHive CLI. diff --git a/cmd/thv-proxyrunner/app/run.go b/cmd/thv-proxyrunner/app/run.go index 4c6248dd3..5b5822f58 100644 --- a/cmd/thv-proxyrunner/app/run.go +++ b/cmd/thv-proxyrunner/app/run.go @@ -8,7 +8,6 @@ import ( "github.com/spf13/cobra" "github.com/stacklok/toolhive/pkg/container" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/registry" "github.com/stacklok/toolhive/pkg/runner" "github.com/stacklok/toolhive/pkg/transport" @@ -224,7 +223,7 @@ func runCmdFunc(cmd *cobra.Command, args []string) error { finalOtelEndpoint, finalOtelSamplingRate, finalOtelEnvironmentVariables := "", 0.0, []string{} // Create container runtime - rt, err := container.NewFactory().Create(ctx) + rt, err := container.NewFactory(logger).Create(ctx) if err != nil { return fmt.Errorf("failed to create container runtime: %v", err) } @@ -233,12 +232,12 @@ func runCmdFunc(cmd *cobra.Command, args []string) error { // If we have called the CLI directly, we use the CLIEnvVarValidator. // If we are running in detached mode, or the CLI is wrapped by the K8s operator, // we use the DetachedEnvVarValidator. - envVarValidator := &runner.DetachedEnvVarValidator{} + envVarValidator := runner.NewDetachedEnvVarValidator(logger) var imageMetadata *registry.ImageMetadata // Initialize a new RunConfig with values from command-line flags - runConfig, err := runner.NewRunConfigBuilder(). + runConfig, err := runner.NewRunConfigBuilder(logger). WithRuntime(rt). WithCmdArgs(cmdArgs). WithName(runName). @@ -266,7 +265,7 @@ func runCmdFunc(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to create RunConfig: %v", err) } - workloadManager, err := workloads.NewManagerFromRuntime(rt) + workloadManager, err := workloads.NewManagerFromRuntime(rt, logger) if err != nil { return fmt.Errorf("failed to create workload manager: %v", err) } diff --git a/cmd/thv-proxyrunner/main.go b/cmd/thv-proxyrunner/main.go index 1cd17b86a..fe73ddb27 100644 --- a/cmd/thv-proxyrunner/main.go +++ b/cmd/thv-proxyrunner/main.go @@ -5,13 +5,9 @@ import ( "os" "github.com/stacklok/toolhive/cmd/thv-proxyrunner/app" - "github.com/stacklok/toolhive/pkg/logger" ) func main() { - // Initialize the logger - logger.Initialize() - // Skip update check for completion command or if we are running in kubernetes if err := app.NewRootCmd().Execute(); err != nil { os.Exit(1) diff --git a/cmd/thv/app/build.go b/cmd/thv/app/build.go index 1fce82d4d..217791b60 100644 --- a/cmd/thv/app/build.go +++ b/cmd/thv/app/build.go @@ -7,7 +7,6 @@ import ( "github.com/spf13/cobra" "github.com/stacklok/toolhive/pkg/container/images" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/runner" ) @@ -70,11 +69,12 @@ func buildCmdFunc(cmd *cobra.Command, args []string) error { } // Create image manager (even for dry-run, we pass it but it won't be used) - imageManager := images.NewImageManager(ctx) + imageManager := images.NewImageManager(ctx, logger) // If dry-run or output is specified, just generate the Dockerfile if buildFlags.DryRun || buildFlags.Output != "" { - dockerfileContent, err := runner.BuildFromProtocolSchemeWithName(ctx, imageManager, protocolScheme, "", buildFlags.Tag, true) + dockerfileContent, err := runner.BuildFromProtocolSchemeWithName( + ctx, imageManager, protocolScheme, "", buildFlags.Tag, true, logger) if err != nil { return fmt.Errorf("failed to generate Dockerfile for %s: %v", protocolScheme, err) } @@ -96,7 +96,7 @@ func buildCmdFunc(cmd *cobra.Command, args []string) error { logger.Infof("Building container for protocol scheme: %s", protocolScheme) // Build the image using the new protocol handler with custom name - imageName, err := runner.BuildFromProtocolSchemeWithName(ctx, imageManager, protocolScheme, "", buildFlags.Tag, false) + imageName, err := runner.BuildFromProtocolSchemeWithName(ctx, imageManager, protocolScheme, "", buildFlags.Tag, false, logger) if err != nil { return fmt.Errorf("failed to build container for %s: %v", protocolScheme, err) } diff --git a/cmd/thv/app/client.go b/cmd/thv/app/client.go index fe9a15140..d3948c896 100644 --- a/cmd/thv/app/client.go +++ b/cmd/thv/app/client.go @@ -6,13 +6,13 @@ import ( "sort" "github.com/spf13/cobra" + "go.uber.org/zap" "github.com/stacklok/toolhive/cmd/thv/app/ui" "github.com/stacklok/toolhive/pkg/client" "github.com/stacklok/toolhive/pkg/config" "github.com/stacklok/toolhive/pkg/core" "github.com/stacklok/toolhive/pkg/groups" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/workloads" ) @@ -110,7 +110,7 @@ func init() { } func clientStatusCmdFunc(_ *cobra.Command, _ []string) error { - clientStatuses, err := client.GetClientStatus() + clientStatuses, err := client.GetClientStatus(logger) if err != nil { return fmt.Errorf("failed to get client status: %w", err) } @@ -118,7 +118,8 @@ func clientStatusCmdFunc(_ *cobra.Command, _ []string) error { } func clientSetupCmdFunc(cmd *cobra.Command, _ []string) error { - clientStatuses, err := client.GetClientStatus() + clientStatuses, err := client.GetClientStatus(logger) + if err != nil { return fmt.Errorf("failed to get client status: %w", err) } @@ -128,7 +129,7 @@ func clientSetupCmdFunc(cmd *cobra.Command, _ []string) error { return nil } // Get available groups for the UI - groupManager, err := groups.NewManager() + groupManager, err := groups.NewManager(logger) if err != nil { return fmt.Errorf("failed to create group manager: %w", err) } @@ -211,11 +212,11 @@ func clientRemoveCmdFunc(cmd *cobra.Command, args []string) error { clientType) } - return performClientRemoval(cmd.Context(), client.Client{Name: client.MCPClient(clientType)}, groupNames) + return performClientRemoval(cmd.Context(), client.Client{Name: client.MCPClient(clientType)}, groupNames, logger) } func listRegisteredClientsCmdFunc(_ *cobra.Command, _ []string) error { - cfg := config.GetConfig() + cfg := config.GetConfig(logger) if len(cfg.Clients.RegisteredClients) == 0 { fmt.Println("No clients are currently registered.") return nil @@ -234,12 +235,12 @@ func listRegisteredClientsCmdFunc(_ *cobra.Command, _ []string) error { } func performClientRegistration(ctx context.Context, clients []client.Client, groupNames []string) error { - clientManager, err := client.NewManager(ctx) + clientManager, err := client.NewManager(ctx, logger) if err != nil { return fmt.Errorf("failed to create client manager: %w", err) } - workloadManager, err := workloads.NewManager(ctx) + workloadManager, err := workloads.NewManager(ctx, logger) if err != nil { return fmt.Errorf("failed to create workload manager: %w", err) } @@ -266,7 +267,7 @@ func registerClientsWithGroups( ) error { fmt.Printf("Filtering workloads to groups: %v\n", groupNames) - groupManager, err := groups.NewManager() + groupManager, err := groups.NewManager(logger) if err != nil { return fmt.Errorf("failed to create group manager: %w", err) } @@ -312,7 +313,7 @@ func registerClientsGlobally( } c.Clients.RegisteredClients = append(c.Clients.RegisteredClients, string(clientToRegister.Name)) - }) + }, logger) if err != nil { return fmt.Errorf("failed to update configuration for client %s: %w", clientToRegister.Name, err) } @@ -329,13 +330,18 @@ func registerClientsGlobally( return nil } -func performClientRemoval(ctx context.Context, clientToRemove client.Client, groupNames []string) error { - clientManager, err := client.NewManager(ctx) +func performClientRemoval( + ctx context.Context, + clientToRemove client.Client, + groupNames []string, + logger *zap.SugaredLogger, +) error { + clientManager, err := client.NewManager(ctx, logger) if err != nil { return fmt.Errorf("failed to create client manager: %w", err) } - workloadManager, err := workloads.NewManager(ctx) + workloadManager, err := workloads.NewManager(ctx, logger) if err != nil { return fmt.Errorf("failed to create workload manager: %w", err) } @@ -345,7 +351,7 @@ func performClientRemoval(ctx context.Context, clientToRemove client.Client, gro return fmt.Errorf("failed to list running workloads: %w", err) } - groupManager, err := groups.NewManager() + groupManager, err := groups.NewManager(logger) if err != nil { return fmt.Errorf("failed to create group manager: %w", err) } @@ -354,7 +360,7 @@ func performClientRemoval(ctx context.Context, clientToRemove client.Client, gro return removeClientFromGroups(ctx, clientToRemove, groupNames, runningWorkloads, groupManager, clientManager) } - return removeClientGlobally(ctx, clientToRemove, runningWorkloads, groupManager, clientManager) + return removeClientGlobally(ctx, clientToRemove, runningWorkloads, groupManager, clientManager, logger) } func removeClientFromGroups( @@ -396,6 +402,7 @@ func removeClientGlobally( runningWorkloads []core.Workload, groupManager groups.Manager, clientManager client.Manager, + logger *zap.SugaredLogger, ) error { // Remove the workloads from the client's configuration file err := clientManager.UnregisterClients(ctx, []client.Client{clientToRemove}, runningWorkloads) @@ -432,7 +439,7 @@ func removeClientGlobally( } } logger.Warnf("Client %s was not found in registered clients list", clientToRemove.Name) - }) + }, logger) if err != nil { return fmt.Errorf("failed to update configuration for client %s: %w", clientToRemove.Name, err) } diff --git a/cmd/thv/app/commands.go b/cmd/thv/app/commands.go index ba13630ce..32ba7390a 100644 --- a/cmd/thv/app/commands.go +++ b/cmd/thv/app/commands.go @@ -5,7 +5,6 @@ import ( "github.com/spf13/cobra" "github.com/spf13/viper" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/updates" ) @@ -25,9 +24,6 @@ container-based isolation for running MCP servers.`, logger.Errorf("Error displaying help: %v", err) } }, - PersistentPreRun: func(_ *cobra.Command, _ []string) { - logger.Initialize() - }, } // NewRootCmd creates a new root command for the ToolHive CLI. diff --git a/cmd/thv/app/common.go b/cmd/thv/app/common.go index 8d4f402df..f9de2d1c1 100644 --- a/cmd/thv/app/common.go +++ b/cmd/thv/app/common.go @@ -5,12 +5,16 @@ import ( "fmt" "github.com/spf13/cobra" + "go.uber.org/zap" "github.com/stacklok/toolhive/pkg/config" + log "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/secrets" "github.com/stacklok/toolhive/pkg/workloads" ) +var logger *zap.SugaredLogger = log.NewLogger() + // AddOIDCFlags adds OIDC validation flags to the provided command. func AddOIDCFlags(cmd *cobra.Command) { cmd.Flags().String("oidc-issuer", "", "OIDC issuer URL (e.g., https://accounts.google.com)") @@ -64,7 +68,7 @@ func SetSecretsProvider(provider secrets.ProviderType) error { // Validate that the provider can be created and works correctly ctx := context.Background() - result := secrets.ValidateProvider(ctx, provider) + result := secrets.ValidateProvider(ctx, provider, logger) if !result.Success { return fmt.Errorf("provider validation failed: %w", result.Error) } @@ -73,7 +77,7 @@ func SetSecretsProvider(provider secrets.ProviderType) error { err := config.UpdateConfig(func(c *config.Config) { c.Secrets.ProviderType = string(provider) c.Secrets.SetupCompleted = true - }) + }, logger) if err != nil { return fmt.Errorf("failed to update configuration: %w", err) } @@ -94,7 +98,7 @@ func completeMCPServerNames(cmd *cobra.Command, args []string, _ string) ([]stri ctx := cmd.Context() // Create status manager - manager, err := workloads.NewManager(ctx) + manager, err := workloads.NewManager(ctx, logger) if err != nil { return nil, cobra.ShellCompDirectiveError } @@ -126,7 +130,7 @@ func completeLogsArgs(cmd *cobra.Command, args []string, _ string) ([]string, co ctx := cmd.Context() // Create status manager - manager, err := workloads.NewManager(ctx) + manager, err := workloads.NewManager(ctx, logger) if err != nil { return []string{"prune"}, cobra.ShellCompDirectiveNoFileComp } diff --git a/cmd/thv/app/config.go b/cmd/thv/app/config.go index eefb73ec9..146b858c6 100644 --- a/cmd/thv/app/config.go +++ b/cmd/thv/app/config.go @@ -113,14 +113,14 @@ func setCACertCmdFunc(_ *cobra.Command, args []string) error { } // Validate the certificate format - if err := certs.ValidateCACertificate(certContent); err != nil { + if err := certs.ValidateCACertificate(certContent, logger); err != nil { return fmt.Errorf("invalid CA certificate: %w", err) } // Update the configuration err = config.UpdateConfig(func(c *config.Config) { c.CACertificatePath = certPath - }) + }, logger) if err != nil { return fmt.Errorf("failed to update configuration: %w", err) } @@ -130,7 +130,7 @@ func setCACertCmdFunc(_ *cobra.Command, args []string) error { } func getCACertCmdFunc(_ *cobra.Command, _ []string) error { - cfg := config.GetConfig() + cfg := config.GetConfig(logger) if cfg.CACertificatePath == "" { fmt.Println("No CA certificate is currently configured.") @@ -148,7 +148,7 @@ func getCACertCmdFunc(_ *cobra.Command, _ []string) error { } func unsetCACertCmdFunc(_ *cobra.Command, _ []string) error { - cfg := config.GetConfig() + cfg := config.GetConfig(logger) if cfg.CACertificatePath == "" { fmt.Println("No CA certificate is currently configured.") @@ -158,7 +158,7 @@ func unsetCACertCmdFunc(_ *cobra.Command, _ []string) error { // Update the configuration err := config.UpdateConfig(func(c *config.Config) { c.CACertificatePath = "" - }) + }, logger) if err != nil { return fmt.Errorf("failed to update configuration: %w", err) } @@ -173,7 +173,7 @@ func setRegistryCmdFunc(_ *cobra.Command, args []string) error { switch registryType { case config.RegistryTypeURL: - err := config.SetRegistryURL(cleanPath, allowPrivateRegistryIp) + err := config.SetRegistryURL(cleanPath, allowPrivateRegistryIp, logger) if err != nil { return err } @@ -188,14 +188,14 @@ func setRegistryCmdFunc(_ *cobra.Command, args []string) error { } return nil case config.RegistryTypeFile: - return config.SetRegistryFile(cleanPath) + return config.SetRegistryFile(cleanPath, logger) default: return fmt.Errorf("unsupported registry type") } } func getRegistryCmdFunc(_ *cobra.Command, _ []string) error { - url, localPath, _, registryType := config.GetRegistryConfig() + url, localPath, _, registryType := config.GetRegistryConfig(logger) switch registryType { case config.RegistryTypeURL: @@ -213,14 +213,14 @@ func getRegistryCmdFunc(_ *cobra.Command, _ []string) error { } func unsetRegistryCmdFunc(_ *cobra.Command, _ []string) error { - url, localPath, _, registryType := config.GetRegistryConfig() + url, localPath, _, registryType := config.GetRegistryConfig(logger) if registryType == "default" { fmt.Println("No custom registry is currently configured.") return nil } - err := config.UnsetRegistry() + err := config.UnsetRegistry(logger) if err != nil { return fmt.Errorf("failed to update configuration: %w", err) } diff --git a/cmd/thv/app/group.go b/cmd/thv/app/group.go index eda36f478..f21e49727 100644 --- a/cmd/thv/app/group.go +++ b/cmd/thv/app/group.go @@ -66,7 +66,7 @@ func groupCreateCmdFunc(cmd *cobra.Command, args []string) error { groupName := args[0] ctx := cmd.Context() - manager, err := groups.NewManager() + manager, err := groups.NewManager(logger) if err != nil { return fmt.Errorf("failed to create group manager: %w", err) } @@ -82,7 +82,7 @@ func groupCreateCmdFunc(cmd *cobra.Command, args []string) error { func groupListCmdFunc(cmd *cobra.Command, _ []string) error { ctx := cmd.Context() - manager, err := groups.NewManager() + manager, err := groups.NewManager(logger) if err != nil { return fmt.Errorf("failed to create group manager: %w", err) } @@ -121,7 +121,7 @@ func groupRmCmdFunc(cmd *cobra.Command, args []string) error { if strings.EqualFold(groupName, groups.DefaultGroup) { return fmt.Errorf("cannot delete the %s group", groups.DefaultGroup) } - manager, err := groups.NewManager() + manager, err := groups.NewManager(logger) if err != nil { return fmt.Errorf("failed to create group manager: %w", err) } @@ -136,7 +136,7 @@ func groupRmCmdFunc(cmd *cobra.Command, args []string) error { } // Create workloads manager - workloadsManager, err := workloads.NewManager(ctx) + workloadsManager, err := workloads.NewManager(ctx, logger) if err != nil { return fmt.Errorf("failed to create workloads manager: %w", err) } @@ -223,7 +223,7 @@ func showWarningAndGetConfirmation(groupName string, groupWorkloads []string) (b func deleteWorkloadsInGroup(ctx context.Context, groupWorkloads []string, groupName string) error { // Delete workloads - workloadManager, err := workloads.NewManager(ctx) + workloadManager, err := workloads.NewManager(ctx, logger) if err != nil { return fmt.Errorf("failed to create workload manager: %w", err) } @@ -246,7 +246,7 @@ func deleteWorkloadsInGroup(ctx context.Context, groupWorkloads []string, groupN // removeWorkloadsFromGroup removes the group membership from the workloads // in the group. func removeWorkloadsMembershipFromGroup(ctx context.Context, groupWorkloads []string, groupName string) error { - workloadManager, err := workloads.NewManager(ctx) + workloadManager, err := workloads.NewManager(ctx, logger) if err != nil { return fmt.Errorf("failed to create workload manager: %w", err) } diff --git a/cmd/thv/app/inspector.go b/cmd/thv/app/inspector.go index 389dd060e..353abc9fa 100644 --- a/cmd/thv/app/inspector.go +++ b/cmd/thv/app/inspector.go @@ -14,7 +14,6 @@ import ( "github.com/stacklok/toolhive/pkg/container/images" "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/labels" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/permissions" "github.com/stacklok/toolhive/pkg/runner" "github.com/stacklok/toolhive/pkg/transport/types" @@ -101,8 +100,8 @@ func inspectorCmdFunc(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to find server: %v", err) } - imageManager := images.NewImageManager(ctx) - processedImage, err := runner.HandleProtocolScheme(ctx, imageManager, inspector.Image, "") + imageManager := images.NewImageManager(ctx, logger) + processedImage, err := runner.HandleProtocolScheme(ctx, imageManager, inspector.Image, "", logger) if err != nil { return fmt.Errorf("failed to handle protocol scheme: %v", err) } @@ -114,7 +113,7 @@ func inspectorCmdFunc(cmd *cobra.Command, args []string) error { options := buildInspectorContainerOptions(uiPortStr, mcpPortStr) // Create workload runtime - rt, err := container.NewFactory().Create(ctx) + rt, err := container.NewFactory(logger).Create(ctx) if err != nil { return fmt.Errorf("failed to create workload runtime: %v", err) } @@ -168,7 +167,7 @@ func inspectorCmdFunc(cmd *cobra.Command, args []string) error { func getServerPortAndTransport(ctx context.Context, serverName string) (int, types.TransportType, error) { // Instantiate the status manager and list all workloads. - manager, err := workloads.NewManager(ctx) + manager, err := workloads.NewManager(ctx, logger) if err != nil { return 0, types.TransportTypeSSE, fmt.Errorf("failed to create status manager: %v", err) } diff --git a/cmd/thv/app/list.go b/cmd/thv/app/list.go index 7fe7bc34d..659c48096 100644 --- a/cmd/thv/app/list.go +++ b/cmd/thv/app/list.go @@ -7,9 +7,9 @@ import ( "text/tabwriter" "github.com/spf13/cobra" + "go.uber.org/zap" "github.com/stacklok/toolhive/pkg/core" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/workloads" ) @@ -42,7 +42,7 @@ func listCmdFunc(cmd *cobra.Command, _ []string) error { ctx := cmd.Context() // Instantiate the status manager. - manager, err := workloads.NewManager(ctx) + manager, err := workloads.NewManager(ctx, logger) if err != nil { return fmt.Errorf("failed to create status manager: %v", err) } @@ -76,7 +76,7 @@ func listCmdFunc(cmd *cobra.Command, _ []string) error { case "mcpservers": return printMCPServersOutput(workloadList) default: - printTextOutput(workloadList) + printTextOutput(workloadList, logger) return nil } } @@ -122,7 +122,7 @@ func printMCPServersOutput(workloadList []core.Workload) error { } // printTextOutput prints workload information in text format -func printTextOutput(workloadList []core.Workload) { +func printTextOutput(workloadList []core.Workload, logger *zap.SugaredLogger) { // Create a tabwriter for pretty output w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) fmt.Fprintln(w, "NAME\tPACKAGE\tSTATUS\tURL\tPORT\tTOOL TYPE\tGROUP\tCREATED AT") diff --git a/cmd/thv/app/logs.go b/cmd/thv/app/logs.go index 5a88a4220..2fbdd7daa 100644 --- a/cmd/thv/app/logs.go +++ b/cmd/thv/app/logs.go @@ -14,7 +14,6 @@ import ( "github.com/spf13/viper" rt "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/workloads" ) @@ -82,7 +81,7 @@ func logsCmdFunc(cmd *cobra.Command, args []string) error { return getProxyLogs(workloadName, follow) } - manager, err := workloads.NewManager(ctx) + manager, err := workloads.NewManager(ctx, logger) if err != nil { return fmt.Errorf("failed to create workload manager: %v", err) } @@ -144,7 +143,7 @@ func getLogsDirectory() (string, error) { } func getManagedContainerNames(ctx context.Context) (map[string]bool, error) { - manager, err := workloads.NewManager(ctx) + manager, err := workloads.NewManager(ctx, logger) if err != nil { return nil, fmt.Errorf("failed to create status manager: %v", err) } diff --git a/cmd/thv/app/mcp.go b/cmd/thv/app/mcp.go index c47437d82..be760151a 100644 --- a/cmd/thv/app/mcp.go +++ b/cmd/thv/app/mcp.go @@ -14,7 +14,6 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/spf13/cobra" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/transport/ssecommon" "github.com/stacklok/toolhive/pkg/transport/streamable" "github.com/stacklok/toolhive/pkg/transport/types" diff --git a/cmd/thv/app/mcp_serve.go b/cmd/thv/app/mcp_serve.go index 07f6f0d35..ce7f7f831 100644 --- a/cmd/thv/app/mcp_serve.go +++ b/cmd/thv/app/mcp_serve.go @@ -14,7 +14,6 @@ import ( "github.com/mark3labs/mcp-go/server" "github.com/spf13/cobra" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/registry" "github.com/stacklok/toolhive/pkg/runner/retriever" "github.com/stacklok/toolhive/pkg/versions" @@ -225,13 +224,13 @@ type toolHiveHandler struct { // newToolHiveHandler creates a new ToolHive handler func newToolHiveHandler(ctx context.Context) (*toolHiveHandler, error) { // Create workload manager - workloadManager, err := workloads.NewManager(ctx) + workloadManager, err := workloads.NewManager(ctx, logger) if err != nil { return nil, fmt.Errorf("failed to create workload manager: %w", err) } // Create registry provider - registryProvider, err := registry.GetDefaultProvider() + registryProvider, err := registry.GetDefaultProvider(logger) if err != nil { return nil, fmt.Errorf("failed to get registry provider: %w", err) } @@ -304,7 +303,7 @@ func (h *toolHiveHandler) runServer(ctx context.Context, request mcp.CallToolReq // Use retriever to properly fetch and prepare the MCP server // TODO: make this configurable so we could warn or even fail - imageURL, imageMetadata, err := retriever.GetMCPServer(ctx, args.Server, "", "disabled") + imageURL, imageMetadata, err := retriever.GetMCPServer(ctx, args.Server, "", "disabled", logger) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("Failed to get MCP server: %v", err)), nil } diff --git a/cmd/thv/app/mcp_serve_helpers.go b/cmd/thv/app/mcp_serve_helpers.go index a3903f9bb..3ab68e2d6 100644 --- a/cmd/thv/app/mcp_serve_helpers.go +++ b/cmd/thv/app/mcp_serve_helpers.go @@ -7,7 +7,6 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/stacklok/toolhive/pkg/container" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/registry" "github.com/stacklok/toolhive/pkg/runner" transporttypes "github.com/stacklok/toolhive/pkg/transport/types" @@ -49,13 +48,13 @@ func buildServerConfig( imageMetadata *registry.ImageMetadata, ) (*runner.RunConfig, error) { // Create container runtime - rt, err := container.NewFactory().Create(ctx) + rt, err := container.NewFactory(logger).Create(ctx) if err != nil { return nil, fmt.Errorf("failed to create container runtime: %w", err) } // Build configuration using the builder pattern - builder := runner.NewRunConfigBuilder(). + builder := runner.NewRunConfigBuilder(logger). WithRuntime(rt). WithImage(imageURL). WithName(args.Name). @@ -117,7 +116,7 @@ func prepareEnvironmentVariables(imageMetadata *registry.ImageMetadata, userEnv // saveAndRunServer saves the configuration and runs the server func (h *toolHiveHandler) saveAndRunServer(ctx context.Context, runConfig *runner.RunConfig, name string) error { // Save the run configuration state before starting - if err := runConfig.SaveState(ctx); err != nil { + if err := runConfig.SaveState(ctx, logger); err != nil { logger.Warnf("Failed to save run configuration for %s: %v", name, err) // Continue anyway, as this is not critical for running } diff --git a/cmd/thv/app/otel.go b/cmd/thv/app/otel.go index acf7eed23..0bad96394 100644 --- a/cmd/thv/app/otel.go +++ b/cmd/thv/app/otel.go @@ -126,7 +126,7 @@ func setOtelEndpointCmdFunc(_ *cobra.Command, args []string) error { // Update the configuration err := config.UpdateConfig(func(c *config.Config) { c.OTEL.Endpoint = endpoint - }) + }, logger) if err != nil { return fmt.Errorf("failed to update configuration: %w", err) } @@ -136,7 +136,7 @@ func setOtelEndpointCmdFunc(_ *cobra.Command, args []string) error { } func getOtelEndpointCmdFunc(_ *cobra.Command, _ []string) error { - cfg := config.GetConfig() + cfg := config.GetConfig(logger) if cfg.OTEL.Endpoint == "" { fmt.Println("No OpenTelemetry endpoint is currently configured.") @@ -148,7 +148,7 @@ func getOtelEndpointCmdFunc(_ *cobra.Command, _ []string) error { } func unsetOtelEndpointCmdFunc(_ *cobra.Command, _ []string) error { - cfg := config.GetConfig() + cfg := config.GetConfig(logger) if cfg.OTEL.Endpoint == "" { fmt.Println("No OpenTelemetry endpoint is currently configured.") @@ -158,7 +158,7 @@ func unsetOtelEndpointCmdFunc(_ *cobra.Command, _ []string) error { // Update the configuration err := config.UpdateConfig(func(c *config.Config) { c.OTEL.Endpoint = "" - }) + }, logger) if err != nil { return fmt.Errorf("failed to update configuration: %w", err) } @@ -181,7 +181,7 @@ func setOtelSamplingRateCmdFunc(_ *cobra.Command, args []string) error { // Update the configuration err = config.UpdateConfig(func(c *config.Config) { c.OTEL.SamplingRate = rate - }) + }, logger) if err != nil { return fmt.Errorf("failed to update configuration: %w", err) } @@ -191,7 +191,7 @@ func setOtelSamplingRateCmdFunc(_ *cobra.Command, args []string) error { } func getOtelSamplingRateCmdFunc(_ *cobra.Command, _ []string) error { - cfg := config.GetConfig() + cfg := config.GetConfig(logger) if cfg.OTEL.SamplingRate == 0.0 { fmt.Println("No OpenTelemetry sampling rate is currently configured.") @@ -203,7 +203,7 @@ func getOtelSamplingRateCmdFunc(_ *cobra.Command, _ []string) error { } func unsetOtelSamplingRateCmdFunc(_ *cobra.Command, _ []string) error { - cfg := config.GetConfig() + cfg := config.GetConfig(logger) if cfg.OTEL.SamplingRate == 0.0 { fmt.Println("No OpenTelemetry sampling rate is currently configured.") @@ -213,7 +213,7 @@ func unsetOtelSamplingRateCmdFunc(_ *cobra.Command, _ []string) error { // Update the configuration err := config.UpdateConfig(func(c *config.Config) { c.OTEL.SamplingRate = 0.0 - }) + }, logger) if err != nil { return fmt.Errorf("failed to update configuration: %w", err) } @@ -233,7 +233,7 @@ func setOtelEnvVarsCmdFunc(_ *cobra.Command, args []string) error { // Update the configuration err := config.UpdateConfig(func(c *config.Config) { c.OTEL.EnvVars = vars - }) + }, logger) if err != nil { return fmt.Errorf("failed to update configuration: %w", err) } @@ -243,7 +243,7 @@ func setOtelEnvVarsCmdFunc(_ *cobra.Command, args []string) error { } func getOtelEnvVarsCmdFunc(_ *cobra.Command, _ []string) error { - cfg := config.GetConfig() + cfg := config.GetConfig(logger) if len(cfg.OTEL.EnvVars) == 0 { fmt.Println("No OpenTelemetry environment variables are currently configured.") @@ -255,7 +255,7 @@ func getOtelEnvVarsCmdFunc(_ *cobra.Command, _ []string) error { } func unsetOtelEnvVarsCmdFunc(_ *cobra.Command, _ []string) error { - cfg := config.GetConfig() + cfg := config.GetConfig(logger) if len(cfg.OTEL.EnvVars) == 0 { fmt.Println("No OpenTelemetry environment variables are currently configured.") @@ -265,7 +265,7 @@ func unsetOtelEnvVarsCmdFunc(_ *cobra.Command, _ []string) error { // Update the configuration err := config.UpdateConfig(func(c *config.Config) { c.OTEL.EnvVars = []string{} - }) + }, logger) if err != nil { return fmt.Errorf("failed to update configuration: %w", err) } diff --git a/cmd/thv/app/proxy.go b/cmd/thv/app/proxy.go index 90da87e18..718bcae73 100644 --- a/cmd/thv/app/proxy.go +++ b/cmd/thv/app/proxy.go @@ -17,7 +17,6 @@ import ( "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/auth/oauth" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/networking" "github.com/stacklok/toolhive/pkg/transport" "github.com/stacklok/toolhive/pkg/transport/proxy/transparent" @@ -195,7 +194,7 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error { } // Select a port for the HTTP proxy (host port) - port, err := networking.FindOrUsePort(proxyPort) + port, err := networking.FindOrUsePort(proxyPort, logger) if err != nil { return err } @@ -245,7 +244,7 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error { } // Get authentication middleware for incoming requests - authMiddleware, authInfoHandler, err := auth.GetAuthenticationMiddleware(ctx, oidcConfig) + authMiddleware, authInfoHandler, err := auth.GetAuthenticationMiddleware(ctx, oidcConfig, logger) if err != nil { return fmt.Errorf("failed to create authentication middleware: %v", err) } @@ -266,6 +265,7 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error { proxyHost, port, serverName, proxyTargetURI, nil, authInfoHandler, false, + logger, middlewares...) if err := proxy.Start(ctx); err != nil { return fmt.Errorf("failed to start proxy: %v", err) @@ -436,7 +436,7 @@ func performOAuthFlow(ctx context.Context, issuer, clientID, clientSecret string } // Create OAuth flow - flow, err := oauth.NewFlow(oauthConfig) + flow, err := oauth.NewFlow(oauthConfig, logger) if err != nil { return nil, nil, fmt.Errorf("failed to create OAuth flow: %w", err) } diff --git a/cmd/thv/app/proxy_tunnel.go b/cmd/thv/app/proxy_tunnel.go index 19abc9c40..816971e04 100644 --- a/cmd/thv/app/proxy_tunnel.go +++ b/cmd/thv/app/proxy_tunnel.go @@ -11,7 +11,6 @@ import ( "github.com/spf13/cobra" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/transport/types" "github.com/stacklok/toolhive/pkg/workloads" ) @@ -111,7 +110,7 @@ func resolveTarget(ctx context.Context, target string) (string, error) { } // Otherwise treat as workload name - workloadManager, err := workloads.NewManager(ctx) + workloadManager, err := workloads.NewManager(ctx, logger) if err != nil { return "", fmt.Errorf("failed to create workload manager: %w", err) } diff --git a/cmd/thv/app/registry.go b/cmd/thv/app/registry.go index 20e55370a..b74eabfc3 100644 --- a/cmd/thv/app/registry.go +++ b/cmd/thv/app/registry.go @@ -54,7 +54,7 @@ func init() { func registryListCmdFunc(_ *cobra.Command, _ []string) error { // Get all servers from registry - provider, err := registry.GetDefaultProvider() + provider, err := registry.GetDefaultProvider(logger) if err != nil { return fmt.Errorf("failed to get registry provider: %v", err) } @@ -79,7 +79,7 @@ func registryListCmdFunc(_ *cobra.Command, _ []string) error { func registryInfoCmdFunc(_ *cobra.Command, args []string) error { // Get server information serverName := args[0] - provider, err := registry.GetDefaultProvider() + provider, err := registry.GetDefaultProvider(logger) if err != nil { return fmt.Errorf("failed to get registry provider: %v", err) } diff --git a/cmd/thv/app/restart.go b/cmd/thv/app/restart.go index 7f8d1b27a..6c11d04d4 100644 --- a/cmd/thv/app/restart.go +++ b/cmd/thv/app/restart.go @@ -52,7 +52,7 @@ func restartCmdFunc(cmd *cobra.Command, args []string) error { } // Create workload managers. - workloadManager, err := workloads.NewManager(ctx) + workloadManager, err := workloads.NewManager(ctx, logger) if err != nil { return fmt.Errorf("failed to create workload manager: %v", err) } @@ -104,7 +104,7 @@ func restartAllContainers(ctx context.Context, workloadManager workloads.Manager func restartWorkloadsByGroup(ctx context.Context, workloadManager workloads.Manager, groupName string, foreground bool) error { // Create a groups manager to list workloads in the group - groupManager, err := groups.NewManager() + groupManager, err := groups.NewManager(logger) if err != nil { return fmt.Errorf("failed to create group manager: %v", err) } diff --git a/cmd/thv/app/rm.go b/cmd/thv/app/rm.go index d77a092fc..1d1cbec28 100644 --- a/cmd/thv/app/rm.go +++ b/cmd/thv/app/rm.go @@ -53,7 +53,7 @@ func rmCmdFunc(cmd *cobra.Command, args []string) error { } // Create workload manager. - manager, err := workloads.NewManager(ctx) + manager, err := workloads.NewManager(ctx, logger) if err != nil { return fmt.Errorf("failed to create workload manager: %v", err) } @@ -75,7 +75,7 @@ func rmCmdFunc(cmd *cobra.Command, args []string) error { func deleteAllWorkloadsInGroup(ctx context.Context, groupName string) error { // Create group manager - groupManager, err := groups.NewManager() + groupManager, err := groups.NewManager(logger) if err != nil { return fmt.Errorf("failed to create group manager: %v", err) } @@ -90,7 +90,7 @@ func deleteAllWorkloadsInGroup(ctx context.Context, groupName string) error { } // Create workload manager - workloadManager, err := workloads.NewManager(ctx) + workloadManager, err := workloads.NewManager(ctx, logger) if err != nil { return fmt.Errorf("failed to create workload manager: %v", err) } diff --git a/cmd/thv/app/run.go b/cmd/thv/app/run.go index 0c76b3597..bad299024 100644 --- a/cmd/thv/app/run.go +++ b/cmd/thv/app/run.go @@ -15,7 +15,6 @@ import ( "github.com/stacklok/toolhive/pkg/container" "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/groups" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/process" "github.com/stacklok/toolhive/pkg/runner" "github.com/stacklok/toolhive/pkg/workloads" @@ -140,11 +139,11 @@ func runCmdFunc(cmd *cobra.Command, args []string) error { debugMode, _ := cmd.Flags().GetBool("debug") // Create container runtime - rt, err := container.NewFactory().Create(ctx) + rt, err := container.NewFactory(logger).Create(ctx) if err != nil { return fmt.Errorf("failed to create container runtime: %v", err) } - workloadManager, err := workloads.NewManagerFromRuntime(rt) + workloadManager, err := workloads.NewManagerFromRuntime(rt, logger) if err != nil { return fmt.Errorf("failed to create workload manager: %v", err) } @@ -162,7 +161,7 @@ func runCmdFunc(cmd *cobra.Command, args []string) error { // Always save the run config to disk before starting (both foreground and detached modes) // NOTE: Save before secrets processing to avoid storing secrets in the state store - if err := runnerConfig.SaveState(ctx); err != nil { + if err := runnerConfig.SaveState(ctx, logger); err != nil { return fmt.Errorf("failed to save run configuration: %v", err) } @@ -205,7 +204,7 @@ func validateGroup(ctx context.Context, workloadsManager workloads.Manager, serv } // Create group manager - groupManager, err := groups.NewManager() + groupManager, err := groups.NewManager(logger) if err != nil { return fmt.Errorf("failed to create group manager: %v", err) } @@ -292,7 +291,7 @@ func runFromConfigFile(ctx context.Context) error { } // Create container runtime - rt, err := container.NewFactory().Create(ctx) + rt, err := container.NewFactory(logger).Create(ctx) if err != nil { return fmt.Errorf("failed to create container runtime: %v", err) } @@ -301,7 +300,7 @@ func runFromConfigFile(ctx context.Context) error { runConfig.Deployer = rt // Create workload manager - workloadManager, err := workloads.NewManagerFromRuntime(rt) + workloadManager, err := workloads.NewManagerFromRuntime(rt, logger) if err != nil { return fmt.Errorf("failed to create workload manager: %v", err) } diff --git a/cmd/thv/app/run_flags.go b/cmd/thv/app/run_flags.go index 07ca11d34..cf97d0190 100644 --- a/cmd/thv/app/run_flags.go +++ b/cmd/thv/app/run_flags.go @@ -210,12 +210,12 @@ func BuildRunnerConfig( } // Get OTEL flag values with config fallbacks - config := cfg.GetConfig() + config := cfg.GetConfig(logger) finalOtelEndpoint, finalOtelSamplingRate, finalOtelEnvironmentVariables := getTelemetryFromFlags(cmd, config, runConfig.OtelEndpoint, runConfig.OtelSamplingRate, runConfig.OtelEnvironmentVariables) // Create container runtime - rt, err := container.NewFactory().Create(ctx) + rt, err := container.NewFactory(logger).Create(ctx) if err != nil { return nil, fmt.Errorf("failed to create container runtime: %v", err) } @@ -226,9 +226,9 @@ func BuildRunnerConfig( // we use the DetachedEnvVarValidator. var envVarValidator runner.EnvVarValidator if process.IsDetached() || runtime.IsKubernetesRuntime() { - envVarValidator = &runner.DetachedEnvVarValidator{} + envVarValidator = runner.NewDetachedEnvVarValidator(logger) } else { - envVarValidator = &runner.CLIEnvVarValidator{} + envVarValidator = runner.NewCLIEnvVarValidator(logger) } // Image retrieval @@ -241,7 +241,7 @@ func BuildRunnerConfig( // Take the MCP server we were supplied and either fetch the image, or // build it from a protocol scheme. If the server URI refers to an image // in our trusted registry, we will also fetch the image metadata. - imageURL, imageMetadata, err = retriever.GetMCPServer(ctx, serverOrImage, runConfig.CACertPath, runConfig.VerifyImage) + imageURL, imageMetadata, err = retriever.GetMCPServer(ctx, serverOrImage, runConfig.CACertPath, runConfig.VerifyImage, logger) if err != nil { return nil, fmt.Errorf("failed to find or create the MCP server %s: %v", serverOrImage, err) } @@ -257,7 +257,7 @@ func BuildRunnerConfig( } // Initialize a new RunConfig with values from command-line flags - return runner.NewRunConfigBuilder(). + return runner.NewRunConfigBuilder(logger). WithRuntime(rt). WithCmdArgs(cmdArgs). WithName(runConfig.Name). diff --git a/cmd/thv/app/run_flags_test.go b/cmd/thv/app/run_flags_test.go index 5d1d701e4..f1b6f2a70 100644 --- a/cmd/thv/app/run_flags_test.go +++ b/cmd/thv/app/run_flags_test.go @@ -11,7 +11,6 @@ import ( "github.com/stretchr/testify/require" "github.com/stacklok/toolhive/pkg/config" - "github.com/stacklok/toolhive/pkg/logger" ) // mockConfig creates a temporary config file with the provided configuration. @@ -25,7 +24,7 @@ func mockConfig(t *testing.T, cfg *config.Config) func() { err := os.MkdirAll(configDir, 0755) require.NoError(t, err) if cfg != nil { - err = config.UpdateConfig(func(c *config.Config) { *c = *cfg }) + err = config.UpdateConfig(func(c *config.Config) { *c = *cfg }, logger) require.NoError(t, err) } return func() { @@ -37,8 +36,6 @@ func mockConfig(t *testing.T, cfg *config.Config) func() { //nolint:paralleltest // Cannot use t.Parallel() with t.Setenv() in Go 1.24+ func TestBuildRunnerConfig_TelemetryProcessing(t *testing.T) { // Initialize logger to prevent nil pointer dereference - logger.Initialize() - tests := []struct { name string setupFlags func(*cobra.Command) @@ -159,7 +156,7 @@ func TestBuildRunnerConfig_TelemetryProcessing(t *testing.T) { OTEL: tt.configOTEL, }) defer cleanup() - configInstance := config.GetConfig() + configInstance := config.GetConfig(logger) finalEndpoint, finalSamplingRate, finalEnvVars := getTelemetryFromFlags( cmd, configInstance, @@ -180,7 +177,6 @@ func TestBuildRunnerConfig_TelemetryProcessing(t *testing.T) { func TestBuildRunnerConfig_TelemetryProcessing_Integration(t *testing.T) { // This is a more complete integration test that tests telemetry processing // within the full BuildRunnerConfig function context - logger.Initialize() cmd := &cobra.Command{} runFlags := &RunFlags{ Transport: "sse", @@ -204,7 +200,7 @@ func TestBuildRunnerConfig_TelemetryProcessing_Integration(t *testing.T) { }) defer cleanup() - configInstance := config.GetConfig() + configInstance := config.GetConfig(logger) finalEndpoint, finalSamplingRate, finalEnvVars := getTelemetryFromFlags( cmd, configInstance, diff --git a/cmd/thv/app/runtime.go b/cmd/thv/app/runtime.go index c8e59e0cb..9337e0697 100644 --- a/cmd/thv/app/runtime.go +++ b/cmd/thv/app/runtime.go @@ -69,7 +69,7 @@ func createWithTimeout(ctx context.Context) (runtime.Runtime, error) { err error }, 1) go func() { - rt, err := container.NewFactory().Create(ctx) + rt, err := container.NewFactory(logger).Create(ctx) done <- struct { rt runtime.Runtime err error diff --git a/cmd/thv/app/search.go b/cmd/thv/app/search.go index 9c6d49515..bd5a8e7ed 100644 --- a/cmd/thv/app/search.go +++ b/cmd/thv/app/search.go @@ -34,7 +34,7 @@ func init() { func searchCmdFunc(_ *cobra.Command, args []string) error { // Search for servers query := args[0] - provider, err := registry.GetDefaultProvider() + provider, err := registry.GetDefaultProvider(logger) if err != nil { return fmt.Errorf("failed to get registry provider: %v", err) } diff --git a/cmd/thv/app/secret.go b/cmd/thv/app/secret.go index 79dfe60a0..5ca4b12f4 100644 --- a/cmd/thv/app/secret.go +++ b/cmd/thv/app/secret.go @@ -167,7 +167,7 @@ Note that some providers (like 1Password) are read-only and do not support setti // Check if the provider supports writing secrets if !manager.Capabilities().CanWrite { - providerType, _ := config.GetConfig().Secrets.GetProviderType() + providerType, _ := config.GetConfig(logger).Secrets.GetProviderType() fmt.Fprintf(os.Stderr, "Error: The %s secrets provider does not support setting secrets (read-only)\n", providerType) return } @@ -250,7 +250,7 @@ If your provider is read-only or doesn't support deletion, this command returns // Check if the provider supports deleting secrets if !manager.Capabilities().CanDelete { - providerType, _ := config.GetConfig().Secrets.GetProviderType() + providerType, _ := config.GetConfig(logger).Secrets.GetProviderType() fmt.Fprintf(os.Stderr, "Error: The %s secrets provider does not support deleting secrets\n", providerType) return } @@ -284,7 +284,7 @@ If descriptions exist for the secrets, the command displays them alongside the n // Check if the provider supports listing secrets if !manager.Capabilities().CanList { - providerType, _ := config.GetConfig().Secrets.GetProviderType() + providerType, _ := config.GetConfig(logger).Secrets.GetProviderType() fmt.Fprintf(os.Stderr, "Error: The %s secrets provider does not support listing secrets\n", providerType) return } @@ -344,7 +344,7 @@ This command only works with the 'encrypted' secrets provider.`, } func getSecretsManager() (secrets.Provider, error) { - cfg := config.GetConfig() + cfg := config.GetConfig(logger) // Check if secrets setup has been completed if !cfg.Secrets.SetupCompleted { @@ -356,7 +356,7 @@ func getSecretsManager() (secrets.Provider, error) { return nil, fmt.Errorf("failed to get secrets provider type: %w", err) } - manager, err := secrets.CreateSecretProvider(providerType) + manager, err := secrets.CreateSecretProvider(providerType, logger) if err != nil { return nil, fmt.Errorf("failed to create secrets manager: %w", err) } diff --git a/cmd/thv/app/server.go b/cmd/thv/app/server.go index 326466723..b071943de 100644 --- a/cmd/thv/app/server.go +++ b/cmd/thv/app/server.go @@ -59,7 +59,7 @@ var serveCmd = &cobra.Command{ } } - return s.Serve(ctx, address, isUnixSocket, debugMode, enableDocs, oidcConfig) + return s.Serve(ctx, address, isUnixSocket, debugMode, enableDocs, oidcConfig, logger) }, } diff --git a/cmd/thv/app/stop.go b/cmd/thv/app/stop.go index 931ec70fd..a07149106 100644 --- a/cmd/thv/app/stop.go +++ b/cmd/thv/app/stop.go @@ -54,7 +54,7 @@ func validateStopArgs(cmd *cobra.Command, args []string) error { func stopCmdFunc(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - workloadManager, err := workloads.NewManager(ctx) + workloadManager, err := workloads.NewManager(ctx, logger) if err != nil { return fmt.Errorf("failed to create workload manager: %v", err) } diff --git a/cmd/thv/main.go b/cmd/thv/main.go index 4b563fd65..e4c092578 100644 --- a/cmd/thv/main.go +++ b/cmd/thv/main.go @@ -7,16 +7,14 @@ import ( "github.com/stacklok/toolhive/cmd/thv/app" "github.com/stacklok/toolhive/pkg/client" "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" ) func main() { - // Initialize the logger - logger.Initialize() - + logger := log.NewLogger() // Check and perform auto-discovery migration if needed // Handles the auto-discovery flag depreciation, only executes once on old config files - client.CheckAndPerformAutoDiscoveryMigration() + client.CheckAndPerformAutoDiscoveryMigration(logger) // Check and perform default group migration if needed // Migrates existing workloads to the default group, only executes once diff --git a/go.mod b/go.mod index 376c79f80..9b4655758 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/docker/docker v28.3.3+incompatible github.com/docker/go-connections v0.6.0 github.com/go-chi/chi/v5 v5.2.2 + github.com/go-logr/zapr v1.3.0 github.com/gofrs/flock v0.12.1 github.com/google/go-containerregistry v0.20.6 github.com/google/uuid v1.6.0 @@ -42,6 +43,7 @@ require ( golang.ngrok.com/ngrok/v2 v2.0.0 golang.org/x/exp/jsonrpc2 v0.0.0-20250811191247-51f88131bc50 golang.org/x/mod v0.27.0 + go.uber.org/zap v1.27.0 golang.org/x/oauth2 v0.30.0 golang.org/x/sync v0.16.0 golang.org/x/term v0.34.0 @@ -238,7 +240,6 @@ require ( go.opentelemetry.io/proto/otlp v1.7.0 // indirect go.uber.org/automaxprocs v1.6.0 // indirect go.uber.org/multierr v1.11.0 // indirect - go.uber.org/zap v1.27.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect golang.ngrok.com/muxado/v2 v2.0.1 // indirect golang.org/x/exp/event v0.0.0-20250718183923-645b1fa84792 // indirect diff --git a/pkg/api/server.go b/pkg/api/server.go index 21ea4c7e7..b6be46942 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -25,13 +25,13 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" + "go.uber.org/zap" v1 "github.com/stacklok/toolhive/pkg/api/v1" "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/client" "github.com/stacklok/toolhive/pkg/container" "github.com/stacklok/toolhive/pkg/groups" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/updates" "github.com/stacklok/toolhive/pkg/workloads" ) @@ -74,7 +74,7 @@ func setupUnixSocket(address string) (net.Listener, error) { return listener, nil } -func cleanupUnixSocket(address string) { +func cleanupUnixSocket(address string, logger *zap.SugaredLogger) { if err := os.Remove(address); err != nil && !os.IsNotExist(err) { logger.Warnf("failed to remove socket file: %v", err) } @@ -90,7 +90,7 @@ func headersMiddleware(next http.Handler) http.Handler { } // updateCheckMiddleware triggers update checks for API usage -func updateCheckMiddleware() func(next http.Handler) http.Handler { +func updateCheckMiddleware(logger *zap.SugaredLogger) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { go func() { @@ -141,6 +141,7 @@ func Serve( debugMode bool, enableDocs bool, oidcConfig *auth.TokenValidatorConfig, + logger *zap.SugaredLogger, ) error { r := chi.NewRouter() r.Use( @@ -151,33 +152,33 @@ func Serve( ) // Add update check middleware - r.Use(updateCheckMiddleware()) + r.Use(updateCheckMiddleware(logger)) // Add authentication middleware - authMiddleware, _, err := auth.GetAuthenticationMiddleware(ctx, oidcConfig) + authMiddleware, _, err := auth.GetAuthenticationMiddleware(ctx, oidcConfig, logger) if err != nil { return fmt.Errorf("failed to create authentication middleware: %v", err) } r.Use(authMiddleware) // Create container runtime - containerRuntime, err := container.NewFactory().Create(ctx) + containerRuntime, err := container.NewFactory(logger).Create(ctx) if err != nil { return fmt.Errorf("failed to create container runtime: %v", err) } - clientManager, err := client.NewManager(ctx) + clientManager, err := client.NewManager(ctx, logger) if err != nil { return fmt.Errorf("failed to create client manager: %v", err) } - workloadManager, err := workloads.NewManagerFromRuntime(containerRuntime) + workloadManager, err := workloads.NewManagerFromRuntime(containerRuntime, logger) if err != nil { return fmt.Errorf("failed to create workload manager: %v", err) } // Create group manager - groupManager, err := groups.NewManager() + groupManager, err := groups.NewManager(logger) if err != nil { return fmt.Errorf("failed to create group manager: %v", err) } @@ -185,12 +186,12 @@ func Serve( routers := map[string]http.Handler{ "/health": v1.HealthcheckRouter(containerRuntime), "/api/v1beta/version": v1.VersionRouter(), - "/api/v1beta/workloads": v1.WorkloadRouter(workloadManager, containerRuntime, groupManager, debugMode), - "/api/v1beta/registry": v1.RegistryRouter(), - "/api/v1beta/discovery": v1.DiscoveryRouter(), - "/api/v1beta/clients": v1.ClientRouter(clientManager, workloadManager, groupManager), - "/api/v1beta/secrets": v1.SecretsRouter(), - "/api/v1beta/groups": v1.GroupsRouter(groupManager, workloadManager), + "/api/v1beta/workloads": v1.WorkloadRouter(workloadManager, containerRuntime, groupManager, debugMode, logger), + "/api/v1beta/registry": v1.RegistryRouter(logger), + "/api/v1beta/discovery": v1.DiscoveryRouter(logger), + "/api/v1beta/clients": v1.ClientRouter(clientManager, workloadManager, groupManager, logger), + "/api/v1beta/secrets": v1.SecretsRouter(logger), + "/api/v1beta/groups": v1.GroupsRouter(groupManager, workloadManager, logger), } // Only mount docs router if enabled @@ -236,13 +237,13 @@ func Serve( <-ctx.Done() if err := srv.Shutdown(ctx); err != nil { if isUnixSocket { - cleanupUnixSocket(address) + cleanupUnixSocket(address, logger) } return fmt.Errorf("server shutdown failed: %w", err) } if isUnixSocket { - cleanupUnixSocket(address) + cleanupUnixSocket(address, logger) } logger.Infof("%s server stopped", addrType) diff --git a/pkg/api/v1/clients.go b/pkg/api/v1/clients.go index 18da01b1c..a57225c3a 100644 --- a/pkg/api/v1/clients.go +++ b/pkg/api/v1/clients.go @@ -7,12 +7,12 @@ import ( "net/http" "github.com/go-chi/chi/v5" + "go.uber.org/zap" "github.com/stacklok/toolhive/pkg/client" "github.com/stacklok/toolhive/pkg/config" "github.com/stacklok/toolhive/pkg/core" "github.com/stacklok/toolhive/pkg/groups" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/workloads" ) @@ -21,6 +21,7 @@ type ClientRoutes struct { clientManager client.Manager workloadManager workloads.Manager groupManager groups.Manager + logger *zap.SugaredLogger } // ClientRouter creates a new router for the client API. @@ -28,11 +29,13 @@ func ClientRouter( manager client.Manager, workloadManager workloads.Manager, groupManager groups.Manager, + logger *zap.SugaredLogger, ) http.Handler { routes := ClientRoutes{ clientManager: manager, workloadManager: workloadManager, groupManager: groupManager, + logger: logger, } r := chi.NewRouter() @@ -56,7 +59,7 @@ func ClientRouter( func (c *ClientRoutes) listClients(w http.ResponseWriter, _ *http.Request) { clients, err := c.clientManager.ListClients() if err != nil { - logger.Errorf("Failed to list clients: %v", err) + c.logger.Errorf("Failed to list clients: %v", err) http.Error(w, "Failed to list clients", http.StatusInternalServerError) return } @@ -84,7 +87,7 @@ func (c *ClientRoutes) registerClient(w http.ResponseWriter, r *http.Request) { var newClient createClientRequest err := json.NewDecoder(r.Body).Decode(&newClient) if err != nil { - logger.Errorf("Failed to decode request body: %v", err) + c.logger.Errorf("Failed to decode request body: %v", err) http.Error(w, "Invalid request body", http.StatusBadRequest) return } @@ -93,7 +96,7 @@ func (c *ClientRoutes) registerClient(w http.ResponseWriter, r *http.Request) { if len(newClient.Groups) == 0 { defaultGroup, err := c.groupManager.Get(r.Context(), groups.DefaultGroupName) if err != nil { - logger.Debugf("Failed to get default group: %v", err) + c.logger.Debugf("Failed to get default group: %v", err) } if defaultGroup != nil { newClient.Groups = []string{groups.DefaultGroupName} @@ -102,7 +105,7 @@ func (c *ClientRoutes) registerClient(w http.ResponseWriter, r *http.Request) { err = c.performClientRegistration(r.Context(), []client.Client{{Name: newClient.Name}}, newClient.Groups) if err != nil { - logger.Errorf("Failed to register client: %v", err) + c.logger.Errorf("Failed to register client: %v", err) http.Error(w, "Failed to register client", http.StatusInternalServerError) return } @@ -133,7 +136,7 @@ func (c *ClientRoutes) unregisterClient(w http.ResponseWriter, r *http.Request) err := c.removeClient(r.Context(), []client.Client{{Name: client.MCPClient(clientName)}}, nil) if err != nil { - logger.Errorf("Failed to unregister client: %v", err) + c.logger.Errorf("Failed to unregister client: %v", err) http.Error(w, "Failed to unregister client", http.StatusInternalServerError) return } @@ -168,7 +171,7 @@ func (c *ClientRoutes) unregisterClientFromGroup(w http.ResponseWriter, r *http. // Remove client from the specific group err := c.removeClient(r.Context(), []client.Client{{Name: client.MCPClient(clientName)}}, []string{groupName}) if err != nil { - logger.Errorf("Failed to unregister client from group: %v", err) + c.logger.Errorf("Failed to unregister client from group: %v", err) http.Error(w, "Failed to unregister client from group", http.StatusInternalServerError) return } @@ -191,7 +194,7 @@ func (c *ClientRoutes) registerClientsBulk(w http.ResponseWriter, r *http.Reques var req bulkClientRequest err := json.NewDecoder(r.Body).Decode(&req) if err != nil { - logger.Errorf("Failed to decode request body: %v", err) + c.logger.Errorf("Failed to decode request body: %v", err) http.Error(w, "Invalid request body", http.StatusBadRequest) return } @@ -208,7 +211,7 @@ func (c *ClientRoutes) registerClientsBulk(w http.ResponseWriter, r *http.Reques err = c.performClientRegistration(r.Context(), clients, req.Groups) if err != nil { - logger.Errorf("Failed to register clients: %v", err) + c.logger.Errorf("Failed to register clients: %v", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -239,7 +242,7 @@ func (c *ClientRoutes) unregisterClientsBulk(w http.ResponseWriter, r *http.Requ var req bulkClientRequest err := json.NewDecoder(r.Body).Decode(&req) if err != nil { - logger.Errorf("Failed to decode request body: %v", err) + c.logger.Errorf("Failed to decode request body: %v", err) http.Error(w, "Invalid request body", http.StatusBadRequest) return } @@ -257,7 +260,7 @@ func (c *ClientRoutes) unregisterClientsBulk(w http.ResponseWriter, r *http.Requ err = c.removeClient(r.Context(), clients, req.Groups) if err != nil { - logger.Errorf("Failed to unregister clients: %v", err) + c.logger.Errorf("Failed to unregister clients: %v", err) http.Error(w, "Failed to unregister clients", http.StatusInternalServerError) return } @@ -293,7 +296,7 @@ func (c *ClientRoutes) performClientRegistration(ctx context.Context, clients [] } if len(groupNames) > 0 { - logger.Infof("Filtering workloads to groups: %v", groupNames) + c.logger.Infof("Filtering workloads to groups: %v", groupNames) filteredWorkloads, err := workloads.FilterByGroups(runningWorkloads, groupNames) if err != nil { @@ -320,21 +323,21 @@ func (c *ClientRoutes) performClientRegistration(ctx context.Context, clients [] } else { // We should never reach this point once groups are enabled for _, clientToRegister := range clients { - err := config.UpdateConfig(func(c *config.Config) { - for _, registeredClient := range c.Clients.RegisteredClients { + err := config.UpdateConfig(func(config *config.Config) { + for _, registeredClient := range config.Clients.RegisteredClients { if registeredClient == string(clientToRegister.Name) { - logger.Infof("Client %s is already registered, skipping...", clientToRegister.Name) + c.logger.Infof("Client %s is already registered, skipping...", clientToRegister.Name) return } } - c.Clients.RegisteredClients = append(c.Clients.RegisteredClients, string(clientToRegister.Name)) - }) + config.Clients.RegisteredClients = append(config.Clients.RegisteredClients, string(clientToRegister.Name)) + }, c.logger) if err != nil { return fmt.Errorf("failed to update configuration for client %s: %w", clientToRegister.Name, err) } - logger.Infof("Successfully registered client: %s\n", clientToRegister.Name) + c.logger.Infof("Successfully registered client: %s\n", clientToRegister.Name) } err = c.clientManager.RegisterClients(clients, runningWorkloads) @@ -428,17 +431,17 @@ func (c *ClientRoutes) removeClientGlobally( // Remove clients from global registered clients list for _, clientToRemove := range clients { - err := config.UpdateConfig(func(c *config.Config) { - for i, registeredClient := range c.Clients.RegisteredClients { + err := config.UpdateConfig(func(config *config.Config) { + for i, registeredClient := range config.Clients.RegisteredClients { if registeredClient == string(clientToRemove.Name) { // Remove client from slice - c.Clients.RegisteredClients = append(c.Clients.RegisteredClients[:i], c.Clients.RegisteredClients[i+1:]...) - logger.Infof("Successfully unregistered client: %s\n", clientToRemove.Name) + config.Clients.RegisteredClients = append(config.Clients.RegisteredClients[:i], config.Clients.RegisteredClients[i+1:]...) + c.logger.Infof("Successfully unregistered client: %s\n", clientToRemove.Name) return } } - logger.Warnf("Client %s was not found in registered clients list", clientToRemove.Name) - }) + c.logger.Warnf("Client %s was not found in registered clients list", clientToRemove.Name) + }, c.logger) if err != nil { return fmt.Errorf("failed to update configuration for client %s: %w", clientToRemove.Name, err) } diff --git a/pkg/api/v1/discovery.go b/pkg/api/v1/discovery.go index 687793df8..ffbbae396 100644 --- a/pkg/api/v1/discovery.go +++ b/pkg/api/v1/discovery.go @@ -5,16 +5,19 @@ import ( "net/http" "github.com/go-chi/chi/v5" + "go.uber.org/zap" "github.com/stacklok/toolhive/pkg/client" ) // DiscoveryRoutes defines the routes for the client discovery API. -type DiscoveryRoutes struct{} +type DiscoveryRoutes struct { + logger *zap.SugaredLogger +} // DiscoveryRouter creates a new router for the client discovery API. -func DiscoveryRouter() http.Handler { - routes := DiscoveryRoutes{} +func DiscoveryRouter(logger *zap.SugaredLogger) http.Handler { + routes := DiscoveryRoutes{logger} r := chi.NewRouter() r.Get("/clients", routes.discoverClients) @@ -29,8 +32,8 @@ func DiscoveryRouter() http.Handler { // @Produce json // @Success 200 {object} clientStatusResponse // @Router /api/v1beta/discovery/clients [get] -func (*DiscoveryRoutes) discoverClients(w http.ResponseWriter, _ *http.Request) { - clients, err := client.GetClientStatus() +func (d *DiscoveryRoutes) discoverClients(w http.ResponseWriter, _ *http.Request) { + clients, err := client.GetClientStatus(d.logger) if err != nil { // TODO: Error should be JSON marshaled http.Error(w, "Failed to get client status", http.StatusInternalServerError) diff --git a/pkg/api/v1/groups.go b/pkg/api/v1/groups.go index 181320f4b..23fc94d5c 100644 --- a/pkg/api/v1/groups.go +++ b/pkg/api/v1/groups.go @@ -6,10 +6,10 @@ import ( "net/http" "github.com/go-chi/chi/v5" + "go.uber.org/zap" "github.com/stacklok/toolhive/pkg/errors" "github.com/stacklok/toolhive/pkg/groups" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/validation" "github.com/stacklok/toolhive/pkg/workloads" ) @@ -18,13 +18,15 @@ import ( type GroupsRoutes struct { groupManager groups.Manager workloadManager workloads.Manager + logger *zap.SugaredLogger } // GroupsRouter creates a new GroupsRoutes instance. -func GroupsRouter(groupManager groups.Manager, workloadManager workloads.Manager) http.Handler { +func GroupsRouter(groupManager groups.Manager, workloadManager workloads.Manager, logger *zap.SugaredLogger) http.Handler { routes := GroupsRoutes{ groupManager: groupManager, workloadManager: workloadManager, + logger: logger, } r := chi.NewRouter() @@ -55,7 +57,7 @@ func (s *GroupsRoutes) listGroups(w http.ResponseWriter, r *http.Request) { ctx := r.Context() groupList, err := s.groupManager.List(ctx) if err != nil { - logger.Errorf("Failed to list groups: %v", err) + s.logger.Errorf("Failed to list groups: %v", err) http.Error(w, "Failed to list groups", http.StatusInternalServerError) return } @@ -63,7 +65,7 @@ func (s *GroupsRoutes) listGroups(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w).Encode(groupListResponse{Groups: groupList}) if err != nil { - logger.Errorf("Failed to marshal group list: %v", err) + s.logger.Errorf("Failed to marshal group list: %v", err) http.Error(w, "Failed to marshal group list", http.StatusInternalServerError) return } @@ -87,21 +89,21 @@ func (s *GroupsRoutes) createGroup(w http.ResponseWriter, r *http.Request) { var req createGroupRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - logger.Errorf("Failed to decode create group request: %v", err) + s.logger.Errorf("Failed to decode create group request: %v", err) http.Error(w, "Invalid request body", http.StatusBadRequest) return } // Validate group name if err := validation.ValidateGroupName(req.Name); err != nil { - logger.Errorf("Invalid group name: %v", err) + s.logger.Errorf("Invalid group name: %v", err) http.Error(w, fmt.Sprintf("Invalid group name: %v", err), http.StatusBadRequest) return } err := s.groupManager.Create(ctx, req.Name) if err != nil { - logger.Errorf("Failed to create group: %v", err) + s.logger.Errorf("Failed to create group: %v", err) if errors.IsGroupAlreadyExists(err) { http.Error(w, err.Error(), http.StatusConflict) } else { @@ -114,7 +116,7 @@ func (s *GroupsRoutes) createGroup(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusCreated) response := createGroupResponse(req) if err := json.NewEncoder(w).Encode(response); err != nil { - logger.Errorf("Failed to marshal create group response: %v", err) + s.logger.Errorf("Failed to marshal create group response: %v", err) http.Error(w, "Failed to marshal response", http.StatusInternalServerError) return } @@ -137,21 +139,21 @@ func (s *GroupsRoutes) getGroup(w http.ResponseWriter, r *http.Request) { // Validate group name if err := validation.ValidateGroupName(name); err != nil { - logger.Errorf("Invalid group name: %v", err) + s.logger.Errorf("Invalid group name: %v", err) http.Error(w, fmt.Sprintf("Invalid group name: %v", err), http.StatusBadRequest) return } group, err := s.groupManager.Get(ctx, name) if err != nil { - logger.Errorf("Failed to get group %s: %v", name, err) + s.logger.Errorf("Failed to get group %s: %v", name, err) http.Error(w, "Group not found", http.StatusNotFound) return } w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(group); err != nil { - logger.Errorf("Failed to marshal group: %v", err) + s.logger.Errorf("Failed to marshal group: %v", err) http.Error(w, "Failed to marshal group", http.StatusInternalServerError) return } @@ -175,7 +177,7 @@ func (s *GroupsRoutes) deleteGroup(w http.ResponseWriter, r *http.Request) { // Validate group name if err := validation.ValidateGroupName(name); err != nil { - logger.Errorf("Invalid group name: %v", err) + s.logger.Errorf("Invalid group name: %v", err) http.Error(w, fmt.Sprintf("Invalid group name: %v", err), http.StatusBadRequest) return } @@ -189,7 +191,7 @@ func (s *GroupsRoutes) deleteGroup(w http.ResponseWriter, r *http.Request) { // Check if group exists before deleting exists, err := s.groupManager.Exists(ctx, name) if err != nil { - logger.Errorf("Failed to check if group exists %s: %v", name, err) + s.logger.Errorf("Failed to check if group exists %s: %v", name, err) http.Error(w, "Failed to check group existence", http.StatusInternalServerError) return } @@ -205,7 +207,7 @@ func (s *GroupsRoutes) deleteGroup(w http.ResponseWriter, r *http.Request) { // Get all workloads in the group groupWorkloads, err := s.workloadManager.ListWorkloadsInGroup(ctx, name) if err != nil { - logger.Errorf("Failed to list workloads in group %s: %v", name, err) + s.logger.Errorf("Failed to list workloads in group %s: %v", name, err) http.Error(w, "Failed to list workloads in group", http.StatusInternalServerError) return } @@ -216,35 +218,35 @@ func (s *GroupsRoutes) deleteGroup(w http.ResponseWriter, r *http.Request) { // Delete all workloads in the group group, err := s.workloadManager.DeleteWorkloads(ctx, groupWorkloads) if err != nil { - logger.Errorf("Failed to delete workloads in group %s: %v", name, err) + s.logger.Errorf("Failed to delete workloads in group %s: %v", name, err) http.Error(w, "Failed to delete workloads in group", http.StatusInternalServerError) return } // Wait for the deletion to complete if err := group.Wait(); err != nil { - logger.Errorf("Failed to delete workloads in group %s: %v", name, err) + s.logger.Errorf("Failed to delete workloads in group %s: %v", name, err) http.Error(w, "Failed to delete workloads in group", http.StatusInternalServerError) return } - logger.Infof("Deleted %d workload(s) from group '%s'", len(groupWorkloads), name) + s.logger.Infof("Deleted %d workload(s) from group '%s'", len(groupWorkloads), name) } else { // Move workloads to default group if err := s.workloadManager.MoveToDefaultGroup(ctx, groupWorkloads, name); err != nil { - logger.Errorf("Failed to move workloads to default group: %v", err) + s.logger.Errorf("Failed to move workloads to default group: %v", err) http.Error(w, "Failed to move workloads to default group", http.StatusInternalServerError) return } - logger.Infof("Moved %d workload(s) from group '%s' to default group", len(groupWorkloads), name) + s.logger.Infof("Moved %d workload(s) from group '%s' to default group", len(groupWorkloads), name) } } // Delete the group err = s.groupManager.Delete(ctx, name) if err != nil { - logger.Errorf("Failed to delete group %s: %v", name, err) + s.logger.Errorf("Failed to delete group %s: %v", name, err) http.Error(w, "Failed to delete group", http.StatusInternalServerError) return } diff --git a/pkg/api/v1/groups_test.go b/pkg/api/v1/groups_test.go index 6e39f94c1..53c40f092 100644 --- a/pkg/api/v1/groups_test.go +++ b/pkg/api/v1/groups_test.go @@ -16,7 +16,7 @@ import ( "github.com/stacklok/toolhive/pkg/errors" "github.com/stacklok/toolhive/pkg/groups" groupsmocks "github.com/stacklok/toolhive/pkg/groups/mocks" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/workloads" workloadsmocks "github.com/stacklok/toolhive/pkg/workloads/mocks" ) @@ -24,8 +24,7 @@ import ( func TestGroupsRouter(t *testing.T) { t.Parallel() - // Initialize logger to prevent panic - logger.Initialize() + logger := log.NewLogger() tests := []struct { name string @@ -222,7 +221,7 @@ func TestGroupsRouter(t *testing.T) { } // Create router - router := GroupsRouter(mockGroupManager, mockWorkloadManager) + router := GroupsRouter(mockGroupManager, mockWorkloadManager, logger) // Create request var req *http.Request @@ -269,18 +268,20 @@ func TestGroupsRouter(t *testing.T) { func TestGroupsRouter_Integration(t *testing.T) { t.Parallel() + logger := log.NewLogger() + // Test with real managers (integration test) - groupManager, err := groups.NewManager() + groupManager, err := groups.NewManager(logger) if err != nil { t.Skip("Skipping integration test: failed to create group manager") } - workloadManager, err := workloads.NewManager(context.Background()) + workloadManager, err := workloads.NewManager(context.Background(), logger) if err != nil { t.Skip("Skipping integration test: failed to create workload manager") } - router := GroupsRouter(groupManager, workloadManager) + router := GroupsRouter(groupManager, workloadManager, logger) // Test creating a group t.Run("create and list group", func(t *testing.T) { diff --git a/pkg/api/v1/registry.go b/pkg/api/v1/registry.go index a481b0c3c..c6a3347e8 100644 --- a/pkg/api/v1/registry.go +++ b/pkg/api/v1/registry.go @@ -6,9 +6,9 @@ import ( "net/http" "github.com/go-chi/chi/v5" + "go.uber.org/zap" "github.com/stacklok/toolhive/pkg/config" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/registry" ) @@ -30,8 +30,8 @@ const ( ) // getRegistryInfo returns the registry type and the source -func getRegistryInfo() (RegistryType, string) { - url, localPath, _, registryType := config.GetRegistryConfig() +func getRegistryInfo(logger *zap.SugaredLogger) (RegistryType, string) { + url, localPath, _, registryType := config.GetRegistryConfig(logger) switch registryType { case "url": @@ -45,8 +45,8 @@ func getRegistryInfo() (RegistryType, string) { } // getCurrentProvider returns the current registry provider -func getCurrentProvider(w http.ResponseWriter) (registry.Provider, bool) { - provider, err := registry.GetDefaultProvider() +func getCurrentProvider(w http.ResponseWriter, logger *zap.SugaredLogger) (registry.Provider, bool) { + provider, err := registry.GetDefaultProvider(logger) if err != nil { http.Error(w, "Failed to get registry provider", http.StatusInternalServerError) logger.Errorf("Failed to get registry provider: %v", err) @@ -56,11 +56,15 @@ func getCurrentProvider(w http.ResponseWriter) (registry.Provider, bool) { } // RegistryRoutes defines the routes for the registry API. -type RegistryRoutes struct{} +type RegistryRoutes struct { + logger *zap.SugaredLogger +} // RegistryRouter creates a new router for the registry API. -func RegistryRouter() http.Handler { - routes := RegistryRoutes{} +func RegistryRouter(logger *zap.SugaredLogger) http.Handler { + routes := RegistryRoutes{ + logger: logger, + } r := chi.NewRouter() r.Get("/", routes.listRegistries) @@ -85,8 +89,8 @@ func RegistryRouter() http.Handler { // @Produce json // @Success 200 {object} registryListResponse // @Router /api/v1beta/registry [get] -func (*RegistryRoutes) listRegistries(w http.ResponseWriter, _ *http.Request) { - provider, ok := getCurrentProvider(w) +func (routes *RegistryRoutes) listRegistries(w http.ResponseWriter, _ *http.Request) { + provider, ok := getCurrentProvider(w, routes.logger) if !ok { return } @@ -97,7 +101,7 @@ func (*RegistryRoutes) listRegistries(w http.ResponseWriter, _ *http.Request) { return } - registryType, source := getRegistryInfo() + registryType, source := getRegistryInfo(routes.logger) registries := []registryInfo{ { @@ -143,7 +147,7 @@ func (*RegistryRoutes) addRegistry(w http.ResponseWriter, _ *http.Request) { // @Success 200 {object} getRegistryResponse // @Failure 404 {string} string "Not Found" // @Router /api/v1beta/registry/{name} [get] -func (*RegistryRoutes) getRegistry(w http.ResponseWriter, r *http.Request) { +func (routes *RegistryRoutes) getRegistry(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "name") // Only "default" registry is supported currently @@ -152,7 +156,7 @@ func (*RegistryRoutes) getRegistry(w http.ResponseWriter, r *http.Request) { return } - provider, ok := getCurrentProvider(w) + provider, ok := getCurrentProvider(w, routes.logger) if !ok { return } @@ -163,7 +167,7 @@ func (*RegistryRoutes) getRegistry(w http.ResponseWriter, r *http.Request) { return } - registryType, source := getRegistryInfo() + registryType, source := getRegistryInfo(routes.logger) response := getRegistryResponse{ Name: defaultRegistryName, @@ -177,7 +181,7 @@ func (*RegistryRoutes) getRegistry(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { - logger.Errorf("Failed to encode response: %v", err) + routes.logger.Errorf("Failed to encode response: %v", err) http.Error(w, "Failed to encode response", http.StatusInternalServerError) return } @@ -196,7 +200,7 @@ func (*RegistryRoutes) getRegistry(w http.ResponseWriter, r *http.Request) { // @Failure 400 {string} string "Bad Request" // @Failure 404 {string} string "Not Found" // @Router /api/v1beta/registry/{name} [put] -func (*RegistryRoutes) updateRegistry(w http.ResponseWriter, r *http.Request) { +func (routes *RegistryRoutes) updateRegistry(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "name") // Only "default" registry can be updated currently @@ -222,8 +226,8 @@ func (*RegistryRoutes) updateRegistry(w http.ResponseWriter, r *http.Request) { // Handle reset to default (no URL or LocalPath specified) if req.URL == nil && req.LocalPath == nil { - if err := config.UnsetRegistry(); err != nil { - logger.Errorf("Failed to unset registry: %v", err) + if err := config.UnsetRegistry(routes.logger); err != nil { + routes.logger.Errorf("Failed to unset registry: %v", err) http.Error(w, "Failed to reset registry configuration", http.StatusInternalServerError) return } @@ -236,8 +240,8 @@ func (*RegistryRoutes) updateRegistry(w http.ResponseWriter, r *http.Request) { allowPrivateIP = *req.AllowPrivateIP } - if err := config.SetRegistryURL(*req.URL, allowPrivateIP); err != nil { - logger.Errorf("Failed to set registry URL: %v", err) + if err := config.SetRegistryURL(*req.URL, allowPrivateIP, routes.logger); err != nil { + routes.logger.Errorf("Failed to set registry URL: %v", err) http.Error(w, fmt.Sprintf("Failed to set registry URL: %v", err), http.StatusBadRequest) return } @@ -245,8 +249,8 @@ func (*RegistryRoutes) updateRegistry(w http.ResponseWriter, r *http.Request) { message = fmt.Sprintf("Successfully set registry URL: %s", *req.URL) } else if req.LocalPath != nil { // Handle local path update - if err := config.SetRegistryFile(*req.LocalPath); err != nil { - logger.Errorf("Failed to set registry file: %v", err) + if err := config.SetRegistryFile(*req.LocalPath, routes.logger); err != nil { + routes.logger.Errorf("Failed to set registry file: %v", err) http.Error(w, fmt.Sprintf("Failed to set registry file: %v", err), http.StatusBadRequest) return } @@ -264,7 +268,7 @@ func (*RegistryRoutes) updateRegistry(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { - logger.Errorf("Failed to encode response: %v", err) + routes.logger.Errorf("Failed to encode response: %v", err) http.Error(w, "Failed to encode response", http.StatusInternalServerError) return } @@ -303,7 +307,7 @@ func (*RegistryRoutes) removeRegistry(w http.ResponseWriter, r *http.Request) { // @Success 200 {object} listServersResponse // @Failure 404 {string} string "Not Found" // @Router /api/v1beta/registry/{name}/servers [get] -func (*RegistryRoutes) listServers(w http.ResponseWriter, r *http.Request) { +func (routes *RegistryRoutes) listServers(w http.ResponseWriter, r *http.Request) { registryName := chi.URLParam(r, "name") // Only "default" registry is supported currently @@ -312,7 +316,7 @@ func (*RegistryRoutes) listServers(w http.ResponseWriter, r *http.Request) { return } - provider, ok := getCurrentProvider(w) + provider, ok := getCurrentProvider(w, routes.logger) if !ok { return } @@ -320,7 +324,7 @@ func (*RegistryRoutes) listServers(w http.ResponseWriter, r *http.Request) { // Get the full registry to access both container and remote servers reg, err := provider.GetRegistry() if err != nil { - logger.Errorf("Failed to get registry: %v", err) + routes.logger.Errorf("Failed to get registry: %v", err) http.Error(w, "Failed to get registry", http.StatusInternalServerError) return } @@ -343,7 +347,7 @@ func (*RegistryRoutes) listServers(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { - logger.Errorf("Failed to encode response: %v", err) + routes.logger.Errorf("Failed to encode response: %v", err) http.Error(w, "Failed to encode response", http.StatusInternalServerError) return } @@ -360,7 +364,7 @@ func (*RegistryRoutes) listServers(w http.ResponseWriter, r *http.Request) { // @Success 200 {object} getServerResponse // @Failure 404 {string} string "Not Found" // @Router /api/v1beta/registry/{name}/servers/{serverName} [get] -func (*RegistryRoutes) getServer(w http.ResponseWriter, r *http.Request) { +func (routes *RegistryRoutes) getServer(w http.ResponseWriter, r *http.Request) { registryName := chi.URLParam(r, "name") serverName := chi.URLParam(r, "serverName") @@ -370,7 +374,7 @@ func (*RegistryRoutes) getServer(w http.ResponseWriter, r *http.Request) { return } - provider, ok := getCurrentProvider(w) + provider, ok := getCurrentProvider(w, routes.logger) if !ok { return } @@ -378,7 +382,7 @@ func (*RegistryRoutes) getServer(w http.ResponseWriter, r *http.Request) { // Try to get the server (could be container or remote) server, err := provider.GetServer(serverName) if err != nil { - logger.Errorf("Failed to get server '%s': %v", serverName, err) + routes.logger.Errorf("Failed to get server '%s': %v", serverName, err) http.Error(w, "Server not found", http.StatusNotFound) return } @@ -403,7 +407,7 @@ func (*RegistryRoutes) getServer(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { - logger.Errorf("Failed to encode response: %v", err) + routes.logger.Errorf("Failed to encode response: %v", err) http.Error(w, "Failed to encode response", http.StatusInternalServerError) return } diff --git a/pkg/api/v1/registry_test.go b/pkg/api/v1/registry_test.go index 30907372e..f58b4e600 100644 --- a/pkg/api/v1/registry_test.go +++ b/pkg/api/v1/registry_test.go @@ -16,12 +16,15 @@ import ( "github.com/stretchr/testify/require" "github.com/stacklok/toolhive/pkg/config" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" ) func MockConfig(t *testing.T, cfg *config.Config) func() { t.Helper() + // Setup logger + logger := log.NewLogger() + // Create a temporary directory for the test tempDir := t.TempDir() @@ -38,7 +41,7 @@ func MockConfig(t *testing.T, cfg *config.Config) func() { // Write the config file if one is provided if cfg != nil { - err = config.UpdateConfig(func(c *config.Config) { *c = *cfg }) + err = config.UpdateConfig(func(c *config.Config) { *c = *cfg }, logger) require.NoError(t, err) } @@ -50,15 +53,15 @@ func MockConfig(t *testing.T, cfg *config.Config) func() { func TestRegistryRouter(t *testing.T) { t.Parallel() - logger.Initialize() + logger := log.NewLogger() - router := RegistryRouter() + router := RegistryRouter(logger) assert.NotNil(t, router) } //nolint:paralleltest // Cannot use t.Parallel() with t.Setenv() in Go 1.24+ func TestGetRegistryInfo(t *testing.T) { - logger.Initialize() + logger := log.NewLogger() // Setup temporary config to avoid modifying user's real config cleanup := MockConfig(t, nil) @@ -73,7 +76,7 @@ func TestGetRegistryInfo(t *testing.T) { { name: "default registry", setupConfig: func() { - _ = config.UnsetRegistry() + _ = config.UnsetRegistry(logger) }, expectedType: RegistryTypeDefault, expectedSource: "", @@ -81,7 +84,7 @@ func TestGetRegistryInfo(t *testing.T) { { name: "URL registry", setupConfig: func() { - _ = config.SetRegistryURL("https://test.com/registry.json", false) + _ = config.SetRegistryURL("https://test.com/registry.json", false, logger) }, expectedType: RegistryTypeURL, expectedSource: "https://test.com/registry.json", @@ -89,10 +92,10 @@ func TestGetRegistryInfo(t *testing.T) { { name: "file registry", setupConfig: func() { - _ = config.UnsetRegistry() + _ = config.UnsetRegistry(logger) _ = config.UpdateConfig(func(c *config.Config) { c.LocalRegistryPath = "/tmp/test-registry.json" - }) + }, logger) }, expectedType: RegistryTypeFile, expectedSource: "/tmp/test-registry.json", @@ -105,7 +108,7 @@ func TestGetRegistryInfo(t *testing.T) { tt.setupConfig() } - registryType, source := getRegistryInfo() + registryType, source := getRegistryInfo(logger) assert.Equal(t, tt.expectedType, registryType, "Registry type should match expected") assert.Equal(t, tt.expectedSource, source, "Registry source should match expected") }) @@ -115,9 +118,9 @@ func TestGetRegistryInfo(t *testing.T) { func TestRegistryAPI_PutEndpoint(t *testing.T) { t.Parallel() - logger.Initialize() + logger := log.NewLogger() - routes := &RegistryRoutes{} + routes := &RegistryRoutes{logger} tests := []struct { name string diff --git a/pkg/api/v1/secrets.go b/pkg/api/v1/secrets.go index b491f52cc..5c4b192e3 100644 --- a/pkg/api/v1/secrets.go +++ b/pkg/api/v1/secrets.go @@ -8,9 +8,9 @@ import ( "net/http" "github.com/go-chi/chi/v5" + "go.uber.org/zap" "github.com/stacklok/toolhive/pkg/config" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/secrets" ) @@ -20,11 +20,15 @@ const ( ) // SecretsRoutes defines the routes for the secrets API. -type SecretsRoutes struct{} +type SecretsRoutes struct { + logger *zap.SugaredLogger +} // SecretsRouter creates a new router for the secrets API. -func SecretsRouter() http.Handler { - routes := SecretsRoutes{} +func SecretsRouter(logger *zap.SugaredLogger) http.Handler { + routes := SecretsRoutes{ + logger: logger, + } r := chi.NewRouter() @@ -59,10 +63,10 @@ func SecretsRouter() http.Handler { // @Failure 400 {string} string "Bad Request" // @Failure 500 {string} string "Internal Server Error" // @Router /api/v1beta/secrets [post] -func (*SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Request) { +func (s *SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Request) { var req setupSecretsRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - logger.Errorf("Failed to decode request body: %v", err) + s.logger.Errorf("Failed to decode request body: %v", err) http.Error(w, "Invalid request body", http.StatusBadRequest) return } @@ -87,13 +91,13 @@ func (*SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Reques } // Check current secrets provider configuration for appropriate messaging - cfg := config.GetConfig() + cfg := config.GetConfig(s.logger) isReconfiguration := false isInitialSetup := !cfg.Secrets.SetupCompleted if cfg.Secrets.SetupCompleted { currentProviderType, err := cfg.Secrets.GetProviderType() if err != nil { - logger.Errorf("Failed to get current provider type: %v", err) + s.logger.Errorf("Failed to get current provider type: %v", err) http.Error(w, "Failed to get current provider configuration", http.StatusInternalServerError) return } @@ -101,10 +105,10 @@ func (*SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Reques // TODO Handle provider reconfiguration in a better way if currentProviderType == providerType { isReconfiguration = true - logger.Infof("Reconfiguring existing %s secrets provider", providerType) + s.logger.Infof("Reconfiguring existing %s secrets provider", providerType) } else { isReconfiguration = true // Changing provider type is also considered reconfiguration - logger.Warnf("Changing secrets provider from %s to %s", currentProviderType, providerType) + s.logger.Warnf("Changing secrets provider from %s to %s", currentProviderType, providerType) } } @@ -115,17 +119,17 @@ func (*SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Reques if req.Password != "" { // Use provided password passwordToUse = req.Password - logger.Infof("Using provided password for encrypted provider setup") + s.logger.Infof("Using provided password for encrypted provider setup") } else { // Generate a secure random password generatedPassword, err := secrets.GenerateSecurePassword() if err != nil { - logger.Errorf("Failed to generate secure password: %v", err) + s.logger.Errorf("Failed to generate secure password: %v", err) http.Error(w, "Failed to generate secure password", http.StatusInternalServerError) return } passwordToUse = generatedPassword - logger.Infof("Generated secure random password for encrypted provider setup") + s.logger.Infof("Generated secure random password for encrypted provider setup") } } @@ -133,9 +137,9 @@ func (*SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Reques // Validate that the provider can be created and works correctly // Use the password from the request for encrypted provider validation and setup ctx := context.Background() - result := secrets.ValidateProviderWithPassword(ctx, providerType, passwordToUse) + result := secrets.ValidateProviderWithPassword(ctx, providerType, passwordToUse, s.logger) if !result.Success { - logger.Errorf("Provider validation failed: %v", result.Error) + s.logger.Errorf("Provider validation failed: %v", result.Error) if errors.Is(result.Error, secrets.ErrKeyringNotAvailable) { http.Error(w, result.Error.Error(), http.StatusBadRequest) return @@ -147,22 +151,22 @@ func (*SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Reques // For encrypted provider during initial setup or reconfiguration, ensure we create the provider // at least once to save password in keyring if providerType == secrets.EncryptedType && (isInitialSetup || isReconfiguration) { - _, err := secrets.CreateSecretProviderWithPassword(providerType, passwordToUse) + _, err := secrets.CreateSecretProviderWithPassword(providerType, passwordToUse, s.logger) if err != nil { - logger.Errorf("Failed to initialize encrypted provider: %v", err) + s.logger.Errorf("Failed to initialize encrypted provider: %v", err) http.Error(w, fmt.Sprintf("Failed to initialize encrypted provider: %v", err), http.StatusInternalServerError) return } - logger.Info("Encrypted provider initialized and password saved to keyring") + s.logger.Info("Encrypted provider initialized and password saved to keyring") } // Update the secrets provider type and mark setup as completed err := config.UpdateConfig(func(c *config.Config) { c.Secrets.ProviderType = string(providerType) c.Secrets.SetupCompleted = true - }) + }, s.logger) if err != nil { - logger.Errorf("Failed to update configuration: %v", err) + s.logger.Errorf("Failed to update configuration: %v", err) http.Error(w, fmt.Sprintf("Failed to update configuration: %v", err), http.StatusInternalServerError) return } @@ -182,7 +186,7 @@ func (*SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Reques Message: message, } if err := json.NewEncoder(w).Encode(resp); err != nil { - logger.Errorf("Failed to encode response: %v", err) + s.logger.Errorf("Failed to encode response: %v", err) http.Error(w, "Failed to encode response", http.StatusInternalServerError) return } @@ -199,7 +203,7 @@ func (*SecretsRoutes) setupSecretsProvider(w http.ResponseWriter, r *http.Reques // @Failure 500 {string} string "Internal Server Error" // @Router /api/v1beta/secrets/default [get] func (s *SecretsRoutes) getSecretsProvider(w http.ResponseWriter, _ *http.Request) { - cfg := config.GetConfig() + cfg := config.GetConfig(s.logger) // Check if secrets provider is setup if !cfg.Secrets.SetupCompleted { @@ -209,7 +213,7 @@ func (s *SecretsRoutes) getSecretsProvider(w http.ResponseWriter, _ *http.Reques providerType, err := cfg.Secrets.GetProviderType() if err != nil { - logger.Errorf("Failed to get provider type: %v", err) + s.logger.Errorf("Failed to get provider type: %v", err) http.Error(w, "Failed to get provider type", http.StatusInternalServerError) return } @@ -217,7 +221,7 @@ func (s *SecretsRoutes) getSecretsProvider(w http.ResponseWriter, _ *http.Reques // Get provider capabilities provider, err := s.getSecretsManager() if err != nil { - logger.Errorf("Failed to create secrets provider: %v", err) + s.logger.Errorf("Failed to create secrets provider: %v", err) http.Error(w, "Failed to access secrets provider", http.StatusInternalServerError) return } @@ -237,7 +241,7 @@ func (s *SecretsRoutes) getSecretsProvider(w http.ResponseWriter, _ *http.Reques }, } if err := json.NewEncoder(w).Encode(resp); err != nil { - logger.Errorf("Failed to encode response: %v", err) + s.logger.Errorf("Failed to encode response: %v", err) http.Error(w, "Failed to encode response", http.StatusInternalServerError) return } @@ -261,7 +265,7 @@ func (s *SecretsRoutes) listSecrets(w http.ResponseWriter, r *http.Request) { http.Error(w, "Secrets provider not setup", http.StatusNotFound) return } - logger.Errorf("Failed to get secrets manager: %v", err) + s.logger.Errorf("Failed to get secrets manager: %v", err) http.Error(w, "Failed to access secrets provider", http.StatusInternalServerError) return } @@ -274,7 +278,7 @@ func (s *SecretsRoutes) listSecrets(w http.ResponseWriter, r *http.Request) { secretDescriptions, err := provider.ListSecrets(r.Context()) if err != nil { - logger.Errorf("Failed to list secrets: %v", err) + s.logger.Errorf("Failed to list secrets: %v", err) http.Error(w, "Failed to list secrets", http.StatusInternalServerError) return } @@ -290,7 +294,7 @@ func (s *SecretsRoutes) listSecrets(w http.ResponseWriter, r *http.Request) { } } if err := json.NewEncoder(w).Encode(resp); err != nil { - logger.Errorf("Failed to encode response: %v", err) + s.logger.Errorf("Failed to encode response: %v", err) http.Error(w, "Failed to encode response", http.StatusInternalServerError) return } @@ -314,7 +318,7 @@ func (s *SecretsRoutes) listSecrets(w http.ResponseWriter, r *http.Request) { func (s *SecretsRoutes) createSecret(w http.ResponseWriter, r *http.Request) { var req createSecretRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - logger.Errorf("Failed to decode request body: %v", err) + s.logger.Errorf("Failed to decode request body: %v", err) http.Error(w, "Invalid request body", http.StatusBadRequest) return } @@ -330,7 +334,7 @@ func (s *SecretsRoutes) createSecret(w http.ResponseWriter, r *http.Request) { http.Error(w, "Secrets provider not setup", http.StatusNotFound) return } - logger.Errorf("Failed to get secrets manager: %v", err) + s.logger.Errorf("Failed to get secrets manager: %v", err) http.Error(w, "Failed to access secrets provider", http.StatusInternalServerError) return } @@ -352,7 +356,7 @@ func (s *SecretsRoutes) createSecret(w http.ResponseWriter, r *http.Request) { // Create the secret if err := provider.SetSecret(r.Context(), req.Key, req.Value); err != nil { - logger.Errorf("Failed to create secret: %v", err) + s.logger.Errorf("Failed to create secret: %v", err) http.Error(w, "Failed to create secret", http.StatusInternalServerError) return } @@ -364,7 +368,7 @@ func (s *SecretsRoutes) createSecret(w http.ResponseWriter, r *http.Request) { Message: "Secret created successfully", } if err := json.NewEncoder(w).Encode(resp); err != nil { - logger.Errorf("Failed to encode response: %v", err) + s.logger.Errorf("Failed to encode response: %v", err) http.Error(w, "Failed to encode response", http.StatusInternalServerError) return } @@ -394,7 +398,7 @@ func (s *SecretsRoutes) updateSecret(w http.ResponseWriter, r *http.Request) { var req updateSecretRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - logger.Errorf("Failed to decode request body: %v", err) + s.logger.Errorf("Failed to decode request body: %v", err) http.Error(w, "Invalid request body", http.StatusBadRequest) return } @@ -410,7 +414,7 @@ func (s *SecretsRoutes) updateSecret(w http.ResponseWriter, r *http.Request) { http.Error(w, "Secrets provider not setup", http.StatusNotFound) return } - logger.Errorf("Failed to get secrets manager: %v", err) + s.logger.Errorf("Failed to get secrets manager: %v", err) http.Error(w, "Failed to access secrets provider", http.StatusInternalServerError) return } @@ -432,7 +436,7 @@ func (s *SecretsRoutes) updateSecret(w http.ResponseWriter, r *http.Request) { // Update the secret if err := provider.SetSecret(r.Context(), key, req.Value); err != nil { - logger.Errorf("Failed to update secret: %v", err) + s.logger.Errorf("Failed to update secret: %v", err) http.Error(w, "Failed to update secret", http.StatusInternalServerError) return } @@ -443,7 +447,7 @@ func (s *SecretsRoutes) updateSecret(w http.ResponseWriter, r *http.Request) { Message: "Secret updated successfully", } if err := json.NewEncoder(w).Encode(resp); err != nil { - logger.Errorf("Failed to encode response: %v", err) + s.logger.Errorf("Failed to encode response: %v", err) http.Error(w, "Failed to encode response", http.StatusInternalServerError) return } @@ -473,7 +477,7 @@ func (s *SecretsRoutes) deleteSecret(w http.ResponseWriter, r *http.Request) { http.Error(w, "Secrets provider not setup", http.StatusNotFound) return } - logger.Errorf("Failed to get secrets manager: %v", err) + s.logger.Errorf("Failed to get secrets manager: %v", err) http.Error(w, "Failed to access secrets provider", http.StatusInternalServerError) return } @@ -486,7 +490,7 @@ func (s *SecretsRoutes) deleteSecret(w http.ResponseWriter, r *http.Request) { // Delete the secret if err := provider.DeleteSecret(r.Context(), key); err != nil { - logger.Errorf("Failed to delete secret: %v", err) + s.logger.Errorf("Failed to delete secret: %v", err) // Check if it's a "not found" error if err.Error() == "cannot delete non-existent secret: "+key { http.Error(w, "Secret not found", http.StatusNotFound) @@ -500,8 +504,8 @@ func (s *SecretsRoutes) deleteSecret(w http.ResponseWriter, r *http.Request) { } // getSecretsManager is a helper function to get the secrets manager -func (*SecretsRoutes) getSecretsManager() (secrets.Provider, error) { - cfg := config.GetConfig() +func (s *SecretsRoutes) getSecretsManager() (secrets.Provider, error) { + cfg := config.GetConfig(s.logger) // Check if secrets setup has been completed if !cfg.Secrets.SetupCompleted { @@ -513,7 +517,7 @@ func (*SecretsRoutes) getSecretsManager() (secrets.Provider, error) { return nil, err } - return secrets.CreateSecretProvider(providerType) + return secrets.CreateSecretProvider(providerType, s.logger) } // Request and response type definitions diff --git a/pkg/api/v1/secrets_test.go b/pkg/api/v1/secrets_test.go index afc51de24..9543746fa 100644 --- a/pkg/api/v1/secrets_test.go +++ b/pkg/api/v1/secrets_test.go @@ -13,19 +13,20 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/secrets" ) func TestSecretsRouter(t *testing.T) { t.Parallel() - router := SecretsRouter() + logger := log.NewLogger() + router := SecretsRouter(logger) assert.NotNil(t, router) } func TestSetupSecretsProvider_ValidRequests(t *testing.T) { t.Parallel() - logger.Initialize() + logger := log.NewLogger() tests := []struct { name string @@ -53,7 +54,7 @@ func TestSetupSecretsProvider_ValidRequests(t *testing.T) { req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - routes := &SecretsRoutes{} + routes := &SecretsRoutes{logger} routes.setupSecretsProvider(w, req) assert.Equal(t, tt.expectedCode, w.Code) @@ -72,7 +73,7 @@ func TestSetupSecretsProvider_ValidRequests(t *testing.T) { func TestSetupSecretsProvider_InvalidRequests(t *testing.T) { t.Parallel() - logger.Initialize() + logger := log.NewLogger() tests := []struct { name string @@ -114,7 +115,7 @@ func TestSetupSecretsProvider_InvalidRequests(t *testing.T) { req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - routes := &SecretsRoutes{} + routes := &SecretsRoutes{logger} routes.setupSecretsProvider(w, req) assert.Equal(t, tt.expectedCode, w.Code) @@ -125,7 +126,7 @@ func TestSetupSecretsProvider_InvalidRequests(t *testing.T) { func TestCreateSecret_InvalidRequests(t *testing.T) { t.Parallel() - logger.Initialize() + logger := log.NewLogger() tests := []struct { name string @@ -177,7 +178,7 @@ func TestCreateSecret_InvalidRequests(t *testing.T) { req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - routes := &SecretsRoutes{} + routes := &SecretsRoutes{logger} routes.createSecret(w, req) assert.Equal(t, tt.expectedCode, w.Code) @@ -188,7 +189,7 @@ func TestCreateSecret_InvalidRequests(t *testing.T) { func TestUpdateSecret_InvalidRequests(t *testing.T) { t.Parallel() - logger.Initialize() + logger := log.NewLogger() tests := []struct { name string @@ -249,7 +250,7 @@ func TestUpdateSecret_InvalidRequests(t *testing.T) { w := httptest.NewRecorder() - routes := &SecretsRoutes{} + routes := &SecretsRoutes{logger} routes.updateSecret(w, req) assert.Equal(t, tt.expectedCode, w.Code) @@ -260,7 +261,7 @@ func TestUpdateSecret_InvalidRequests(t *testing.T) { func TestDeleteSecret_InvalidRequests(t *testing.T) { t.Parallel() - logger.Initialize() + logger := log.NewLogger() tests := []struct { name string @@ -291,7 +292,7 @@ func TestDeleteSecret_InvalidRequests(t *testing.T) { w := httptest.NewRecorder() - routes := &SecretsRoutes{} + routes := &SecretsRoutes{logger} routes.deleteSecret(w, req) assert.Equal(t, tt.expectedCode, w.Code) @@ -401,8 +402,7 @@ func TestRequestResponseTypes(t *testing.T) { func TestErrorHandling(t *testing.T) { t.Parallel() - logger.Initialize() - + logger := log.NewLogger() t.Run("malformed json request", func(t *testing.T) { t.Parallel() malformedJSON := `{"provider_type": "encrypted", "invalid": json}` @@ -410,7 +410,7 @@ func TestErrorHandling(t *testing.T) { req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - routes := &SecretsRoutes{} + routes := &SecretsRoutes{logger} routes.setupSecretsProvider(w, req) assert.Equal(t, http.StatusBadRequest, w.Code) @@ -423,7 +423,7 @@ func TestErrorHandling(t *testing.T) { req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() - routes := &SecretsRoutes{} + routes := &SecretsRoutes{logger} routes.createSecret(w, req) assert.Equal(t, http.StatusBadRequest, w.Code) @@ -435,7 +435,7 @@ func TestErrorHandling(t *testing.T) { // Deliberately not setting Content-Type header w := httptest.NewRecorder() - routes := &SecretsRoutes{} + routes := &SecretsRoutes{logger} routes.setupSecretsProvider(w, req) // Should still work as the handler doesn't strictly require content-type @@ -445,11 +445,12 @@ func TestErrorHandling(t *testing.T) { func TestRouterIntegration(t *testing.T) { t.Parallel() - logger.Initialize() + + logger := log.NewLogger() t.Run("router setup test", func(t *testing.T) { t.Parallel() - router := SecretsRouter() + router := SecretsRouter(logger) // Test POST / endpoint setupReq := setupSecretsRequest{ diff --git a/pkg/api/v1/workloads.go b/pkg/api/v1/workloads.go index c4b5867ba..681070757 100644 --- a/pkg/api/v1/workloads.go +++ b/pkg/api/v1/workloads.go @@ -8,12 +8,12 @@ import ( "net/http" "github.com/go-chi/chi/v5" + "go.uber.org/zap" "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/core" thverrors "github.com/stacklok/toolhive/pkg/errors" "github.com/stacklok/toolhive/pkg/groups" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/permissions" "github.com/stacklok/toolhive/pkg/runner" "github.com/stacklok/toolhive/pkg/runner/retriever" @@ -31,6 +31,7 @@ type WorkloadRoutes struct { containerRuntime runtime.Runtime debugMode bool groupManager groups.Manager + logger *zap.SugaredLogger } // @title ToolHive API @@ -45,12 +46,14 @@ func WorkloadRouter( containerRuntime runtime.Runtime, groupManager groups.Manager, debugMode bool, + logger *zap.SugaredLogger, ) http.Handler { routes := WorkloadRoutes{ workloadManager: workloadManager, containerRuntime: containerRuntime, debugMode: debugMode, groupManager: groupManager, + logger: logger, } r := chi.NewRouter() @@ -86,7 +89,7 @@ func (s *WorkloadRoutes) listWorkloads(w http.ResponseWriter, r *http.Request) { workloadList, err := s.workloadManager.ListWorkloads(ctx, listAll) if err != nil { - logger.Errorf("Failed to list workloads: %v", err) + s.logger.Errorf("Failed to list workloads: %v", err) http.Error(w, "Failed to list workloads", http.StatusInternalServerError) return } @@ -102,7 +105,7 @@ func (s *WorkloadRoutes) listWorkloads(w http.ResponseWriter, r *http.Request) { if thverrors.IsGroupNotFound(err) { http.Error(w, "Group not found", http.StatusNotFound) } else { - logger.Errorf("Failed to filter workloads by group: %v", err) + s.logger.Errorf("Failed to filter workloads by group: %v", err) http.Error(w, "Failed to list workloads in group", http.StatusInternalServerError) } return @@ -140,7 +143,7 @@ func (s *WorkloadRoutes) getWorkload(w http.ResponseWriter, r *http.Request) { http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest) return } - logger.Errorf("Failed to get workload: %v", err) + s.logger.Errorf("Failed to get workload: %v", err) http.Error(w, "Failed to get workload", http.StatusInternalServerError) return } @@ -174,7 +177,7 @@ func (s *WorkloadRoutes) stopWorkload(w http.ResponseWriter, r *http.Request) { http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest) return } - logger.Errorf("Failed to stop workload: %v", err) + s.logger.Errorf("Failed to stop workload: %v", err) http.Error(w, "Failed to stop workload", http.StatusInternalServerError) return } @@ -203,7 +206,7 @@ func (s *WorkloadRoutes) restartWorkload(w http.ResponseWriter, r *http.Request) http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest) return } - logger.Errorf("Failed to restart workload: %v", err) + s.logger.Errorf("Failed to restart workload: %v", err) http.Error(w, "Failed to restart workload", http.StatusInternalServerError) return } @@ -231,7 +234,7 @@ func (s *WorkloadRoutes) deleteWorkload(w http.ResponseWriter, r *http.Request) http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest) return } - logger.Errorf("Failed to delete workload: %v", err) + s.logger.Errorf("Failed to delete workload: %v", err) http.Error(w, "Failed to delete workload", http.StatusInternalServerError) return } @@ -265,6 +268,7 @@ func (s *WorkloadRoutes) createWorkload(w http.ResponseWriter, r *http.Request) req.Image, "", // We do not let the user specify a CA cert path here. retriever.VerifyImageWarn, + s.logger, ) if err != nil { if errors.Is(err, retriever.ErrImageNotFound) { @@ -277,7 +281,7 @@ func (s *WorkloadRoutes) createWorkload(w http.ResponseWriter, r *http.Request) // NOTE: None of the k8s-related config logic is included here. runSecrets := secrets.SecretParametersToCLI(req.Secrets) - runConfig, err := runner.NewRunConfigBuilder(). + runConfig, err := runner.NewRunConfigBuilder(s.logger). WithRuntime(s.containerRuntime). WithCmdArgs(req.CmdArguments). WithName(req.Name). @@ -299,15 +303,15 @@ func (s *WorkloadRoutes) createWorkload(w http.ResponseWriter, r *http.Request) "", "", "", "", "", false). // JWKS auth parameters not exposed through API yet WithTelemetryConfig("", false, "", 0.0, nil, false, nil). // Not exposed through API yet. WithToolsFilter(req.ToolsFilter). - Build(ctx, imageMetadata, req.EnvVars, &runner.DetachedEnvVarValidator{}) + Build(ctx, imageMetadata, req.EnvVars, runner.NewDetachedEnvVarValidator(s.logger)) if err != nil { - logger.Errorf("Failed to create run config: %v", err) + s.logger.Errorf("Failed to create run config: %v", err) http.Error(w, "Failed to create run config", http.StatusBadRequest) return } - if err := runConfig.SaveState(ctx); err != nil { - logger.Errorf("Failed to save workload config: %v", err) + if err := runConfig.SaveState(ctx, s.logger); err != nil { + s.logger.Errorf("Failed to save workload config: %v", err) http.Error(w, "Failed to save workload config", http.StatusInternalServerError) return } @@ -315,7 +319,7 @@ func (s *WorkloadRoutes) createWorkload(w http.ResponseWriter, r *http.Request) // Start workload with specified RunConfig. err = s.workloadManager.RunWorkloadDetached(ctx, runConfig) if err != nil { - logger.Errorf("Failed to start workload: %v", err) + s.logger.Errorf("Failed to start workload: %v", err) http.Error(w, "Failed to start workload", http.StatusInternalServerError) return } @@ -371,7 +375,7 @@ func (s *WorkloadRoutes) stopWorkloadsBulk(w http.ResponseWriter, r *http.Reques http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest) return } - logger.Errorf("Failed to stop workloads: %v", err) + s.logger.Errorf("Failed to stop workloads: %v", err) http.Error(w, "Failed to stop workloads", http.StatusInternalServerError) return } @@ -417,7 +421,7 @@ func (s *WorkloadRoutes) restartWorkloadsBulk(w http.ResponseWriter, r *http.Req http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest) return } - logger.Errorf("Failed to restart workloads: %v", err) + s.logger.Errorf("Failed to restart workloads: %v", err) http.Error(w, "Failed to restart workloads", http.StatusInternalServerError) return } @@ -462,7 +466,7 @@ func (s *WorkloadRoutes) deleteWorkloadsBulk(w http.ResponseWriter, r *http.Requ http.Error(w, "Invalid workload name: "+err.Error(), http.StatusBadRequest) return } - logger.Errorf("Failed to delete workloads: %v", err) + s.logger.Errorf("Failed to delete workloads: %v", err) http.Error(w, "Failed to delete workloads", http.StatusInternalServerError) return } @@ -489,14 +493,14 @@ func (s *WorkloadRoutes) getLogsForWorkload(w http.ResponseWriter, r *http.Reque http.Error(w, "Workload not found", http.StatusNotFound) return } - logger.Errorf("Failed to get logs: %v", err) + s.logger.Errorf("Failed to get logs: %v", err) http.Error(w, "Failed to get logs", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "text/plain") _, err = w.Write([]byte(logs)) if err != nil { - logger.Errorf("Failed to write logs response: %v", err) + s.logger.Errorf("Failed to write logs response: %v", err) http.Error(w, "Failed to write logs response", http.StatusInternalServerError) return } @@ -512,7 +516,7 @@ func (s *WorkloadRoutes) getLogsForWorkload(w http.ResponseWriter, r *http.Reque // @Success 200 {object} runner.RunConfig // @Failure 404 {string} string "Not Found" // @Router /api/v1beta/workloads/{name}/export [get] -func (*WorkloadRoutes) exportWorkload(w http.ResponseWriter, r *http.Request) { +func (s *WorkloadRoutes) exportWorkload(w http.ResponseWriter, r *http.Request) { ctx := r.Context() name := chi.URLParam(r, "name") @@ -523,7 +527,7 @@ func (*WorkloadRoutes) exportWorkload(w http.ResponseWriter, r *http.Request) { http.Error(w, "Workload configuration not found", http.StatusNotFound) return } - logger.Errorf("Failed to load workload configuration: %v", err) + s.logger.Errorf("Failed to load workload configuration: %v", err) http.Error(w, "Failed to load workload configuration", http.StatusInternalServerError) return } @@ -531,7 +535,7 @@ func (*WorkloadRoutes) exportWorkload(w http.ResponseWriter, r *http.Request) { // Return the configuration as JSON w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(runConfig); err != nil { - logger.Errorf("Failed to encode workload configuration: %v", err) + s.logger.Errorf("Failed to encode workload configuration: %v", err) http.Error(w, "Failed to encode workload configuration", http.StatusInternalServerError) return } diff --git a/pkg/audit/auditor.go b/pkg/audit/auditor.go index 1de1ef135..77fb07f32 100644 --- a/pkg/audit/auditor.go +++ b/pkg/audit/auditor.go @@ -12,8 +12,9 @@ import ( "strings" "time" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/auth" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/mcp" ) @@ -40,7 +41,7 @@ type Auditor struct { } // NewAuditor creates a new Auditor with the given configuration. -func NewAuditor(config *Config) (*Auditor, error) { +func NewAuditor(config *Config, logger *zap.SugaredLogger) (*Auditor, error) { var logWriter io.Writer = os.Stdout // default to stdout if config != nil { diff --git a/pkg/audit/auditor_test.go b/pkg/audit/auditor_test.go index 610add0ad..c8eecffe9 100644 --- a/pkg/audit/auditor_test.go +++ b/pkg/audit/auditor_test.go @@ -15,18 +15,14 @@ import ( "github.com/stretchr/testify/require" "github.com/stacklok/toolhive/pkg/auth" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" ) -func init() { - // Initialize logger for tests - logger.Initialize() -} - func TestNewAuditor(t *testing.T) { t.Parallel() + logger := log.NewLogger() config := &Config{} - auditor, err := NewAuditor(config) + auditor, err := NewAuditor(config, logger) assert.NoError(t, err) assert.NotNil(t, auditor) @@ -35,8 +31,9 @@ func TestNewAuditor(t *testing.T) { func TestAuditorMiddlewareDisabled(t *testing.T) { t.Parallel() + logger := log.NewLogger() config := &Config{} - auditor, err := NewAuditor(config) + auditor, err := NewAuditor(config, logger) require.NoError(t, err) handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -57,11 +54,12 @@ func TestAuditorMiddlewareDisabled(t *testing.T) { func TestAuditorMiddlewareWithRequestData(t *testing.T) { t.Parallel() + logger := log.NewLogger() config := &Config{ IncludeRequestData: true, MaxDataSize: 1024, } - auditor, err := NewAuditor(config) + auditor, err := NewAuditor(config, logger) require.NoError(t, err) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -87,11 +85,12 @@ func TestAuditorMiddlewareWithRequestData(t *testing.T) { func TestAuditorMiddlewareWithResponseData(t *testing.T) { t.Parallel() + logger := log.NewLogger() config := &Config{ IncludeResponseData: true, MaxDataSize: 1024, } - auditor, err := NewAuditor(config) + auditor, err := NewAuditor(config, logger) require.NoError(t, err) responseData := `{"result": "success"}` @@ -114,7 +113,8 @@ func TestAuditorMiddlewareWithResponseData(t *testing.T) { func TestDetermineEventType(t *testing.T) { t.Parallel() - auditor, err := NewAuditor(&Config{}) + logger := log.NewLogger() + auditor, err := NewAuditor(&Config{}, logger) require.NoError(t, err) tests := []struct { @@ -155,6 +155,7 @@ func TestDetermineEventType(t *testing.T) { func TestMapMCPMethodToEventType(t *testing.T) { t.Parallel() + logger := log.NewLogger() tests := []struct { mcpMethod string expected string @@ -174,7 +175,7 @@ func TestMapMCPMethodToEventType(t *testing.T) { {"unknown_method", "mcp_request"}, } - auditor, err := NewAuditor(&Config{}) + auditor, err := NewAuditor(&Config{}, logger) require.NoError(t, err) for _, tt := range tests { t.Run(tt.mcpMethod, func(t *testing.T) { @@ -187,7 +188,8 @@ func TestMapMCPMethodToEventType(t *testing.T) { func TestDetermineOutcome(t *testing.T) { t.Parallel() - auditor, err := NewAuditor(&Config{}) + logger := log.NewLogger() + auditor, err := NewAuditor(&Config{}, logger) require.NoError(t, err) tests := []struct { @@ -218,7 +220,8 @@ func TestDetermineOutcome(t *testing.T) { func TestGetClientIP(t *testing.T) { t.Parallel() - auditor, err := NewAuditor(&Config{}) + logger := log.NewLogger() + auditor, err := NewAuditor(&Config{}, logger) require.NoError(t, err) tests := []struct { @@ -268,7 +271,8 @@ func TestGetClientIP(t *testing.T) { func TestExtractSubjects(t *testing.T) { t.Parallel() - auditor, err := NewAuditor(&Config{}) + logger := log.NewLogger() + auditor, err := NewAuditor(&Config{}, logger) require.NoError(t, err) t.Run("with JWT claims", func(t *testing.T) { @@ -339,10 +343,11 @@ func TestExtractSubjects(t *testing.T) { func TestDetermineComponent(t *testing.T) { t.Parallel() + logger := log.NewLogger() t.Run("with configured component", func(t *testing.T) { t.Parallel() config := &Config{Component: "custom-component"} - auditor, err := NewAuditor(config) + auditor, err := NewAuditor(config, logger) require.NoError(t, err) req := httptest.NewRequest("GET", "/test", nil) @@ -354,7 +359,7 @@ func TestDetermineComponent(t *testing.T) { t.Run("without configured component", func(t *testing.T) { t.Parallel() config := &Config{} - auditor, err := NewAuditor(config) + auditor, err := NewAuditor(config, logger) require.NoError(t, err) req := httptest.NewRequest("GET", "/test", nil) @@ -366,7 +371,8 @@ func TestDetermineComponent(t *testing.T) { func TestExtractTarget(t *testing.T) { t.Parallel() - auditor, err := NewAuditor(&Config{}) + logger := log.NewLogger() + auditor, err := NewAuditor(&Config{}, logger) require.NoError(t, err) tests := []struct { @@ -423,7 +429,8 @@ func TestExtractTarget(t *testing.T) { func TestAddMetadata(t *testing.T) { t.Parallel() - auditor, err := NewAuditor(&Config{}) + logger := log.NewLogger() + auditor, err := NewAuditor(&Config{}, logger) require.NoError(t, err) event := NewAuditEvent("test", EventSource{}, OutcomeSuccess, map[string]string{}, "test") @@ -444,13 +451,14 @@ func TestAddMetadata(t *testing.T) { func TestAddEventData(t *testing.T) { t.Parallel() + logger := log.NewLogger() t.Run("with request and response data", func(t *testing.T) { t.Parallel() config := &Config{ IncludeRequestData: true, IncludeResponseData: true, } - auditor, err := NewAuditor(config) + auditor, err := NewAuditor(config, logger) require.NoError(t, err) event := NewAuditEvent("test", EventSource{}, OutcomeSuccess, map[string]string{}, "test") @@ -483,7 +491,7 @@ func TestAddEventData(t *testing.T) { IncludeRequestData: true, IncludeResponseData: true, } - auditor, err := NewAuditor(config) + auditor, err := NewAuditor(config, logger) require.NoError(t, err) event := NewAuditEvent("test", EventSource{}, OutcomeSuccess, map[string]string{}, "test") @@ -511,7 +519,7 @@ func TestAddEventData(t *testing.T) { IncludeRequestData: false, IncludeResponseData: false, } - auditor, err := NewAuditor(config) + auditor, err := NewAuditor(config, logger) require.NoError(t, err) event := NewAuditEvent("test", EventSource{}, OutcomeSuccess, map[string]string{}, "test") @@ -527,11 +535,12 @@ func TestAddEventData(t *testing.T) { func TestResponseWriterCapture(t *testing.T) { t.Parallel() + logger := log.NewLogger() config := &Config{ IncludeResponseData: true, MaxDataSize: 10, // Small limit for testing } - auditor, err := NewAuditor(config) + auditor, err := NewAuditor(config, logger) require.NoError(t, err) rw := &responseWriter{ @@ -568,7 +577,8 @@ func TestResponseWriterStatusCode(t *testing.T) { func TestExtractSourceWithHeaders(t *testing.T) { t.Parallel() - auditor, err := NewAuditor(&Config{}) + logger := log.NewLogger() + auditor, err := NewAuditor(&Config{}, logger) require.NoError(t, err) req := httptest.NewRequest("GET", "/test", nil) diff --git a/pkg/audit/config.go b/pkg/audit/config.go index 712358a50..b070836a0 100644 --- a/pkg/audit/config.go +++ b/pkg/audit/config.go @@ -9,6 +9,8 @@ import ( "os" "path/filepath" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/transport/types" ) @@ -105,8 +107,8 @@ func (c *Config) ShouldAuditEvent(eventType string) bool { } // CreateMiddleware creates an HTTP middleware from the audit configuration. -func (c *Config) CreateMiddleware() (types.MiddlewareFunction, error) { - auditor, err := NewAuditor(c) +func (c *Config) CreateMiddleware(logger *zap.SugaredLogger) (types.MiddlewareFunction, error) { + auditor, err := NewAuditor(c, logger) if err != nil { return nil, fmt.Errorf("failed to create auditor: %w", err) } @@ -114,7 +116,7 @@ func (c *Config) CreateMiddleware() (types.MiddlewareFunction, error) { } // GetMiddlewareFromFile loads the audit configuration from a file and creates an HTTP middleware. -func GetMiddlewareFromFile(path string) (func(http.Handler) http.Handler, error) { +func GetMiddlewareFromFile(path string, logger *zap.SugaredLogger) (func(http.Handler) http.Handler, error) { // Load the configuration config, err := LoadFromFile(path) if err != nil { @@ -122,7 +124,7 @@ func GetMiddlewareFromFile(path string) (func(http.Handler) http.Handler, error) } // Create the middleware - return config.CreateMiddleware() + return config.CreateMiddleware(logger) } // Validate validates the audit configuration. diff --git a/pkg/audit/config_test.go b/pkg/audit/config_test.go index 5f6fb2a7b..769cd133b 100644 --- a/pkg/audit/config_test.go +++ b/pkg/audit/config_test.go @@ -8,6 +8,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + log "github.com/stacklok/toolhive/pkg/logger" ) func TestDefaultConfig(t *testing.T) { @@ -112,7 +114,9 @@ func TestCreateMiddleware(t *testing.T) { t.Parallel() config := &Config{} - middleware, err := config.CreateMiddleware() + logger := log.NewLogger() + + middleware, err := config.CreateMiddleware(logger) assert.NoError(t, err) assert.NotNil(t, middleware) } @@ -235,8 +239,11 @@ func TestConfigMinimalJSON(t *testing.T) { func TestGetMiddlewareFromFileError(t *testing.T) { t.Parallel() + + logger := log.NewLogger() + // Test with non-existent file - _, err := GetMiddlewareFromFile("/non/existent/file.json") + _, err := GetMiddlewareFromFile("/non/existent/file.json", logger) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to load audit config") } diff --git a/pkg/auth/oauth/flow.go b/pkg/auth/oauth/flow.go index 371fdf902..aecae37e0 100644 --- a/pkg/auth/oauth/flow.go +++ b/pkg/auth/oauth/flow.go @@ -17,9 +17,9 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/pkg/browser" + "go.uber.org/zap" "golang.org/x/oauth2" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/networking" ) @@ -66,6 +66,8 @@ type Flow struct { state string tokenSource oauth2.TokenSource + + logger *zap.SugaredLogger } // TokenResult contains the result of the OAuth flow @@ -79,7 +81,7 @@ type TokenResult struct { } // NewFlow creates a new OAuth flow -func NewFlow(config *Config) (*Flow, error) { +func NewFlow(config *Config, logger *zap.SugaredLogger) (*Flow, error) { if config == nil { return nil, errors.New("OAuth config cannot be nil") } @@ -97,7 +99,7 @@ func NewFlow(config *Config) (*Flow, error) { } // Use specified callback port or find an available port for the local server - port, err := networking.FindOrUsePort(config.CallbackPort) + port, err := networking.FindOrUsePort(config.CallbackPort, logger) if err != nil { return nil, fmt.Errorf("failed to find available port: %w", err) } @@ -124,6 +126,7 @@ func NewFlow(config *Config) (*Flow, error) { config: config, oauth2Config: oauth2Config, port: port, + logger: logger, } // Generate PKCE parameters if enabled @@ -186,7 +189,7 @@ func (f *Flow) Start(ctx context.Context, skipBrowser bool) (*TokenResult, error // Start the server in a goroutine go func() { - logger.Infof("Starting OAuth callback server on port %d", f.port) + f.logger.Infof("Starting OAuth callback server on port %d", f.port) if err := f.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { errorChan <- fmt.Errorf("failed to start callback server: %w", err) } @@ -197,7 +200,7 @@ func (f *Flow) Start(ctx context.Context, skipBrowser bool) (*TokenResult, error shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := f.server.Shutdown(shutdownCtx); err != nil { - logger.Warnf("Failed to shutdown OAuth callback server: %v", err) + f.logger.Warnf("Failed to shutdown OAuth callback server: %v", err) } }() @@ -206,16 +209,16 @@ func (f *Flow) Start(ctx context.Context, skipBrowser bool) (*TokenResult, error // Open browser or display URL if !skipBrowser { - logger.Infof("Opening browser to: %s", authURL) + f.logger.Infof("Opening browser to: %s", authURL) if err := browser.OpenURL(authURL); err != nil { - logger.Warnf("Failed to open browser: %v", err) - logger.Infof("Please manually open this URL in your browser: %s", authURL) + f.logger.Warnf("Failed to open browser: %v", err) + f.logger.Infof("Please manually open this URL in your browser: %s", authURL) } } else { - logger.Infof("Please open this URL in your browser: %s", authURL) + f.logger.Infof("Please open this URL in your browser: %s", authURL) } - logger.Info("Waiting for OAuth callback...") + f.logger.Info("Waiting for OAuth callback...") // Set up signal handling for graceful shutdown sigChan := make(chan os.Signal, 1) @@ -224,7 +227,7 @@ func (f *Flow) Start(ctx context.Context, skipBrowser bool) (*TokenResult, error // Wait for token, error, or cancellation select { case token := <-tokenChan: - logger.Info("OAuth flow completed successfully") + f.logger.Info("OAuth flow completed successfully") return f.processToken(token), nil case err := <-errorChan: return nil, fmt.Errorf("OAuth flow failed: %w", err) @@ -354,7 +357,7 @@ func (f *Flow) handleRoot() http.HandlerFunc { ` if _, err := w.Write([]byte(htmlContent)); err != nil { - logger.Warnf("Failed to write HTML content: %v", err) + f.logger.Warnf("Failed to write HTML content: %v", err) } } } @@ -386,12 +389,12 @@ func (f *Flow) writeSuccessPage(w http.ResponseWriter) { ` if _, err := w.Write([]byte(htmlContent)); err != nil { - logger.Warnf("Failed to write HTML content: %v", err) + f.logger.Warnf("Failed to write HTML content: %v", err) } } // writeErrorPage writes an error page to the response -func (*Flow) writeErrorPage(w http.ResponseWriter, err error) { +func (f *Flow) writeErrorPage(w http.ResponseWriter, err error) { w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Frame-Options", "DENY") @@ -425,7 +428,7 @@ func (*Flow) writeErrorPage(w http.ResponseWriter, err error) { `, escapedError) if _, err := w.Write([]byte(htmlContent)); err != nil { - logger.Warnf("Failed to write HTML content: %v", err) + f.logger.Warnf("Failed to write HTML content: %v", err) } } @@ -449,17 +452,17 @@ func (f *Flow) processToken(token *oauth2.Token) *TokenResult { result.IDToken = idToken if claims, err := f.extractJWTClaims(idToken); err == nil { result.Claims = claims - logger.Debugf("Successfully extracted JWT claims from ID token") + f.logger.Debugf("Successfully extracted JWT claims from ID token") } else { - logger.Debugf("Could not extract JWT claims from ID token: %v", err) + f.logger.Debugf("Could not extract JWT claims from ID token: %v", err) } } else { // Fallback: try to extract claims from the access token (e.g., Keycloak) if claims, err := f.extractJWTClaims(token.AccessToken); err == nil { result.Claims = claims - logger.Debugf("Successfully extracted JWT claims from access token") + f.logger.Debugf("Successfully extracted JWT claims from access token") } else { - logger.Debugf("Could not extract JWT claims from access token (may be opaque token): %v", err) + f.logger.Debugf("Could not extract JWT claims from access token (may be opaque token): %v", err) } } diff --git a/pkg/auth/oauth/flow_test.go b/pkg/auth/oauth/flow_test.go index d88403478..1fd25de4f 100644 --- a/pkg/auth/oauth/flow_test.go +++ b/pkg/auth/oauth/flow_test.go @@ -18,13 +18,10 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/oauth2" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" ) func TestMain(m *testing.M) { - // Initialize logger for tests - logger.Initialize() - // Run tests code := m.Run() @@ -34,6 +31,9 @@ func TestMain(m *testing.M) { func TestNewFlow(t *testing.T) { t.Parallel() + + logger := log.NewLogger() + tests := []struct { name string config *Config @@ -99,7 +99,7 @@ func TestNewFlow(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - flow, err := NewFlow(tt.config) + flow, err := NewFlow(tt.config, logger) if tt.expectError { require.Error(t, err) @@ -201,6 +201,9 @@ func TestGenerateState(t *testing.T) { func TestBuildAuthURL(t *testing.T) { t.Parallel() + + logger := log.NewLogger() + tests := []struct { name string config *Config @@ -262,7 +265,7 @@ func TestBuildAuthURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - flow, err := NewFlow(tt.config) + flow, err := NewFlow(tt.config, logger) require.NoError(t, err) authURL := flow.buildAuthURL() @@ -275,6 +278,9 @@ func TestBuildAuthURL(t *testing.T) { func TestHandleCallback_SecurityValidation(t *testing.T) { t.Parallel() + + logger := log.NewLogger() + config := &Config{ ClientID: "test-client", AuthURL: "https://example.com/auth", @@ -282,7 +288,7 @@ func TestHandleCallback_SecurityValidation(t *testing.T) { UsePKCE: true, } - flow, err := NewFlow(config) + flow, err := NewFlow(config, logger) require.NoError(t, err) tokenChan := make(chan *oauth2.Token, 1) @@ -466,6 +472,10 @@ func TestWriteErrorPage_XSSPrevention(t *testing.T) { func TestProcessToken(t *testing.T) { t.Parallel() + + // Setup logger + logger := log.NewLogger() + // Create a proper flow with config to avoid nil pointer issues config := &Config{ ClientID: "test-client", @@ -473,7 +483,7 @@ func TestProcessToken(t *testing.T) { TokenURL: "https://example.com/token", } - flow, err := NewFlow(config) + flow, err := NewFlow(config, logger) require.NoError(t, err) // Test with a valid OAuth2 token @@ -608,6 +618,9 @@ func TestStateSecurityProperties(t *testing.T) { func TestStart(t *testing.T) { t.Parallel() + + logger := log.NewLogger() + tests := []struct { name string config *Config @@ -640,7 +653,7 @@ func TestStart(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - flow, err := NewFlow(tt.config) + flow, err := NewFlow(tt.config, logger) require.NoError(t, err) // Generate the auth URL before starting the flow @@ -730,6 +743,10 @@ func TestWriteSuccessPage(t *testing.T) { func TestHandleCallback_SuccessfulFlow(t *testing.T) { t.Parallel() + + // Setup logger + logger := log.NewLogger() + // Create a mock token server tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "POST", r.Method) @@ -774,7 +791,7 @@ func TestHandleCallback_SuccessfulFlow(t *testing.T) { UsePKCE: true, } - flow, err := NewFlow(config) + flow, err := NewFlow(config, logger) require.NoError(t, err) tokenChan := make(chan *oauth2.Token, 1) @@ -811,13 +828,16 @@ func TestHandleCallback_SuccessfulFlow(t *testing.T) { func TestProcessToken_WithJWTClaims(t *testing.T) { t.Parallel() + + logger := log.NewLogger() + config := &Config{ ClientID: "test-client", AuthURL: "https://example.com/auth", TokenURL: "https://example.com/token", } - flow, err := NewFlow(config) + flow, err := NewFlow(config, logger) require.NoError(t, err) // Create a test JWT token @@ -859,13 +879,16 @@ func TestProcessToken_WithJWTClaims(t *testing.T) { func TestProcessToken_WithOpaqueToken(t *testing.T) { t.Parallel() + + logger := log.NewLogger() + config := &Config{ ClientID: "test-client", AuthURL: "https://example.com/auth", TokenURL: "https://example.com/token", } - flow, err := NewFlow(config) + flow, err := NewFlow(config, logger) require.NoError(t, err) // Test with opaque access token diff --git a/pkg/auth/token.go b/pkg/auth/token.go index 8957f7c5e..b83d85869 100644 --- a/pkg/auth/token.go +++ b/pkg/auth/token.go @@ -15,8 +15,8 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/lestrrat-go/httprc/v3" "github.com/lestrrat-go/jwx/v3/jwk" + "go.uber.org/zap" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/networking" "github.com/stacklok/toolhive/pkg/versions" ) @@ -474,7 +474,7 @@ type RFC9728AuthInfo struct { } // NewAuthInfoHandler creates an HTTP handler that returns RFC-9728 compliant OAuth Protected Resource metadata -func NewAuthInfoHandler(issuer, jwksURL, resourceURL string, scopes []string) http.Handler { +func NewAuthInfoHandler(issuer, jwksURL, resourceURL string, scopes []string, logger *zap.SugaredLogger) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Set CORS headers for all requests origin := r.Header.Get("Origin") diff --git a/pkg/auth/token_test.go b/pkg/auth/token_test.go index b1cbecc93..5e7b0e798 100644 --- a/pkg/auth/token_test.go +++ b/pkg/auth/token_test.go @@ -16,6 +16,8 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/lestrrat-go/jwx/v3/jwk" + + log "github.com/stacklok/toolhive/pkg/logger" ) const testKeyID = "test-key-1" @@ -764,6 +766,8 @@ func TestTokenValidator_OpaqueToken(t *testing.T) { func TestNewAuthInfoHandler(t *testing.T) { t.Parallel() + logger := log.NewLogger() + testCases := []struct { name string issuer string @@ -843,7 +847,7 @@ func TestNewAuthInfoHandler(t *testing.T) { t.Parallel() // Create the handler - handler := NewAuthInfoHandler(tc.issuer, tc.jwksURL, tc.resourceURL, tc.scopes) + handler := NewAuthInfoHandler(tc.issuer, tc.jwksURL, tc.resourceURL, tc.scopes, logger) // Create test request req := httptest.NewRequest(tc.method, "/", nil) diff --git a/pkg/auth/utils.go b/pkg/auth/utils.go index 213eef22f..ca3689a14 100644 --- a/pkg/auth/utils.go +++ b/pkg/auth/utils.go @@ -7,8 +7,7 @@ import ( "os/user" "github.com/golang-jwt/jwt/v5" - - "github.com/stacklok/toolhive/pkg/logger" + "go.uber.org/zap" ) // GetClaimsFromContext retrieves the claims from the request context. @@ -26,7 +25,7 @@ func GetClaimsFromContext(ctx context.Context) (jwt.MapClaims, bool) { // GetAuthenticationMiddleware returns the appropriate authentication middleware based on the configuration. // If OIDC config is provided, it returns JWT middleware. Otherwise, it returns local user middleware. -func GetAuthenticationMiddleware(ctx context.Context, oidcConfig *TokenValidatorConfig, +func GetAuthenticationMiddleware(ctx context.Context, oidcConfig *TokenValidatorConfig, logger *zap.SugaredLogger, ) (func(http.Handler) http.Handler, http.Handler, error) { if oidcConfig != nil { logger.Info("OIDC validation enabled") @@ -37,7 +36,7 @@ func GetAuthenticationMiddleware(ctx context.Context, oidcConfig *TokenValidator return nil, nil, err } - authInfoHandler := NewAuthInfoHandler(oidcConfig.Issuer, jwtValidator.jwksURL, oidcConfig.ResourceURL, nil) + authInfoHandler := NewAuthInfoHandler(oidcConfig.Issuer, jwtValidator.jwksURL, oidcConfig.ResourceURL, nil, logger) return jwtValidator.Middleware, authInfoHandler, nil } diff --git a/pkg/auth/utils_test.go b/pkg/auth/utils_test.go index 1e25e81c2..1980b20b3 100644 --- a/pkg/auth/utils_test.go +++ b/pkg/auth/utils_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" ) func TestGetClaimsFromContext(t *testing.T) { @@ -104,12 +104,12 @@ func TestGetClaimsFromContextWithDifferentClaimTypes(t *testing.T) { func TestGetAuthenticationMiddleware(t *testing.T) { t.Parallel() // Initialize logger for testing - logger.Initialize() + logger := log.NewLogger() ctx := context.Background() // Test with nil OIDC config (should return local user middleware) - middleware, _, err := GetAuthenticationMiddleware(ctx, nil) + middleware, _, err := GetAuthenticationMiddleware(ctx, nil, logger) require.NoError(t, err, "Expected no error when OIDC config is nil") require.NotNil(t, middleware, "Expected middleware to be returned") diff --git a/pkg/authz/cedar.go b/pkg/authz/cedar.go index f39948ad1..d96be4a4e 100644 --- a/pkg/authz/cedar.go +++ b/pkg/authz/cedar.go @@ -11,9 +11,9 @@ import ( cedar "github.com/cedar-policy/cedar-go" "github.com/golang-jwt/jwt/v5" + "go.uber.org/zap" "github.com/stacklok/toolhive/pkg/auth" - "github.com/stacklok/toolhive/pkg/logger" ) // Common errors for Cedar authorization @@ -75,6 +75,8 @@ type CedarAuthorizer struct { entityFactory *EntityFactory // Mutex for thread safety mu sync.RWMutex + + logger *zap.SugaredLogger } // CedarAuthorizerConfig contains configuration for the Cedar authorizer. @@ -86,11 +88,12 @@ type CedarAuthorizerConfig struct { } // NewCedarAuthorizer creates a new Cedar authorizer. -func NewCedarAuthorizer(config CedarAuthorizerConfig) (*CedarAuthorizer, error) { +func NewCedarAuthorizer(config CedarAuthorizerConfig, logger *zap.SugaredLogger) (*CedarAuthorizer, error) { authorizer := &CedarAuthorizer{ policySet: cedar.NewPolicySet(), entities: cedar.EntityMap{}, entityFactory: NewEntityFactory(), + logger: logger, } // Load policies @@ -260,15 +263,15 @@ func (a *CedarAuthorizer) IsAuthorized( } // Debug logging for authorization - logger.Debugf("Cedar authorization check - Principal: %s, Action: %s, Resource: %s", + a.logger.Debugf("Cedar authorization check - Principal: %s, Action: %s, Resource: %s", req.Principal, req.Action, req.Resource) - logger.Debugf("Cedar context: %+v", req.Context) + a.logger.Debugf("Cedar context: %+v", req.Context) // Check authorization decision, diagnostic := cedar.Authorize(a.policySet, entityMap, req) // Log the decision - logger.Debugf("Cedar decision: %v, diagnostic: %+v", decision, diagnostic) + a.logger.Debugf("Cedar decision: %v, diagnostic: %+v", decision, diagnostic) // Cedar's Authorize returns a Decision and a Diagnostic // Check if the Diagnostic contains any errors diff --git a/pkg/authz/cedar_test.go b/pkg/authz/cedar_test.go index 567c3a815..e8742e0a6 100644 --- a/pkg/authz/cedar_test.go +++ b/pkg/authz/cedar_test.go @@ -9,15 +9,15 @@ import ( "github.com/stretchr/testify/require" "github.com/stacklok/toolhive/pkg/auth" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" ) // TestNewCedarAuthorizer tests the creation of a new Cedar authorizer with different configurations. func TestNewCedarAuthorizer(t *testing.T) { t.Parallel() - // Initialize logger for tests - logger.Initialize() + logger := log.NewLogger() + // Test cases testCases := []struct { name string @@ -73,7 +73,7 @@ func TestNewCedarAuthorizer(t *testing.T) { authorizer, err := NewCedarAuthorizer(CedarAuthorizerConfig{ Policies: tc.policies, EntitiesJSON: tc.entitiesJSON, - }) + }, logger) // Check error expectations if tc.expectError { @@ -95,6 +95,9 @@ func TestNewCedarAuthorizer(t *testing.T) { // TestAuthorizeWithJWTClaims tests the AuthorizeWithJWTClaims function with different roles in claims. func TestAuthorizeWithJWTClaims(t *testing.T) { t.Parallel() + + logger := log.NewLogger() + // Test cases testCases := []struct { name string @@ -317,7 +320,7 @@ func TestAuthorizeWithJWTClaims(t *testing.T) { authorizer, err := NewCedarAuthorizer(CedarAuthorizerConfig{ Policies: []string{tc.policy}, EntitiesJSON: `[]`, - }) + }, logger) require.NoError(t, err, "Failed to create Cedar authorizer") // Create a context with JWT claims @@ -334,6 +337,9 @@ func TestAuthorizeWithJWTClaims(t *testing.T) { // TestAuthorizeWithJWTClaimsErrors tests error cases for AuthorizeWithJWTClaims. func TestAuthorizeWithJWTClaimsErrors(t *testing.T) { t.Parallel() + + logger := log.NewLogger() + // Create a context ctx := context.Background() @@ -341,7 +347,7 @@ func TestAuthorizeWithJWTClaimsErrors(t *testing.T) { authorizer, err := NewCedarAuthorizer(CedarAuthorizerConfig{ Policies: []string{`permit(principal, action, resource);`}, EntitiesJSON: `[]`, - }) + }, logger) require.NoError(t, err, "Failed to create Cedar authorizer") // Test cases diff --git a/pkg/authz/config.go b/pkg/authz/config.go index c025d2f0f..faccd9af8 100644 --- a/pkg/authz/config.go +++ b/pkg/authz/config.go @@ -9,6 +9,7 @@ import ( "path/filepath" "strings" + "go.uber.org/zap" "sigs.k8s.io/yaml" "github.com/stacklok/toolhive/pkg/transport/types" @@ -119,7 +120,7 @@ func (c *Config) Validate() error { } // CreateMiddleware creates an HTTP middleware from the configuration. -func (c *Config) CreateMiddleware() (types.MiddlewareFunction, error) { +func (c *Config) CreateMiddleware(logger *zap.SugaredLogger) (types.MiddlewareFunction, error) { // Create the appropriate middleware based on the configuration type switch c.Type { case ConfigTypeCedarV1: @@ -127,7 +128,7 @@ func (c *Config) CreateMiddleware() (types.MiddlewareFunction, error) { authorizer, err := NewCedarAuthorizer(CedarAuthorizerConfig{ Policies: c.Cedar.Policies, EntitiesJSON: c.Cedar.EntitiesJSON, - }) + }, logger) if err != nil { return nil, fmt.Errorf("failed to create Cedar authorizer: %w", err) } @@ -140,7 +141,7 @@ func (c *Config) CreateMiddleware() (types.MiddlewareFunction, error) { } // GetMiddlewareFromFile loads the authorization configuration from a file and creates an HTTP middleware. -func GetMiddlewareFromFile(path string) (func(http.Handler) http.Handler, error) { +func GetMiddlewareFromFile(path string, logger *zap.SugaredLogger) (func(http.Handler) http.Handler, error) { // Load the configuration config, err := LoadConfig(path) if err != nil { @@ -148,5 +149,5 @@ func GetMiddlewareFromFile(path string) (func(http.Handler) http.Handler, error) } // Create the middleware - return config.CreateMiddleware() + return config.CreateMiddleware(logger) } diff --git a/pkg/authz/config_test.go b/pkg/authz/config_test.go index e7fe85e7f..a855c280e 100644 --- a/pkg/authz/config_test.go +++ b/pkg/authz/config_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stacklok/toolhive/pkg/auth" + log "github.com/stacklok/toolhive/pkg/logger" mcpparser "github.com/stacklok/toolhive/pkg/mcp" ) @@ -205,6 +206,10 @@ func TestValidateConfig(t *testing.T) { func TestCreateMiddleware(t *testing.T) { t.Parallel() + + // Setup logger + logger := log.NewLogger() + // Create a valid configuration config := &Config{ Version: "1.0", @@ -218,7 +223,7 @@ func TestCreateMiddleware(t *testing.T) { } // Create the middleware - middleware, err := config.CreateMiddleware() + middleware, err := config.CreateMiddleware(logger) require.NoError(t, err, "Failed to create middleware") require.NotNil(t, middleware, "Middleware is nil") diff --git a/pkg/authz/integration_test.go b/pkg/authz/integration_test.go index c6fe0ba76..535570aa4 100644 --- a/pkg/authz/integration_test.go +++ b/pkg/authz/integration_test.go @@ -15,7 +15,7 @@ import ( "golang.org/x/exp/jsonrpc2" "github.com/stacklok/toolhive/pkg/auth" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" mcpparser "github.com/stacklok/toolhive/pkg/mcp" ) @@ -24,8 +24,9 @@ import ( func TestIntegrationListFiltering(t *testing.T) { t.Parallel() - // Initialize logger for tests - logger.Initialize() + // Setup logger + logger := log.NewLogger() + // Create a realistic Cedar authorizer with role-based policies authorizer, err := NewCedarAuthorizer(CedarAuthorizerConfig{ Policies: []string{ @@ -48,7 +49,7 @@ func TestIntegrationListFiltering(t *testing.T) { `permit(principal, action == Action::"read_resource", resource) when { principal.claim_role == "admin" };`, }, EntitiesJSON: `[]`, - }) + }, logger) require.NoError(t, err, "Failed to create Cedar authorizer") testCases := []struct { @@ -323,6 +324,10 @@ func TestIntegrationListFiltering(t *testing.T) { // TestIntegrationNonListOperations verifies that non-list operations still work correctly func TestIntegrationNonListOperations(t *testing.T) { t.Parallel() + + //Setup logger + logger := log.NewLogger() + // Create a Cedar authorizer with specific permissions authorizer, err := NewCedarAuthorizer(CedarAuthorizerConfig{ Policies: []string{ @@ -330,7 +335,7 @@ func TestIntegrationNonListOperations(t *testing.T) { `permit(principal, action == Action::"call_tool", resource) when { principal.claim_role == "admin" };`, }, EntitiesJSON: `[]`, - }) + }, logger) require.NoError(t, err, "Failed to create Cedar authorizer") testCases := []struct { diff --git a/pkg/authz/middleware_test.go b/pkg/authz/middleware_test.go index f0a0df68e..55a8d7dad 100644 --- a/pkg/authz/middleware_test.go +++ b/pkg/authz/middleware_test.go @@ -15,15 +15,15 @@ import ( "golang.org/x/exp/jsonrpc2" "github.com/stacklok/toolhive/pkg/auth" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" mcpparser "github.com/stacklok/toolhive/pkg/mcp" ) func TestMiddleware(t *testing.T) { t.Parallel() - // Initialize logger for tests - logger.Initialize() + // Setup logger + logger := log.NewLogger() // Create a Cedar authorizer authorizer, err := NewCedarAuthorizer(CedarAuthorizerConfig{ @@ -33,7 +33,7 @@ func TestMiddleware(t *testing.T) { `permit(principal, action == Action::"read_resource", resource == Resource::"data");`, }, EntitiesJSON: `[]`, - }) + }, logger) require.NoError(t, err, "Failed to create Cedar authorizer") // Test cases @@ -249,13 +249,17 @@ func TestMiddleware(t *testing.T) { // TestMiddlewareWithGETRequest tests that the middleware doesn't panic with GET requests. func TestMiddlewareWithGETRequest(t *testing.T) { t.Parallel() + + // Setup logger + logger := log.NewLogger() + // Create a Cedar authorizer authorizer, err := NewCedarAuthorizer(CedarAuthorizerConfig{ Policies: []string{ `permit(principal, action == Action::"call_tool", resource == Tool::"weather");`, }, EntitiesJSON: `[]`, - }) + }, logger) require.NoError(t, err, "Failed to create Cedar authorizer") // Create a handler that records if it was called diff --git a/pkg/authz/response_filter_test.go b/pkg/authz/response_filter_test.go index 39c4c6cd6..506419694 100644 --- a/pkg/authz/response_filter_test.go +++ b/pkg/authz/response_filter_test.go @@ -14,14 +14,14 @@ import ( "golang.org/x/exp/jsonrpc2" "github.com/stacklok/toolhive/pkg/auth" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" ) func TestResponseFilteringWriter(t *testing.T) { t.Parallel() - // Initialize logger for tests - logger.Initialize() + // Setup logger + logger := log.NewLogger() // Create a Cedar authorizer with specific tool permissions authorizer, err := NewCedarAuthorizer(CedarAuthorizerConfig{ @@ -31,7 +31,7 @@ func TestResponseFilteringWriter(t *testing.T) { `permit(principal, action == Action::"read_resource", resource == Resource::"data");`, }, EntitiesJSON: `[]`, - }) + }, logger) require.NoError(t, err, "Failed to create Cedar authorizer") testCases := []struct { @@ -211,13 +211,17 @@ func TestResponseFilteringWriter(t *testing.T) { func TestResponseFilteringWriter_NonListOperations(t *testing.T) { t.Parallel() + + // Setup logger + logger := log.NewLogger() + // Create a Cedar authorizer authorizer, err := NewCedarAuthorizer(CedarAuthorizerConfig{ Policies: []string{ `permit(principal, action == Action::"call_tool", resource == Tool::"weather");`, }, EntitiesJSON: `[]`, - }) + }, logger) require.NoError(t, err, "Failed to create Cedar authorizer") // Test that non-list operations pass through unchanged @@ -260,13 +264,17 @@ func TestResponseFilteringWriter_NonListOperations(t *testing.T) { func TestResponseFilteringWriter_ErrorResponse(t *testing.T) { t.Parallel() + + // Setup logger + logger := log.NewLogger() + // Create a Cedar authorizer authorizer, err := NewCedarAuthorizer(CedarAuthorizerConfig{ Policies: []string{ `permit(principal, action == Action::"call_tool", resource == Tool::"weather");`, }, EntitiesJSON: `[]`, - }) + }, logger) require.NoError(t, err, "Failed to create Cedar authorizer") // Create an error response diff --git a/pkg/certs/validation.go b/pkg/certs/validation.go index bc83f5f11..affecdfb6 100644 --- a/pkg/certs/validation.go +++ b/pkg/certs/validation.go @@ -6,11 +6,11 @@ import ( "encoding/pem" "fmt" - "github.com/stacklok/toolhive/pkg/logger" + "go.uber.org/zap" ) // ValidateCACertificate validates that the provided data contains a valid PEM-encoded certificate -func ValidateCACertificate(certData []byte) error { +func ValidateCACertificate(certData []byte, logger *zap.SugaredLogger) error { // Check if the data contains PEM blocks block, _ := pem.Decode(certData) if block == nil { diff --git a/pkg/certs/validation_test.go b/pkg/certs/validation_test.go index 5969c66bf..f7c0d1f93 100644 --- a/pkg/certs/validation_test.go +++ b/pkg/certs/validation_test.go @@ -6,13 +6,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" ) func TestValidateCACertificate(t *testing.T) { t.Parallel() - // Initialize logger for testing - logger.Initialize() + + logger := log.NewLogger() tests := []struct { name string @@ -108,7 +108,7 @@ aW52YWxpZCBjZXJ0aWZpY2F0ZSBkYXRh for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - err := ValidateCACertificate(tt.certData) + err := ValidateCACertificate(tt.certData, logger) if tt.wantErr { require.Error(t, err, "ValidateCACertificate() should return an error") diff --git a/pkg/client/config.go b/pkg/client/config.go index fefe4c6c1..9813dfedb 100644 --- a/pkg/client/config.go +++ b/pkg/client/config.go @@ -11,8 +11,8 @@ import ( "time" "github.com/tailscale/hujson" + "go.uber.org/zap" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/transport/types" ) @@ -339,9 +339,9 @@ type MCPServerConfig struct { } // FindClientConfig returns the client configuration file for a given client type. -func FindClientConfig(clientType MCPClient) (*ConfigFile, error) { +func FindClientConfig(clientType MCPClient, logger *zap.SugaredLogger) (*ConfigFile, error) { // retrieve the metadata of the config files - configFile, err := retrieveConfigFileMetadata(clientType) + configFile, err := retrieveConfigFileMetadata(clientType, logger) if err != nil { if errors.Is(err, ErrConfigFileNotFound) { // Propagate the error if the file is not found @@ -359,8 +359,8 @@ func FindClientConfig(clientType MCPClient) (*ConfigFile, error) { } // FindRegisteredClientConfigs finds all registered client configs and creates them if they don't exist. -func FindRegisteredClientConfigs() ([]ConfigFile, error) { - clientStatuses, err := GetClientStatus() +func FindRegisteredClientConfigs(logger *zap.SugaredLogger) ([]ConfigFile, error) { + clientStatuses, err := GetClientStatus(logger) if err != nil { return nil, fmt.Errorf("failed to get client status: %w", err) } @@ -370,11 +370,11 @@ func FindRegisteredClientConfigs() ([]ConfigFile, error) { if !clientStatus.Installed || !clientStatus.Registered { continue } - cf, err := FindClientConfig(clientStatus.ClientType) + cf, err := FindClientConfig(clientStatus.ClientType, logger) if err != nil { if errors.Is(err, ErrConfigFileNotFound) { logger.Infof("Client config file not found for %s, creating it...", clientStatus.ClientType) - cf, err = CreateClientConfig(clientStatus.ClientType) + cf, err = CreateClientConfig(clientStatus.ClientType, logger) if err != nil { logger.Warnf("Unable to create client config for %s: %v", clientStatus.ClientType, err) continue @@ -392,7 +392,7 @@ func FindRegisteredClientConfigs() ([]ConfigFile, error) { } // CreateClientConfig creates a new client configuration file for a given client type. -func CreateClientConfig(clientType MCPClient) (*ConfigFile, error) { +func CreateClientConfig(clientType MCPClient, logger *zap.SugaredLogger) (*ConfigFile, error) { // Get home directory home, err := os.UserHomeDir() if err != nil { @@ -427,7 +427,7 @@ func CreateClientConfig(clientType MCPClient) (*ConfigFile, error) { return nil, fmt.Errorf("failed to create client config file: %w", err) } - return FindClientConfig(clientType) + return FindClientConfig(clientType, logger) } // Upsert updates/inserts an MCP server in a client configuration file @@ -459,7 +459,7 @@ func Upsert(cf ConfigFile, name string, url string, transportType string) error } // retrieveConfigFileMetadata retrieves the metadata for client configuration files for a given client type. -func retrieveConfigFileMetadata(clientType MCPClient) (*ConfigFile, error) { +func retrieveConfigFileMetadata(clientType MCPClient, logger *zap.SugaredLogger) (*ConfigFile, error) { // Get home directory home, err := os.UserHomeDir() if err != nil { @@ -491,6 +491,7 @@ func retrieveConfigFileMetadata(clientType MCPClient) (*ConfigFile, error) { configUpdater := &JSONConfigUpdater{ Path: path, MCPServersPathPrefix: clientCfg.MCPServersPathPrefix, + logger: logger, } // Return the configuration file metadata diff --git a/pkg/client/config_editor.go b/pkg/client/config_editor.go index 261795d69..c3d8a221e 100644 --- a/pkg/client/config_editor.go +++ b/pkg/client/config_editor.go @@ -11,8 +11,7 @@ import ( "github.com/gofrs/flock" "github.com/tailscale/hujson" "github.com/tidwall/gjson" - - "github.com/stacklok/toolhive/pkg/logger" + "go.uber.org/zap" ) // ConfigUpdater defines the interface for types which can edit MCP client config files. @@ -33,6 +32,7 @@ type MCPServer struct { type JSONConfigUpdater struct { Path string MCPServersPathPrefix string + logger *zap.SugaredLogger } // Upsert inserts or updates an MCP server in the MCP client config file @@ -56,7 +56,7 @@ func (jcu *JSONConfigUpdater) Upsert(serverName string, data MCPServer) error { content, err := os.ReadFile(jcu.Path) if err != nil { - logger.Errorf("Failed to read file: %v", err) + jcu.logger.Errorf("Failed to read file: %v", err) } if len(content) == 0 { @@ -64,32 +64,32 @@ func (jcu *JSONConfigUpdater) Upsert(serverName string, data MCPServer) error { content = []byte("{}") } - content = ensurePathExists(content, jcu.MCPServersPathPrefix) + content = ensurePathExists(content, jcu.MCPServersPathPrefix, jcu.logger) v, _ := hujson.Parse(content) dataJSON, err := json.Marshal(data) if err != nil { - logger.Errorf("Unable to marshal the MCPServer into JSON: %v", err) + jcu.logger.Errorf("Unable to marshal the MCPServer into JSON: %v", err) } patch := fmt.Sprintf(`[{ "op": "add", "path": "%s/%s", "value": %s } ]`, jcu.MCPServersPathPrefix, serverName, dataJSON) err = v.Patch([]byte(patch)) if err != nil { - logger.Errorf("Failed to patch file: %v", err) + jcu.logger.Errorf("Failed to patch file: %v", err) } formatted, _ := hujson.Format(v.Pack()) if err != nil { - logger.Errorf("Failed to format the patched file: %v", err) + jcu.logger.Errorf("Failed to format the patched file: %v", err) } // Write back to the file if err := os.WriteFile(jcu.Path, formatted, 0600); err != nil { - logger.Errorf("Failed to write file: %v", err) + jcu.logger.Errorf("Failed to write file: %v", err) } - logger.Infof("Successfully updated the client config file for MCPServer %s", serverName) + jcu.logger.Infof("Successfully updated the client config file for MCPServer %s", serverName) return nil } @@ -115,7 +115,7 @@ func (jcu *JSONConfigUpdater) Remove(serverName string) error { content, err := os.ReadFile(jcu.Path) if err != nil { - logger.Errorf("Failed to read file: %v", err) + jcu.logger.Errorf("Failed to read file: %v", err) } if len(content) == 0 { @@ -128,17 +128,17 @@ func (jcu *JSONConfigUpdater) Remove(serverName string) error { patch := fmt.Sprintf(`[{ "op": "remove", "path": "%s/%s" } ]`, jcu.MCPServersPathPrefix, serverName) err = v.Patch([]byte(patch)) if err != nil { - logger.Errorf("Failed to patch file: %v", err) + jcu.logger.Errorf("Failed to patch file: %v", err) } formatted, _ := hujson.Format(v.Pack()) // Write back to the file if err := os.WriteFile(jcu.Path, formatted, 0600); err != nil { - logger.Errorf("Failed to write file: %v", err) + jcu.logger.Errorf("Failed to write file: %v", err) } - logger.Infof("Successfully removed the MCPServer %s from the client config file", serverName) + jcu.logger.Infof("Successfully removed the MCPServer %s from the client config file", serverName) return nil } @@ -155,7 +155,7 @@ func (jcu *JSONConfigUpdater) Remove(serverName string) error { // // This is necessary because the MCP client config file is a JSON object, // and we need to ensure that the path exists before we can add a new key to it. -func ensurePathExists(content []byte, path string) []byte { +func ensurePathExists(content []byte, path string, logger *zap.SugaredLogger) []byte { // Special case: if path is root ("/"), just return everything (formatted) if path == "/" { v, _ := hujson.Parse(content) diff --git a/pkg/client/config_editor_test.go b/pkg/client/config_editor_test.go index 037dfa4f3..1a0a03d60 100644 --- a/pkg/client/config_editor_test.go +++ b/pkg/client/config_editor_test.go @@ -3,7 +3,6 @@ package client import ( "encoding/json" "fmt" - "log" "os" "path/filepath" "testing" @@ -12,14 +11,12 @@ import ( "github.com/tidwall/gjson" "gotest.tools/assert" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" ) func TestUpsertMCPServerConfig(t *testing.T) { t.Parallel() - logger.Initialize() - tests := []struct { mcpServerPatchPath string // the path used by the patch operation mcpServerKeyPath string // the path used to retrieve the value from the config file (for testing purposes) @@ -40,6 +37,7 @@ func TestUpsertMCPServerConfig(t *testing.T) { jsu := JSONConfigUpdater{ Path: configPath, MCPServersPathPrefix: tt.mcpServerPatchPath, + logger: log.NewLogger(), } mcpServer := MCPServer{ @@ -76,6 +74,7 @@ func TestUpsertMCPServerConfig(t *testing.T) { jsu := JSONConfigUpdater{ Path: configPath, MCPServersPathPrefix: tt.mcpServerPatchPath, + logger: log.NewLogger(), } // add an MCP server so we can update it @@ -117,8 +116,6 @@ func TestUpsertMCPServerConfig(t *testing.T) { func TestRemoveMCPServerConfigNew(t *testing.T) { t.Parallel() - logger.Initialize() - tests := []struct { mcpServerPatchPath string // the path used by the patch operation mcpServerKeyPath string // the path used to retrieve the value from the config file (for testing purposes) @@ -138,6 +135,7 @@ func TestRemoveMCPServerConfigNew(t *testing.T) { jsu := JSONConfigUpdater{ Path: configPath, MCPServersPathPrefix: tt.mcpServerPatchPath, + logger: log.NewLogger(), } // add an MCP server so we can remove it @@ -160,7 +158,7 @@ func TestRemoveMCPServerConfigNew(t *testing.T) { // read the config file and check that the mcp servers are removed content, err := os.ReadFile(configPath) if err != nil { - log.Fatalf("Failed to read file: %v", err) + t.Fatalf("Failed to read file: %v", err) } testMcpServerJson := gjson.GetBytes(content, tt.mcpServerKeyPath+"."+tt.mcpServerName).Raw @@ -228,7 +226,7 @@ func getMCPServerFromFile(t *testing.T, configPath string, key string) MCPServer func TestEnsurePathExists(t *testing.T) { t.Parallel() - logger.Initialize() + logger := log.NewLogger() tests := []struct { name string @@ -278,7 +276,7 @@ func TestEnsurePathExists(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - result := ensurePathExists(tt.content, tt.path) + result := ensurePathExists(tt.content, tt.path, logger) assert.DeepEqual(t, tt.expectedResult, result) }) diff --git a/pkg/client/config_test.go b/pkg/client/config_test.go index 9d88a4588..4520dd719 100644 --- a/pkg/client/config_test.go +++ b/pkg/client/config_test.go @@ -12,9 +12,10 @@ import ( "github.com/adrg/xdg" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/zap" "github.com/stacklok/toolhive/pkg/config" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/transport/types" ) @@ -74,7 +75,7 @@ func createMockClientConfigs() []mcpClientConfig { // MockConfig creates a temporary config file with the provided configuration. // It returns a cleanup function that should be deferred. -func MockConfig(t *testing.T, cfg *config.Config) func() { +func MockConfig(t *testing.T, cfg *config.Config, logger *zap.SugaredLogger) func() { t.Helper() // Create a temporary directory for the test @@ -93,7 +94,7 @@ func MockConfig(t *testing.T, cfg *config.Config) func() { // Write the config file if one is provided if cfg != nil { - err = config.UpdateConfig(func(c *config.Config) { *c = *cfg }) + err = config.UpdateConfig(func(c *config.Config) { *c = *cfg }, logger) require.NoError(t, err) } @@ -104,8 +105,6 @@ func MockConfig(t *testing.T, cfg *config.Config) func() { } func TestFindClientConfigs(t *testing.T) { //nolint:paralleltest // Uses environment variables - logger.Initialize() - // Setup a temporary home directory for testing originalHome := os.Getenv("HOME") tempHome := t.TempDir() @@ -137,8 +136,8 @@ func TestFindClientConfigs(t *testing.T) { //nolint:paralleltest // Uses environ r, w, _ := os.Pipe() os.Stderr = w - // Re-initialize logger to use the captured stderr - logger.Initialize() + // Setup logger + logger := log.NewLogger() // Create an invalid JSON file invalidPath := filepath.Join(tempHome, ".cursor", "invalid.json") @@ -179,12 +178,12 @@ func TestFindClientConfigs(t *testing.T) { //nolint:paralleltest // Uses environ }, } - cleanup := MockConfig(t, testConfig) + cleanup := MockConfig(t, testConfig, logger) defer cleanup() // Find client configs - this should NOT fail due to the invalid JSON // Instead, it should log a warning and continue - configs, err := FindRegisteredClientConfigs() + configs, err := FindRegisteredClientConfigs(logger) assert.NoError(t, err, "FindRegisteredClientConfigs should not return an error for invalid config files") // The invalid client should be skipped, so we should get configs for valid clients only @@ -207,7 +206,8 @@ func TestFindClientConfigs(t *testing.T) { //nolint:paralleltest // Uses environ } func TestSuccessfulClientConfigOperations(t *testing.T) { - logger.Initialize() + // Setup logger + logger := log.NewLogger() // Setup a temporary home directory for testing originalHome := os.Getenv("HOME") @@ -251,11 +251,11 @@ func TestSuccessfulClientConfigOperations(t *testing.T) { }, } - cleanup := MockConfig(t, testConfig) + cleanup := MockConfig(t, testConfig, logger) defer cleanup() t.Run("FindAllConfiguredClients", func(t *testing.T) { //nolint:paralleltest // Uses environment variables - configs, err := FindRegisteredClientConfigs() + configs, err := FindRegisteredClientConfigs(logger) require.NoError(t, err) assert.Len(t, configs, len(supportedClientIntegrations), "Should find all mock client configs") @@ -272,7 +272,7 @@ func TestSuccessfulClientConfigOperations(t *testing.T) { }) t.Run("VerifyConfigFileContents", func(t *testing.T) { //nolint:paralleltest // Uses environment variables - configs, err := FindRegisteredClientConfigs() + configs, err := FindRegisteredClientConfigs(logger) require.NoError(t, err) require.NotEmpty(t, configs) @@ -326,7 +326,7 @@ func TestSuccessfulClientConfigOperations(t *testing.T) { }) t.Run("AddAndVerifyMCPServer", func(t *testing.T) { //nolint:paralleltest // Uses environment variables - configs, err := FindRegisteredClientConfigs() + configs, err := FindRegisteredClientConfigs(logger) require.NoError(t, err) require.NotEmpty(t, configs) diff --git a/pkg/client/discovery.go b/pkg/client/discovery.go index be90fb293..f1541ac39 100644 --- a/pkg/client/discovery.go +++ b/pkg/client/discovery.go @@ -7,6 +7,8 @@ import ( "runtime" "sort" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/config" ) @@ -23,7 +25,7 @@ type MCPClientStatus struct { } // GetClientStatus returns the installation status of all supported MCP clients -func GetClientStatus() ([]MCPClientStatus, error) { +func GetClientStatus(logger *zap.SugaredLogger) ([]MCPClientStatus, error) { var statuses []MCPClientStatus // Get home directory @@ -33,7 +35,7 @@ func GetClientStatus() ([]MCPClientStatus, error) { } // Get app configuration to check for registered clients - appConfig := config.GetConfig() + appConfig := config.GetConfig(logger) registeredClients := make(map[string]bool) // Create a map of registered clients for quick lookup diff --git a/pkg/client/discovery_test.go b/pkg/client/discovery_test.go index e9e01a28c..988b0444c 100644 --- a/pkg/client/discovery_test.go +++ b/pkg/client/discovery_test.go @@ -9,9 +9,12 @@ import ( "github.com/stretchr/testify/require" "github.com/stacklok/toolhive/pkg/config" + log "github.com/stacklok/toolhive/pkg/logger" ) func TestGetClientStatus(t *testing.T) { + logger := log.NewLogger() + // Setup a temporary home directory for testing tempHome, err := os.MkdirTemp("", "toolhive-test-home") require.NoError(t, err) @@ -25,7 +28,7 @@ func TestGetClientStatus(t *testing.T) { RegisteredClients: []string{string(ClaudeCode)}, }, } - cleanup := MockConfig(t, mockConfig) + cleanup := MockConfig(t, mockConfig, logger) defer cleanup() // Create a mock Cursor config file @@ -36,7 +39,7 @@ func TestGetClientStatus(t *testing.T) { _, err = os.Create(filepath.Join(tempHome, ".claude.json")) require.NoError(t, err) - statuses, err := GetClientStatus() + statuses, err := GetClientStatus(logger) require.NoError(t, err) require.NotNil(t, statuses) @@ -63,6 +66,8 @@ func TestGetClientStatus(t *testing.T) { } func TestGetClientStatus_Sorting(t *testing.T) { + logger := log.NewLogger() + // Setup a temporary home directory for testing origHome := os.Getenv("HOME") tempHome, err := os.MkdirTemp("", "toolhive-test-home") @@ -78,10 +83,10 @@ func TestGetClientStatus_Sorting(t *testing.T) { RegisteredClients: []string{}, }, } - cleanup := MockConfig(t, mockConfig) + cleanup := MockConfig(t, mockConfig, logger) defer cleanup() - statuses, err := GetClientStatus() + statuses, err := GetClientStatus(logger) require.NoError(t, err) require.NotNil(t, statuses) require.Greater(t, len(statuses), 1, "Need at least 2 clients to test sorting") diff --git a/pkg/client/manager.go b/pkg/client/manager.go index b975c158f..6e4d706a6 100644 --- a/pkg/client/manager.go +++ b/pkg/client/manager.go @@ -5,12 +5,13 @@ import ( "errors" "fmt" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/config" ct "github.com/stacklok/toolhive/pkg/container" rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/core" "github.com/stacklok/toolhive/pkg/groups" - "github.com/stacklok/toolhive/pkg/logger" ) const ( @@ -39,16 +40,17 @@ type Manager interface { type defaultManager struct { runtime rt.Runtime groupManager groups.Manager + logger *zap.SugaredLogger } // NewManager creates a new client manager instance. -func NewManager(ctx context.Context) (Manager, error) { - runtime, err := ct.NewFactory().Create(ctx) +func NewManager(ctx context.Context, logger *zap.SugaredLogger) (Manager, error) { + runtime, err := ct.NewFactory(logger).Create(ctx) if err != nil { return nil, err } - groupManager, err := groups.NewManager() + groupManager, err := groups.NewManager(logger) if err != nil { return nil, err } @@ -56,12 +58,13 @@ func NewManager(ctx context.Context) (Manager, error) { return &defaultManager{ runtime: runtime, groupManager: groupManager, + logger: logger, }, nil } -func (*defaultManager) ListClients() ([]Client, error) { +func (m *defaultManager) ListClients() ([]Client, error) { clients := []Client{} - appConfig := config.GetConfig() + appConfig := config.GetConfig(m.logger) for _, clientName := range appConfig.Clients.RegisteredClients { clients = append(clients, Client{Name: MCPClient(clientName)}) @@ -101,14 +104,14 @@ func (m *defaultManager) AddServerToClients( targetClients := m.getTargetClients(ctx, serverName, group) if len(targetClients) == 0 { - logger.Infof("No target clients found for server %s", serverName) + m.logger.Infof("No target clients found for server %s", serverName) return nil } // Add the server to each target client for _, clientName := range targetClients { if err := m.updateClientWithServer(clientName, serverName, serverURL, transportType); err != nil { - logger.Warnf("Warning: Failed to update client %s: %v", clientName, err) + m.logger.Warnf("Warning: Failed to update client %s: %v", clientName, err) } } return nil @@ -121,14 +124,14 @@ func (m *defaultManager) RemoveServerFromClients(ctx context.Context, serverName targetClients := m.getTargetClients(ctx, serverName, group) if len(targetClients) == 0 { - logger.Infof("No target clients found for server %s", serverName) + m.logger.Infof("No target clients found for server %s", serverName) return nil } // Remove the server from each target client for _, clientName := range targetClients { if err := m.removeServerFromClient(MCPClient(clientName), serverName); err != nil { - logger.Warnf("Warning: Failed to remove server from client %s: %v", clientName, err) + m.logger.Warnf("Warning: Failed to remove server from client %s: %v", clientName, err) } } @@ -156,7 +159,7 @@ func (m *defaultManager) addWorkloadsToClient(clientType MCPClient, workloads [] return fmt.Errorf("failed to add workload %s to client %s: %v", workload.Name, clientType, err) } - logger.Infof("Added MCP server %s to client %s\n", workload.Name, clientType) + m.logger.Infof("Added MCP server %s to client %s\n", workload.Name, clientType) } return nil @@ -185,8 +188,8 @@ func (m *defaultManager) removeWorkloadsFromClient(clientType MCPClient, workloa } // removeServerFromClient removes an MCP server from a single client configuration -func (*defaultManager) removeServerFromClient(clientName MCPClient, serverName string) error { - clientConfig, err := FindClientConfig(clientName) +func (m *defaultManager) removeServerFromClient(clientName MCPClient, serverName string) error { + clientConfig, err := FindClientConfig(clientName, m.logger) if err != nil { return fmt.Errorf("failed to find client configurations: %w", err) } @@ -196,17 +199,17 @@ func (*defaultManager) removeServerFromClient(clientName MCPClient, serverName s return fmt.Errorf("failed to remove MCP server configuration from %s: %v", clientConfig.Path, err) } - logger.Infof("Removed MCP server %s from client %s\n", serverName, clientName) + m.logger.Infof("Removed MCP server %s from client %s\n", serverName, clientName) return nil } // updateClientWithServer updates a single client with an MCP server configuration, creating config if needed -func (*defaultManager) updateClientWithServer(clientName, serverName, serverURL, transportType string) error { - clientConfig, err := FindClientConfig(MCPClient(clientName)) +func (m *defaultManager) updateClientWithServer(clientName, serverName, serverURL, transportType string) error { + clientConfig, err := FindClientConfig(MCPClient(clientName), m.logger) if err != nil { if errors.Is(err, ErrConfigFileNotFound) { // Create a new client configuration if it doesn't exist - clientConfig, err = CreateClientConfig(MCPClient(clientName)) + clientConfig, err = CreateClientConfig(MCPClient(clientName), m.logger) if err != nil { return fmt.Errorf("failed to create client configuration for %s: %w", clientName, err) } @@ -215,13 +218,13 @@ func (*defaultManager) updateClientWithServer(clientName, serverName, serverURL, } } - logger.Infof("Updating client configuration: %s", clientConfig.Path) + m.logger.Infof("Updating client configuration: %s", clientConfig.Path) if err := Upsert(*clientConfig, serverName, serverURL, transportType); err != nil { return fmt.Errorf("failed to update MCP server configuration in %s: %v", clientConfig.Path, err) } - logger.Infof("Successfully updated client configuration: %s", clientConfig.Path) + m.logger.Infof("Successfully updated client configuration: %s", clientConfig.Path) return nil } @@ -231,14 +234,14 @@ func (m *defaultManager) getTargetClients(ctx context.Context, serverName, group if groupName != "" { group, err := m.groupManager.Get(ctx, groupName) if err != nil { - logger.Warnf( + m.logger.Warnf( "Warning: Failed to get group %s for server %s, skipping client config updates: %v", group, serverName, err, ) return nil } - logger.Infof( + m.logger.Infof( "Server %s belongs to group %s, updating %d registered client(s)", serverName, group, len(group.RegisteredClients), ) @@ -246,9 +249,9 @@ func (m *defaultManager) getTargetClients(ctx context.Context, serverName, group } // Server has no group - use backward compatible behavior (update all registered clients) - appConfig := config.GetConfig() + appConfig := config.GetConfig(m.logger) targetClients := appConfig.Clients.RegisteredClients - logger.Infof( + m.logger.Infof( "Server %s has no group, updating %d globally registered client(s) for backward compatibility", serverName, len(targetClients), ) diff --git a/pkg/client/migration.go b/pkg/client/migration.go index 306ef6c70..50a9fbaa0 100644 --- a/pkg/client/migration.go +++ b/pkg/client/migration.go @@ -4,8 +4,9 @@ import ( "fmt" "sync" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/config" - "github.com/stacklok/toolhive/pkg/logger" ) // migrationOnce ensures the migration only runs once @@ -13,31 +14,30 @@ var migrationOnce sync.Once // CheckAndPerformAutoDiscoveryMigration checks if auto-discovery migration is needed and performs it // This is called once at application startup -func CheckAndPerformAutoDiscoveryMigration() { +func CheckAndPerformAutoDiscoveryMigration(logger *zap.SugaredLogger) { migrationOnce.Do(func() { - appConfig := config.GetConfig() - + appConfig := config.GetConfig(logger) // Check if auto-discovery flag is set to true, use of deprecated object is expected here if appConfig.Clients.AutoDiscovery { - performAutoDiscoveryMigration() + performAutoDiscoveryMigration(logger) } }) } // performAutoDiscoveryMigration discovers and registers all installed clients -func performAutoDiscoveryMigration() { +func performAutoDiscoveryMigration(logger *zap.SugaredLogger) { fmt.Println("Migrating from deprecated auto-discovery to manual client registration...") fmt.Println() // Get current client statuses to determine what to register - clientStatuses, err := GetClientStatus() + clientStatuses, err := GetClientStatus(logger) if err != nil { logger.Errorf("Error discovering clients during migration: %v", err) return } // Get current config to see what's already registered - appConfig := config.GetConfig() + appConfig := config.GetConfig(logger) var clientsToRegister []string var alreadyRegistered = appConfig.Clients.RegisteredClients @@ -69,7 +69,7 @@ func performAutoDiscoveryMigration() { // Remove the auto-discovery flag during the same config update c.Clients.AutoDiscovery = false - }) + }, logger) if err != nil { logger.Errorf("Error updating config during migration: %v", err) diff --git a/pkg/config/config.go b/pkg/config/config.go index eb36b92b3..15ac788b2 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -12,9 +12,9 @@ import ( "github.com/adrg/xdg" "github.com/gofrs/flock" + "go.uber.org/zap" "gopkg.in/yaml.v3" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/secrets" ) @@ -132,14 +132,14 @@ func applyBackwardCompatibility(config *Config) error { // LoadOrCreateConfig fetches the application configuration. // If it does not already exist - it will create a new config file with default values. -func LoadOrCreateConfig() (*Config, error) { - return LoadOrCreateConfigWithPath("") +func LoadOrCreateConfig(logger *zap.SugaredLogger) (*Config, error) { + return LoadOrCreateConfigWithPath("", logger) } // LoadOrCreateConfigWithPath fetches the application configuration from a specific path. // If configPath is empty, it uses the default path. // If it does not already exist - it will create a new config file with default values. -func LoadOrCreateConfigWithPath(configPath string) (*Config, error) { +func LoadOrCreateConfigWithPath(configPath string, logger *zap.SugaredLogger) (*Config, error) { var config Config var err error @@ -225,14 +225,14 @@ func (c *Config) saveToPath(configPath string) error { // UpdateConfig locks a separate lock file, reads from disk, applies the changes // from the anonymous function, writes to disk and unlocks the file. -func UpdateConfig(updateFn func(*Config)) error { - return UpdateConfigAtPath("", updateFn) +func UpdateConfig(updateFn func(*Config), logger *zap.SugaredLogger) error { + return UpdateConfigAtPath("", updateFn, logger) } // UpdateConfigAtPath locks a separate lock file, reads from disk, applies the changes // from the anonymous function, writes to disk and unlocks the file. // If configPath is empty, it uses the default path. -func UpdateConfigAtPath(configPath string, updateFn func(*Config)) error { +func UpdateConfigAtPath(configPath string, updateFn func(*Config), logger *zap.SugaredLogger) error { if configPath == "" { var err error configPath, err = getConfigPath() @@ -258,7 +258,7 @@ func UpdateConfigAtPath(configPath string, updateFn func(*Config)) error { defer fileLock.Unlock() // Load the config after acquiring the lock to avoid race conditions - c, err := LoadOrCreateConfigWithPath(configPath) + c, err := LoadOrCreateConfigWithPath(configPath, logger) if err != nil { return fmt.Errorf("failed to load config from disk: %w", err) } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 53e45bf37..2888c1952 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/require" "gopkg.in/yaml.v3" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/secrets" ) @@ -41,7 +41,8 @@ func SetupTestConfig(t *testing.T, configContent *Config) (string, string) { func TestLoadOrCreateConfig(t *testing.T) { t.Parallel() - logger.Initialize() + + logger := log.NewLogger() t.Run("TestLoadOrCreateConfigWithMockConfig", func(t *testing.T) { t.Parallel() @@ -55,7 +56,7 @@ func TestLoadOrCreateConfig(t *testing.T) { }) // Load the config - config, err := LoadOrCreateConfigWithPath(configPath) + config, err := LoadOrCreateConfigWithPath(configPath, logger) require.NoError(t, err) // Verify the loaded config matches our mock @@ -75,7 +76,7 @@ func TestLoadOrCreateConfig(t *testing.T) { tempDir, configPath := SetupTestConfig(t, nil) // Load the config - this should create a new one since none exists - config, err := LoadOrCreateConfigWithPath(configPath) + config, err := LoadOrCreateConfigWithPath(configPath, logger) require.NoError(t, err) // Verify the default values @@ -93,7 +94,6 @@ func TestLoadOrCreateConfig(t *testing.T) { func TestSave(t *testing.T) { t.Parallel() - logger.Initialize() t.Run("TestSave", func(t *testing.T) { t.Parallel() @@ -143,7 +143,8 @@ func TestSave(t *testing.T) { func TestRegistryURLConfig(t *testing.T) { t.Parallel() - logger.Initialize() + + logger := log.NewLogger() t.Run("TestSetAndGetRegistryURL", func(t *testing.T) { t.Parallel() @@ -161,22 +162,22 @@ func TestRegistryURLConfig(t *testing.T) { testURL := "https://example.com/registry.json" err := UpdateConfigAtPath(configPath, func(c *Config) { c.RegistryUrl = testURL - }) + }, logger) require.NoError(t, err) // Load the config and verify the URL was set - config, err := LoadOrCreateConfigWithPath(configPath) + config, err := LoadOrCreateConfigWithPath(configPath, logger) require.NoError(t, err) assert.Equal(t, testURL, config.RegistryUrl) // Test unsetting the registry URL err = UpdateConfigAtPath(configPath, func(c *Config) { c.RegistryUrl = "" - }) + }, logger) require.NoError(t, err) // Load the config and verify the URL was unset - config, err = LoadOrCreateConfigWithPath(configPath) + config, err = LoadOrCreateConfigWithPath(configPath, logger) require.NoError(t, err) assert.Equal(t, "", config.RegistryUrl) @@ -196,11 +197,11 @@ func TestRegistryURLConfig(t *testing.T) { // Set the registry URL err := UpdateConfigAtPath(configPath, func(c *Config) { c.RegistryUrl = testURL - }) + }, logger) require.NoError(t, err) // Load config again to verify persistence - config, err := LoadOrCreateConfigWithPath(configPath) + config, err := LoadOrCreateConfigWithPath(configPath, logger) require.NoError(t, err) assert.Equal(t, testURL, config.RegistryUrl) @@ -227,22 +228,22 @@ func TestRegistryURLConfig(t *testing.T) { // Test enabling err := UpdateConfigAtPath(configPath, func(c *Config) { c.AllowPrivateRegistryIp = true - }) + }, logger) require.NoError(t, err) // Load the config and verify the setting was toggled to true - config, err := LoadOrCreateConfigWithPath(configPath) + config, err := LoadOrCreateConfigWithPath(configPath, logger) require.NoError(t, err) assert.Equal(t, true, config.AllowPrivateRegistryIp) // Test toggling setting to false err = UpdateConfigAtPath(configPath, func(c *Config) { c.AllowPrivateRegistryIp = false - }) + }, logger) require.NoError(t, err) // Load the config and verify the setting was toggled to false - config, err = LoadOrCreateConfigWithPath(configPath) + config, err = LoadOrCreateConfigWithPath(configPath, logger) require.NoError(t, err) assert.Equal(t, false, config.AllowPrivateRegistryIp) @@ -256,7 +257,6 @@ func TestRegistryURLConfig(t *testing.T) { func TestSecrets_GetProviderType_EnvironmentVariable(t *testing.T) { t.Parallel() - logger.Initialize() // Save original env value and restore at the end originalEnv := os.Getenv(secrets.ProviderEnvVar) diff --git a/pkg/config/registry.go b/pkg/config/registry.go index 3adbb21f7..badbb106f 100644 --- a/pkg/config/registry.go +++ b/pkg/config/registry.go @@ -7,6 +7,8 @@ import ( "path/filepath" "strings" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/networking" ) @@ -34,7 +36,7 @@ func DetectRegistryType(input string) (registryType string, cleanPath string) { } // SetRegistryURL validates and sets a registry URL -func SetRegistryURL(registryURL string, allowPrivateRegistryIp bool) error { +func SetRegistryURL(registryURL string, allowPrivateRegistryIp bool, logger *zap.SugaredLogger) error { // Basic URL validation - check if it starts with http:// or https:// if registryURL != "" && !strings.HasPrefix(registryURL, "http://") && !strings.HasPrefix(registryURL, "https://") { return fmt.Errorf("registry URL must start with http:// or https://") @@ -56,7 +58,7 @@ func SetRegistryURL(registryURL string, allowPrivateRegistryIp bool) error { c.RegistryUrl = registryURL c.LocalRegistryPath = "" // Clear local path when setting URL c.AllowPrivateRegistryIp = allowPrivateRegistryIp - }) + }, logger) if err != nil { return fmt.Errorf("failed to update configuration: %w", err) } @@ -65,7 +67,7 @@ func SetRegistryURL(registryURL string, allowPrivateRegistryIp bool) error { } // SetRegistryFile validates and sets a local registry file -func SetRegistryFile(registryPath string) error { +func SetRegistryFile(registryPath string, logger *zap.SugaredLogger) error { // Validate that the file exists and is readable if _, err := os.Stat(registryPath); err != nil { return fmt.Errorf("local registry file not found or not accessible: %w", err) @@ -93,7 +95,7 @@ func SetRegistryFile(registryPath string) error { err = UpdateConfig(func(c *Config) { c.LocalRegistryPath = registryPath c.RegistryUrl = "" // Clear URL when setting local path - }) + }, logger) if err != nil { return fmt.Errorf("failed to update configuration: %w", err) } @@ -102,12 +104,12 @@ func SetRegistryFile(registryPath string) error { } // UnsetRegistry resets registry configuration to defaults -func UnsetRegistry() error { +func UnsetRegistry(logger *zap.SugaredLogger) error { err := UpdateConfig(func(c *Config) { c.RegistryUrl = "" c.LocalRegistryPath = "" c.AllowPrivateRegistryIp = false - }) + }, logger) if err != nil { return fmt.Errorf("failed to update configuration: %w", err) } @@ -115,8 +117,8 @@ func UnsetRegistry() error { } // GetRegistryConfig returns current registry configuration -func GetRegistryConfig() (url, localPath string, allowPrivateIP bool, registryType string) { - cfg := GetConfig() +func GetRegistryConfig(logger *zap.SugaredLogger) (url, localPath string, allowPrivateIP bool, registryType string) { + cfg := GetConfig(logger) if cfg.RegistryUrl != "" { return cfg.RegistryUrl, "", cfg.AllowPrivateRegistryIp, RegistryTypeURL diff --git a/pkg/config/singleton.go b/pkg/config/singleton.go index ef983e969..434b446f2 100644 --- a/pkg/config/singleton.go +++ b/pkg/config/singleton.go @@ -4,7 +4,7 @@ import ( "os" "sync" - "github.com/stacklok/toolhive/pkg/logger" + "go.uber.org/zap" ) // Singleton value - should only be written to by the GetConfig function. @@ -13,12 +13,12 @@ var appConfig *Config var lock = &sync.Mutex{} // GetConfig is a Singleton that returns the application configuration. -func GetConfig() *Config { +func GetConfig(logger *zap.SugaredLogger) *Config { if appConfig == nil { lock.Lock() defer lock.Unlock() if appConfig == nil { - appConfig, err := LoadOrCreateConfig() + appConfig, err := LoadOrCreateConfig(logger) if err != nil { logger.Errorf("error loading configuration: %v", err) os.Exit(1) diff --git a/pkg/container/docker/client.go b/pkg/container/docker/client.go index 15cb1326c..28e7f2f9e 100644 --- a/pkg/container/docker/client.go +++ b/pkg/container/docker/client.go @@ -23,13 +23,13 @@ import ( "github.com/docker/docker/client" "github.com/docker/docker/pkg/stdcopy" "github.com/docker/go-connections/nat" + "go.uber.org/zap" "github.com/stacklok/toolhive/pkg/container/docker/sdk" "github.com/stacklok/toolhive/pkg/container/images" "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/ignore" lb "github.com/stacklok/toolhive/pkg/labels" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/networking" "github.com/stacklok/toolhive/pkg/permissions" ) @@ -49,22 +49,24 @@ type Client struct { socketPath string client *client.Client imageManager images.ImageManager + logger *zap.SugaredLogger } // NewClient creates a new container client -func NewClient(ctx context.Context) (*Client, error) { - dockerClient, socketPath, runtimeType, err := sdk.NewDockerClient(ctx) +func NewClient(ctx context.Context, logger *zap.SugaredLogger) (*Client, error) { + dockerClient, socketPath, runtimeType, err := sdk.NewDockerClient(ctx, logger) if err != nil { return nil, err // there is already enough context in the error. } - imageManager := images.NewRegistryImageManager(dockerClient) + imageManager := images.NewRegistryImageManager(dockerClient, logger) c := &Client{ runtimeType: runtimeType, socketPath: socketPath, client: dockerClient, imageManager: imageManager, + logger: logger, } return c, nil @@ -156,7 +158,7 @@ func (c *Client) DeployWorkload( } // only remap if is not an auxiliary tool - newPortBindings, hostPort, err := generatePortBindings(labels, options.PortBindings) + newPortBindings, hostPort, err := generatePortBindings(labels, options.PortBindings, c.logger) if err != nil { return 0, fmt.Errorf("failed to generate port bindings: %v", err) } @@ -317,7 +319,7 @@ func (c *Client) RemoveWorkload(ctx context.Context, workloadName string) error // get container name from ID containerResponse, err := c.inspectContainerByName(ctx, workloadName) if err != nil { - logger.Warnf("Failed to inspect container %s: %v", workloadName, err) + c.logger.Warnf("Failed to inspect container %s: %v", workloadName, err) } // remove the / if it starts with it @@ -346,7 +348,7 @@ func (c *Client) RemoveWorkload(ctx context.Context, workloadName string) error // This also deletes the external network if no other workloads are using it. err = c.deleteNetworks(ctx, containerName) if err != nil { - logger.Warnf("Failed to delete networks for container %s: %v", containerName, err) + c.logger.Warnf("Failed to delete networks for container %s: %v", containerName, err) } return nil } @@ -374,7 +376,7 @@ func (c *Client) GetWorkloadLogs(ctx context.Context, workloadName string, follo if follow { _, err = stdcopy.StdCopy(os.Stdout, os.Stderr, logs) if err != nil && err != io.EOF { - logger.Errorf("Error reading workload logs: %v", err) + c.logger.Errorf("Error reading workload logs: %v", err) return "", NewContainerError(err, workloadName, fmt.Sprintf("failed to follow workload logs: %v", err)) } } @@ -423,7 +425,7 @@ func (c *Client) GetWorkloadInfo(ctx context.Context, workloadName string) (runt hostPort := 0 if _, err := fmt.Sscanf(binding.HostPort, "%d", &hostPort); err != nil { // If we can't parse the port, just use 0 - logger.Warnf("Warning: Failed to parse host port %s: %v", binding.HostPort, err) + c.logger.Warnf("Warning: Failed to parse host port %s: %v", binding.HostPort, err) } ports = append(ports, runtime.PortMapping{ @@ -482,7 +484,7 @@ func (c *Client) AttachToWorkload(ctx context.Context, workloadName string) (io. // Use stdcopy to demultiplex the container streams _, err := stdcopy.StdCopy(stdoutWriter, io.Discard, resp.Reader) if err != nil && err != io.EOF { - logger.Errorf("Error demultiplexing container streams: %v", err) + c.logger.Errorf("Error demultiplexing container streams: %v", err) } }() @@ -504,7 +506,7 @@ func (c *Client) IsRunning(ctx context.Context) error { // getPermissionConfigFromProfile converts a permission profile to a container permission config // with transport-specific settings (internal function) // addReadOnlyMounts adds read-only mounts to the permission config -func (*Client) addReadOnlyMounts( +func (c *Client) addReadOnlyMounts( config *runtime.PermissionConfig, mounts []permissions.MountDeclaration, ignoreConfig *ignore.Config, @@ -513,18 +515,18 @@ func (*Client) addReadOnlyMounts( source, target, err := mountDecl.Parse() if err != nil { // Skip invalid mounts - logger.Warnf("Warning: Skipping invalid mount declaration: %s (%v)", mountDecl, err) + c.logger.Warnf("Warning: Skipping invalid mount declaration: %s (%v)", mountDecl, err) continue } // Skip resource URIs for now (they need special handling) if strings.Contains(source, "://") { - logger.Warnf("Warning: Resource URI mounts not yet supported: %s", source) + c.logger.Warnf("Warning: Resource URI mounts not yet supported: %s", source) continue } // Convert relative paths to absolute paths - absPath, ok := convertRelativePathToAbsolute(source, mountDecl) + absPath, ok := convertRelativePathToAbsolute(source, mountDecl, c.logger) if !ok { continue } @@ -537,12 +539,12 @@ func (*Client) addReadOnlyMounts( }) // Process ignore patterns and add tmpfs overlays - addIgnoreOverlays(config, absPath, target, ignoreConfig) + addIgnoreOverlays(config, absPath, target, ignoreConfig, c.logger) } } // addReadWriteMounts adds read-write mounts to the permission config -func (*Client) addReadWriteMounts( +func (c *Client) addReadWriteMounts( config *runtime.PermissionConfig, mounts []permissions.MountDeclaration, ignoreConfig *ignore.Config, @@ -551,18 +553,18 @@ func (*Client) addReadWriteMounts( source, target, err := mountDecl.Parse() if err != nil { // Skip invalid mounts - logger.Warnf("Warning: Skipping invalid mount declaration: %s (%v)", mountDecl, err) + c.logger.Warnf("Warning: Skipping invalid mount declaration: %s (%v)", mountDecl, err) continue } // Skip resource URIs for now (they need special handling) if strings.Contains(source, "://") { - logger.Warnf("Warning: Resource URI mounts not yet supported: %s", source) + c.logger.Warnf("Warning: Resource URI mounts not yet supported: %s", source) continue } // Convert relative paths to absolute paths - absPath, ok := convertRelativePathToAbsolute(source, mountDecl) + absPath, ok := convertRelativePathToAbsolute(source, mountDecl, c.logger) if !ok { continue } @@ -589,19 +591,25 @@ func (*Client) addReadWriteMounts( } // Process ignore patterns and add tmpfs overlays - addIgnoreOverlays(config, absPath, target, ignoreConfig) + addIgnoreOverlays(config, absPath, target, ignoreConfig, c.logger) } } // addIgnoreOverlays processes ignore patterns for a mount and adds overlay mounts -func addIgnoreOverlays(config *runtime.PermissionConfig, sourceDir, containerPath string, ignoreConfig *ignore.Config) { +func addIgnoreOverlays( + config *runtime.PermissionConfig, + sourceDir, + containerPath string, + ignoreConfig *ignore.Config, + logger *zap.SugaredLogger, +) { // Skip if no ignore configuration is provided if ignoreConfig == nil { return } // Create ignore processor with configuration - ignoreProcessor := ignore.NewProcessor(ignoreConfig) + ignoreProcessor := ignore.NewProcessor(ignoreConfig, logger) // Load global ignore patterns if enabled if ignoreConfig.LoadGlobal { @@ -639,13 +647,18 @@ func addIgnoreOverlays(config *runtime.PermissionConfig, sourceDir, containerPat ReadOnly: false, Type: mountType, }) - logger.Debugf("Added %s overlay for ignored path: %s -> %s", overlayMount.Type, source, overlayMount.ContainerPath) + logger.Debugf("Added %s overlay for ignored path: %s -> %s", + overlayMount.Type, source, overlayMount.ContainerPath) } } // convertRelativePathToAbsolute converts a relative path to an absolute path // Returns the absolute path and a boolean indicating if the conversion was successful -func convertRelativePathToAbsolute(source string, mountDecl permissions.MountDeclaration) (string, bool) { +func convertRelativePathToAbsolute( + source string, + mountDecl permissions.MountDeclaration, + logger *zap.SugaredLogger, +) (string, bool) { // If it's already an absolute path, return it as is if filepath.IsAbs(source) { return source, true @@ -1016,7 +1029,7 @@ func (c *Client) deleteNetwork(ctx context.Context, name string) error { // If the network does not exist, there is nothing to do here. if len(networks) == 0 { - logger.Debugf("network %s not found, nothing to delete", name) + c.logger.Debugf("network %s not found, nothing to delete", name) return nil } @@ -1067,7 +1080,7 @@ func (c *Client) removeProxyContainers( containerName := fmt.Sprintf("%s-%s", containerName, suffix) containerId, err := c.findExistingContainer(ctx, containerName) if err != nil { - logger.Debugf("Failed to find %s container %s: %v", suffix, containerName, err) + c.logger.Debugf("Failed to find %s container %s: %v", suffix, containerName, err) continue } if containerId == "" { @@ -1221,7 +1234,7 @@ func (c *Client) createContainer( func (c *Client) createDnsContainer(ctx context.Context, dnsContainerName string, attachStdio bool, networkName string, endpointsConfig map[string]*network.EndpointSettings) (string, string, error) { - logger.Infof("Setting up DNS container for %s with image %s...", dnsContainerName, DnsImage) + c.logger.Infof("Setting up DNS container for %s with image %s...", dnsContainerName, DnsImage) dnsLabels := map[string]string{} lb.AddStandardLabels(dnsLabels, dnsContainerName, dnsContainerName, "stdio", 80) dnsLabels[ToolhiveAuxiliaryWorkloadLabel] = LabelValueTrue @@ -1232,7 +1245,7 @@ func (c *Client) createDnsContainer(ctx context.Context, dnsContainerName string // Check if the DNS image exists locally before failing _, inspectErr := c.client.ImageInspect(ctx, DnsImage) if inspectErr == nil { - logger.Infof("DNS image %s exists locally, continuing despite pull failure", DnsImage) + c.logger.Infof("DNS image %s exists locally, continuing despite pull failure", DnsImage) } else { return "", "", fmt.Errorf("failed to pull DNS image: %v", err) } @@ -1363,7 +1376,7 @@ func addEgressEnvVars(envVars map[string]string, egressContainerName string) map func (c *Client) createIngressContainer(ctx context.Context, containerName string, upstreamPort int, attachStdio bool, externalEndpointsConfig map[string]*network.EndpointSettings) (int, error) { - squidPort, err := networking.FindOrUsePort(upstreamPort + 1) + squidPort, err := networking.FindOrUsePort(upstreamPort+1, c.logger) if err != nil { return 0, fmt.Errorf("failed to find or use port %d: %v", squidPort, err) } @@ -1428,7 +1441,7 @@ func (c *Client) createExternalNetworks(ctx context.Context) error { } func generatePortBindings(labels map[string]string, - portBindings map[string][]runtime.PortBinding) (map[string][]runtime.PortBinding, int, error) { + portBindings map[string][]runtime.PortBinding, logger *zap.SugaredLogger) (map[string][]runtime.PortBinding, int, error) { var hostPort int // check if we need to map to a random port of not if _, ok := labels["toolhive-auxiliary"]; ok && labels["toolhive-auxiliary"] == "true" { @@ -1446,7 +1459,7 @@ func generatePortBindings(labels map[string]string, } } else { // bind to a random host port - hostPort = networking.FindAvailable() + hostPort = networking.FindAvailable(logger) if hostPort == 0 { return nil, 0, fmt.Errorf("could not find an available port") } @@ -1467,11 +1480,11 @@ func generatePortBindings(labels map[string]string, func (c *Client) stopProxyContainer(ctx context.Context, containerName string, timeoutSeconds int) { containerId, err := c.findExistingContainer(ctx, containerName) if err != nil { - logger.Debugf("Failed to find internal container %s: %v", containerName, err) + c.logger.Debugf("Failed to find internal container %s: %v", containerName, err) } else { err = c.client.ContainerStop(ctx, containerId, container.StopOptions{Timeout: &timeoutSeconds}) if err != nil { - logger.Debugf("Failed to stop internal container %s: %v", containerName, err) + c.logger.Debugf("Failed to stop internal container %s: %v", containerName, err) } } } @@ -1490,14 +1503,14 @@ func (c *Client) deleteNetworks(ctx context.Context, containerName string) error networkName := fmt.Sprintf("toolhive-%s-internal", containerName) if err := c.deleteNetwork(ctx, networkName); err != nil { // just log the error and continue - logger.Warnf("failed to delete network %q: %v", networkName, err) + c.logger.Warnf("failed to delete network %q: %v", networkName, err) } if len(toolHiveContainers) == 0 { // remove external network if err := c.deleteNetwork(ctx, "toolhive-external"); err != nil { // just log the error and continue - logger.Warnf("failed to delete network %q: %v", "toolhive-external", err) + c.logger.Warnf("failed to delete network %q: %v", "toolhive-external", err) } } return nil diff --git a/pkg/container/docker/sdk/client_unix.go b/pkg/container/docker/sdk/client_unix.go index 464db850f..a655a8fa1 100644 --- a/pkg/container/docker/sdk/client_unix.go +++ b/pkg/container/docker/sdk/client_unix.go @@ -12,9 +12,9 @@ import ( "path/filepath" "github.com/docker/docker/client" + "go.uber.org/zap" "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/logger" ) // ErrRuntimeNotFound is returned when a container runtime is not found @@ -42,7 +42,7 @@ func newPlatformClient(socketPath string) (*http.Client, []client.Opt) { } // findPlatformContainerSocket finds a container socket path on Unix systems -func findPlatformContainerSocket(rt runtime.Type) (string, runtime.Type, error) { +func findPlatformContainerSocket(rt runtime.Type, logger *zap.SugaredLogger) (string, runtime.Type, error) { // First check for custom socket paths via environment variables if customSocketPath := os.Getenv(PodmanSocketEnv); customSocketPath != "" { logger.Debugf("Using Podman socket from env: %s", customSocketPath) @@ -63,14 +63,14 @@ func findPlatformContainerSocket(rt runtime.Type) (string, runtime.Type, error) } if rt == runtime.TypePodman { - socketPath, err := findPodmanSocket() + socketPath, err := findPodmanSocket(logger) if err == nil { return socketPath, runtime.TypePodman, nil } } if rt == runtime.TypeDocker { - socketPath, err := findDockerSocket() + socketPath, err := findDockerSocket(logger) if err == nil { return socketPath, runtime.TypeDocker, nil } @@ -80,7 +80,7 @@ func findPlatformContainerSocket(rt runtime.Type) (string, runtime.Type, error) } // findPodmanSocket attempts to locate a Podman socket -func findPodmanSocket() (string, error) { +func findPodmanSocket(logger *zap.SugaredLogger) (string, error) { // Check standard Podman location _, err := os.Stat(PodmanSocketPath) if err == nil { @@ -120,7 +120,7 @@ func findPodmanSocket() (string, error) { } // findDockerSocket attempts to locate a Docker socket -func findDockerSocket() (string, error) { +func findDockerSocket(logger *zap.SugaredLogger) (string, error) { // Try Docker socket as fallback _, err := os.Stat(DockerSocketPath) diff --git a/pkg/container/docker/sdk/factory.go b/pkg/container/docker/sdk/factory.go index ec940c0f6..b320bb04a 100644 --- a/pkg/container/docker/sdk/factory.go +++ b/pkg/container/docker/sdk/factory.go @@ -6,9 +6,9 @@ import ( "fmt" "github.com/docker/docker/client" + "go.uber.org/zap" "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/logger" ) /* @@ -41,14 +41,14 @@ const ( var supportedSocketPaths = []runtime.Type{runtime.TypePodman, runtime.TypeDocker} // NewDockerClient creates a new container client -func NewDockerClient(ctx context.Context) (*client.Client, string, runtime.Type, error) { +func NewDockerClient(ctx context.Context, logger *zap.SugaredLogger) (*client.Client, string, runtime.Type, error) { var lastErr error // We try to find a container socket for the given runtime // We try Podman first, then Docker as fallback for _, sp := range supportedSocketPaths { // Try to find a container socket for the given runtime - socketPath, runtimeType, err := findContainerSocket(sp) + socketPath, runtimeType, err := findContainerSocket(sp, logger) if err != nil { logger.Debugf("Failed to find socket for %s: %v", sp, err) lastErr = err @@ -93,7 +93,7 @@ func newClientWithSocketPath(ctx context.Context, socketPath string) (*client.Cl } // findContainerSocket finds a container socket path, preferring Podman over Docker -func findContainerSocket(rt runtime.Type) (string, runtime.Type, error) { +func findContainerSocket(rt runtime.Type, logger *zap.SugaredLogger) (string, runtime.Type, error) { // Use platform-specific implementation - return findPlatformContainerSocket(rt) + return findPlatformContainerSocket(rt, logger) } diff --git a/pkg/container/docker/squid.go b/pkg/container/docker/squid.go index a57843f6e..fef333893 100644 --- a/pkg/container/docker/squid.go +++ b/pkg/container/docker/squid.go @@ -12,7 +12,6 @@ import ( "github.com/stacklok/toolhive/pkg/container/runtime" lb "github.com/stacklok/toolhive/pkg/labels" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/permissions" ) @@ -88,7 +87,7 @@ func createSquidContainer( squidConfPath string, ) (string, error) { - logger.Infof("Setting up squid container for %s with image %s...", squidContainerName, getSquidImage()) + c.logger.Infof("Setting up squid container for %s with image %s...", squidContainerName, getSquidImage()) squidLabels := map[string]string{} lb.AddStandardLabels(squidLabels, squidContainerName, squidContainerName, "stdio", 80) squidLabels[ToolhiveAuxiliaryWorkloadLabel] = LabelValueTrue @@ -101,7 +100,7 @@ func createSquidContainer( // Check if the squid image exists locally before failing _, inspectErr := c.client.ImageInspect(ctx, squidImage) if inspectErr == nil { - logger.Infof("Squid image %s exists locally, continuing despite pull failure", squidImage) + c.logger.Infof("Squid image %s exists locally, continuing despite pull failure", squidImage) } else { return "", fmt.Errorf("failed to pull squid image: %v", err) } diff --git a/pkg/container/factory.go b/pkg/container/factory.go index b11d633c6..79d8b299d 100644 --- a/pkg/container/factory.go +++ b/pkg/container/factory.go @@ -5,30 +5,34 @@ package container import ( "context" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/container/docker" "github.com/stacklok/toolhive/pkg/container/kubernetes" "github.com/stacklok/toolhive/pkg/container/runtime" ) // Factory creates container runtimes -type Factory struct{} +type Factory struct { + logger *zap.SugaredLogger +} // NewFactory creates a new container factory -func NewFactory() *Factory { - return &Factory{} +func NewFactory(logger *zap.SugaredLogger) *Factory { + return &Factory{logger} } // Create creates a container runtime -func (*Factory) Create(ctx context.Context) (runtime.Runtime, error) { +func (f *Factory) Create(ctx context.Context) (runtime.Runtime, error) { if !runtime.IsKubernetesRuntime() { - client, err := docker.NewClient(ctx) + client, err := docker.NewClient(ctx, f.logger) if err != nil { return nil, err } return client, nil } - client, err := kubernetes.NewClient(ctx) + client, err := kubernetes.NewClient(ctx, f.logger) if err != nil { return nil, err } diff --git a/pkg/container/images/docker.go b/pkg/container/images/docker.go index 0c0b563ad..7940facf8 100644 --- a/pkg/container/images/docker.go +++ b/pkg/container/images/docker.go @@ -13,21 +13,22 @@ import ( "github.com/docker/docker/api/types/filters" dockerimage "github.com/docker/docker/api/types/image" "github.com/docker/docker/client" - - "github.com/stacklok/toolhive/pkg/logger" + "go.uber.org/zap" ) // DockerImageManager implements the ImageManager interface for Docker, // or compatible runtimes such as Podman. type DockerImageManager struct { client *client.Client + logger *zap.SugaredLogger } // NewDockerImageManager creates a new DockerImageManager instance // This is intended for the Docker runtime implementation. -func NewDockerImageManager(dockerClient *client.Client) *DockerImageManager { +func NewDockerImageManager(dockerClient *client.Client, logger *zap.SugaredLogger) *DockerImageManager { return &DockerImageManager{ client: dockerClient, + logger: logger, } } @@ -49,12 +50,12 @@ func (d *DockerImageManager) ImageExists(ctx context.Context, imageName string) // BuildImage builds a Docker image from a Dockerfile in the specified context directory func (d *DockerImageManager) BuildImage(ctx context.Context, contextDir, imageName string) error { - return buildDockerImage(ctx, d.client, contextDir, imageName) + return buildDockerImage(ctx, d.client, contextDir, imageName, d.logger) } // PullImage pulls an image from a registry func (d *DockerImageManager) PullImage(ctx context.Context, imageName string) error { - logger.Infof("Pulling image: %s", imageName) + d.logger.Infof("Pulling image: %s", imageName) // Pull the image reader, err := d.client.ImagePull(ctx, imageName, dockerimage.PullOptions{}) @@ -72,8 +73,14 @@ func (d *DockerImageManager) PullImage(ctx context.Context, imageName string) er } // buildDockerImage builds a Docker image using the Docker client API -func buildDockerImage(ctx context.Context, dockerClient *client.Client, contextDir, imageName string) error { - logger.Infof("Building image %s from context directory %s", imageName, contextDir) +func buildDockerImage( + ctx context.Context, + dockerClient *client.Client, + contextDir, + imageName string, + logger *zap.SugaredLogger, +) error { + logger.Infof("Building image %s from context directory %s", imageName, contextDir, logger) // Create a tar archive of the context directory tarFile, err := os.CreateTemp("", "docker-build-context-*.tar") diff --git a/pkg/container/images/image.go b/pkg/container/images/image.go index 0ed732025..d1a30b3c4 100644 --- a/pkg/container/images/image.go +++ b/pkg/container/images/image.go @@ -4,9 +4,10 @@ package images import ( "context" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/container/docker/sdk" "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/logger" ) // ImageManager defines the interface for managing container images. @@ -26,7 +27,7 @@ type ImageManager interface { // NewImageManager creates an instance of ImageManager appropriate // for the current environment, or returns an error if it is not supported. -func NewImageManager(ctx context.Context) ImageManager { +func NewImageManager(ctx context.Context, logger *zap.SugaredLogger) ImageManager { // Check if we are running in a Kubernetes environment if runtime.IsKubernetesRuntime() { logger.Debug("running in Kubernetes environment, using no-op image manager") @@ -34,13 +35,13 @@ func NewImageManager(ctx context.Context) ImageManager { } // Check if we are running in a Docker or compatible environment - dockerClient, _, _, err := sdk.NewDockerClient(ctx) + dockerClient, _, _, err := sdk.NewDockerClient(ctx, logger) if err != nil { logger.Debug("no docker runtime found, using no-op image manager") return &NoopImageManager{} } - return NewRegistryImageManager(dockerClient) + return NewRegistryImageManager(dockerClient, logger) } // NoopImageManager is a no-op implementation of ImageManager. diff --git a/pkg/container/images/registry.go b/pkg/container/images/registry.go index 55f9f8908..de10c98a1 100644 --- a/pkg/container/images/registry.go +++ b/pkg/container/images/registry.go @@ -12,8 +12,7 @@ import ( "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/daemon" "github.com/google/go-containerregistry/pkg/v1/remote" - - "github.com/stacklok/toolhive/pkg/logger" + "go.uber.org/zap" ) // RegistryImageManager implements the ImageManager interface using go-containerregistry @@ -23,14 +22,16 @@ type RegistryImageManager struct { keychain authn.Keychain platform *v1.Platform dockerClient *client.Client + logger *zap.SugaredLogger } // NewRegistryImageManager creates a new RegistryImageManager instance -func NewRegistryImageManager(dockerClient *client.Client) *RegistryImageManager { +func NewRegistryImageManager(dockerClient *client.Client, logger *zap.SugaredLogger) *RegistryImageManager { return &RegistryImageManager{ keychain: NewCompositeKeychain(), // Use composite keychain (env vars + default) platform: getDefaultPlatform(), // Use a default platform based on host architecture dockerClient: dockerClient, // Used solely for building images from Dockerfiles + logger: logger, } } @@ -62,7 +63,7 @@ func (r *RegistryImageManager) ImageExists(_ context.Context, imageName string) // PullImage pulls an image from a registry and saves it to the local daemon func (r *RegistryImageManager) PullImage(ctx context.Context, imageName string) error { - logger.Infof("Pulling image: %s", imageName) + r.logger.Infof("Pulling image: %s", imageName) // Parse the image reference ref, err := name.ParseReference(imageName) @@ -104,14 +105,14 @@ func (r *RegistryImageManager) PullImage(ctx context.Context, imageName string) // Display success message fmt.Fprintf(os.Stdout, "Successfully pulled %s\n", imageName) - logger.Infof("Pull complete for image: %s, response: %s", imageName, response) + r.logger.Infof("Pull complete for image: %s, response: %s", imageName, response) return nil } // BuildImage builds a Docker image from a Dockerfile in the specified context directory func (r *RegistryImageManager) BuildImage(ctx context.Context, contextDir, imageName string) error { - return buildDockerImage(ctx, r.dockerClient, contextDir, imageName) + return buildDockerImage(ctx, r.dockerClient, contextDir, imageName, r.logger) } // WithKeychain sets the keychain for authentication diff --git a/pkg/container/kubernetes/client.go b/pkg/container/kubernetes/client.go index a9f6af049..f7ee90829 100644 --- a/pkg/container/kubernetes/client.go +++ b/pkg/container/kubernetes/client.go @@ -13,6 +13,7 @@ import ( "time" "github.com/cenkalti/backoff/v5" + "go.uber.org/zap" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" @@ -29,7 +30,6 @@ import ( "k8s.io/client-go/tools/watch" "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/permissions" transtypes "github.com/stacklok/toolhive/pkg/transport/types" ) @@ -47,11 +47,18 @@ type Client struct { runtimeType runtime.Type client kubernetes.Interface // waitForStatefulSetReadyFunc is used for testing to mock the waitForStatefulSetReady function - waitForStatefulSetReadyFunc func(ctx context.Context, clientset kubernetes.Interface, namespace, name string) error + waitForStatefulSetReadyFunc func( + ctx context.Context, + clientset kubernetes.Interface, + namespace, + name string, + logger *zap.SugaredLogger, + ) error + logger *zap.SugaredLogger } // NewClient creates a new container client -func NewClient(_ context.Context) (*Client, error) { +func NewClient(_ context.Context, logger *zap.SugaredLogger) (*Client, error) { // creates the in-cluster config config, err := rest.InClusterConfig() if err != nil { @@ -66,6 +73,7 @@ func NewClient(_ context.Context) (*Client, error) { return &Client{ runtimeType: runtime.TypeKubernetes, client: clientset, + logger: logger, }, nil } @@ -117,7 +125,7 @@ func (c *Client) AttachToWorkload(ctx context.Context, workloadName string) (io. return nil, nil, fmt.Errorf("failed to create SPDY executor: %v", err) } - logger.Infof("Attaching to pod %s workload %s...", podName, workloadName) + c.logger.Infof("Attaching to pod %s workload %s...", podName, workloadName) stdinReader, stdinWriter := io.Pipe() stdoutReader, stdoutWriter := io.Pipe() @@ -140,23 +148,23 @@ func (c *Client) AttachToWorkload(ctx context.Context, workloadName string) (io. backoff.WithBackOff(expBackoff), backoff.WithMaxTries(5), backoff.WithNotify(func(err error, duration time.Duration) { - logger.Errorf("Error attaching to workload %s: %v. Retrying in %s...", workloadName, err, duration) + c.logger.Errorf("Error attaching to workload %s: %v. Retrying in %s...", workloadName, err, duration) }), ) if err != nil { if statusErr, ok := err.(*errors.StatusError); ok { - logger.Errorf("Kubernetes API error: Status=%s, Message=%s, Reason=%s, Code=%d", + c.logger.Errorf("Kubernetes API error: Status=%s, Message=%s, Reason=%s, Code=%d", statusErr.ErrStatus.Status, statusErr.ErrStatus.Message, statusErr.ErrStatus.Reason, statusErr.ErrStatus.Code) if statusErr.ErrStatus.Code == 0 && statusErr.ErrStatus.Message == "" { - logger.Info("Empty status error - this typically means the connection was closed unexpectedly") - logger.Info("This often happens when the container terminates or doesn't read from stdin") + c.logger.Info("Empty status error - this typically means the connection was closed unexpectedly") + c.logger.Info("This often happens when the container terminates or doesn't read from stdin") } } else { - logger.Errorf("Non-status error: %v", err) + c.logger.Errorf("Non-status error: %v", err) } } }() @@ -285,7 +293,7 @@ func (c *Client) DeployWorkload(ctx context.Context, return 0, fmt.Errorf("failed to apply statefulset: %v", err) } - logger.Infof("Applied statefulset %s", createdStatefulSet.Name) + c.logger.Infof("Applied statefulset %s", createdStatefulSet.Name) if transportTypeRequiresHeadlessService(transportType) && options != nil { // Create a headless service for SSE transport @@ -300,7 +308,7 @@ func (c *Client) DeployWorkload(ctx context.Context, if c.waitForStatefulSetReadyFunc != nil { waitFunc = c.waitForStatefulSetReadyFunc } - err = waitFunc(ctx, c.client, namespace, createdStatefulSet.Name) + err = waitFunc(ctx, c.client, namespace, createdStatefulSet.Name, c.logger) if err != nil { return 0, fmt.Errorf("statefulset applied but failed to become ready: %w", err) } @@ -463,13 +471,13 @@ func (c *Client) RemoveWorkload(ctx context.Context, workloadName string) error if err != nil { if errors.IsNotFound(err) { // If the statefulset doesn't exist, that's fine - logger.Infof("Statefulset %s not found, nothing to remove", workloadName) + c.logger.Infof("Statefulset %s not found, nothing to remove", workloadName) return nil } return fmt.Errorf("failed to delete statefulset %s: %w", workloadName, err) } - logger.Infof("Deleted statefulset %s", workloadName) + c.logger.Infof("Deleted statefulset %s", workloadName) return nil } @@ -492,7 +500,13 @@ func (c *Client) IsRunning(ctx context.Context) error { } // waitForStatefulSetReady waits for a statefulset to be ready using the watch API -func waitForStatefulSetReady(ctx context.Context, clientset kubernetes.Interface, namespace, name string) error { +func waitForStatefulSetReady( + ctx context.Context, + clientset kubernetes.Interface, + namespace, + name string, + logger *zap.SugaredLogger, +) error { // Create a field selector to watch only this specific statefulset fieldSelector := fmt.Sprintf("metadata.name=%s", name) @@ -731,7 +745,7 @@ func (c *Client) createHeadlessService( // If no ports were configured, don't create a service if len(servicePorts) == 0 { - logger.Info("No ports configured for SSE transport, skipping service creation") + c.logger.Info("No ports configured for SSE transport, skipping service creation") return nil } @@ -771,7 +785,7 @@ func (c *Client) createHeadlessService( return fmt.Errorf("failed to apply service: %v", err) } - logger.Infof("Created headless service %s for SSE transport", containerName) + c.logger.Infof("Created headless service %s for SSE transport", containerName) options.SSEHeadlessServiceName = svcName return nil diff --git a/pkg/container/kubernetes/client_test.go b/pkg/container/kubernetes/client_test.go index 784b063a2..cabeefe75 100644 --- a/pkg/container/kubernetes/client_test.go +++ b/pkg/container/kubernetes/client_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/zap" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -15,16 +16,11 @@ import ( "k8s.io/client-go/kubernetes/fake" "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" ) -func init() { - // Initialize the logger for tests - logger.Initialize() -} - // mockWaitForStatefulSetReady is used to mock the waitForStatefulSetReady function in tests -var mockWaitForStatefulSetReady = func(_ context.Context, _ kubernetes.Interface, _, _ string) error { +var mockWaitForStatefulSetReady = func(_ context.Context, _ kubernetes.Interface, _, _ string, _ *zap.SugaredLogger) error { return nil } @@ -165,6 +161,7 @@ func TestCreateContainerWithPodTemplatePatch(t *testing.T) { runtimeType: runtime.TypeKubernetes, client: clientset, waitForStatefulSetReadyFunc: mockWaitForStatefulSetReady, + logger: log.NewLogger(), } // Create workload options with the pod template patch options := runtime.NewDeployWorkloadOptions() @@ -663,6 +660,7 @@ func TestCreateContainerWithMCP(t *testing.T) { runtimeType: runtime.TypeKubernetes, client: clientset, waitForStatefulSetReadyFunc: mockWaitForStatefulSetReady, + logger: log.NewLogger(), } // Deploy the workload diff --git a/pkg/container/verifier/attestations.go b/pkg/container/verifier/attestations.go index 42e6dac64..b4d97be17 100644 --- a/pkg/container/verifier/attestations.go +++ b/pkg/container/verifier/attestations.go @@ -11,14 +11,13 @@ import ( "github.com/google/go-containerregistry/pkg/v1/remote" containerdigest "github.com/opencontainers/go-digest" "github.com/sigstore/sigstore-go/pkg/bundle" - - "github.com/stacklok/toolhive/pkg/logger" + "go.uber.org/zap" ) // bundleFromAttestation retrieves the attestation bundles from the image reference. Note that the attestation // bundles are stored as OCI image references. The function uses the referrers API to get the attestation. GitHub supports // discovering the attestations via their API, but this is not supported here for now. -func bundleFromAttestation(imageRef string, keychain authn.Keychain) ([]sigstoreBundle, error) { +func bundleFromAttestation(imageRef string, keychain authn.Keychain, logger *zap.SugaredLogger) ([]sigstoreBundle, error) { var bundles []sigstoreBundle // Get the auth options diff --git a/pkg/container/verifier/sigstore.go b/pkg/container/verifier/sigstore.go index 2dfd93548..c502c4a99 100644 --- a/pkg/container/verifier/sigstore.go +++ b/pkg/container/verifier/sigstore.go @@ -20,8 +20,7 @@ import ( protocommon "github.com/sigstore/protobuf-specs/gen/pb-go/common/v1" protorekor "github.com/sigstore/protobuf-specs/gen/pb-go/rekor/v1" "github.com/sigstore/sigstore-go/pkg/bundle" - - "github.com/stacklok/toolhive/pkg/logger" + "go.uber.org/zap" ) type sigstoreBundle struct { @@ -31,7 +30,11 @@ type sigstoreBundle struct { } // bundleFromSigstoreSignedImage returns a bundle from a Sigstore signed image -func bundleFromSigstoreSignedImage(imageRef string, keychain authn.Keychain) ([]sigstoreBundle, error) { +func bundleFromSigstoreSignedImage( + imageRef string, + keychain authn.Keychain, + logger *zap.SugaredLogger, +) ([]sigstoreBundle, error) { // Get the signature manifest from the OCI image reference signatureRef, err := getSignatureReferenceFromOCIImage(imageRef, keychain) if err != nil { diff --git a/pkg/container/verifier/utils.go b/pkg/container/verifier/utils.go index 67fd56731..f63aadf07 100644 --- a/pkg/container/verifier/utils.go +++ b/pkg/container/verifier/utils.go @@ -10,6 +10,7 @@ import ( "github.com/google/go-containerregistry/pkg/authn" "github.com/sigstore/sigstore-go/pkg/tuf" "github.com/sigstore/sigstore-go/pkg/verify" + "go.uber.org/zap" ) //go:embed tufroots @@ -112,13 +113,14 @@ func embeddedRootJson(tufRootURL string) ([]byte, error) { func getSigstoreBundles( imageRef string, keychain authn.Keychain, + logger *zap.SugaredLogger, ) ([]sigstoreBundle, error) { // Try to build a bundle from a Sigstore signed image - bundles, err := bundleFromSigstoreSignedImage(imageRef, keychain) + bundles, err := bundleFromSigstoreSignedImage(imageRef, keychain, logger) if errors.Is(err, ErrProvenanceNotFoundOrIncomplete) { // If we get this error, it means that the image is not signed // or the signature is incomplete. Let's try to see if we can find attestation for the image. - return bundleFromAttestation(imageRef, keychain) + return bundleFromAttestation(imageRef, keychain, logger) } else if err != nil { return nil, err } diff --git a/pkg/container/verifier/verifier.go b/pkg/container/verifier/verifier.go index e1841f5d3..0e0d37773 100644 --- a/pkg/container/verifier/verifier.go +++ b/pkg/container/verifier/verifier.go @@ -10,9 +10,9 @@ import ( "github.com/sigstore/sigstore-go/pkg/fulcio/certificate" "github.com/sigstore/sigstore-go/pkg/root" "github.com/sigstore/sigstore-go/pkg/verify" + "go.uber.org/zap" "github.com/stacklok/toolhive/pkg/container/images" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/registry" ) @@ -29,6 +29,7 @@ const ( type Sigstore struct { verifier *verify.Verifier keychain authn.Keychain + logger *zap.SugaredLogger } // Result is the result of the verification @@ -39,7 +40,7 @@ type Result struct { } // New creates a new Sigstore verifier -func New(serverInfo *registry.ImageMetadata) (*Sigstore, error) { +func New(serverInfo *registry.ImageMetadata, logger *zap.SugaredLogger) (*Sigstore, error) { // Fail the verification early if the server information is not set if serverInfo == nil || serverInfo.Provenance == nil { return nil, ErrProvenanceServerInformationNotSet @@ -73,6 +74,7 @@ func New(serverInfo *registry.ImageMetadata) (*Sigstore, error) { return &Sigstore{ verifier: sev, keychain: images.NewCompositeKeychain(), + logger: logger, }, nil } @@ -87,12 +89,12 @@ func (s *Sigstore) GetVerificationResults( imageRef string, ) ([]*verify.VerificationResult, error) { // Construct the bundle(s) for the image reference - bundles, err := getSigstoreBundles(imageRef, s.keychain) + bundles, err := getSigstoreBundles(imageRef, s.keychain, s.logger) if err != nil && !errors.Is(err, ErrProvenanceNotFoundOrIncomplete) { // We got some other unexpected error prior to querying for the signature/attestation return nil, err } - logger.Debugf("Number of sigstore bundles we managed to construct is %d", len(bundles)) + s.logger.Debugf("Number of sigstore bundles we managed to construct is %d", len(bundles)) // If we didn't manage to construct any valid bundles, it probably means that the image is not signed. if len(bundles) == 0 || errors.Is(err, ErrProvenanceNotFoundOrIncomplete) { @@ -100,7 +102,7 @@ func (s *Sigstore) GetVerificationResults( } // Construct the verification result for each bundle we managed to generate. - return getVerifiedResults(s.verifier, bundles), nil + return getVerifiedResults(s.verifier, bundles, s.logger), nil } // getVerifiedResults verifies the artifact using the bundles against the configured sigstore instance @@ -108,6 +110,7 @@ func (s *Sigstore) GetVerificationResults( func getVerifiedResults( sev *verify.Verifier, bundles []sigstoreBundle, + logger *zap.SugaredLogger, ) []*verify.VerificationResult { var results []*verify.VerificationResult @@ -144,7 +147,7 @@ func (s *Sigstore) VerifyServer(imageRef string, serverInfo *registry.ImageMetad // Compare the server information with the verification results for _, res := range results { - if !isVerificationResultMatchingServerProvenance(res, serverInfo.Provenance) { + if !isVerificationResultMatchingServerProvenance(res, serverInfo.Provenance, s.logger) { // The server information does not match the verification result, fail the verification return false, nil } @@ -153,13 +156,17 @@ func (s *Sigstore) VerifyServer(imageRef string, serverInfo *registry.ImageMetad return true, nil } -func isVerificationResultMatchingServerProvenance(r *verify.VerificationResult, p *registry.Provenance) bool { +func isVerificationResultMatchingServerProvenance( + r *verify.VerificationResult, + p *registry.Provenance, + logger *zap.SugaredLogger, +) bool { if r == nil || p == nil || r.Signature == nil || r.Signature.Certificate == nil { return false } // Compare the base properties of the verification result and the server provenance - if !compareBaseProperties(r, p) { + if !compareBaseProperties(r, p, logger) { return false } @@ -175,7 +182,7 @@ func isVerificationResultMatchingServerProvenance(r *verify.VerificationResult, } // compareBaseProperties compares the base properties of the verification result and the server provenance -func compareBaseProperties(r *verify.VerificationResult, p *registry.Provenance) bool { +func compareBaseProperties(r *verify.VerificationResult, p *registry.Provenance, logger *zap.SugaredLogger) bool { // Extract the signer identity from the certificate siIdentity, err := signerIdentityFromCertificate(r.Signature.Certificate) if err != nil { diff --git a/pkg/groups/group_test.go b/pkg/groups/group_test.go index 8866c2fcf..218381613 100644 --- a/pkg/groups/group_test.go +++ b/pkg/groups/group_test.go @@ -11,15 +11,10 @@ import ( "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/state/mocks" ) -func init() { - // Initialize logger for tests - logger.Initialize() -} - const testGroupName = "testgroup" // TestManager_Create demonstrates using gomock for testing group creation @@ -458,6 +453,8 @@ func TestManager_Exists(t *testing.T) { func TestManager_RegisterClients(t *testing.T) { t.Parallel() + logger := log.NewLogger() + tests := []struct { name string groupName string @@ -519,7 +516,10 @@ func TestManager_RegisterClients(t *testing.T) { defer ctrl.Finish() mockStore := mocks.NewMockStore(ctrl) - manager := &manager{groupStore: mockStore} + manager := &manager{ + groupStore: mockStore, + logger: logger, + } // Set up mock expectations tt.setupMock(mockStore) @@ -542,6 +542,8 @@ func TestManager_RegisterClients(t *testing.T) { func TestManager_UnregisterClients(t *testing.T) { t.Parallel() + logger := log.NewLogger() + tests := []struct { name string groupName string @@ -603,7 +605,10 @@ func TestManager_UnregisterClients(t *testing.T) { defer ctrl.Finish() mockStore := mocks.NewMockStore(ctrl) - manager := &manager{groupStore: mockStore} + manager := &manager{ + groupStore: mockStore, + logger: logger, + } // Set up mock expectations tt.setupMock(mockStore) diff --git a/pkg/groups/manager.go b/pkg/groups/manager.go index 5d9ec4069..143103993 100644 --- a/pkg/groups/manager.go +++ b/pkg/groups/manager.go @@ -7,8 +7,9 @@ import ( "sort" "strings" + "go.uber.org/zap" + thverrors "github.com/stacklok/toolhive/pkg/errors" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/state" ) @@ -20,16 +21,20 @@ const ( // manager implements the Manager interface type manager struct { groupStore state.Store + logger *zap.SugaredLogger } // NewManager creates a new group manager -func NewManager() (Manager, error) { +func NewManager(logger *zap.SugaredLogger) (Manager, error) { store, err := state.NewGroupConfigStore("toolhive") if err != nil { return nil, fmt.Errorf("failed to create group state store: %w", err) } - return &manager{groupStore: store}, nil + return &manager{ + groupStore: store, + logger: logger, + }, nil } // Create creates a new group with the given name @@ -121,14 +126,14 @@ func (m *manager) RegisterClients(ctx context.Context, groupNames []string, clie } if alreadyRegistered { - logger.Infof("Client %s is already registered with group %s, skipping", clientName, groupName) + m.logger.Infof("Client %s is already registered with group %s, skipping", clientName, groupName) continue } // Add the client to the group group.RegisteredClients = append(group.RegisteredClients, clientName) groupModified = true - logger.Infof("Successfully registered client %s with group %s", clientName, groupName) + m.logger.Infof("Successfully registered client %s with group %s", clientName, groupName) } // Only save if the group was actually modified @@ -160,7 +165,7 @@ func (m *manager) UnregisterClients(ctx context.Context, groupNames []string, cl // Remove client from slice group.RegisteredClients = append(group.RegisteredClients[:i], group.RegisteredClients[i+1:]...) groupModified = true - logger.Infof("Successfully unregistered client %s from group %s", clientName, groupName) + m.logger.Infof("Successfully unregistered client %s from group %s", clientName, groupName) break } } diff --git a/pkg/healthcheck/healthcheck.go b/pkg/healthcheck/healthcheck.go index 2de0934b5..a7842d183 100644 --- a/pkg/healthcheck/healthcheck.go +++ b/pkg/healthcheck/healthcheck.go @@ -8,7 +8,8 @@ import ( "net/http" "time" - "github.com/stacklok/toolhive/pkg/logger" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/versions" ) @@ -64,13 +65,15 @@ type MCPPinger interface { type HealthChecker struct { transport string mcpPinger MCPPinger + logger *zap.SugaredLogger } // NewHealthChecker creates a new health checker instance -func NewHealthChecker(transport string, mcpPinger MCPPinger) *HealthChecker { +func NewHealthChecker(transport string, mcpPinger MCPPinger, logger *zap.SugaredLogger) *HealthChecker { return &HealthChecker{ transport: transport, mcpPinger: mcpPinger, + logger: logger, } } @@ -112,12 +115,12 @@ func (hc *HealthChecker) checkMCPStatus(ctx context.Context) *MCPStatus { if err != nil { status.Available = false status.Error = err.Error() - logger.Debugf("MCP ping failed: %v", err) + hc.logger.Debugf("MCP ping failed: %v", err) } else { status.Available = true responseTimeMs := duration.Milliseconds() status.ResponseTime = &responseTimeMs - logger.Debugf("MCP ping successful: %v", duration) + hc.logger.Debugf("MCP ping successful: %v", duration) } return status @@ -145,7 +148,7 @@ func (hc *HealthChecker) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if err := json.NewEncoder(w).Encode(health); err != nil { - logger.Warnf("Failed to encode health response: %v", err) + hc.logger.Warnf("Failed to encode health response: %v", err) http.Error(w, "Internal server error", http.StatusInternalServerError) } } diff --git a/pkg/healthcheck/healthcheck_test.go b/pkg/healthcheck/healthcheck_test.go index edb12c254..53e2c00da 100644 --- a/pkg/healthcheck/healthcheck_test.go +++ b/pkg/healthcheck/healthcheck_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/versions" ) @@ -31,8 +31,7 @@ func (m *mockMCPPinger) Ping(_ context.Context) (time.Duration, error) { func TestHealthChecker_CheckHealth(t *testing.T) { t.Parallel() - // Initialize logger for tests - logger.Initialize() + logger := log.NewLogger() tests := []struct { name string @@ -68,7 +67,7 @@ func TestHealthChecker_CheckHealth(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - hc := NewHealthChecker(tt.transport, tt.pinger) + hc := NewHealthChecker(tt.transport, tt.pinger, logger) ctx := context.Background() health := hc.CheckHealth(ctx) @@ -100,8 +99,7 @@ func TestHealthChecker_CheckHealth(t *testing.T) { func TestHealthChecker_ServeHTTP(t *testing.T) { t.Parallel() - // Initialize logger for tests - logger.Initialize() + logger := log.NewLogger() tests := []struct { name string @@ -159,7 +157,7 @@ func TestHealthChecker_ServeHTTP(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - hc := NewHealthChecker("stdio", tt.pinger) + hc := NewHealthChecker("stdio", tt.pinger, logger) req := httptest.NewRequest(tt.method, "/health", nil) w := httptest.NewRecorder() @@ -182,9 +180,6 @@ func TestHealthChecker_ServeHTTP(t *testing.T) { func TestHealthResponse_JSON(t *testing.T) { t.Parallel() - // Initialize logger for tests - logger.Initialize() - response := &HealthResponse{ Status: StatusHealthy, Timestamp: time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC), diff --git a/pkg/ignore/processor.go b/pkg/ignore/processor.go index e9d541750..dd5bb5615 100644 --- a/pkg/ignore/processor.go +++ b/pkg/ignore/processor.go @@ -10,8 +10,7 @@ import ( "strings" "github.com/adrg/xdg" - - "github.com/stacklok/toolhive/pkg/logger" + "go.uber.org/zap" ) // Processor handles loading and processing ignore patterns @@ -20,6 +19,7 @@ type Processor struct { LocalPatterns []string Config *Config sharedEmptyFile string // Cached path to a single shared empty file + logger *zap.SugaredLogger } // Config holds configuration for ignore processing @@ -31,17 +31,19 @@ type Config struct { const ignoreFileName = ".thvignore" // NewProcessor creates a new Processor instance with the given configuration -func NewProcessor(config *Config) *Processor { +func NewProcessor(config *Config, logger *zap.SugaredLogger) *Processor { if config == nil { config = &Config{ LoadGlobal: true, PrintOverlays: false, } } + return &Processor{ GlobalPatterns: make([]string, 0), LocalPatterns: make([]string, 0), Config: config, + logger: logger, } } @@ -49,27 +51,27 @@ func NewProcessor(config *Config) *Processor { func (p *Processor) LoadGlobal() error { // Skip loading global patterns if disabled in config if !p.Config.LoadGlobal { - logger.Debugf("Global ignore patterns disabled by configuration") + p.logger.Debug("Global ignore patterns disabled by configuration") return nil } globalIgnoreFile, err := xdg.ConfigFile("toolhive/thvignore") if err != nil { - logger.Debugf("Failed to get XDG config file path: %v", err) + p.logger.Debugf("Failed to get XDG config file path: %v", err) return nil // Not a fatal error, continue without global patterns } patterns, err := p.loadIgnoreFile(globalIgnoreFile) if err != nil { if os.IsNotExist(err) { - logger.Debugf("Global ignore file not found: %s", globalIgnoreFile) + p.logger.Debugf("Global ignore file not found: %s", globalIgnoreFile) return nil // Not a fatal error } return fmt.Errorf("failed to load global ignore file: %w", err) } p.GlobalPatterns = patterns - logger.Debugf("Loaded %d global ignore patterns from %s", len(patterns), globalIgnoreFile) + p.logger.Debugf("Loaded %d global ignore patterns from %s", len(patterns), globalIgnoreFile) return nil } @@ -79,14 +81,14 @@ func (p *Processor) LoadLocal(sourceDir string) error { patterns, err := p.loadIgnoreFile(localIgnoreFile) if err != nil { if os.IsNotExist(err) { - logger.Debugf("Local ignore file not found: %s", localIgnoreFile) + p.logger.Debugf("Local ignore file not found: %s", localIgnoreFile) return nil // Not a fatal error } return fmt.Errorf("failed to load local ignore file: %w", err) } p.LocalPatterns = append(p.LocalPatterns, patterns...) - logger.Debugf("Loaded %d local ignore patterns from %s", len(patterns), localIgnoreFile) + p.logger.Debugf("Loaded %d local ignore patterns from %s", len(patterns), localIgnoreFile) return nil } @@ -165,7 +167,7 @@ func (p *Processor) createOverlayMount( // Calculate relative path from bind mount to matched path relPath, err := filepath.Rel(bindMount, matchPath) if err != nil { - logger.Debugf("Failed to calculate relative path for %s: %v", matchPath, err) + p.logger.Debugf("Failed to calculate relative path for %s: %v", matchPath, err) return nil } @@ -181,13 +183,13 @@ func (p *Processor) createOverlayMount( // Check if the matched path is a directory or file info, err := os.Stat(matchPath) if err != nil { - logger.Debugf("Failed to stat path %s: %v", matchPath, err) + p.logger.Debugf("Failed to stat path %s: %v", matchPath, err) return nil } if info.IsDir() { // For directories, use tmpfs mount - logger.Debugf("Adding tmpfs overlay for directory pattern '%s' at container path: %s", pattern, containerOverlayPath) + p.logger.Debugf("Adding tmpfs overlay for directory pattern '%s' at container path: %s", pattern, containerOverlayPath) return &OverlayMount{ ContainerPath: containerOverlayPath, HostPath: "", // tmpfs doesn't need host path @@ -198,11 +200,11 @@ func (p *Processor) createOverlayMount( // For files, create empty file and bind mount it emptyFilePath, err := p.createEmptyFile() if err != nil { - logger.Debugf("Failed to create empty file for pattern '%s': %v", pattern, err) + p.logger.Debugf("Failed to create empty file for pattern '%s': %v", pattern, err) return nil } - logger.Debugf("Adding bind overlay for file pattern '%s' at container path: %s (host: %s)", + p.logger.Debugf("Adding bind overlay for file pattern '%s' at container path: %s (host: %s)", pattern, containerOverlayPath, emptyFilePath) return &OverlayMount{ ContainerPath: containerOverlayPath, @@ -214,12 +216,12 @@ func (p *Processor) createOverlayMount( // printOverlays prints resolved overlays if requested func (p *Processor) printOverlays(overlayMounts []OverlayMount, bindMount, containerPath string) { if p.Config.PrintOverlays && len(overlayMounts) > 0 { - logger.Infof("Resolved overlays for mount %s -> %s:", bindMount, containerPath) + p.logger.Infof("Resolved overlays for mount %s -> %s:", bindMount, containerPath) for _, overlay := range overlayMounts { if overlay.Type == "tmpfs" { - logger.Infof(" - %s (tmpfs)", overlay.ContainerPath) + p.logger.Infof(" - %s (tmpfs)", overlay.ContainerPath) } else { - logger.Infof(" - %s (bind: %s)", overlay.ContainerPath, overlay.HostPath) + p.logger.Infof(" - %s (bind: %s)", overlay.ContainerPath, overlay.HostPath) } } } @@ -248,7 +250,7 @@ func (p *Processor) createEmptyFile() (string, error) { // Cache the path for reuse p.sharedEmptyFile = tmpFile.Name() - logger.Debugf("Created shared empty file for bind mounting: %s", p.sharedEmptyFile) + p.logger.Debugf("Created shared empty file for bind mounting: %s", p.sharedEmptyFile) return p.sharedEmptyFile, nil } @@ -257,10 +259,10 @@ func (p *Processor) createEmptyFile() (string, error) { func (p *Processor) Cleanup() error { if p.sharedEmptyFile != "" { if err := os.Remove(p.sharedEmptyFile); err != nil && !os.IsNotExist(err) { - logger.Debugf("Failed to remove shared empty file %s: %v", p.sharedEmptyFile, err) + p.logger.Debugf("Failed to remove shared empty file %s: %v", p.sharedEmptyFile, err) return fmt.Errorf("failed to remove shared empty file: %w", err) } - logger.Debugf("Cleaned up shared empty file: %s", p.sharedEmptyFile) + p.logger.Debugf("Cleaned up shared empty file: %s", p.sharedEmptyFile) p.sharedEmptyFile = "" } return nil @@ -280,7 +282,7 @@ func (p *Processor) GetOverlayPaths(bindMount, containerPath string) []string { } // getMatchingPaths returns all paths that match the given pattern in the directory -func (*Processor) getMatchingPaths(dir, pattern string) []string { +func (p *Processor) getMatchingPaths(dir, pattern string) []string { var matchingPaths []string // Handle directory patterns (ending with /) @@ -303,7 +305,7 @@ func (*Processor) getMatchingPaths(dir, pattern string) []string { // Handle glob patterns matches, err := filepath.Glob(filepath.Join(dir, pattern)) if err != nil { - logger.Debugf("Error matching pattern '%s': %v", pattern, err) + p.logger.Debugf("Error matching pattern '%s': %v", pattern, err) return matchingPaths } diff --git a/pkg/ignore/processor_test.go b/pkg/ignore/processor_test.go index 9d74b17e6..ba6f48e48 100644 --- a/pkg/ignore/processor_test.go +++ b/pkg/ignore/processor_test.go @@ -5,19 +5,18 @@ import ( "path/filepath" "testing" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" ) -func init() { - logger.Initialize() // ensure logging doesn't panic -} - func TestNewProcessor(t *testing.T) { t.Parallel() + + logger := log.NewLogger() + processor := NewProcessor(&Config{ LoadGlobal: true, PrintOverlays: false, - }) + }, logger) if processor == nil { t.Error("NewProcessor should return a non-nil processor") return @@ -32,6 +31,9 @@ func TestNewProcessor(t *testing.T) { func TestLoadIgnoreFile(t *testing.T) { t.Parallel() + + logger := log.NewLogger() + testCases := []struct { name string fileContent string @@ -94,7 +96,7 @@ temp/ processor := NewProcessor(&Config{ LoadGlobal: true, PrintOverlays: false, - }) + }, logger) patterns, err := processor.loadIgnoreFile(ignoreFile) if tc.expectError { @@ -115,6 +117,9 @@ temp/ func TestLoadLocal(t *testing.T) { t.Parallel() + + logger := log.NewLogger() + testCases := []struct { name string createFile bool @@ -153,7 +158,7 @@ func TestLoadLocal(t *testing.T) { processor := NewProcessor(&Config{ LoadGlobal: true, PrintOverlays: false, - }) + }, logger) err := processor.LoadLocal(tmpDir) if tc.expectError { @@ -174,6 +179,10 @@ func TestLoadLocal(t *testing.T) { func TestPatternMatchesInDirectory(t *testing.T) { t.Parallel() + + // Setup logger + logger := log.NewLogger() + // Create test directory structure tmpDir := t.TempDir() @@ -199,7 +208,7 @@ func TestPatternMatchesInDirectory(t *testing.T) { processor := NewProcessor(&Config{ LoadGlobal: true, PrintOverlays: false, - }) + }, logger) testCases := []struct { name string @@ -246,6 +255,9 @@ func TestPatternMatchesInDirectory(t *testing.T) { func TestGetOverlayPaths(t *testing.T) { t.Parallel() + + logger := log.NewLogger() + // Create test directory structure tmpDir := t.TempDir() @@ -265,7 +277,7 @@ func TestGetOverlayPaths(t *testing.T) { processor := NewProcessor(&Config{ LoadGlobal: true, PrintOverlays: false, - }) + }, logger) processor.GlobalPatterns = []string{"node_modules/", ".DS_Store"} processor.LocalPatterns = []string{".ssh/", ".env"} @@ -297,10 +309,13 @@ func TestGetOverlayPaths(t *testing.T) { func TestShouldIgnore(t *testing.T) { t.Parallel() + + logger := log.NewLogger() + processor := NewProcessor(&Config{ LoadGlobal: true, PrintOverlays: false, - }) + }, logger) processor.GlobalPatterns = []string{"node_modules", "*.log"} processor.LocalPatterns = []string{".ssh", ".env"} diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 7be4ac945..f6f5812ac 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -2,132 +2,59 @@ package logger import ( - "context" - "fmt" - "log/slog" "os" - "runtime" "strconv" "time" - "github.com/lmittmann/tint" + "github.com/go-logr/logr" + "github.com/go-logr/zapr" "github.com/spf13/viper" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" ) -// Log is a global logger instance -var log Logger +// NewLogger creates a new zap sugared logger instance +func NewLogger() *zap.SugaredLogger { + config := buildConfig() + logger, err := config.Build() -// Debug logs a message at debug level using the singleton logger. -func Debug(msg string, args ...any) { - log.Debug(msg, args...) -} - -// Debugf logs a message at debug level using the singleton logger. -func Debugf(msg string, args ...any) { - log.Debugf(msg, args...) -} - -// Info logs a message at info level using the singleton logger. -func Info(msg string, args ...any) { - log.Info(msg, args...) -} - -// Infof logs a message at info level using the singleton logger. -func Infof(msg string, args ...any) { - log.Infof(msg, args...) -} - -// Warn logs a message at warning level using the singleton logger. -func Warn(msg string, args ...any) { - log.Warn(msg, args...) -} - -// Warnf logs a message at warning level using the singleton logger. -func Warnf(msg string, args ...any) { - log.Warnf(msg, args...) -} - -// Error logs a message at error level using the singleton logger. -func Error(msg string, args ...any) { - log.Error(msg, args...) -} - -// Errorf logs a message at error level using the singleton logger. -func Errorf(msg string, args ...any) { - log.Errorf(msg, args...) -} - -// Panic logs a message at error level using the singleton logger and panics the program. -func Panic(msg string) { - log.Panic(msg) -} - -// Panicf logs a message at error level using the singleton logger and panics the program. -func Panicf(msg string, args ...any) { - log.Panicf(msg, args...) -} - -// Logger provides a unified interface for logging -type Logger interface { - Debug(msg string, args ...any) - Debugf(msg string, args ...any) - Info(msg string, args ...any) - Infof(msg string, args ...any) - Warn(msg string, args ...any) - Warnf(msg string, args ...any) - Error(msg string, args ...any) - Errorf(msg string, args ...any) - Panic(msg string) - Panicf(msg string, args ...any) -} - -// Implementation using slog -type slogLogger struct { - logger *slog.Logger -} - -func (l *slogLogger) Debugf(msg string, args ...any) { - l.logger.Debug(fmt.Sprintf(msg, args...)) -} - -func (l *slogLogger) Infof(msg string, args ...any) { - l.logger.Info(fmt.Sprintf(msg, args...)) -} - -func (l *slogLogger) Warnf(msg string, args ...any) { - l.logger.Warn(fmt.Sprintf(msg, args...)) -} - -func (l *slogLogger) Errorf(msg string, args ...any) { - l.logger.Error(fmt.Sprintf(msg, args...)) -} + if err != nil { + panic(err) // TODO: handle error appropriately + } -func (l *slogLogger) Panicf(msg string, args ...any) { - l.Panic(fmt.Sprintf(msg, args...)) + return logger.Sugar() } -func (l *slogLogger) Debug(msg string, args ...any) { - l.logger.Debug(msg, args...) -} +// NewLogr returns a logr.Logger which uses zap logger, name to be updated once NewLogr is removed +func NewLogr() logr.Logger { + sugaredLogger := NewLogger() + logger := sugaredLogger.Desugar() -func (l *slogLogger) Info(msg string, args ...any) { - l.logger.Info(msg, args...) + return zapr.NewLogger(logger) } -func (l *slogLogger) Warn(msg string, args ...any) { - l.logger.Warn(msg, args...) -} +// TODO: Update the config as per the project's requirements +// buildConfig returns the cached base configuration +func buildConfig() zap.Config { + var config zap.Config + if unstructuredLogs() { + config = zap.NewDevelopmentConfig() + config.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder + config.EncoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout(time.Kitchen) + config.OutputPaths = []string{"stderr"} + } else { + config = zap.NewProductionConfig() + config.OutputPaths = []string{"stdout"} + } -func (l *slogLogger) Error(msg string, args ...any) { - l.logger.Error(msg, args...) -} + // Set log level based on current debug flag + if viper.GetBool("debug") { + config.Level = zap.NewAtomicLevelAt(zap.DebugLevel) + } else { + config.Level = zap.NewAtomicLevelAt(zap.InfoLevel) + } -func (l *slogLogger) Panic(msg string) { - var pcs [1]uintptr - runtime.Callers(2, pcs[:]) // skip [Callers, Panic] - record := slog.NewRecord(time.Now(), slog.LevelError, msg, pcs[0]) - _ = l.logger.Handler().Handle(context.Background(), record) - panic(msg) + return config } func unstructuredLogs() bool { @@ -139,56 +66,3 @@ func unstructuredLogs() bool { } return unstructuredLogs } - -// Initialize creates and configures the appropriate logger. -// If the UNSTRUCTURED_LOGS is set to true, it will output plain log message -// with only time and LogLevelType (INFO, DEBUG, ERROR, WARN)). -// Otherwise it will create a standard structured slog logger -func Initialize() { - if unstructuredLogs() { - w := os.Stderr - - handler := tint.NewHandler(w, &tint.Options{ - Level: getLogLevel(), - TimeFormat: time.Kitchen, - }) - - slogger := slog.New(handler) - - slog.SetDefault(slogger) - log = &slogLogger{logger: slogger} - } else { - w := os.Stdout - - handler := slog.NewJSONHandler(w, &slog.HandlerOptions{ - Level: getLogLevel(), - }) - - slogger := slog.New(handler) - - slog.SetDefault(slogger) - log = &slogLogger{logger: slogger} - } -} - -// GetLogger returns a context-specific logger -func GetLogger(component string) Logger { - if slogger, ok := log.(*slogLogger); ok { - return &slogLogger{ - logger: slogger.logger.With("component", component), - } - } - - return log -} - -// getLogLevel returns the appropriate slog.Level based on the debug flag -func getLogLevel() slog.Level { - var level slog.Level - if viper.GetBool("debug") { - level = slog.LevelDebug - } else { - level = slog.LevelInfo - } - return level -} diff --git a/pkg/logger/logger_test.go b/pkg/logger/logger_test.go index 622b186c0..c1fbfe8a3 100644 --- a/pkg/logger/logger_test.go +++ b/pkg/logger/logger_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" ) +// TestUnstructuredLogsCheck tests the unstructuredLogs function func TestUnstructuredLogsCheck(t *testing.T) { //nolint:paralleltest // Uses environment variables tests := []struct { name string @@ -41,25 +42,33 @@ func TestUnstructuredLogsCheck(t *testing.T) { //nolint:paralleltest // Uses env } } +// TestStructuredLogger tests the structured logger functionality +// TODO: Keeping this for migration but can be removed as we don't need really need to test zap func TestStructuredLogger(t *testing.T) { //nolint:paralleltest // Uses environment variables - unformattedLogTestCases := []struct { - level string // The log level to test - message string // The message to log - key string // Key for structured logging - value string // Value for structured logging - contains bool // Whether to check if output contains the message + const ( + levelDebug = "debug" + levelInfo = "info" + levelWarn = "warn" + levelError = "error" + levelDPanic = "dpanic" + levelPanic = "panic" + ) + // Test cases for basic logging methods (Debug, Info, Warn, etc.) + basicLogTestCases := []struct { + level string // The log level to test + message string // The message to log }{ - {"DEBUG", "debug message", "key", "value", true}, - {"INFO", "info message", "key", "value", true}, - {"WARN", "warn message", "key", "value", true}, - {"ERROR", "error message", "key", "value", true}, + {levelDebug, "debug message"}, + {levelInfo, "info message"}, + {levelWarn, "warn message"}, + {levelError, "error message"}, + {levelDPanic, "dpanic message"}, + {levelPanic, "panic message"}, } - for _, tc := range unformattedLogTestCases { - t.Run("NonFormattedLogs", func(t *testing.T) { - + for _, tc := range basicLogTestCases { //nolint:paralleltest // Uses environment variables + t.Run("BasicLogs_"+tc.level, func(t *testing.T) { // we create a pipe to capture the output of the log - // so we can test that the logger logs the right message originalStdout := os.Stdout r, w, _ := os.Pipe() os.Stdout = w @@ -69,22 +78,115 @@ func TestStructuredLogger(t *testing.T) { //nolint:paralleltest // Uses environm viper.SetDefault("debug", true) - Initialize() + logger := NewLogger() + + // Handle panic and fatal recovery + defer func() { + if r := recover(); r != nil { + if tc.level != levelPanic && tc.level != levelDPanic { + t.Errorf("Unexpected panic for level %s: %v", tc.level, r) + } + } + }() + + // Log using basic methods + switch tc.level { + case levelDebug: + logger.Debug(tc.message) + case levelInfo: + logger.Info(tc.message) + case levelWarn: + logger.Warn(tc.message) + case levelError: + logger.Error(tc.message) + case levelDPanic: + logger.DPanic(tc.message) + case levelPanic: + logger.Panic(tc.message) + } + + w.Close() os.Stdout = originalStdout - // Log the message based on the level + // Read the captured output + var capturedOutput bytes.Buffer + io.Copy(&capturedOutput, r) + output := capturedOutput.String() + + // Parse JSON output + var logEntry map[string]any + if err := json.Unmarshal([]byte(output), &logEntry); err != nil { + t.Fatalf("Failed to parse JSON log output: %v", err) + } + + // Check level + if level, ok := logEntry["level"].(string); !ok || level != tc.level { + t.Errorf("Expected level %s, got %v", tc.level, logEntry["level"]) + } + + // Check message + if msg, ok := logEntry["msg"].(string); !ok || msg != tc.message { + t.Errorf("Expected message %s, got %v", tc.message, logEntry["msg"]) + } + }) + } + + // Test cases for structured logging methods (Debugw, Infow, etc.) + structuredLogTestCases := []struct { + level string // The log level to test + message string // The message to log + key string // Key for structured logging + value string // Value for structured logging + }{ + {levelDebug, "debug message", "key", "value"}, + {levelInfo, "info message", "key", "value"}, + {levelWarn, "warn message", "key", "value"}, + {levelError, "error message", "key", "value"}, + {levelDPanic, "dpanic message", "key", "value"}, + {levelPanic, "panic message", "key", "value"}, + } + + for _, tc := range structuredLogTestCases { //nolint:paralleltest // Uses environment variables + t.Run("StructuredLogs_"+tc.level, func(t *testing.T) { + // we create a pipe to capture the output of the log + originalStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + os.Setenv("UNSTRUCTURED_LOGS", "false") + defer os.Unsetenv("UNSTRUCTURED_LOGS") + + viper.SetDefault("debug", true) + + logger := NewLogger() + + // Handle panic and fatal recovery + defer func() { + if r := recover(); r != nil { + if tc.level != "panic" && tc.level != levelDPanic { + t.Errorf("Unexpected panic for level %s: %v", tc.level, r) + } + } + }() + + // Log using structured methods switch tc.level { - case "DEBUG": - log.Debug(tc.message, tc.key, tc.value) - case "INFO": - log.Info(tc.message, tc.key, tc.value) - case "WARN": - log.Warn(tc.message, tc.key, tc.value) - case "ERROR": - log.Error(tc.message, tc.key, tc.value) + case levelDebug: + logger.Debugw(tc.message, tc.key, tc.value) + case levelInfo: + logger.Infow(tc.message, tc.key, tc.value) + case levelWarn: + logger.Warnw(tc.message, tc.key, tc.value) + case levelError: + logger.Errorw(tc.message, tc.key, tc.value) + case levelDPanic: + logger.DPanicw(tc.message, tc.key, tc.value) + case levelPanic: + logger.Panicw(tc.message, tc.key, tc.value) } w.Close() + os.Stdout = originalStdout // Read the captured output var capturedOutput bytes.Buffer @@ -92,7 +194,7 @@ func TestStructuredLogger(t *testing.T) { //nolint:paralleltest // Uses environm output := capturedOutput.String() // Parse JSON output - var logEntry map[string]interface{} + var logEntry map[string]any if err := json.Unmarshal([]byte(output), &logEntry); err != nil { t.Fatalf("Failed to parse JSON log output: %v", err) } @@ -114,6 +216,7 @@ func TestStructuredLogger(t *testing.T) { //nolint:paralleltest // Uses environm }) } + // Test cases for formatted logging methods (Debugf, Infof, etc.) formattedLogTestCases := []struct { level string message string @@ -122,16 +225,17 @@ func TestStructuredLogger(t *testing.T) { //nolint:paralleltest // Uses environm expected string contains bool }{ - {"DEBUG", "debug message %s and %s", "key", "value", "debug message key and value", true}, - {"INFO", "info message %s and %s", "key", "value", "info message key and value", true}, - {"WARN", "warn message %s and %s", "key", "value", "warn message key and value", true}, - {"ERROR", "error message %s and %s", "key", "value", "error message key and value", true}, + {levelDebug, "debug message %s and %s", "key", "value", "debug message key and value", true}, + {levelInfo, "info message %s and %s", "key", "value", "info message key and value", true}, + {levelWarn, "warn message %s and %s", "key", "value", "warn message key and value", true}, + {levelError, "error message %s and %s", "key", "value", "error message key and value", true}, + {levelDPanic, "dpanic message %s and %s", "key", "value", "dpanic message key and value", true}, + {levelPanic, "panic message %s and %s", "key", "value", "panic message key and value", true}, } for _, tc := range formattedLogTestCases { //nolint:paralleltest // Uses environment variables - t.Run("FormattedLogs", func(t *testing.T) { + t.Run("FormattedLogs_"+tc.level, func(t *testing.T) { // we create a pipe to capture the output of the log - // so we can test that the logger logs the right message originalStdout := os.Stdout r, w, _ := os.Pipe() os.Stdout = w @@ -141,32 +245,42 @@ func TestStructuredLogger(t *testing.T) { //nolint:paralleltest // Uses environm viper.SetDefault("debug", true) - Initialize() - os.Stdout = originalStdout + logger := NewLogger() - // Log the message based on the level + // Handle panic and fatal recovery + defer func() { + if r := recover(); r != nil { + if tc.level != levelPanic && tc.level != levelDPanic { + t.Errorf("Unexpected panic for level %s: %v", tc.level, r) + } + } + }() + + // Log using formatted methods switch tc.level { - case "DEBUG": - log.Debugf(tc.message, tc.key, tc.value) - case "INFO": - log.Infof(tc.message, tc.key, tc.value) - case "WARN": - log.Warnf(tc.message, tc.key, tc.value) - case "ERROR": - log.Errorf(tc.message, tc.key, tc.value) + case levelDebug: + logger.Debugf(tc.message, tc.key, tc.value) + case levelInfo: + logger.Infof(tc.message, tc.key, tc.value) + case levelWarn: + logger.Warnf(tc.message, tc.key, tc.value) + case levelError: + logger.Errorf(tc.message, tc.key, tc.value) + case levelDPanic: + logger.DPanicf(tc.message, tc.key, tc.value) + case levelPanic: + logger.Panicf(tc.message, tc.key, tc.value) } - w.Close() + os.Stdout = originalStdout // Read the captured output var capturedOutput bytes.Buffer io.Copy(&capturedOutput, r) output := capturedOutput.String() - capturedOutput.Reset() - // Parse JSON output - var logEntry map[string]interface{} + var logEntry map[string]any if err := json.Unmarshal([]byte(output), &logEntry); err != nil { t.Fatalf("Failed to parse JSON log output: %v", err) } @@ -178,15 +292,26 @@ func TestStructuredLogger(t *testing.T) { //nolint:paralleltest // Uses environm // Check message if msg, ok := logEntry["msg"].(string); !ok || msg != tc.expected { - t.Errorf(tc.expected, tc.message, logEntry["msg"]) + t.Errorf("Expected message %s, got %v", tc.expected, logEntry["msg"]) } }) } } +// TestUnstructuredLogger tests the unstructured logger functionality +// TODO: Keeping this for migration but can be removed as we don't need really need to test zap func TestUnstructuredLogger(t *testing.T) { //nolint:paralleltest // Uses environment variables // we only test for the formatted logs here because the unstructured logs // do not contain the key/value pair format that the structured logs do + const ( + levelDebug = "DEBUG" + levelInfo = "INFO" + levelWarn = "WARN" + levelError = "ERROR" + levelDPanic = "DPANIC" + levelPanic = "PANIC" + ) + formattedLogTestCases := []struct { level string message string @@ -194,14 +319,16 @@ func TestUnstructuredLogger(t *testing.T) { //nolint:paralleltest // Uses enviro value string expected string }{ - {"DBG", "debug message %s and %s", "key", "value", "debug message key and value"}, - {"INF", "info message %s and %s", "key", "value", "info message key and value"}, - {"WRN", "warn message %s and %s", "key", "value", "warn message key and value"}, - {"ERR", "error message %s and %s", "key", "value", "error message key and value"}, + {levelDebug, "debug message %s and %s", "key", "value", "debug message key and value"}, + {levelInfo, "info message %s and %s", "key", "value", "info message key and value"}, + {levelWarn, "warn message %s and %s", "key", "value", "warn message key and value"}, + {levelError, "error message %s and %s", "key", "value", "error message key and value"}, + {levelDPanic, "error message %s and %s", "key", "value", "dpanic message key and value"}, + {levelPanic, "error message %s and %s", "key", "value", "panic message key and value"}, } for _, tc := range formattedLogTestCases { //nolint:paralleltest // Uses environment variables - t.Run("FormattedLogs", func(t *testing.T) { + t.Run("FormattedLogs_"+tc.level, func(t *testing.T) { // we create a pipe to capture the output of the log // so we can test that the logger logs the right message @@ -211,22 +338,36 @@ func TestUnstructuredLogger(t *testing.T) { //nolint:paralleltest // Uses enviro viper.SetDefault("debug", true) - Initialize() - os.Stderr = originalStderr + logger := NewLogger() + + // Handle panic recovery for DPANIC and PANIC levels + defer func() { + if r := recover(); r != nil { + // Expected for panic levels + if tc.level != "PANIC" && tc.level != "DPANIC" { + t.Errorf("Unexpected panic for level %s: %v", tc.level, r) + } + } + }() // Log the message based on the level switch tc.level { - case "DBG": - log.Debugf(tc.message, tc.key, tc.value) - case "INF": - log.Infof(tc.message, tc.key, tc.value) - case "WRN": - log.Warnf(tc.message, tc.key, tc.value) - case "ERR": - log.Errorf(tc.message, tc.key, tc.value) + case levelDebug: + logger.Debugf(tc.message, tc.key, tc.value) + case levelInfo: + logger.Infof(tc.message, tc.key, tc.value) + case levelWarn: + logger.Warnf(tc.message, tc.key, tc.value) + case levelError: + logger.Errorf(tc.message, tc.key, tc.value) + case levelDPanic: + logger.DPanicf(tc.message, tc.key, tc.value) + case levelPanic: + logger.Panicf(tc.message, tc.key, tc.value) } w.Close() + os.Stderr = originalStderr // Read the captured output var capturedOutput bytes.Buffer @@ -239,8 +380,8 @@ func TestUnstructuredLogger(t *testing.T) { //nolint:paralleltest // Uses enviro } } -// TestInitialize tests the Initialize function -func TestInitialize(t *testing.T) { //nolint:paralleltest // Uses environment variables +// TestNewLogger tests the NewLogger function +func TestNewLogger(t *testing.T) { //nolint:paralleltest // Uses environment variables // Test structured logs (JSON) t.Run("Structured Logs", func(t *testing.T) { //nolint:paralleltest // Uses environment variables // Set environment to use structured logs @@ -252,11 +393,11 @@ func TestInitialize(t *testing.T) { //nolint:paralleltest // Uses environment va r, w, _ := os.Pipe() os.Stdout = w - // Run initialization - Initialize() + // Initialize the logger + logger := NewLogger() // Log a test message - log.Info("test message", "key", "value") + logger.Info("test message") // Restore stdout w.Close() @@ -268,7 +409,7 @@ func TestInitialize(t *testing.T) { //nolint:paralleltest // Uses environment va output := buf.String() // Verify JSON format - var logEntry map[string]interface{} + var logEntry map[string]any if err := json.Unmarshal([]byte(output), &logEntry); err != nil { t.Fatalf("Failed to parse JSON log output: %v", err) } @@ -289,11 +430,11 @@ func TestInitialize(t *testing.T) { //nolint:paralleltest // Uses environment va r, w, _ := os.Pipe() os.Stderr = w - // Run initialization - Initialize() + // Initialize the logger + logger := NewLogger() // Log a test message - log.Info("test message", "key", "value") + logger.Info("test message", "key", "value") // Restore stderr w.Close() @@ -314,42 +455,3 @@ func TestInitialize(t *testing.T) { //nolint:paralleltest // Uses environment va } }) } - -// TestGetLogger tests the GetLogger function -func TestGetLogger(t *testing.T) { //nolint:paralleltest // Uses environment variables - // Set up structured logger for testing - os.Setenv("UNSTRUCTURED_LOGS", "false") - defer os.Unsetenv("UNSTRUCTURED_LOGS") - - // Redirect stdout to capture output - oldStdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w - - // Initialize and get a component logger - Initialize() - componentLogger := GetLogger("test-component") - - // Log a test message - componentLogger.Info("component message") - - // Restore stdout - w.Close() - os.Stdout = oldStdout - - // Read captured output - var buf bytes.Buffer - buf.ReadFrom(r) - output := buf.String() - - // Parse JSON output - var logEntry map[string]interface{} - if err := json.Unmarshal([]byte(output), &logEntry); err != nil { - t.Fatalf("Failed to parse JSON log output: %v", err) - } - - // Verify the component was added - if component, ok := logEntry["component"].(string); !ok || component != "test-component" { - t.Errorf("Expected component='test-component', got %v", logEntry["component"]) - } -} diff --git a/pkg/logger/logr.go b/pkg/logger/logr.go deleted file mode 100644 index 4d4f8c9ea..000000000 --- a/pkg/logger/logr.go +++ /dev/null @@ -1,74 +0,0 @@ -// Package logger provides a logging capability for toolhive for running locally as a CLI and in Kubernetes -package logger - -import ( - "github.com/go-logr/logr" -) - -// NewLogr returns a logr.Logger which uses the singleton logger. -func NewLogr() logr.Logger { - return logr.New(&toolhiveLogSink{logger: log}) -} - -// toolhiveLogSink adapts our logger to the logr.LogSink interface -type toolhiveLogSink struct { - logger Logger - name string -} - -// Init implements logr.LogSink -func (*toolhiveLogSink) Init(logr.RuntimeInfo) { - // Nothing to do -} - -// Enabled implements logr.LogSink -func (*toolhiveLogSink) Enabled(int) bool { - // Always enable logging - return true -} - -// Info implements logr.LogSink -func (l *toolhiveLogSink) Info(_ int, msg string, keysAndValues ...interface{}) { - l.logger.Info(msg, keysAndValues...) -} - -// Error implements logr.LogSink -func (l *toolhiveLogSink) Error(err error, msg string, keysAndValues ...interface{}) { - args := append([]interface{}{"error", err}, keysAndValues...) - l.logger.Error(msg, args...) -} - -// WithValues implements logr.LogSink -func (l *toolhiveLogSink) WithValues(keysAndValues ...interface{}) logr.LogSink { - // Create a new logger with the additional key-value pairs - if slogger, ok := l.logger.(*slogLogger); ok { - newLogger := &slogLogger{ - logger: slogger.logger.With(keysAndValues...), - } - return &toolhiveLogSink{ - logger: newLogger, - name: l.name, - } - } - - // If we can't add the values, just return a sink with the same logger - return &toolhiveLogSink{ - logger: l.logger, - name: l.name, - } -} - -// WithName implements logr.LogSink -func (l *toolhiveLogSink) WithName(name string) logr.LogSink { - // If we already have a name, append the new name - newName := name - if l.name != "" { - newName = l.name + "/" + name - } - - // Create a new sink with the component logger - return &toolhiveLogSink{ - logger: GetLogger(newName), - name: newName, - } -} diff --git a/pkg/mcp/tool_filter.go b/pkg/mcp/tool_filter.go index 8a96cf7c8..30cad486d 100644 --- a/pkg/mcp/tool_filter.go +++ b/pkg/mcp/tool_filter.go @@ -9,7 +9,8 @@ import ( "net/http" "strings" - "github.com/stacklok/toolhive/pkg/logger" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/transport/types" ) @@ -73,7 +74,7 @@ func NewToolFilterMiddleware(filterTools []string) (types.MiddlewareFunction, er // This middleware is designed to be used ONLY when tool filtering is enabled, // and expects the list of tools to be "correct" (i.e. not empty and not // containing nonexisting tools). -func NewToolCallFilterMiddleware(filterTools []string) (types.MiddlewareFunction, error) { +func NewToolCallFilterMiddleware(filterTools []string, logger *zap.SugaredLogger) (types.MiddlewareFunction, error) { if len(filterTools) == 0 { return nil, fmt.Errorf("tools list for filtering is empty") } @@ -138,6 +139,7 @@ type toolFilterWriter struct { http.ResponseWriter buffer []byte filterTools map[string]struct{} + logger *zap.SugaredLogger } // WriteHeader captures the status code @@ -159,19 +161,19 @@ func (rw *toolFilterWriter) Flush() { if mimeType == "" { _, err := rw.ResponseWriter.Write(rw.buffer) if err != nil { - logger.Errorf("Error writing buffer: %v", err) + rw.logger.Errorf("Error writing buffer: %v", err) } return } var b bytes.Buffer if err := processBuffer(rw.filterTools, rw.buffer, mimeType, &b); err != nil { - logger.Errorf("Error flushing response: %v", err) + rw.logger.Errorf("Error flushing response: %v", err) } _, err := rw.ResponseWriter.Write(b.Bytes()) if err != nil { - logger.Errorf("Error writing buffer: %v", err) + rw.logger.Errorf("Error writing buffer: %v", err) } rw.buffer = rw.buffer[:0] // Reset buffer } diff --git a/pkg/mcp/tool_filter_test.go b/pkg/mcp/tool_filter_test.go index 30d4dccb0..5958a2552 100644 --- a/pkg/mcp/tool_filter_test.go +++ b/pkg/mcp/tool_filter_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" ) func TestProcessToolCallRequest(t *testing.T) { @@ -584,6 +584,8 @@ func TestProcessSSEEvents_EdgeCases(t *testing.T) { func TestProcessToolCallRequest_EdgeCases(t *testing.T) { t.Parallel() + logger := log.NewLogger() + tests := []struct { name string filterTools map[string]struct{} @@ -656,9 +658,6 @@ func TestProcessToolCallRequest_EdgeCases(t *testing.T) { func TestToolFilterWriter_Flush(t *testing.T) { t.Parallel() - // Initialize logger to avoid panic - logger.Initialize() - tests := []struct { name string writeData []byte @@ -730,6 +729,7 @@ func TestToolFilterWriter_Flush(t *testing.T) { ResponseWriter: mockWriter, buffer: []byte{}, filterTools: tt.filterTools, + logger: log.NewLogger(), } // Set status code using WriteHeader diff --git a/pkg/migration/default_group.go b/pkg/migration/default_group.go index 2dc1b0300..ea919adc0 100644 --- a/pkg/migration/default_group.go +++ b/pkg/migration/default_group.go @@ -4,9 +4,10 @@ import ( "context" "fmt" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/config" "github.com/stacklok/toolhive/pkg/groups" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/workloads" ) @@ -14,6 +15,7 @@ import ( type DefaultGroupMigrator struct { groupManager groups.Manager workloadsManager workloads.Manager + logger *zap.SugaredLogger } // Migrate performs the complete default group migration @@ -47,7 +49,7 @@ func (m *DefaultGroupMigrator) Migrate(ctx context.Context) error { // Mark default group migration as completed err = config.UpdateConfig(func(c *config.Config) { c.DefaultGroupMigration = true - }) + }, m.logger) if err != nil { return fmt.Errorf("failed to update config during migration: %w", err) } @@ -59,12 +61,12 @@ func (m *DefaultGroupMigrator) Migrate(ctx context.Context) error { func (m *DefaultGroupMigrator) initManagers(ctx context.Context) error { var err error - m.groupManager, err = groups.NewManager() + m.groupManager, err = groups.NewManager(m.logger) if err != nil { return fmt.Errorf("failed to create group manager: %w", err) } - m.workloadsManager, err = workloads.NewManager(ctx) + m.workloadsManager, err = workloads.NewManager(ctx, m.logger) if err != nil { return fmt.Errorf("failed to create workloads manager: %w", err) } @@ -74,7 +76,7 @@ func (m *DefaultGroupMigrator) initManagers(ctx context.Context) error { // createDefaultGroup creates the default group if it doesn't exist func (m *DefaultGroupMigrator) createDefaultGroup(ctx context.Context) error { - logger.Infof("Creating default group '%s'", groups.DefaultGroupName) + m.logger.Infof("Creating default group '%s'", groups.DefaultGroupName) if err := m.groupManager.Create(ctx, groups.DefaultGroupName); err != nil { return fmt.Errorf("failed to create default group: %w", err) } @@ -93,7 +95,7 @@ func (m *DefaultGroupMigrator) migrateWorkloadsToDefaultGroup(ctx context.Contex for _, workloadName := range workloadsWithoutGroup { // Move workload to default group if err := m.workloadsManager.MoveToDefaultGroup(ctx, []string{workloadName}, ""); err != nil { - logger.Warnf("Failed to migrate workload %s to default group: %v", workloadName, err) + m.logger.Warnf("Failed to migrate workload %s to default group: %v", workloadName, err) continue } migratedCount++ @@ -104,11 +106,11 @@ func (m *DefaultGroupMigrator) migrateWorkloadsToDefaultGroup(ctx context.Contex // migrateClientConfigs migrates client configurations from global config to default group func (m *DefaultGroupMigrator) migrateClientConfigs(ctx context.Context) error { - appConfig := config.GetConfig() + appConfig := config.GetConfig(m.logger) // If there are no registered clients, nothing to migrate if len(appConfig.Clients.RegisteredClients) == 0 { - logger.Infof("No client configurations to migrate") + m.logger.Infof("No client configurations to migrate") return nil } @@ -134,7 +136,7 @@ func (m *DefaultGroupMigrator) migrateClientConfigs(ctx context.Context) error { if !alreadyRegistered { if err := m.groupManager.RegisterClients(ctx, []string{groups.DefaultGroupName}, []string{clientName}); err != nil { - logger.Warnf("Failed to register client %s to default group: %v", clientName, err) + m.logger.Warnf("Failed to register client %s to default group: %v", clientName, err) continue } migratedCount++ @@ -147,14 +149,14 @@ func (m *DefaultGroupMigrator) migrateClientConfigs(ctx context.Context) error { // Clear the global client configurations after successful migration err = config.UpdateConfig(func(c *config.Config) { c.Clients.RegisteredClients = []string{} - }) + }, m.logger) if err != nil { - logger.Warnf("Failed to clear global client configurations after migration: %v", err) + m.logger.Warnf("Failed to clear global client configurations after migration: %v", err) } else { - logger.Infof("Cleared global client configurations") + m.logger.Infof("Cleared global client configurations") } } else { - logger.Infof("No client configurations needed migration") + m.logger.Infof("No client configurations needed migration") } return nil diff --git a/pkg/migration/migration.go b/pkg/migration/migration.go index 9aa118e8a..3009edb2b 100644 --- a/pkg/migration/migration.go +++ b/pkg/migration/migration.go @@ -6,8 +6,9 @@ import ( "fmt" "sync" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/config" - "github.com/stacklok/toolhive/pkg/logger" ) // migrationOnce ensures the migration only runs once @@ -15,9 +16,9 @@ var migrationOnce sync.Once // CheckAndPerformDefaultGroupMigration checks if default group migration is needed and performs it // This is called once at application startup -func CheckAndPerformDefaultGroupMigration() { +func CheckAndPerformDefaultGroupMigration(logger *zap.SugaredLogger) { migrationOnce.Do(func() { - appConfig := config.GetConfig() + appConfig := config.GetConfig(logger) // Check if default group migration has already been performed if appConfig.DefaultGroupMigration { diff --git a/pkg/networking/port.go b/pkg/networking/port.go index 376a6fa39..3e1e09004 100644 --- a/pkg/networking/port.go +++ b/pkg/networking/port.go @@ -8,7 +8,7 @@ import ( "math/big" "net" - "github.com/stacklok/toolhive/pkg/logger" + "go.uber.org/zap" ) const ( @@ -21,7 +21,7 @@ const ( ) // IsAvailable checks if a port is available -func IsAvailable(port int) bool { +func IsAvailable(port int, logger *zap.SugaredLogger) bool { // Check TCP tcpAddr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", port)) if err != nil { @@ -91,7 +91,7 @@ func IsIPv6Available() bool { } // FindAvailable finds an available port -func FindAvailable() int { +func FindAvailable(logger *zap.SugaredLogger) int { for i := 0; i < MaxAttempts; i++ { // Generate a cryptographically secure random number n, err := rand.Int(rand.Reader, big.NewInt(int64(MaxPort-MinPort))) @@ -100,14 +100,14 @@ func FindAvailable() int { break } port := int(n.Int64()) + MinPort - if IsAvailable(port) { + if IsAvailable(port, logger) { return port } } // If we can't find a random port, try sequential ports for port := MinPort; port <= MaxPort; port++ { - if IsAvailable(port) { + if IsAvailable(port, logger) { return port } } @@ -120,22 +120,22 @@ func FindAvailable() int { // If port is 0, it will find an available port. // If port is not 0, it will check if the port is available. // Returns the selected port and an error if any. -func FindOrUsePort(port int) (int, error) { +func FindOrUsePort(port int, logger *zap.SugaredLogger) (int, error) { if port == 0 { // Find an available port - port = FindAvailable() + port = FindAvailable(logger) if port == 0 { return 0, fmt.Errorf("could not find an available port") } return port, nil } - if IsAvailable(port) { + if IsAvailable(port, logger) { return port, nil } // Requested port is busy — find an alternative - alt := FindAvailable() + alt := FindAvailable(logger) if alt == 0 { return 0, fmt.Errorf("failed to find an alternative port after requested port %d was unavailable", port) } diff --git a/pkg/registry/factory.go b/pkg/registry/factory.go index e4e1fbf66..9dff01979 100644 --- a/pkg/registry/factory.go +++ b/pkg/registry/factory.go @@ -3,6 +3,8 @@ package registry import ( "sync" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/config" ) @@ -25,9 +27,9 @@ func NewRegistryProvider(cfg *config.Config) Provider { // GetDefaultProvider returns the default registry provider instance // This maintains backward compatibility with the existing singleton pattern -func GetDefaultProvider() (Provider, error) { +func GetDefaultProvider(logger *zap.SugaredLogger) (Provider, error) { defaultProviderOnce.Do(func() { - cfg, err := config.LoadOrCreateConfig() + cfg, err := config.LoadOrCreateConfig(logger) if err != nil { defaultProviderErr = err return diff --git a/pkg/registry/provider_test.go b/pkg/registry/provider_test.go index 218b93e8a..96c554fd6 100644 --- a/pkg/registry/provider_test.go +++ b/pkg/registry/provider_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/stacklok/toolhive/pkg/config" + log "github.com/stacklok/toolhive/pkg/logger" ) func TestNewRegistryProvider(t *testing.T) { @@ -215,7 +216,8 @@ func getTypeName(v interface{}) string { func TestGetRegistry(t *testing.T) { t.Parallel() - provider, err := GetDefaultProvider() + logger := log.NewLogger() + provider, err := GetDefaultProvider(logger) if err != nil { t.Fatalf("Failed to get registry provider: %v", err) } @@ -244,8 +246,9 @@ func TestGetRegistry(t *testing.T) { func TestGetServer(t *testing.T) { t.Parallel() + logger := log.NewLogger() // Test getting an existing server - provider, err := GetDefaultProvider() + provider, err := GetDefaultProvider(logger) if err != nil { t.Fatalf("Failed to get registry provider: %v", err) } @@ -281,8 +284,9 @@ func TestGetServer(t *testing.T) { func TestSearchServers(t *testing.T) { t.Parallel() + logger := log.NewLogger() // Test searching for servers - provider, err := GetDefaultProvider() + provider, err := GetDefaultProvider(logger) if err != nil { t.Fatalf("Failed to get registry provider: %v", err) } @@ -308,7 +312,8 @@ func TestSearchServers(t *testing.T) { func TestListServers(t *testing.T) { t.Parallel() - provider, err := GetDefaultProvider() + logger := log.NewLogger() + provider, err := GetDefaultProvider(logger) if err != nil { t.Fatalf("Failed to get registry provider: %v", err) } diff --git a/pkg/runner/config.go b/pkg/runner/config.go index fa9642b08..baaeb4d6e 100644 --- a/pkg/runner/config.go +++ b/pkg/runner/config.go @@ -7,6 +7,8 @@ import ( "fmt" "io" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/audit" "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/authz" @@ -15,7 +17,6 @@ import ( "github.com/stacklok/toolhive/pkg/environment" "github.com/stacklok/toolhive/pkg/ignore" "github.com/stacklok/toolhive/pkg/labels" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/networking" "github.com/stacklok/toolhive/pkg/permissions" "github.com/stacklok/toolhive/pkg/secrets" @@ -207,7 +208,7 @@ func (c *RunConfig) WithTransport(t string) (*RunConfig, error) { } // WithPorts configures the host and target ports -func (c *RunConfig) WithPorts(proxyPort, targetPort int) (*RunConfig, error) { +func (c *RunConfig) WithPorts(proxyPort, targetPort int, logger *zap.SugaredLogger) (*RunConfig, error) { var selectedPort int var err error @@ -215,14 +216,14 @@ func (c *RunConfig) WithPorts(proxyPort, targetPort int) (*RunConfig, error) { // If not available - treat as an error, since picking a random port here // is going to lead to confusion. if proxyPort != 0 { - if !networking.IsAvailable(proxyPort) { + if !networking.IsAvailable(proxyPort, logger) { return c, fmt.Errorf("requested proxy port %d is not available", proxyPort) } logger.Debugf("Using requested port: %d", proxyPort) selectedPort = proxyPort } else { // Otherwise - pick a random available port. - selectedPort, err = networking.FindOrUsePort(proxyPort) + selectedPort, err = networking.FindOrUsePort(proxyPort, logger) if err != nil { return c, err } @@ -231,7 +232,7 @@ func (c *RunConfig) WithPorts(proxyPort, targetPort int) (*RunConfig, error) { // Select a target port for the container if using SSE or Streamable HTTP transport if c.Transport == types.TransportTypeSSE || c.Transport == types.TransportTypeStreamableHTTP { - selectedTargetPort, err := networking.FindOrUsePort(targetPort) + selectedTargetPort, err := networking.FindOrUsePort(targetPort, logger) if err != nil { return c, fmt.Errorf("target port error: %w", err) } @@ -338,8 +339,8 @@ func (c *RunConfig) GetBaseName() string { } // SaveState saves the run configuration to the state store -func (c *RunConfig) SaveState(ctx context.Context) error { - return state.SaveRunConfig(ctx, c) +func (c *RunConfig) SaveState(ctx context.Context, logger *zap.SugaredLogger) error { + return state.SaveRunConfig(ctx, c, logger) } // LoadState loads a run configuration from the state store diff --git a/pkg/runner/config_builder.go b/pkg/runner/config_builder.go index dc9164549..98ad4043e 100644 --- a/pkg/runner/config_builder.go +++ b/pkg/runner/config_builder.go @@ -6,13 +6,14 @@ import ( "slices" "strings" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/audit" "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/authz" rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/ignore" "github.com/stacklok/toolhive/pkg/labels" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/permissions" "github.com/stacklok/toolhive/pkg/registry" "github.com/stacklok/toolhive/pkg/telemetry" @@ -28,15 +29,18 @@ type RunConfigBuilder struct { // Store ports separately for proper validation port int targetPort int + + logger *zap.SugaredLogger } // NewRunConfigBuilder creates a new RunConfigBuilder with default values -func NewRunConfigBuilder() *RunConfigBuilder { +func NewRunConfigBuilder(logger *zap.SugaredLogger) *RunConfigBuilder { return &RunConfigBuilder{ config: &RunConfig{ ContainerLabels: make(map[string]string), EnvVars: make(map[string]string), }, + logger: logger, } } @@ -175,7 +179,7 @@ func (b *RunConfigBuilder) WithLabels(labelStrings []string) *RunConfigBuilder { for _, labelString := range labelStrings { key, value, err := labels.ParseLabelWithValidation(labelString) if err != nil { - logger.Warnf("Skipping invalid label: %s (%v)", labelString, err) + b.logger.Warnf("Skipping invalid label: %s (%v)", labelString, err) continue } b.config.ContainerLabels[key] = value @@ -343,10 +347,10 @@ func (b *RunConfigBuilder) validateConfig(imageMetadata *registry.ImageMetadata) mcpTransport := b.transportString if mcpTransport == "" { if imageMetadata != nil && imageMetadata.Transport != "" { - logger.Debugf("Using registry mcpTransport: %s", imageMetadata.Transport) + b.logger.Debugf("Using registry mcpTransport: %s", imageMetadata.Transport) mcpTransport = imageMetadata.Transport } else { - logger.Debugf("Defaulting mcpTransport to stdio") + b.logger.Debugf("Defaulting mcpTransport to stdio") mcpTransport = types.TransportTypeStdio.String() } } @@ -361,12 +365,12 @@ func (b *RunConfigBuilder) validateConfig(imageMetadata *registry.ImageMetadata) isHTTPServer := mcpTransport == types.TransportTypeSSE.String() || mcpTransport == types.TransportTypeStreamableHTTP.String() if targetPort == 0 && isHTTPServer && imageMetadata.TargetPort > 0 { - logger.Debugf("Using registry target port: %d", imageMetadata.TargetPort) + b.logger.Debugf("Using registry target port: %d", imageMetadata.TargetPort) targetPort = imageMetadata.TargetPort } } // Configure ports and target host - if _, err = c.WithPorts(b.port, targetPort); err != nil { + if _, err = c.WithPorts(b.port, targetPort, b.logger); err != nil { return err } @@ -410,13 +414,13 @@ func (b *RunConfigBuilder) validateConfig(imageMetadata *registry.ImageMetadata) if imageMetadata != nil && len(imageMetadata.Args) > 0 { if len(c.CmdArgs) == 0 { // No user args provided, use registry defaults - logger.Debugf("Using registry default args: %v", imageMetadata.Args) + b.logger.Debugf("Using registry default args: %v", imageMetadata.Args) c.CmdArgs = append(c.CmdArgs, imageMetadata.Args...) } } if c.ToolsFilter != nil && imageMetadata != nil && imageMetadata.Tools != nil { - logger.Debugf("Using tools filter: %v", c.ToolsFilter) + b.logger.Debugf("Using tools filter: %v", c.ToolsFilter) for _, tool := range c.ToolsFilter { if !slices.Contains(imageMetadata.Tools, tool) { return fmt.Errorf("tool %s not found in registry", tool) @@ -449,12 +453,12 @@ func (b *RunConfigBuilder) loadPermissionProfile(imageMetadata *registry.ImageMe // If a profile was not set by name or path, check the image metadata. if imageMetadata != nil && imageMetadata.Permissions != nil { - logger.Debugf("Using registry permission profile: %v", imageMetadata.Permissions) + b.logger.Debugf("Using registry permission profile: %v", imageMetadata.Permissions) return imageMetadata.Permissions, nil } // If no metadata is available, use the network permission profile as default. - logger.Debugf("Using default permission profile: %s", permissions.ProfileNetwork) + b.logger.Debugf("Using default permission profile: %s", permissions.ProfileNetwork) return permissions.BuiltinNetworkProfile(), nil } @@ -504,7 +508,7 @@ func (b *RunConfigBuilder) processVolumeMounts() error { // Check for duplicate mount target if existingSource, isDuplicate := existingMounts[target]; isDuplicate { - logger.Warnf("Skipping duplicate mount target: %s (already mounted from %s)", + b.logger.Warnf("Skipping duplicate mount target: %s (already mounted from %s)", target, existingSource) continue } @@ -519,7 +523,7 @@ func (b *RunConfigBuilder) processVolumeMounts() error { // Add to the map of existing mounts existingMounts[target] = source - logger.Infof("Adding volume mount: %s -> %s (%s)", + b.logger.Infof("Adding volume mount: %s -> %s (%s)", source, target, map[bool]string{true: "read-only", false: "read-write"}[readOnly]) } diff --git a/pkg/runner/config_builder_test.go b/pkg/runner/config_builder_test.go index ea2a5c60a..d17cb99c2 100644 --- a/pkg/runner/config_builder_test.go +++ b/pkg/runner/config_builder_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/permissions" "github.com/stacklok/toolhive/pkg/registry" ) @@ -17,8 +17,8 @@ import ( func TestRunConfigBuilder_Build_WithPermissionProfile(t *testing.T) { t.Parallel() - // Needed to prevent a nil pointer dereference in the logger. - logger.Initialize() + // Setup logger + logger := log.NewLogger() // Create a mock environment variable validator mockValidator := &mockEnvVarValidator{} @@ -177,7 +177,7 @@ func TestRunConfigBuilder_Build_WithPermissionProfile(t *testing.T) { t.Parallel() // Create a new builder and apply the setup - builder := tc.setupBuilder(NewRunConfigBuilder()) + builder := tc.setupBuilder(NewRunConfigBuilder(logger)) require.NotNil(t, builder, "Builder should not be nil") // Create a temporary profile file if needed @@ -234,8 +234,8 @@ func TestRunConfigBuilder_Build_WithPermissionProfile(t *testing.T) { func TestRunConfigBuilder_Build_WithVolumeMounts(t *testing.T) { t.Parallel() - // Initialize logger to prevent nil pointer dereference when processing volume mounts - logger.Initialize() + // Setup logger + logger := log.NewLogger() // Create a mock environment variable validator mockValidator := &mockEnvVarValidator{} @@ -314,7 +314,7 @@ func TestRunConfigBuilder_Build_WithVolumeMounts(t *testing.T) { t.Parallel() // Create a new builder and apply the setup - builder := tc.setupBuilder(NewRunConfigBuilder()) + builder := tc.setupBuilder(NewRunConfigBuilder(logger)) require.NotNil(t, builder, "Builder should not be nil") // Save original read/write mounts count if there's a permission profile diff --git a/pkg/runner/config_test.go b/pkg/runner/config_test.go index b79d73da5..d40e6e325 100644 --- a/pkg/runner/config_test.go +++ b/pkg/runner/config_test.go @@ -13,7 +13,7 @@ import ( "github.com/stacklok/toolhive/pkg/authz" "github.com/stacklok/toolhive/pkg/container/runtime/mocks" "github.com/stacklok/toolhive/pkg/ignore" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/permissions" "github.com/stacklok/toolhive/pkg/registry" "github.com/stacklok/toolhive/pkg/secrets" @@ -84,6 +84,9 @@ func TestRunConfig_WithTransport(t *testing.T) { // Note: This test uses actual port finding logic, so it may fail if ports are in use func TestRunConfig_WithPorts(t *testing.T) { t.Parallel() + + logger := log.NewLogger() + testCases := []struct { name string config *RunConfig @@ -121,12 +124,10 @@ func TestRunConfig_WithPorts(t *testing.T) { }, } - logger.Initialize() - for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() - result, err := tc.config.WithPorts(tc.port, tc.targetPort) + result, err := tc.config.WithPorts(tc.port, tc.targetPort, logger) if tc.expectError { assert.Error(t, err, "WithPorts should return an error") @@ -558,8 +559,7 @@ func (*mockEnvVarValidator) Validate(_ context.Context, _ *registry.ImageMetadat func TestRunConfigBuilder(t *testing.T) { t.Parallel() - // Needed to prevent a nil pointer dereference in the logger. - logger.Initialize() + logger := log.NewLogger() runtime := &mocks.MockRuntime{} cmdArgs := []string{"arg1", "arg2"} name := "test-server" @@ -590,7 +590,7 @@ func TestRunConfigBuilder(t *testing.T) { k8sPodPatch := `{"spec":{"containers":[{"name":"test","resources":{"limits":{"memory":"512Mi"}}}]}}` envVarValidator := &mockEnvVarValidator{} - config, err := NewRunConfigBuilder(). + config, err := NewRunConfigBuilder(logger). WithRuntime(runtime). WithCmdArgs(cmdArgs). WithName(name). @@ -771,8 +771,7 @@ func TestCommaSeparatedEnvVars(t *testing.T) { func TestRunConfigBuilder_MetadataOverrides(t *testing.T) { t.Parallel() - // Needed to prevent a nil pointer dereference in the logger. - logger.Initialize() + logger := log.NewLogger() tests := []struct { name string @@ -838,7 +837,7 @@ func TestRunConfigBuilder_MetadataOverrides(t *testing.T) { runtime := &mocks.MockRuntime{} validator := &mockEnvVarValidator{} - config, err := NewRunConfigBuilder(). + config, err := NewRunConfigBuilder(logger). WithRuntime(runtime). WithCmdArgs(nil). WithName("test-server"). @@ -880,12 +879,11 @@ func TestRunConfigBuilder_MetadataOverrides(t *testing.T) { func TestRunConfigBuilder_EnvironmentVariableTransportDependency(t *testing.T) { t.Parallel() - // Needed to prevent a nil pointer dereference in the logger. - logger.Initialize() + logger := log.NewLogger() runtime := &mocks.MockRuntime{} validator := &mockEnvVarValidator{} - config, err := NewRunConfigBuilder(). + config, err := NewRunConfigBuilder(logger). WithRuntime(runtime). WithCmdArgs(nil). WithName("test-server"). @@ -926,9 +924,7 @@ func TestRunConfigBuilder_EnvironmentVariableTransportDependency(t *testing.T) { func TestRunConfigBuilder_CmdArgsMetadataOverride(t *testing.T) { t.Parallel() - // Needed to prevent a nil pointer dereference in the logger. - logger.Initialize() - + logger := log.NewLogger() runtime := &mocks.MockRuntime{} validator := &mockEnvVarValidator{} @@ -937,7 +933,7 @@ func TestRunConfigBuilder_CmdArgsMetadataOverride(t *testing.T) { Args: []string{"--metadata-arg1", "--metadata-arg2"}, } - config, err := NewRunConfigBuilder(). + config, err := NewRunConfigBuilder(logger). WithRuntime(runtime). WithCmdArgs(userArgs). WithName("test-server"). @@ -980,8 +976,7 @@ func TestRunConfigBuilder_CmdArgsMetadataOverride(t *testing.T) { func TestRunConfigBuilder_CmdArgsMetadataDefaults(t *testing.T) { t.Parallel() - // Needed to prevent a nil pointer dereference in the logger. - logger.Initialize() + logger := log.NewLogger() runtime := &mocks.MockRuntime{} validator := &mockEnvVarValidator{} @@ -992,7 +987,7 @@ func TestRunConfigBuilder_CmdArgsMetadataDefaults(t *testing.T) { Args: []string{"--metadata-arg1", "--metadata-arg2"}, } - config, err := NewRunConfigBuilder(). + config, err := NewRunConfigBuilder(logger). WithRuntime(runtime). WithCmdArgs(userArgs). WithName("test-server"). @@ -1033,8 +1028,7 @@ func TestRunConfigBuilder_CmdArgsMetadataDefaults(t *testing.T) { func TestRunConfigBuilder_VolumeProcessing(t *testing.T) { t.Parallel() - // Needed to prevent a nil pointer dereference in the logger. - logger.Initialize() + logger := log.NewLogger() runtime := &mocks.MockRuntime{} validator := &mockEnvVarValidator{} @@ -1043,7 +1037,7 @@ func TestRunConfigBuilder_VolumeProcessing(t *testing.T) { "/host/write:/container/write", } - config, err := NewRunConfigBuilder(). + config, err := NewRunConfigBuilder(logger). WithRuntime(runtime). WithCmdArgs(nil). WithName("test-server"). @@ -1102,8 +1096,7 @@ func TestRunConfigBuilder_VolumeProcessing(t *testing.T) { func TestRunConfigBuilder_FilesystemMCPScenario(t *testing.T) { t.Parallel() - // Needed to prevent a nil pointer dereference in the logger. - logger.Initialize() + logger := log.NewLogger() runtime := &mocks.MockRuntime{} validator := &mockEnvVarValidator{} @@ -1116,7 +1109,7 @@ func TestRunConfigBuilder_FilesystemMCPScenario(t *testing.T) { // Simulate user providing their own arguments userArgs := []string{"/Users/testuser/repos/github.com/stacklok/toolhive"} - config, err := NewRunConfigBuilder(). + config, err := NewRunConfigBuilder(logger). WithRuntime(runtime). WithCmdArgs(userArgs). WithName("filesystem"). diff --git a/pkg/runner/env.go b/pkg/runner/env.go index 0b49974ff..721303dce 100644 --- a/pkg/runner/env.go +++ b/pkg/runner/env.go @@ -6,10 +6,10 @@ import ( "os" "strings" + "go.uber.org/zap" "golang.org/x/term" "github.com/stacklok/toolhive/pkg/config" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/registry" "github.com/stacklok/toolhive/pkg/secrets" ) @@ -32,11 +32,18 @@ type EnvVarValidator interface { // DetachedEnvVarValidator implements the EnvVarValidator interface for // scenarios where the user cannot be prompted for input. Any missing, // mandatory variables will result in an error being returned. -type DetachedEnvVarValidator struct{} +type DetachedEnvVarValidator struct { + logger *zap.SugaredLogger +} + +// NewDetachedEnvVarValidator creates a new DetachedEnvVarValidator instance +func NewDetachedEnvVarValidator(logger *zap.SugaredLogger) *DetachedEnvVarValidator { + return &DetachedEnvVarValidator{logger} +} // Validate checks that all required environment variables and secrets are provided // and returns the processed environment variables to be set. -func (*DetachedEnvVarValidator) Validate( +func (v *DetachedEnvVarValidator) Validate( _ context.Context, metadata *registry.ImageMetadata, runConfig *RunConfig, @@ -54,7 +61,7 @@ func (*DetachedEnvVarValidator) Validate( } else if envVar.Secret { return nil, fmt.Errorf("missing required secret environment variable: %s", envVar.Name) } else if envVar.Default != "" { - addAsEnvironmentVariable(envVar, envVar.Default, &suppliedEnvVars) + addAsEnvironmentVariable(envVar, envVar.Default, &suppliedEnvVars, v.logger) } } } @@ -65,11 +72,18 @@ func (*DetachedEnvVarValidator) Validate( // CLIEnvVarValidator implements the EnvVarValidator interface for // CLI usage. If any missing, mandatory variables are found, this code will // prompt the user to supply them through stdin. -type CLIEnvVarValidator struct{} +type CLIEnvVarValidator struct { + logger *zap.SugaredLogger +} + +// NewCLIEnvVarValidator creates a new CLIEnvVarValidator instance +func NewCLIEnvVarValidator(logger *zap.SugaredLogger) *CLIEnvVarValidator { + return &CLIEnvVarValidator{logger} +} // Validate checks that all required environment variables and secrets are provided // and returns the processed environment variables to be set. -func (*CLIEnvVarValidator) Validate( +func (v *CLIEnvVarValidator) Validate( ctx context.Context, metadata *registry.ImageMetadata, runConfig *RunConfig, @@ -92,7 +106,7 @@ func (*CLIEnvVarValidator) Validate( // Create a new slice with capacity for all env vars // Initialize secrets manager if needed - secretsManager := initializeSecretsManagerIfNeeded(registryEnvVars) + secretsManager := initializeSecretsManagerIfNeeded(registryEnvVars, v.logger) // Process each environment variable from the registry for _, envVar := range registryEnvVars { @@ -101,27 +115,26 @@ func (*CLIEnvVarValidator) Validate( } if envVar.Required { - if envVar.Secret { value, err := secretsManager.GetSecret(ctx, envVar.Name) if err != nil { - logger.Warnf("Unable to find secret %s in the secrets manager: %v", envVar.Name, err) + v.logger.Warnf("Unable to find secret %s in the secrets manager: %v", envVar.Name, err) } else { - addNewVariable(ctx, envVar, value, secretsManager, &envVars, &secretsList) + addNewVariable(ctx, envVar, value, secretsManager, &envVars, &secretsList, v.logger) continue } } - value, err := promptForEnvironmentVariable(envVar) + value, err := promptForEnvironmentVariable(envVar, v.logger) if err != nil { - logger.Warnf("Warning: Failed to read input for %s: %v", envVar.Name, err) + v.logger.Warnf("Warning: Failed to read input for %s: %v", envVar.Name, err) continue } if value != "" { - addNewVariable(ctx, envVar, value, secretsManager, &envVars, &secretsList) + addNewVariable(ctx, envVar, value, secretsManager, &envVars, &secretsList, v.logger) } } else if envVar.Default != "" { - addNewVariable(ctx, envVar, envVar.Default, secretsManager, &envVars, &secretsList) + addNewVariable(ctx, envVar, envVar.Default, secretsManager, &envVars, &secretsList, v.logger) } } @@ -132,7 +145,7 @@ func (*CLIEnvVarValidator) Validate( } // promptForEnvironmentVariable prompts the user for an environment variable value -func promptForEnvironmentVariable(envVar *registry.EnvVar) (string, error) { +func promptForEnvironmentVariable(envVar *registry.EnvVar, logger *zap.SugaredLogger) (string, error) { var byteValue []byte var err error if envVar.Secret { @@ -167,11 +180,12 @@ func addNewVariable( secretsManager secrets.Provider, envVars *[]string, secretsList *[]string, + logger *zap.SugaredLogger, ) { if envVar.Secret && secretsManager != nil { - addAsSecret(ctx, envVar, value, secretsManager, secretsList, envVars) + addAsSecret(ctx, envVar, value, secretsManager, secretsList, envVars, logger) } else { - addAsEnvironmentVariable(envVar, value, envVars) + addAsEnvironmentVariable(envVar, value, envVars, logger) } } @@ -183,6 +197,7 @@ func addAsSecret( secretsManager secrets.Provider, secretsList *[]string, envVars *[]string, + logger *zap.SugaredLogger, ) { var secretName string if envVar.Required { @@ -209,7 +224,7 @@ func addAsSecret( } // initializeSecretsManagerIfNeeded initializes the secrets manager if there are secret environment variables -func initializeSecretsManagerIfNeeded(registryEnvVars []*registry.EnvVar) secrets.Provider { +func initializeSecretsManagerIfNeeded(registryEnvVars []*registry.EnvVar, logger *zap.SugaredLogger) secrets.Provider { // Check if we have any secret environment variables hasSecrets := false for _, envVar := range registryEnvVars { @@ -223,7 +238,7 @@ func initializeSecretsManagerIfNeeded(registryEnvVars []*registry.EnvVar) secret return nil } - secretsManager, err := getSecretsManager() + secretsManager, err := getSecretsManager(logger) if err != nil { logger.Warnf("Warning: Failed to initialize secrets manager: %v", err) logger.Warnf("Secret environment variables will be stored as regular environment variables") @@ -235,8 +250,8 @@ func initializeSecretsManagerIfNeeded(registryEnvVars []*registry.EnvVar) secret // Duplicated from cmd/thv/app/app.go // It may be possible to de-duplicate this in future. -func getSecretsManager() (secrets.Provider, error) { - cfg := config.GetConfig() +func getSecretsManager(logger *zap.SugaredLogger) (secrets.Provider, error) { + cfg := config.GetConfig(logger) // Check if secrets setup has been completed if !cfg.Secrets.SetupCompleted { @@ -248,7 +263,7 @@ func getSecretsManager() (secrets.Provider, error) { return nil, fmt.Errorf("failed to get secrets provider type: %w", err) } - manager, err := secrets.CreateSecretProvider(providerType) + manager, err := secrets.CreateSecretProvider(providerType, logger) if err != nil { return nil, fmt.Errorf("failed to create secrets manager: %w", err) } @@ -304,6 +319,7 @@ func addAsEnvironmentVariable( envVar *registry.EnvVar, value string, envVars *[]string, + logger *zap.SugaredLogger, ) { *envVars = append(*envVars, fmt.Sprintf("%s=%s", envVar.Name, value)) diff --git a/pkg/runner/permissions.go b/pkg/runner/permissions.go index fa9d7fe56..fa913f282 100644 --- a/pkg/runner/permissions.go +++ b/pkg/runner/permissions.go @@ -7,7 +7,8 @@ import ( "path/filepath" "strings" - "github.com/stacklok/toolhive/pkg/logger" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/permissions" ) @@ -15,7 +16,7 @@ import ( // It will likely be moved elsewhere in a future PR. // CreatePermissionProfileFile creates a temporary file with the permission profile -func CreatePermissionProfileFile(serverName string, permProfile *permissions.Profile) (string, error) { +func CreatePermissionProfileFile(serverName string, permProfile *permissions.Profile, logger *zap.SugaredLogger) (string, error) { tempFile, err := os.CreateTemp("", fmt.Sprintf("toolhive-%s-permissions-*.json", serverName)) if err != nil { return "", fmt.Errorf("failed to create temporary file: %v", err) @@ -42,7 +43,7 @@ func CreatePermissionProfileFile(serverName string, permProfile *permissions.Pro } // CleanupTempPermissionProfile removes a temporary permission profile file if it was created by toolhive -func CleanupTempPermissionProfile(permissionProfilePath string) error { +func CleanupTempPermissionProfile(permissionProfilePath string, logger *zap.SugaredLogger) error { if permissionProfilePath == "" { return nil } diff --git a/pkg/runner/permissions_test.go b/pkg/runner/permissions_test.go index 4dcc274c5..63c7d75b5 100644 --- a/pkg/runner/permissions_test.go +++ b/pkg/runner/permissions_test.go @@ -5,12 +5,12 @@ import ( "path/filepath" "testing" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" ) func TestIsTempPermissionProfile(t *testing.T) { t.Parallel() - logger.Initialize() + tests := []struct { name string filePath string @@ -66,7 +66,10 @@ func TestIsTempPermissionProfile(t *testing.T) { func TestCleanupTempPermissionProfile(t *testing.T) { t.Parallel() - logger.Initialize() + + // Create logger + logger := log.NewLogger() + // Create a temporary file that matches our pattern tempFile, err := os.CreateTemp("", "toolhive-test-permissions-*.json") if err != nil { @@ -81,7 +84,7 @@ func TestCleanupTempPermissionProfile(t *testing.T) { } // Clean up the temp file - err = CleanupTempPermissionProfile(tempPath) + err = CleanupTempPermissionProfile(tempPath, logger) if err != nil { t.Fatalf("CleanupTempPermissionProfile failed: %v", err) } @@ -94,12 +97,15 @@ func TestCleanupTempPermissionProfile(t *testing.T) { func TestCleanupTempPermissionProfile_NonTempFile(t *testing.T) { t.Parallel() - logger.Initialize() + + // Create logger + logger := log.NewLogger() + // Test with a non-temp file path nonTempPath := "/home/user/my-permissions.json" // This should not fail and should not attempt to remove the file - err := CleanupTempPermissionProfile(nonTempPath) + err := CleanupTempPermissionProfile(nonTempPath, logger) if err != nil { t.Errorf("CleanupTempPermissionProfile should not fail for non-temp files: %v", err) } @@ -107,12 +113,15 @@ func TestCleanupTempPermissionProfile_NonTempFile(t *testing.T) { func TestCleanupTempPermissionProfile_NonExistentFile(t *testing.T) { t.Parallel() - logger.Initialize() + + // Create logger + logger := log.NewLogger() + // Test with a temp file pattern that doesn't exist nonExistentPath := filepath.Join(os.TempDir(), "toolhive-nonexistent-permissions-999.json") // This should not fail - err := CleanupTempPermissionProfile(nonExistentPath) + err := CleanupTempPermissionProfile(nonExistentPath, logger) if err != nil { t.Errorf("CleanupTempPermissionProfile should not fail for non-existent files: %v", err) } diff --git a/pkg/runner/protocol.go b/pkg/runner/protocol.go index f8dbfb7a6..3703fb44e 100644 --- a/pkg/runner/protocol.go +++ b/pkg/runner/protocol.go @@ -9,11 +9,11 @@ import ( "time" nameref "github.com/google/go-containerregistry/pkg/name" + "go.uber.org/zap" "github.com/stacklok/toolhive/pkg/certs" "github.com/stacklok/toolhive/pkg/container/images" "github.com/stacklok/toolhive/pkg/container/templates" - "github.com/stacklok/toolhive/pkg/logger" ) // Protocol schemes @@ -31,8 +31,9 @@ func HandleProtocolScheme( imageManager images.ImageManager, serverOrImage string, caCertPath string, + logger *zap.SugaredLogger, ) (string, error) { - return BuildFromProtocolSchemeWithName(ctx, imageManager, serverOrImage, caCertPath, "", false) + return BuildFromProtocolSchemeWithName(ctx, imageManager, serverOrImage, caCertPath, "", false, logger) } // BuildFromProtocolSchemeWithName checks if the serverOrImage string contains a protocol scheme (uvx://, npx://, or go://) @@ -47,13 +48,14 @@ func BuildFromProtocolSchemeWithName( caCertPath string, imageName string, dryRun bool, + logger *zap.SugaredLogger, ) (string, error) { transportType, packageName, err := parseProtocolScheme(serverOrImage) if err != nil { return "", err } - templateData, err := createTemplateData(transportType, packageName, caCertPath) + templateData, err := createTemplateData(transportType, packageName, caCertPath, logger) if err != nil { return "", err } @@ -67,7 +69,7 @@ func BuildFromProtocolSchemeWithName( return dockerfileContent, nil } - return buildImageFromTemplateWithName(ctx, imageManager, transportType, packageName, templateData, imageName) + return buildImageFromTemplateWithName(ctx, imageManager, transportType, packageName, templateData, imageName, logger) } // parseProtocolScheme extracts the transport type and package name from the protocol scheme. @@ -85,7 +87,12 @@ func parseProtocolScheme(serverOrImage string) (templates.TransportType, string, } // createTemplateData creates the template data with optional CA certificate. -func createTemplateData(transportType templates.TransportType, packageName, caCertPath string) (templates.TemplateData, error) { +func createTemplateData( + transportType templates.TransportType, + packageName, + caCertPath string, + logger *zap.SugaredLogger, +) (templates.TemplateData, error) { // Check if this is a local path (for Go packages only) isLocalPath := transportType == templates.TransportTypeGO && isLocalGoPath(packageName) @@ -96,7 +103,7 @@ func createTemplateData(transportType templates.TransportType, packageName, caCe } if caCertPath != "" { - if err := addCACertToTemplate(caCertPath, &templateData); err != nil { + if err := addCACertToTemplate(caCertPath, &templateData, logger); err != nil { return templateData, err } } @@ -105,7 +112,7 @@ func createTemplateData(transportType templates.TransportType, packageName, caCe } // addCACertToTemplate reads and validates a CA certificate, adding it to the template data. -func addCACertToTemplate(caCertPath string, templateData *templates.TemplateData) error { +func addCACertToTemplate(caCertPath string, templateData *templates.TemplateData, logger *zap.SugaredLogger) error { logger.Debugf("Using custom CA certificate from: %s", caCertPath) // Read the CA certificate file @@ -116,7 +123,7 @@ func addCACertToTemplate(caCertPath string, templateData *templates.TemplateData } // Validate that the file contains a valid PEM certificate - if err := certs.ValidateCACertificate(caCertContent); err != nil { + if err := certs.ValidateCACertificate(caCertContent, logger); err != nil { return fmt.Errorf("invalid CA certificate: %w", err) } @@ -135,15 +142,15 @@ type buildContext struct { // setupBuildContext sets up the appropriate build context directory based on whether // we're dealing with a local path or remote package. -func setupBuildContext(packageName string, isLocalPath bool) (*buildContext, error) { +func setupBuildContext(packageName string, isLocalPath bool, logger *zap.SugaredLogger) (*buildContext, error) { if isLocalPath { - return setupLocalBuildContext(packageName) + return setupLocalBuildContext(packageName, logger) } - return setupTempBuildContext() + return setupTempBuildContext(logger) } // setupLocalBuildContext sets up a build context using the local directory directly. -func setupLocalBuildContext(packageName string) (*buildContext, error) { +func setupLocalBuildContext(packageName string, logger *zap.SugaredLogger) (*buildContext, error) { absPath, err := filepath.Abs(packageName) if err != nil { return nil, fmt.Errorf("failed to get absolute path for %s: %w", packageName, err) @@ -186,7 +193,7 @@ func setupLocalBuildContext(packageName string) (*buildContext, error) { } // setupTempBuildContext sets up a temporary build context directory. -func setupTempBuildContext() (*buildContext, error) { +func setupTempBuildContext(logger *zap.SugaredLogger) (*buildContext, error) { tempDir, err := os.MkdirTemp("", "toolhive-docker-build-") if err != nil { return nil, fmt.Errorf("failed to create temporary directory: %w", err) @@ -209,7 +216,7 @@ func setupTempBuildContext() (*buildContext, error) { // writeDockerfile writes the Dockerfile content to the build context. // For local paths, it checks if a Dockerfile already exists and avoids overwriting it. -func writeDockerfile(dockerfilePath, dockerfileContent string, isLocalPath bool) error { +func writeDockerfile(dockerfilePath, dockerfileContent string, isLocalPath bool, logger *zap.SugaredLogger) error { if isLocalPath { // Check if a Dockerfile already exists if _, err := os.Stat(dockerfilePath); err == nil { @@ -233,7 +240,7 @@ func writeDockerfile(dockerfilePath, dockerfileContent string, isLocalPath bool) } // writeCACertificate writes the CA certificate to the build context if provided. -func writeCACertificate(buildContextDir, caCertContent string, isLocalPath bool) (func(), error) { +func writeCACertificate(buildContextDir, caCertContent string, isLocalPath bool, logger *zap.SugaredLogger) (func(), error) { if caCertContent == "" { return func() {}, nil } @@ -279,6 +286,7 @@ func buildImageFromTemplateWithName( packageName string, templateData templates.TemplateData, imageName string, + logger *zap.SugaredLogger, ) (string, error) { // Get the Dockerfile content @@ -288,19 +296,19 @@ func buildImageFromTemplateWithName( } // Set up the build context - buildCtx, err := setupBuildContext(packageName, templateData.IsLocalPath) + buildCtx, err := setupBuildContext(packageName, templateData.IsLocalPath, logger) if err != nil { return "", err } defer buildCtx.CleanupFunc() // Write the Dockerfile - if err := writeDockerfile(buildCtx.DockerfilePath, dockerfileContent, templateData.IsLocalPath); err != nil { + if err := writeDockerfile(buildCtx.DockerfilePath, dockerfileContent, templateData.IsLocalPath, logger); err != nil { return "", err } // Write CA certificate if provided - caCertCleanup, err := writeCACertificate(buildCtx.Dir, templateData.CACertContent, templateData.IsLocalPath) + caCertCleanup, err := writeCACertificate(buildCtx.Dir, templateData.CACertContent, templateData.IsLocalPath, logger) if err != nil { return "", err } diff --git a/pkg/runner/retriever/retriever.go b/pkg/runner/retriever/retriever.go index 615896bea..2bd99b710 100644 --- a/pkg/runner/retriever/retriever.go +++ b/pkg/runner/retriever/retriever.go @@ -7,11 +7,11 @@ import ( "fmt" nameref "github.com/google/go-containerregistry/pkg/name" + "go.uber.org/zap" "github.com/stacklok/toolhive/pkg/config" "github.com/stacklok/toolhive/pkg/container/images" "github.com/stacklok/toolhive/pkg/container/verifier" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/registry" "github.com/stacklok/toolhive/pkg/runner" ) @@ -38,17 +38,18 @@ func GetMCPServer( serverOrImage string, rawCACertPath string, verificationType string, + logger *zap.SugaredLogger, ) (string, *registry.ImageMetadata, error) { var imageMetadata *registry.ImageMetadata var imageToUse string - imageManager := images.NewImageManager(ctx) + imageManager := images.NewImageManager(ctx, logger) // Check if the serverOrImage is a protocol scheme, e.g., uvx://, npx://, or go:// if runner.IsImageProtocolScheme(serverOrImage) { logger.Debugf("Detected protocol scheme: %s", serverOrImage) // Process the protocol scheme and build the image - caCertPath := resolveCACertPath(rawCACertPath) - generatedImage, err := runner.HandleProtocolScheme(ctx, imageManager, serverOrImage, caCertPath) + caCertPath := resolveCACertPath(rawCACertPath, logger) + generatedImage, err := runner.HandleProtocolScheme(ctx, imageManager, serverOrImage, caCertPath, logger) if err != nil { return "", nil, errors.Join(ErrBadProtocolScheme, err) } @@ -58,7 +59,7 @@ func GetMCPServer( } else { logger.Debugf("No protocol scheme detected, using image: %s", serverOrImage) // Try to find the server in the registry - provider, err := registry.GetDefaultProvider() + provider, err := registry.GetDefaultProvider(logger) if err != nil { return "", nil, fmt.Errorf("failed to get registry provider: %v", err) } @@ -88,12 +89,12 @@ func GetMCPServer( } // Verify the image against the expected provenance info (if applicable) - if err := verifyImage(imageToUse, imageMetadata, verificationType); err != nil { + if err := verifyImage(imageToUse, imageMetadata, verificationType, logger); err != nil { return "", nil, err } // Pull the image if necessary - if err := pullImage(ctx, imageToUse, imageManager); err != nil { + if err := pullImage(ctx, imageToUse, imageManager, logger); err != nil { return "", nil, fmt.Errorf("failed to retrieve or pull image: %v", err) } @@ -105,9 +106,9 @@ func GetMCPServer( // If the image has the latest tag, it will be pulled to ensure we have the most recent version. // however, if there is a failure in pulling the "latest" tag, it will check if the image exists locally // as it is possible that the image was locally built. -func pullImage(ctx context.Context, image string, imageManager images.ImageManager) error { +func pullImage(ctx context.Context, image string, imageManager images.ImageManager, logger *zap.SugaredLogger) error { // Check if the image has the "latest" tag - isLatestTag := hasLatestTag(image) + isLatestTag := hasLatestTag(image, logger) if isLatestTag { // For "latest" tag, try to pull first @@ -155,14 +156,14 @@ func pullImage(ctx context.Context, image string, imageManager images.ImageManag } // resolveCACertPath determines the CA certificate path to use, prioritizing command-line flag over configuration -func resolveCACertPath(flagValue string) string { +func resolveCACertPath(flagValue string, logger *zap.SugaredLogger) string { // If command-line flag is provided, use it (highest priority) if flagValue != "" { return flagValue } // Otherwise, check configuration - cfg := config.GetConfig() + cfg := config.GetConfig(logger) if cfg.CACertificatePath != "" { logger.Debugf("Using configured CA certificate: %s", cfg.CACertificatePath) return cfg.CACertificatePath @@ -173,13 +174,13 @@ func resolveCACertPath(flagValue string) string { } // verifyImage verifies the image using the specified verification setting (warn, enabled, or disabled) -func verifyImage(image string, server *registry.ImageMetadata, verifySetting string) error { +func verifyImage(image string, server *registry.ImageMetadata, verifySetting string, logger *zap.SugaredLogger) error { switch verifySetting { case VerifyImageDisabled: logger.Warn("Image verification is disabled") case VerifyImageWarn, VerifyImageEnabled: // Create a new verifier - v, err := verifier.New(server) + v, err := verifier.New(server, logger) if err != nil { // This happens if we have no provenance entry in the registry for this server. // Not finding provenance info in the registry is not a fatal error if the setting is "warn". @@ -211,7 +212,7 @@ func verifyImage(image string, server *registry.ImageMetadata, verifySetting str } // hasLatestTag checks if the given image reference has the "latest" tag or no tag (which defaults to "latest") -func hasLatestTag(imageRef string) bool { +func hasLatestTag(imageRef string, logger *zap.SugaredLogger) bool { ref, err := nameref.ParseReference(imageRef) if err != nil { // If we can't parse the reference, assume it's not "latest" diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 3c35ac1a8..7461f1233 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -9,12 +9,13 @@ import ( "syscall" "time" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/client" "github.com/stacklok/toolhive/pkg/config" rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/labels" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/mcp" "github.com/stacklok/toolhive/pkg/process" "github.com/stacklok/toolhive/pkg/secrets" @@ -36,13 +37,16 @@ type Runner struct { supportedMiddleware map[string]types.MiddlewareFactory statusManager statuses.StatusManager + + logger *zap.SugaredLogger } // NewRunner creates a new Runner with the provided configuration -func NewRunner(runConfig *RunConfig, statusManager statuses.StatusManager) *Runner { +func NewRunner(runConfig *RunConfig, statusManager statuses.StatusManager, logger *zap.SugaredLogger) *Runner { return &Runner{ Config: runConfig, statusManager: statusManager, + logger: logger, } } @@ -78,7 +82,7 @@ func (r *Runner) Run(ctx context.Context) error { // Ensure middleware is cleaned up on shutdown. defer func() { if err := middleware.Close(); err != nil { - logger.Warnf("Failed to close middleware of type %s: %v", middlewareConfig.Type, err) + r.logger.Warnf("Failed to close middleware of type %s: %v", middlewareConfig.Type, err) } }() transportConfig.Middlewares = append(transportConfig.Middlewares, middleware.Handler()) @@ -91,14 +95,14 @@ func (r *Runner) Run(ctx context.Context) error { } transportConfig.Middlewares = append(transportConfig.Middlewares, toolsFilterMiddleware) - toolsCallFilterMiddleware, err := mcp.NewToolCallFilterMiddleware(r.Config.ToolsFilter) + toolsCallFilterMiddleware, err := mcp.NewToolCallFilterMiddleware(r.Config.ToolsFilter, r.logger) if err != nil { return fmt.Errorf("failed to create tools call filter middleware: %v", err) } transportConfig.Middlewares = append(transportConfig.Middlewares, toolsCallFilterMiddleware) } - authMiddleware, authInfoHandler, err := auth.GetAuthenticationMiddleware(ctx, r.Config.OIDCConfig) + authMiddleware, authInfoHandler, err := auth.GetAuthenticationMiddleware(ctx, r.Config.OIDCConfig, r.logger) if err != nil { return fmt.Errorf("failed to create authentication middleware: %v", err) } @@ -106,12 +110,12 @@ func (r *Runner) Run(ctx context.Context) error { transportConfig.AuthInfoHandler = authInfoHandler // Add MCP parsing middleware after authentication - logger.Info("MCP parsing middleware enabled for transport") + r.logger.Info("MCP parsing middleware enabled for transport") transportConfig.Middlewares = append(transportConfig.Middlewares, mcp.ParsingMiddleware) // Add telemetry middleware if telemetry configuration is provided if r.Config.TelemetryConfig != nil { - logger.Info("OpenTelemetry instrumentation enabled for transport") + r.logger.Info("OpenTelemetry instrumentation enabled for transport") // Create telemetry provider telemetryProvider, err := telemetry.NewProvider(ctx, *r.Config.TelemetryConfig) @@ -126,7 +130,7 @@ func (r *Runner) Run(ctx context.Context) error { // Add Prometheus handler to transport config if metrics port is configured if r.Config.TelemetryConfig.EnablePrometheusMetricsPath { transportConfig.PrometheusHandler = telemetryProvider.PrometheusHandler() - logger.Infof("Prometheus metrics will be exposed on port %d at /metrics", r.Config.Port) + r.logger.Infof("Prometheus metrics will be exposed on port %d at /metrics", r.Config.Port) } // Store provider for cleanup @@ -135,10 +139,10 @@ func (r *Runner) Run(ctx context.Context) error { // Add authorization middleware if authorization configuration is provided if r.Config.AuthzConfig != nil { - logger.Info("Authorization enabled for transport") + r.logger.Info("Authorization enabled for transport") // Get the middleware from the configuration - middleware, err := r.Config.AuthzConfig.CreateMiddleware() + middleware, err := r.Config.AuthzConfig.CreateMiddleware(r.logger) if err != nil { return fmt.Errorf("failed to get authorization middleware: %v", err) } @@ -149,7 +153,7 @@ func (r *Runner) Run(ctx context.Context) error { // Add audit middleware if audit configuration is provided if r.Config.AuditConfig != nil { - logger.Info("Audit logging enabled for transport") + r.logger.Info("Audit logging enabled for transport") // Set the component name if not already set if r.Config.AuditConfig.Component == "" { @@ -157,7 +161,7 @@ func (r *Runner) Run(ctx context.Context) error { } // Get the middleware from the configuration - middleware, err := r.Config.AuditConfig.CreateMiddleware() + middleware, err := r.Config.AuditConfig.CreateMiddleware(r.logger) if err != nil { return fmt.Errorf("failed to create audit middleware: %w", err) } @@ -169,21 +173,21 @@ func (r *Runner) Run(ctx context.Context) error { // Set proxy mode for stdio transport transportConfig.ProxyMode = r.Config.ProxyMode - transportHandler, err := transport.NewFactory().Create(transportConfig) + transportHandler, err := transport.NewFactory(r.logger).Create(transportConfig) if err != nil { return fmt.Errorf("failed to create transport: %v", err) } // Process secrets if provided if len(r.Config.Secrets) > 0 { - cfg := config.GetConfig() + cfg := config.GetConfig(r.logger) providerType, err := cfg.Secrets.GetProviderType() if err != nil { return fmt.Errorf("error determining secrets provider type: %w", err) } - secretManager, err := secrets.CreateSecretProvider(providerType) + secretManager, err := secrets.CreateSecretProvider(providerType, r.logger) if err != nil { return fmt.Errorf("error instantiating secret manager %v", err) } @@ -195,7 +199,7 @@ func (r *Runner) Run(ctx context.Context) error { } // Set up the transport - logger.Infof("Setting up %s transport...", r.Config.Transport) + r.logger.Infof("Setting up %s transport...", r.Config.Transport) if err := transportHandler.Setup( ctx, r.Config.Deployer, r.Config.ContainerName, r.Config.Image, r.Config.CmdArgs, r.Config.EnvVars, r.Config.ContainerLabels, r.Config.PermissionProfile, r.Config.K8sPodTemplatePatch, @@ -205,61 +209,61 @@ func (r *Runner) Run(ctx context.Context) error { } // Start the transport (which also starts the container and monitoring) - logger.Infof("Starting %s transport for %s...", r.Config.Transport, r.Config.ContainerName) + r.logger.Infof("Starting %s transport for %s...", r.Config.Transport, r.Config.ContainerName) if err := transportHandler.Start(ctx); err != nil { return fmt.Errorf("failed to start transport: %v", err) } - logger.Infof("MCP server %s started successfully", r.Config.ContainerName) + r.logger.Infof("MCP server %s started successfully", r.Config.ContainerName) // Update client configurations with the MCP server URL. // Note that this function checks the configuration to determine which // clients should be updated, if any. - clientManager, err := client.NewManager(ctx) + clientManager, err := client.NewManager(ctx, r.logger) if err != nil { - logger.Warnf("Warning: Failed to create client manager: %v", err) + r.logger.Warnf("Warning: Failed to create client manager: %v", err) } else { transportType := labels.GetTransportType(r.Config.ContainerLabels) serverURL := transport.GenerateMCPServerURL(transportType, "localhost", r.Config.Port, r.Config.ContainerName) if err := clientManager.AddServerToClients(ctx, r.Config.ContainerName, serverURL, transportType, r.Config.Group); err != nil { - logger.Warnf("Warning: Failed to add server to client configurations: %v", err) + r.logger.Warnf("Warning: Failed to add server to client configurations: %v", err) } } // Define a function to stop the MCP server stopMCPServer := func(reason string) { - logger.Infof("Stopping MCP server: %s", reason) + r.logger.Infof("Stopping MCP server: %s", reason) // Stop the transport (which also stops the container, monitoring, and handles removal) - logger.Infof("Stopping %s transport...", r.Config.Transport) + r.logger.Infof("Stopping %s transport...", r.Config.Transport) if err := transportHandler.Stop(ctx); err != nil { - logger.Warnf("Warning: Failed to stop transport: %v", err) + r.logger.Warnf("Warning: Failed to stop transport: %v", err) } // Cleanup telemetry provider if err := r.Cleanup(ctx); err != nil { - logger.Warnf("Warning: Failed to cleanup telemetry: %v", err) + r.logger.Warnf("Warning: Failed to cleanup telemetry: %v", err) } // Remove the PID file if it exists if err := process.RemovePIDFile(r.Config.BaseName); err != nil { - logger.Warnf("Warning: Failed to remove PID file: %v", err) + r.logger.Warnf("Warning: Failed to remove PID file: %v", err) } - logger.Infof("MCP server %s stopped", r.Config.ContainerName) + r.logger.Infof("MCP server %s stopped", r.Config.ContainerName) } if process.IsDetached() { // We're a detached process running in foreground mode // Write the PID to a file so the stop command can kill the process if err := process.WriteCurrentPIDFile(r.Config.BaseName); err != nil { - logger.Warnf("Warning: Failed to write PID file: %v", err) + r.logger.Warnf("Warning: Failed to write PID file: %v", err) } - logger.Infof("Running as detached process (PID: %d)", os.Getpid()) + r.logger.Infof("Running as detached process (PID: %d)", os.Getpid()) } else { - logger.Info("Press Ctrl+C to stop or wait for container to exit") + r.logger.Info("Press Ctrl+C to stop or wait for container to exit") } // Set up signal handling @@ -274,7 +278,7 @@ func (r *Runner) Run(ctx context.Context) error { for { // Safely check if transportHandler is nil if transportHandler == nil { - logger.Info("Transport handler is nil, exiting monitoring routine...") + r.logger.Info("Transport handler is nil, exiting monitoring routine...") close(doneCh) return } @@ -282,14 +286,14 @@ func (r *Runner) Run(ctx context.Context) error { // Check if the transport is still running running, err := transportHandler.IsRunning(ctx) if err != nil { - logger.Errorf("Error checking transport status: %v", err) + r.logger.Errorf("Error checking transport status: %v", err) // Don't exit immediately on error, try again after pause time.Sleep(1 * time.Second) continue } if !running { // Transport is no longer running (container exited or was stopped) - logger.Info("Transport is no longer running, exiting...") + r.logger.Info("Transport is no longer running, exiting...") close(doneCh) return } @@ -313,10 +317,10 @@ func (r *Runner) Run(ctx context.Context) error { // The transport has already been stopped (likely by the container monitor) // Clean up the PID file and state if err := process.RemovePIDFile(r.Config.BaseName); err != nil { - logger.Warnf("Warning: Failed to remove PID file: %v", err) + r.logger.Warnf("Warning: Failed to remove PID file: %v", err) } - logger.Infof("MCP server %s stopped", r.Config.ContainerName) + r.logger.Infof("MCP server %s stopped", r.Config.ContainerName) } return nil @@ -325,9 +329,9 @@ func (r *Runner) Run(ctx context.Context) error { // Cleanup performs cleanup operations for the runner, including shutting down telemetry. func (r *Runner) Cleanup(ctx context.Context) error { if r.telemetryProvider != nil { - logger.Debug("Shutting down telemetry provider") + r.logger.Debug("Shutting down telemetry provider") if err := r.telemetryProvider.Shutdown(ctx); err != nil { - logger.Warnf("Warning: Failed to shutdown telemetry provider: %v", err) + r.logger.Warnf("Warning: Failed to shutdown telemetry provider: %v", err) return err } } diff --git a/pkg/secrets/factory.go b/pkg/secrets/factory.go index 6866360b7..d919d184f 100644 --- a/pkg/secrets/factory.go +++ b/pkg/secrets/factory.go @@ -11,9 +11,9 @@ import ( "github.com/adrg/xdg" "github.com/zalando/go-keyring" + "go.uber.org/zap" "golang.org/x/term" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/process" ) @@ -57,20 +57,25 @@ type SetupResult struct { } // ValidateProvider validates that a provider can be created and performs basic functionality tests -func ValidateProvider(ctx context.Context, providerType ProviderType) *SetupResult { - return ValidateProviderWithPassword(ctx, providerType, "") +func ValidateProvider(ctx context.Context, providerType ProviderType, logger *zap.SugaredLogger) *SetupResult { + return ValidateProviderWithPassword(ctx, providerType, "", logger) } // ValidateProviderWithPassword validates that a provider can be created and performs basic functionality tests. // If password is provided for encrypted provider, it uses that password instead of reading from stdin. -func ValidateProviderWithPassword(ctx context.Context, providerType ProviderType, password string) *SetupResult { +func ValidateProviderWithPassword( + ctx context.Context, + providerType ProviderType, + password string, + logger *zap.SugaredLogger, +) *SetupResult { result := &SetupResult{ ProviderType: providerType, Success: false, } // Test that we can create the provider - provider, err := CreateSecretProviderWithPassword(providerType, password) + provider, err := CreateSecretProviderWithPassword(providerType, password, logger) if err != nil { result.Error = fmt.Errorf("failed to create provider: %w", err) result.Message = fmt.Sprintf("Failed to initialize %s provider", providerType) @@ -175,14 +180,14 @@ func IsKeyringAvailable() bool { // CreateSecretProvider creates the specified type of secrets provider. // TODO CREATE function does not actually create anything, refactor or rename -func CreateSecretProvider(managerType ProviderType) (Provider, error) { - return CreateSecretProviderWithPassword(managerType, "") +func CreateSecretProvider(managerType ProviderType, logger *zap.SugaredLogger) (Provider, error) { + return CreateSecretProviderWithPassword(managerType, "", logger) } // CreateSecretProviderWithPassword creates the specified type of secrets provider with an optional password. // If password is empty, it uses the current functionality (read from keyring or stdin). // If password is provided, it uses that password and stores it in the keyring if not already setup. -func CreateSecretProviderWithPassword(managerType ProviderType, password string) (Provider, error) { +func CreateSecretProviderWithPassword(managerType ProviderType, password string, logger *zap.SugaredLogger) (Provider, error) { switch managerType { case EncryptedType: // Enforce keyring availability for encrypted provider @@ -190,7 +195,7 @@ func CreateSecretProviderWithPassword(managerType ProviderType, password string) return nil, ErrKeyringNotAvailable } - secretsPassword, err := GetSecretsPassword(password) + secretsPassword, err := GetSecretsPassword(password, logger) if err != nil { return nil, fmt.Errorf("failed to get secrets password: %w", err) } @@ -213,7 +218,7 @@ func CreateSecretProviderWithPassword(managerType ProviderType, password string) // GetSecretsPassword returns the password to use for encrypting and decrypting secrets. // If optionalPassword is provided and keyring is not yet setup, it uses that password and stores it. // Otherwise, it uses the current functionality (read from keyring or stdin). -func GetSecretsPassword(optionalPassword string) ([]byte, error) { +func GetSecretsPassword(optionalPassword string, logger *zap.SugaredLogger) ([]byte, error) { // Attempt to load the password from the OS keyring. keyringSecret, err := keyring.Get(keyringService, keyringService) if err == nil { diff --git a/pkg/secrets/none_test.go b/pkg/secrets/none_test.go index bde6551a2..6c878235f 100644 --- a/pkg/secrets/none_test.go +++ b/pkg/secrets/none_test.go @@ -6,6 +6,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + log "github.com/stacklok/toolhive/pkg/logger" ) func TestNewNoneManager(t *testing.T) { @@ -115,7 +117,10 @@ func TestNoneManager_Capabilities(t *testing.T) { func TestCreateSecretProvider_None(t *testing.T) { t.Parallel() - provider, err := CreateSecretProvider(NoneType) + + logger := log.NewLogger() + + provider, err := CreateSecretProvider(NoneType, logger) require.NoError(t, err) assert.NotNil(t, provider) diff --git a/pkg/state/runconfig.go b/pkg/state/runconfig.go index 672648c69..d3c484c72 100644 --- a/pkg/state/runconfig.go +++ b/pkg/state/runconfig.go @@ -6,8 +6,9 @@ import ( "fmt" "io" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/errors" - "github.com/stacklok/toolhive/pkg/logger" ) // LoadRunConfigJSON loads a run configuration from the state store and returns the raw reader @@ -37,7 +38,7 @@ func LoadRunConfigJSON(ctx context.Context, name string) (io.ReadCloser, error) } // DeleteSavedRunConfig deletes a saved run configuration -func DeleteSavedRunConfig(ctx context.Context, name string) error { +func DeleteSavedRunConfig(ctx context.Context, name string, logger *zap.SugaredLogger) error { // Create a state store store, err := NewRunConfigStore(DefaultAppName) if err != nil { @@ -74,7 +75,7 @@ type RunConfigPersister interface { type ReadJSONFunc[T any] func(r io.Reader) (T, error) // SaveRunConfig saves a run configuration to the state store -func SaveRunConfig[T RunConfigPersister](ctx context.Context, config T) error { +func SaveRunConfig[T RunConfigPersister](ctx context.Context, config T, logger *zap.SugaredLogger) error { // Create a state store store, err := NewRunConfigStore(DefaultAppName) if err != nil { diff --git a/pkg/transport/factory.go b/pkg/transport/factory.go index 2ace50804..2b3d4ff41 100644 --- a/pkg/transport/factory.go +++ b/pkg/transport/factory.go @@ -3,24 +3,28 @@ package transport import ( + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/transport/errors" "github.com/stacklok/toolhive/pkg/transport/types" ) // Factory creates transports -type Factory struct{} +type Factory struct { + logger *zap.SugaredLogger +} // NewFactory creates a new transport factory -func NewFactory() *Factory { - return &Factory{} +func NewFactory(logger *zap.SugaredLogger) *Factory { + return &Factory{logger} } // Create creates a transport based on the provided configuration -func (*Factory) Create(config types.Config) (types.Transport, error) { +func (f *Factory) Create(config types.Config) (types.Transport, error) { switch config.Type { case types.TransportTypeStdio: tr := NewStdioTransport( - config.Host, config.ProxyPort, config.Deployer, config.Debug, config.PrometheusHandler, config.Middlewares..., + config.Host, config.ProxyPort, config.Deployer, config.Debug, config.PrometheusHandler, f.logger, config.Middlewares..., ) tr.SetProxyMode(config.ProxyMode) return tr, nil @@ -35,6 +39,7 @@ func (*Factory) Create(config types.Config) (types.Transport, error) { config.TargetHost, config.AuthInfoHandler, config.PrometheusHandler, + f.logger, config.Middlewares..., ), nil case types.TransportTypeStreamableHTTP: @@ -48,6 +53,7 @@ func (*Factory) Create(config types.Config) (types.Transport, error) { config.TargetHost, config.AuthInfoHandler, config.PrometheusHandler, + f.logger, config.Middlewares..., ), nil case types.TransportTypeInspector: diff --git a/pkg/transport/http.go b/pkg/transport/http.go index 578851564..b7008be86 100644 --- a/pkg/transport/http.go +++ b/pkg/transport/http.go @@ -6,10 +6,11 @@ import ( "net/http" "sync" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/container" rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/ignore" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/permissions" "github.com/stacklok/toolhive/pkg/transport/errors" "github.com/stacklok/toolhive/pkg/transport/proxy/transparent" @@ -49,6 +50,8 @@ type HTTPTransport struct { // Container monitor monitor rt.Monitor errorCh <-chan error + + logger *zap.SugaredLogger } // NewHTTPTransport creates a new HTTP transport. @@ -62,6 +65,7 @@ func NewHTTPTransport( targetHost string, authInfoHandler http.Handler, prometheusHandler http.Handler, + logger *zap.SugaredLogger, middlewares ...types.MiddlewareFunction, ) *HTTPTransport { if host == "" { @@ -85,6 +89,7 @@ func NewHTTPTransport( prometheusHandler: prometheusHandler, authInfoHandler: authInfoHandler, shutdownCh: make(chan struct{}), + logger: logger, } } @@ -156,7 +161,7 @@ func (t *HTTPTransport) Setup(ctx context.Context, runtime rt.Deployer, containe containerOptions.AttachStdio = false // Create the container - logger.Infof("Deploying workload %s from image %s...", containerName, image) + t.logger.Infof("Deploying workload %s from image %s...", containerName, image) exposedPort, err := t.deployer.DeployWorkload( ctx, image, @@ -172,7 +177,7 @@ func (t *HTTPTransport) Setup(ctx context.Context, runtime rt.Deployer, containe if err != nil { return fmt.Errorf("failed to create container: %v", err) } - logger.Infof("Container created: %s", containerName) + t.logger.Infof("Container created: %s", containerName) if (t.Mode() == types.TransportTypeSSE || t.Mode() == types.TransportTypeStreamableHTTP) && rt.IsKubernetesRuntime() { // If the SSEHeadlessServiceName is set, use it as the target host @@ -225,7 +230,7 @@ func (t *HTTPTransport) Start(ctx context.Context) error { // Use the target port for the container containerPort := t.targetPort targetURI := fmt.Sprintf("http://%s:%d", targetHost, containerPort) - logger.Infof("Setting up transparent proxy to forward from host port %d to %s", + t.logger.Infof("Setting up transparent proxy to forward from host port %d to %s", t.proxyPort, targetURI) // Create the transparent proxy with middlewares @@ -233,15 +238,16 @@ func (t *HTTPTransport) Start(ctx context.Context) error { t.host, t.proxyPort, t.containerName, targetURI, t.prometheusHandler, t.authInfoHandler, true, + t.logger, t.middlewares...) if err := t.proxy.Start(ctx); err != nil { return err } - logger.Infof("HTTP transport started for container %s on port %d", t.containerName, t.proxyPort) + t.logger.Infof("HTTP transport started for container %s on port %d", t.containerName, t.proxyPort) // Create a container monitor - monitorRuntime, err := container.NewFactory().Create(ctx) + monitorRuntime, err := container.NewFactory(t.logger).Create(ctx) if err != nil { return fmt.Errorf("failed to create container monitor: %v", err) } @@ -276,7 +282,7 @@ func (t *HTTPTransport) Stop(ctx context.Context) error { // Stop the transparent proxy if t.proxy != nil { if err := t.proxy.Stop(ctx); err != nil { - logger.Warnf("Warning: Failed to stop proxy: %v", err) + t.logger.Warnf("Warning: Failed to stop proxy: %v", err) } } @@ -296,10 +302,10 @@ func (t *HTTPTransport) handleContainerExit(ctx context.Context) { case <-ctx.Done(): return case err := <-t.errorCh: - logger.Infof("Container %s exited: %v", t.containerName, err) + t.logger.Infof("Container %s exited: %v", t.containerName, err) // Stop the transport when the container exits if stopErr := t.Stop(ctx); stopErr != nil { - logger.Errorf("Error stopping transport after container exit: %v", stopErr) + t.logger.Errorf("Error stopping transport after container exit: %v", stopErr) } } } diff --git a/pkg/transport/proxy/httpsse/http_proxy.go b/pkg/transport/proxy/httpsse/http_proxy.go index 4631c2095..d260df6b5 100644 --- a/pkg/transport/proxy/httpsse/http_proxy.go +++ b/pkg/transport/proxy/httpsse/http_proxy.go @@ -11,10 +11,10 @@ import ( "time" "github.com/google/uuid" + "go.uber.org/zap" "golang.org/x/exp/jsonrpc2" "github.com/stacklok/toolhive/pkg/healthcheck" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/transport/ssecommon" "github.com/stacklok/toolhive/pkg/transport/types" ) @@ -74,11 +74,19 @@ type HTTPSSEProxy struct { // Health checker healthChecker *healthcheck.HealthChecker + + // Logger + logger *zap.SugaredLogger } // NewHTTPSSEProxy creates a new HTTP SSE proxy for transports. func NewHTTPSSEProxy( - host string, port int, containerName string, prometheusHandler http.Handler, middlewares ...types.MiddlewareFunction, + host string, + port int, + containerName string, + prometheusHandler http.Handler, + logger *zap.SugaredLogger, + middlewares ...types.MiddlewareFunction, ) *HTTPSSEProxy { proxy := &HTTPSSEProxy{ middlewares: middlewares, @@ -90,11 +98,12 @@ func NewHTTPSSEProxy( sseClients: make(map[string]*ssecommon.SSEClient), pendingMessages: []*ssecommon.PendingSSEMessage{}, prometheusHandler: prometheusHandler, + logger: logger, } // Create MCP pinger and health checker mcpPinger := NewMCPPinger(proxy) - proxy.healthChecker = healthcheck.NewHealthChecker("stdio", mcpPinger) + proxy.healthChecker = healthcheck.NewHealthChecker("stdio", mcpPinger, logger) return proxy } @@ -135,7 +144,7 @@ func (p *HTTPSSEProxy) Start(_ context.Context) error { // Add Prometheus metrics endpoint if handler is provided (no middlewares) if p.prometheusHandler != nil { mux.Handle("/metrics", p.prometheusHandler) - logger.Info("Prometheus metrics endpoint enabled at /metrics") + p.logger.Info("Prometheus metrics endpoint enabled at /metrics") } // Create the server @@ -147,12 +156,12 @@ func (p *HTTPSSEProxy) Start(_ context.Context) error { // Start the server in a goroutine go func() { - logger.Infof("HTTP proxy started for container %s on port %d", p.containerName, p.port) - logger.Infof("SSE endpoint: http://%s:%d%s", p.host, p.port, ssecommon.HTTPSSEEndpoint) - logger.Infof("JSON-RPC endpoint: http://%s:%d%s", p.host, p.port, ssecommon.HTTPMessagesEndpoint) + p.logger.Infof("HTTP proxy started for container %s on port %d", p.containerName, p.port) + p.logger.Infof("SSE endpoint: http://%s:%d%s", p.host, p.port, ssecommon.HTTPSSEEndpoint) + p.logger.Infof("JSON-RPC endpoint: http://%s:%d%s", p.host, p.port, ssecommon.HTTPMessagesEndpoint) if err := p.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.Errorf("HTTP server error: %v", err) + p.logger.Errorf("HTTP server error: %v", err) } }() @@ -285,7 +294,7 @@ func (p *HTTPSSEProxy) handleSSEConnection(w http.ResponseWriter, r *http.Reques delete(p.sseClients, clientID) p.sseClientsMutex.Unlock() close(messageCh) - logger.Infof("Client %s disconnected", clientID) + p.logger.Infof("Client %s disconnected", clientID) }() // Send messages to the client @@ -348,7 +357,7 @@ func (p *HTTPSSEProxy) handlePostRequest(w http.ResponseWriter, r *http.Request) } // Log the message - logger.Infof("Received JSON-RPC message: %T", msg) + p.logger.Infof("Received JSON-RPC message: %T", msg) // Send the message to the destination if err := p.SendMessageToDestination(msg); err != nil { @@ -359,7 +368,7 @@ func (p *HTTPSSEProxy) handlePostRequest(w http.ResponseWriter, r *http.Request) // Return a success response w.WriteHeader(http.StatusAccepted) if _, err := w.Write([]byte("Accepted")); err != nil { - logger.Warnf("Warning: Failed to write response: %v", err) + p.logger.Warnf("Warning: Failed to write response: %v", err) } } @@ -380,7 +389,7 @@ func (p *HTTPSSEProxy) sendSSEEvent(msg *ssecommon.SSEMessage) error { // Channel is full or closed, remove the client delete(p.sseClients, clientID) close(client.MessageCh) - logger.Infof("Client %s removed (channel full or closed)", clientID) + p.logger.Infof("Client %s removed (channel full or closed)", clientID) } } @@ -407,7 +416,7 @@ func (p *HTTPSSEProxy) processPendingMessages(clientID string, messageCh chan<- // Message sent successfully default: // Channel is full, stop sending - logger.Errorf("Failed to send pending message to client %s (channel full)", clientID) + p.logger.Errorf("Failed to send pending message to client %s (channel full)", clientID) return } } diff --git a/pkg/transport/proxy/httpsse/pinger.go b/pkg/transport/proxy/httpsse/pinger.go index 40b379304..e7973411a 100644 --- a/pkg/transport/proxy/httpsse/pinger.go +++ b/pkg/transport/proxy/httpsse/pinger.go @@ -10,7 +10,6 @@ import ( "golang.org/x/exp/jsonrpc2" "github.com/stacklok/toolhive/pkg/healthcheck" - "github.com/stacklok/toolhive/pkg/logger" ) // MCPPinger implements healthcheck.MCPPinger for HTTP SSE proxies @@ -51,7 +50,7 @@ func (p *MCPPinger) Ping(ctx context.Context) (time.Duration, error) { // Send the ping request select { case messageCh <- pingRequest: - logger.Debugf("Sent MCP ping request with ID: %s", pingID) + p.proxy.logger.Debugf("Sent MCP ping request with ID: %s", pingID) case <-ctx.Done(): return 0, ctx.Err() default: @@ -64,6 +63,6 @@ func (p *MCPPinger) Ping(ctx context.Context) (time.Duration, error) { // In a real implementation, you might want to set up a response listener duration := time.Since(start) - logger.Debugf("MCP ping request sent in %v", duration) + p.proxy.logger.Debugf("MCP ping request sent in %v", duration) return duration, nil } diff --git a/pkg/transport/proxy/manager.go b/pkg/transport/proxy/manager.go index 80326754f..df0841e08 100644 --- a/pkg/transport/proxy/manager.go +++ b/pkg/transport/proxy/manager.go @@ -2,7 +2,8 @@ package proxy import ( - "github.com/stacklok/toolhive/pkg/logger" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/process" ) @@ -10,7 +11,7 @@ import ( // have been moved to this package to keep proxy-related logic grouped together. // StopProcess stops the proxy process associated with the container -func StopProcess(containerBaseName string) { +func StopProcess(containerBaseName string, logger *zap.SugaredLogger) { if containerBaseName == "" { logger.Warnf("Warning: Could not find base container name in labels") return @@ -38,7 +39,7 @@ func StopProcess(containerBaseName string) { } // IsRunning checks if the proxy process is running -func IsRunning(containerBaseName string) bool { +func IsRunning(containerBaseName string, logger *zap.SugaredLogger) bool { if containerBaseName == "" { return false } diff --git a/pkg/transport/proxy/streamable/streamable_proxy.go b/pkg/transport/proxy/streamable/streamable_proxy.go index 6ac6e7b17..e5f4d392d 100644 --- a/pkg/transport/proxy/streamable/streamable_proxy.go +++ b/pkg/transport/proxy/streamable/streamable_proxy.go @@ -10,9 +10,9 @@ import ( "net/http" "time" + "go.uber.org/zap" "golang.org/x/exp/jsonrpc2" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/transport/types" ) @@ -36,6 +36,8 @@ type HTTPProxy struct { responseCh chan jsonrpc2.Message server *http.Server + + logger *zap.SugaredLogger } // NewHTTPProxy creates a new HTTPProxy for streamable HTTP transport. @@ -44,6 +46,7 @@ func NewHTTPProxy( port int, containerName string, prometheusHandler http.Handler, + logger *zap.SugaredLogger, middlewares ...types.MiddlewareFunction, ) *HTTPProxy { return &HTTPProxy{ @@ -55,6 +58,7 @@ func NewHTTPProxy( middlewares: middlewares, messageCh: make(chan jsonrpc2.Message, 100), responseCh: make(chan jsonrpc2.Message, 100), + logger: logger, } } @@ -75,10 +79,10 @@ func (p *HTTPProxy) Start(_ context.Context) error { } go func() { - logger.Infof("Streamable HTTP proxy started for container %s on port %d", p.containerName, p.port) - logger.Infof("Streamable HTTP endpoint: http://%s:%d%s", p.host, p.port, StreamableHTTPEndpoint) + p.logger.Infof("Streamable HTTP proxy started for container %s on port %d", p.containerName, p.port) + p.logger.Infof("Streamable HTTP endpoint: http://%s:%d%s", p.host, p.port, StreamableHTTPEndpoint) if err := p.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.Errorf("Streamable HTTP server error: %v", err) + p.logger.Errorf("Streamable HTTP server error: %v", err) } }() @@ -153,7 +157,7 @@ func (p *HTTPProxy) handleBatchRequest(w http.ResponseWriter, body []byte) bool // Decode batch var rawMessages []json.RawMessage if err := json.Unmarshal(trimmed, &rawMessages); err != nil { - logger.Warnf("Failed to decode batch JSON-RPC: %s", string(body)) + p.logger.Warnf("Failed to decode batch JSON-RPC: %s", string(body)) http.Error(w, "Invalid batch JSON-RPC", http.StatusBadRequest) return true } @@ -162,13 +166,13 @@ func (p *HTTPProxy) handleBatchRequest(w http.ResponseWriter, body []byte) bool for _, raw := range rawMessages { msg, err := jsonrpc2.DecodeMessage(raw) if err != nil { - logger.Warnf("Skipping invalid message in batch: %s", string(raw)) + p.logger.Warnf("Skipping invalid message in batch: %s", string(raw)) continue } // Send each message to the container if err := p.SendMessageToDestination(msg); err != nil { - logger.Errorf("Failed to send message to destination: %v", err) + p.logger.Errorf("Failed to send message to destination: %v", err) continue } @@ -178,13 +182,13 @@ func (p *HTTPProxy) handleBatchRequest(w http.ResponseWriter, body []byte) bool if r, ok := resp.(*jsonrpc2.Response); ok && r.ID.IsValid() { data, err := jsonrpc2.EncodeMessage(r) if err != nil { - logger.Errorf("Failed to encode JSON-RPC response: %v", err) + p.logger.Errorf("Failed to encode JSON-RPC response: %v", err) continue } responses = append(responses, data) } case <-time.After(10 * time.Second): - logger.Warnf("Timeout waiting for response from container for batch message") + p.logger.Warnf("Timeout waiting for response from container for batch message") // Optionally, append a JSON-RPC error response here } } @@ -197,12 +201,12 @@ func (p *HTTPProxy) handleBatchRequest(w http.ResponseWriter, body []byte) bool } respBytes, err := json.Marshal(responses) if err != nil { - logger.Errorf("Failed to marshal batch response: %v", err) + p.logger.Errorf("Failed to marshal batch response: %v", err) http.Error(w, "Failed to encode batch response", http.StatusInternalServerError) return true } if _, err := w.Write(respBytes); err != nil { - logger.Errorf("Failed to write batch response: %v", err) + p.logger.Errorf("Failed to write batch response: %v", err) } return true } @@ -224,7 +228,7 @@ func (p *HTTPProxy) handleStreamableRequest(w http.ResponseWriter, r *http.Reque return } - msg, ok := decodeJSONRPCMessage(w, body) + msg, ok := decodeJSONRPCMessage(w, body, p.logger) if !ok { return } @@ -239,7 +243,7 @@ func (p *HTTPProxy) handleStreamableRequest(w http.ResponseWriter, r *http.Reque func (p *HTTPProxy) handleNotificationOrResponse(w http.ResponseWriter, msg jsonrpc2.Message) bool { if isNotification(msg) || (func() bool { _, ok := msg.(*jsonrpc2.Response); return ok })() { if err := p.SendMessageToDestination(msg); err != nil { - logger.Errorf("Failed to send message to destination: %v", err) + p.logger.Errorf("Failed to send message to destination: %v", err) } w.WriteHeader(http.StatusAccepted) return true @@ -258,18 +262,18 @@ func (p *HTTPProxy) handleRequestResponse(w http.ResponseWriter, msg jsonrpc2.Me w.Header().Set("Content-Type", "application/json") data, err := jsonrpc2.EncodeMessage(r) if err != nil { - logger.Errorf("Failed to encode JSON-RPC response: %v", err) + p.logger.Errorf("Failed to encode JSON-RPC response: %v", err) http.Error(w, "Failed to encode response", http.StatusInternalServerError) return } if _, err := w.Write(data); err != nil { - logger.Errorf("Failed to write response: %v", err) + p.logger.Errorf("Failed to write response: %v", err) } } else { w.Header().Set("Content-Type", "application/json") errResp := getInvalidJsonrpcError() if err := json.NewEncoder(w).Encode(errResp); err != nil { - logger.Errorf("Failed to encode error response: %v", err) + p.logger.Errorf("Failed to encode error response: %v", err) } } case <-time.After(10 * time.Second): @@ -290,7 +294,7 @@ func getInvalidJsonrpcError() map[string]interface{} { } // decodeJSONRPCMessage decodes a JSON-RPC message from the request body. -func decodeJSONRPCMessage(w http.ResponseWriter, body []byte) (jsonrpc2.Message, bool) { +func decodeJSONRPCMessage(w http.ResponseWriter, body []byte, logger *zap.SugaredLogger) (jsonrpc2.Message, bool) { msg, err := jsonrpc2.DecodeMessage(body) if err != nil { logger.Warnf("Skipping message that failed to decode: %s", string(body)) diff --git a/pkg/transport/proxy/transparent/pinger.go b/pkg/transport/proxy/transparent/pinger.go index e183e6a06..396a02bac 100644 --- a/pkg/transport/proxy/transparent/pinger.go +++ b/pkg/transport/proxy/transparent/pinger.go @@ -7,23 +7,26 @@ import ( "net/http" "time" + "go.uber.org/zap" + "github.com/stacklok/toolhive/pkg/healthcheck" - "github.com/stacklok/toolhive/pkg/logger" ) // MCPPinger implements healthcheck.MCPPinger for transparent proxies type MCPPinger struct { targetURL string client *http.Client + logger *zap.SugaredLogger } // NewMCPPinger creates a new MCP pinger for transparent proxies -func NewMCPPinger(targetURL string) healthcheck.MCPPinger { +func NewMCPPinger(targetURL string, logger *zap.SugaredLogger) healthcheck.MCPPinger { return &MCPPinger{ targetURL: targetURL, client: &http.Client{ Timeout: 5 * time.Second, }, + logger: logger, } } @@ -40,7 +43,7 @@ func (p *MCPPinger) Ping(ctx context.Context) (time.Duration, error) { return 0, fmt.Errorf("failed to create HTTP request: %w", err) } - logger.Debugf("Checking SSE server health at %s", p.targetURL) + p.logger.Debugf("Checking SSE server health at %s", p.targetURL) // Send the request resp, err := p.client.Do(req) @@ -56,7 +59,7 @@ func (p *MCPPinger) Ping(ctx context.Context) (time.Duration, error) { // - 404 for non-existent endpoints (but server is still alive) // - Other 4xx/5xx may indicate server issues if resp.StatusCode >= 200 && resp.StatusCode < 500 { - logger.Debugf("SSE server health check successful in %v (status: %d)", duration, resp.StatusCode) + p.logger.Debugf("SSE server health check successful in %v (status: %d)", duration, resp.StatusCode) return duration, nil } diff --git a/pkg/transport/proxy/transparent/transparent_proxy.go b/pkg/transport/proxy/transparent/transparent_proxy.go index 2b043dc17..2f6c801e7 100644 --- a/pkg/transport/proxy/transparent/transparent_proxy.go +++ b/pkg/transport/proxy/transparent/transparent_proxy.go @@ -20,10 +20,10 @@ import ( "sync" "time" + "go.uber.org/zap" "golang.org/x/exp/jsonrpc2" "github.com/stacklok/toolhive/pkg/healthcheck" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/transport/session" "github.com/stacklok/toolhive/pkg/transport/types" ) @@ -69,6 +69,9 @@ type TransparentProxy struct { // Listener for the HTTP server listener net.Listener + + // Logger for the proxy + logger *zap.SugaredLogger } // NewTransparentProxy creates a new transparent proxy with optional middlewares. @@ -80,6 +83,7 @@ func NewTransparentProxy( prometheusHandler http.Handler, authInfoHandler http.Handler, enableHealthCheck bool, + logger *zap.SugaredLogger, middlewares ...types.MiddlewareFunction, ) *TransparentProxy { proxy := &TransparentProxy{ @@ -92,12 +96,13 @@ func NewTransparentProxy( prometheusHandler: prometheusHandler, authInfoHandler: authInfoHandler, sessionManager: session.NewManager(30*time.Minute, session.NewProxySession), + logger: logger, } // Create MCP pinger and health checker only if enabled if enableHealthCheck { - mcpPinger := NewMCPPinger(targetURI) - proxy.healthChecker = healthcheck.NewHealthChecker("sse", mcpPinger) + mcpPinger := NewMCPPinger(targetURI, logger) + proxy.healthChecker = healthcheck.NewHealthChecker("sse", mcpPinger, logger) } return proxy @@ -113,7 +118,7 @@ func (p *TransparentProxy) setServerInitialized() { p.mutex.Lock() p.IsServerInitialized = true p.mutex.Unlock() - logger.Infof("Server was initialized successfully for %s", p.containerName) + p.logger.Infof("Server was initialized successfully for %s", p.containerName) } } @@ -126,7 +131,7 @@ func (t *tracingTransport) forward(req *http.Request) (*http.Response, error) { } func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) { - reqBody := readRequestBody(req) + reqBody := readRequestBody(req, t.p.logger) path := req.URL.Path isMCP := strings.HasPrefix(path, "/mcp") @@ -143,17 +148,17 @@ func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) // Expected during shutdown or client disconnect—silently ignore return nil, err } - logger.Errorf("Failed to forward request: %v", err) + t.p.logger.Errorf("Failed to forward request: %v", err) return nil, err } if resp.StatusCode == http.StatusOK { // check if we saw a valid mcp header ct := resp.Header.Get("Mcp-Session-Id") if ct != "" { - logger.Infof("Detected Mcp-Session-Id header: %s", ct) + t.p.logger.Infof("Detected Mcp-Session-Id header: %s", ct) if _, ok := t.p.sessionManager.Get(ct); !ok { if err := t.p.sessionManager.AddWithID(ct); err != nil { - logger.Errorf("Failed to create session from header %s: %v", ct, err) + t.p.logger.Errorf("Failed to create session from header %s: %v", ct, err) } } t.p.setServerInitialized() @@ -169,7 +174,7 @@ func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) return resp, nil } -func readRequestBody(req *http.Request) []byte { +func readRequestBody(req *http.Request, logger *zap.SugaredLogger) []byte { reqBody := []byte{} if req.Body != nil { buf, err := io.ReadAll(req.Body) @@ -188,11 +193,11 @@ func (t *tracingTransport) detectInitialize(body []byte) bool { Method string `json:"method"` } if err := json.Unmarshal(body, &rpc); err != nil { - logger.Errorf("Failed to parse JSON-RPC body: %v", err) + t.p.logger.Errorf("Failed to parse JSON-RPC body: %v", err) return false } if rpc.Method == "initialize" { - logger.Infof("Detected initialize method call for %s", t.p.containerName) + t.p.logger.Infof("Detected initialize method call for %s", t.p.containerName) return true } return false @@ -226,7 +231,7 @@ func (p *TransparentProxy) modifyForSessionID(resp *http.Response) error { p.setServerInitialized() err := p.sessionManager.AddWithID(sid) if err != nil { - logger.Errorf("Failed to create session from SSE line: %v", err) + p.logger.Errorf("Failed to create session from SSE line: %v", err) } found = true } @@ -237,7 +242,7 @@ func (p *TransparentProxy) modifyForSessionID(resp *http.Response) error { } _, err := io.Copy(pw, originalBody) if err != nil && err != io.EOF { - logger.Errorf("Failed to copy response body: %v", err) + p.logger.Errorf("Failed to copy response body: %v", err) } }() @@ -265,7 +270,7 @@ func (p *TransparentProxy) Start(ctx context.Context) error { // Create a handler that logs requests handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - logger.Infof("Transparent proxy: %s %s -> %s", r.Method, r.URL.Path, targetURL) + p.logger.Infof("Transparent proxy: %s %s -> %s", r.Method, r.URL.Path, targetURL) proxy.ServeHTTP(w, r) }) @@ -276,7 +281,7 @@ func (p *TransparentProxy) Start(ctx context.Context) error { var finalHandler http.Handler = handler for i := len(p.middlewares) - 1; i >= 0; i-- { finalHandler = p.middlewares[i](finalHandler) - logger.Infof("Applied middleware %d\n", i+1) + p.logger.Infof("Applied middleware %d\n", i+1) } // Add the proxy handler for all paths except /health @@ -297,7 +302,7 @@ func (p *TransparentProxy) Start(ctx context.Context) error { // Add Prometheus metrics endpoint if handler is provided (no middlewares) if p.prometheusHandler != nil { mux.Handle("/metrics", p.prometheusHandler) - logger.Info("Prometheus metrics endpoint enabled at /metrics") + p.logger.Info("Prometheus metrics endpoint enabled at /metrics") } ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", p.host, p.port)) if err != nil { @@ -317,7 +322,7 @@ func (p *TransparentProxy) Start(ctx context.Context) error { } }) mux.Handle("/.well-known/", wellKnownHandler) - logger.Info("Well-known discovery endpoints enabled at /.well-known/ (no middlewares)") + p.logger.Info("Well-known discovery endpoints enabled at /.well-known/ (no middlewares)") } // Create the server @@ -335,7 +340,7 @@ func (p *TransparentProxy) Start(ctx context.Context) error { // Expected when listener is closed—silently return return } - logger.Errorf("Transparent proxy error: %v", err) + p.logger.Errorf("Transparent proxy error: %v", err) } }() // Start health-check monitoring only if health checker is enabled @@ -361,24 +366,24 @@ func (p *TransparentProxy) monitorHealth(parentCtx context.Context) { for { select { case <-parentCtx.Done(): - logger.Infof("Context cancelled, stopping health monitor for %s", p.containerName) + p.logger.Infof("Context cancelled, stopping health monitor for %s", p.containerName) return case <-p.shutdownCh: - logger.Infof("Shutdown initiated, stopping health monitor for %s", p.containerName) + p.logger.Infof("Shutdown initiated, stopping health monitor for %s", p.containerName) return case <-ticker.C: // Perform health check only if mcp server has been initialized if p.IsServerInitialized { alive := p.healthChecker.CheckHealth(parentCtx) if alive.Status != healthcheck.StatusHealthy { - logger.Infof("Health check failed for %s; initiating proxy shutdown", p.containerName) + p.logger.Infof("Health check failed for %s; initiating proxy shutdown", p.containerName) if err := p.Stop(parentCtx); err != nil { - logger.Errorf("Failed to stop proxy for %s: %v", p.containerName, err) + p.logger.Errorf("Failed to stop proxy for %s: %v", p.containerName, err) } return } } else { - logger.Infof("MCP server not initialized yet, skipping health check for %s", p.containerName) + p.logger.Infof("MCP server not initialized yet, skipping health check for %s", p.containerName) } } } @@ -396,10 +401,10 @@ func (p *TransparentProxy) Stop(ctx context.Context) error { if p.server != nil { err := p.server.Shutdown(ctx) if err != nil && err != http.ErrServerClosed && err != context.DeadlineExceeded { - logger.Warnf("Error during proxy shutdown: %v", err) + p.logger.Warnf("Error during proxy shutdown: %v", err) return err } - logger.Infof("Server for %s stopped successfully", p.containerName) + p.logger.Infof("Server for %s stopped successfully", p.containerName) p.server = nil } diff --git a/pkg/transport/proxy/transparent/transparent_test.go b/pkg/transport/proxy/transparent/transparent_test.go index 409cd01cf..bde4356f7 100644 --- a/pkg/transport/proxy/transparent/transparent_test.go +++ b/pkg/transport/proxy/transparent/transparent_test.go @@ -11,16 +11,14 @@ import ( "github.com/stretchr/testify/assert" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" ) -func init() { - logger.Initialize() // ensure logging doesn't panic -} - func TestStreamingSessionIDDetection(t *testing.T) { t.Parallel() - proxy := NewTransparentProxy("127.0.0.1", 0, "test", "http://example.com", nil, nil, true) + + logger := log.NewLogger() + proxy := NewTransparentProxy("127.0.0.1", 0, "test", "http://example.com", nil, nil, true, logger) target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") w.WriteHeader(200) @@ -77,7 +75,8 @@ func createBasicProxy(p *TransparentProxy, targetURL *url.URL) *httputil.Reverse func TestNoSessionIDInNonSSE(t *testing.T) { t.Parallel() - p := NewTransparentProxy("127.0.0.1", 0, "test", "", nil, nil, false) + logger := log.NewLogger() + p := NewTransparentProxy("127.0.0.1", 0, "test", "", nil, nil, false, logger) target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { // Set both content-type and also optionally MCP header to test behavior @@ -103,7 +102,8 @@ func TestNoSessionIDInNonSSE(t *testing.T) { func TestHeaderBasedSessionInitialization(t *testing.T) { t.Parallel() - p := NewTransparentProxy("127.0.0.1", 0, "test", "", nil, nil, false) + logger := log.NewLogger() + p := NewTransparentProxy("127.0.0.1", 0, "test", "", nil, nil, false, logger) target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { // Set both content-type and also optionally MCP header to test behavior diff --git a/pkg/transport/stdio.go b/pkg/transport/stdio.go index d4c56d851..e12b79831 100644 --- a/pkg/transport/stdio.go +++ b/pkg/transport/stdio.go @@ -11,12 +11,12 @@ import ( "time" "unicode" + "go.uber.org/zap" "golang.org/x/exp/jsonrpc2" "github.com/stacklok/toolhive/pkg/container" rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/ignore" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/permissions" "github.com/stacklok/toolhive/pkg/transport/errors" "github.com/stacklok/toolhive/pkg/transport/proxy/httpsse" @@ -52,6 +52,8 @@ type StdioTransport struct { // Container monitor monitor rt.Monitor + + logger *zap.SugaredLogger } // NewStdioTransport creates a new stdio transport. @@ -61,6 +63,7 @@ func NewStdioTransport( deployer rt.Deployer, debug bool, prometheusHandler http.Handler, + logger *zap.SugaredLogger, middlewares ...types.MiddlewareFunction, ) *StdioTransport { return &StdioTransport{ @@ -72,6 +75,7 @@ func NewStdioTransport( prometheusHandler: prometheusHandler, shutdownCh: make(chan struct{}), proxyMode: types.ProxyModeSSE, // default to SSE for backward compatibility + logger: logger, } } @@ -119,7 +123,7 @@ func (t *StdioTransport) Setup( containerOptions.IgnoreConfig = ignoreConfig // Create the container - logger.Infof("Deploying workload %s from image %s...", containerName, image) + t.logger.Infof("Deploying workload %s from image %s...", containerName, image) _, err := t.deployer.DeployWorkload( ctx, image, @@ -135,7 +139,7 @@ func (t *StdioTransport) Setup( if err != nil { return fmt.Errorf("failed to create container: %v", err) } - logger.Infof("Container created: %s", containerName) + t.logger.Infof("Container created: %s", containerName) return nil } @@ -167,17 +171,17 @@ func (t *StdioTransport) Start(ctx context.Context) error { // Create and start the correct proxy with middlewares switch t.proxyMode { case types.ProxyModeStreamableHTTP: - t.httpProxy = streamable.NewHTTPProxy(t.host, t.proxyPort, t.containerName, t.prometheusHandler, t.middlewares...) + t.httpProxy = streamable.NewHTTPProxy(t.host, t.proxyPort, t.containerName, t.prometheusHandler, t.logger, t.middlewares...) if err := t.httpProxy.Start(ctx); err != nil { return err } - logger.Info("Streamable HTTP proxy started, processing messages...") + t.logger.Info("Streamable HTTP proxy started, processing messages...") case types.ProxyModeSSE: - t.httpProxy = httpsse.NewHTTPSSEProxy(t.host, t.proxyPort, t.containerName, t.prometheusHandler, t.middlewares...) + t.httpProxy = httpsse.NewHTTPSSEProxy(t.host, t.proxyPort, t.containerName, t.prometheusHandler, t.logger, t.middlewares...) if err := t.httpProxy.Start(ctx); err != nil { return err } - logger.Info("HTTP SSE proxy started, processing messages...") + t.logger.Info("HTTP SSE proxy started, processing messages...") default: return fmt.Errorf("unsupported proxy mode: %v", t.proxyMode) } @@ -186,7 +190,7 @@ func (t *StdioTransport) Start(ctx context.Context) error { go t.processMessages(ctx, t.stdin, t.stdout) // Create a container monitor - monitorRuntime, err := container.NewFactory().Create(ctx) + monitorRuntime, err := container.NewFactory(t.logger).Create(ctx) if err != nil { return fmt.Errorf("failed to create container monitor: %v", err) } @@ -240,14 +244,14 @@ func (t *StdioTransport) Stop(ctx context.Context) error { // Stop the HTTP proxy if t.httpProxy != nil { if err := t.httpProxy.Stop(ctx); err != nil { - logger.Warnf("Warning: Failed to stop HTTP proxy: %v", err) + t.logger.Warnf("Warning: Failed to stop HTTP proxy: %v", err) } } // Close stdin and stdout if they're open if t.stdin != nil { if err := t.stdin.Close(); err != nil { - logger.Warnf("Warning: Failed to close stdin: %v", err) + t.logger.Warnf("Warning: Failed to close stdin: %v", err) } t.stdin = nil } @@ -258,11 +262,11 @@ func (t *StdioTransport) Stop(ctx context.Context) error { running, err := t.deployer.IsWorkloadRunning(ctx, t.containerName) if err != nil { // If there's an error checking the workload status, it might be gone already - logger.Warnf("Warning: Failed to check workload status: %v", err) + t.logger.Warnf("Warning: Failed to check workload status: %v", err) } else if running { // Only try to stop the workload if it's still running if err := t.deployer.StopWorkload(ctx, t.containerName); err != nil { - logger.Warnf("Warning: Failed to stop workload: %v", err) + t.logger.Warnf("Warning: Failed to stop workload: %v", err) } } } @@ -310,11 +314,11 @@ func (t *StdioTransport) processMessages(ctx context.Context, stdin io.WriteClos case <-ctx.Done(): return case msg := <-messageCh: - logger.Info("Process incoming messages and sending message to container") + t.logger.Info("Process incoming messages and sending message to container") if err := t.sendMessageToContainer(ctx, stdin, msg); err != nil { - logger.Errorf("Error sending message to container: %v", err) + t.logger.Errorf("Error sending message to container: %v", err) } - logger.Info("Messages processed") + t.logger.Info("Messages processed") } } } @@ -336,9 +340,9 @@ func (t *StdioTransport) processStdout(ctx context.Context, stdout io.ReadCloser n, err := stdout.Read(readBuffer) if err != nil { if err == io.EOF { - logger.Info("Container stdout closed") + t.logger.Info("Container stdout closed") } else { - logger.Errorf("Error reading from container stdout: %v", err) + t.logger.Errorf("Error reading from container stdout: %v", err) } return } @@ -421,9 +425,9 @@ func isSpace(r rune) bool { // parseAndForwardJSONRPC parses a JSON-RPC message and forwards it. func (t *StdioTransport) parseAndForwardJSONRPC(ctx context.Context, line string) { // Log the raw line for debugging - logger.Infof("JSON-RPC raw: %s", line) + t.logger.Infof("JSON-RPC raw: %s", line) jsonData := sanitizeJSONString(line) - logger.Infof("Sanitized JSON: %s", jsonData) + t.logger.Infof("Sanitized JSON: %s", jsonData) if jsonData == "" || jsonData == "[]" { return @@ -432,24 +436,24 @@ func (t *StdioTransport) parseAndForwardJSONRPC(ctx context.Context, line string // Try to parse the JSON msg, err := jsonrpc2.DecodeMessage([]byte(jsonData)) if err != nil { - logger.Errorf("Error parsing JSON-RPC message: %v", err) + t.logger.Errorf("Error parsing JSON-RPC message: %v", err) return } // Log the message - logger.Infof("Received JSON-RPC message: %T", msg) + t.logger.Infof("Received JSON-RPC message: %T", msg) if err := t.httpProxy.ForwardResponseToClients(ctx, msg); err != nil { if t.proxyMode == types.ProxyModeStreamableHTTP { - logger.Errorf("Error forwarding to streamable-http client: %v", err) + t.logger.Errorf("Error forwarding to streamable-http client: %v", err) } else { - logger.Errorf("Error forwarding to SSE clients: %v", err) + t.logger.Errorf("Error forwarding to SSE clients: %v", err) } } } // sendMessageToContainer sends a JSON-RPC message to the container. -func (*StdioTransport) sendMessageToContainer(_ context.Context, stdin io.Writer, msg jsonrpc2.Message) error { +func (t *StdioTransport) sendMessageToContainer(_ context.Context, stdin io.Writer, msg jsonrpc2.Message) error { // Serialize the message data, err := jsonrpc2.EncodeMessage(msg) if err != nil { @@ -460,11 +464,11 @@ func (*StdioTransport) sendMessageToContainer(_ context.Context, stdin io.Writer data = append(data, '\n') // Write to stdin - logger.Info("Writing to container stdin") + t.logger.Info("Writing to container stdin") if _, err := stdin.Write(data); err != nil { return fmt.Errorf("failed to write to container stdin: %w", err) } - logger.Info("Wrote to container stdin") + t.logger.Info("Wrote to container stdin") return nil } @@ -477,17 +481,17 @@ func (t *StdioTransport) handleContainerExit(ctx context.Context) { case err, ok := <-t.errorCh: // Check if the channel is closed if !ok { - logger.Infof("Container monitor channel closed for %s", t.containerName) + t.logger.Infof("Container monitor channel closed for %s", t.containerName) return } - logger.Infof("Container %s exited: %v", t.containerName, err) + t.logger.Infof("Container %s exited: %v", t.containerName, err) // Check if the transport is already stopped before trying to stop it select { case <-t.shutdownCh: // Transport is already stopping or stopped - logger.Infof("Transport for %s is already stopping or stopped", t.containerName) + t.logger.Infof("Transport for %s is already stopping or stopped", t.containerName) return default: // Transport is still running, stop it @@ -496,7 +500,7 @@ func (t *StdioTransport) handleContainerExit(ctx context.Context) { defer cancel() if stopErr := t.Stop(stopCtx); stopErr != nil { - logger.Errorf("Error stopping transport after container exit: %v", stopErr) + t.logger.Errorf("Error stopping transport after container exit: %v", stopErr) } } } diff --git a/pkg/transport/stdio_test.go b/pkg/transport/stdio_test.go index 9bd83f78f..fbbe77048 100644 --- a/pkg/transport/stdio_test.go +++ b/pkg/transport/stdio_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/mock" "golang.org/x/exp/jsonrpc2" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" ) // MockHTTPProxy is a mock implementation of types.Proxy @@ -100,8 +100,6 @@ func TestSanitizeJSONString(t *testing.T) { func TestParseAndForwardJSONRPC(t *testing.T) { t.Parallel() - // Initialize logger for testing - logger.Initialize() tests := []struct { name string @@ -148,6 +146,7 @@ func TestParseAndForwardJSONRPC(t *testing.T) { // Create transport with mock proxy transport := &StdioTransport{ httpProxy: mockProxy, + logger: log.NewLogger(), } // Set up expectations if the message should be forwarded diff --git a/pkg/transport/tunnel/ngrok/tunnel_provider.go b/pkg/transport/tunnel/ngrok/tunnel_provider.go index 47fb40cd8..bd70e9be5 100644 --- a/pkg/transport/tunnel/ngrok/tunnel_provider.go +++ b/pkg/transport/tunnel/ngrok/tunnel_provider.go @@ -8,15 +8,15 @@ import ( "path/filepath" "strings" + "go.uber.org/zap" "golang.ngrok.com/ngrok/v2" "gopkg.in/yaml.v3" - - "github.com/stacklok/toolhive/pkg/logger" ) // TunnelProvider implements the TunnelProvider interface for ngrok. type TunnelProvider struct { config TunnelConfig + logger *zap.SugaredLogger } // TunnelConfig holds configuration options for the ngrok tunnel provider. @@ -92,12 +92,12 @@ func (p *TunnelProvider) StartTunnel(ctx context.Context, name, targetURI string <-ctx.Done() return nil } - logger.Infof("[ngrok] Starting tunnel %q → %s", name, targetURI) + p.logger.Infof("[ngrok] Starting tunnel %q → %s", name, targetURI) agent, err := ngrok.NewAgent( ngrok.WithAuthtoken(p.config.AuthToken), ngrok.WithEventHandler(func(e ngrok.Event) { - logger.Infof("ngrok event: %s at %s", e.EventType(), e.Timestamp()) + p.logger.Infof("ngrok event: %s at %s", e.EventType(), e.Timestamp()) }), ) @@ -127,12 +127,12 @@ func (p *TunnelProvider) StartTunnel(ctx context.Context, name, targetURI string return fmt.Errorf("ngrok.Forward error: %w", err) } - logger.Infof("ngrok forwarding live at %s", forwarder.URL()) + p.logger.Infof("ngrok forwarding live at %s", forwarder.URL()) // Run in background, non-blocking on `.Done()` go func() { <-forwarder.Done() - logger.Infof("ngrok forwarding stopped: %s", forwarder.URL()) + p.logger.Infof("ngrok forwarding stopped: %s", forwarder.URL()) }() // Return immediately diff --git a/pkg/workloads/manager.go b/pkg/workloads/manager.go index c7b5549a8..b4f58b936 100644 --- a/pkg/workloads/manager.go +++ b/pkg/workloads/manager.go @@ -12,6 +12,7 @@ import ( "time" "github.com/adrg/xdg" + "go.uber.org/zap" "golang.org/x/sync/errgroup" "github.com/stacklok/toolhive/pkg/client" @@ -21,7 +22,6 @@ import ( "github.com/stacklok/toolhive/pkg/core" "github.com/stacklok/toolhive/pkg/groups" "github.com/stacklok/toolhive/pkg/labels" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/process" "github.com/stacklok/toolhive/pkg/runner" "github.com/stacklok/toolhive/pkg/secrets" @@ -67,6 +67,7 @@ type Manager interface { type defaultManager struct { runtime rt.Runtime statuses statuses.StatusManager + logger *zap.SugaredLogger } // ErrWorkloadNotRunning is returned when a container cannot be found by name. @@ -78,13 +79,13 @@ const ( ) // NewManager creates a new container manager instance. -func NewManager(ctx context.Context) (Manager, error) { - runtime, err := ct.NewFactory().Create(ctx) +func NewManager(ctx context.Context, logger *zap.SugaredLogger) (Manager, error) { + runtime, err := ct.NewFactory(logger).Create(ctx) if err != nil { return nil, err } - statusManager, err := statuses.NewStatusManager(runtime) + statusManager, err := statuses.NewStatusManager(runtime, logger) if err != nil { return nil, fmt.Errorf("failed to create status manager: %w", err) } @@ -92,12 +93,13 @@ func NewManager(ctx context.Context) (Manager, error) { return &defaultManager{ runtime: runtime, statuses: statusManager, + logger: logger, }, nil } // NewManagerFromRuntime creates a new container manager instance from an existing runtime. -func NewManagerFromRuntime(runtime rt.Runtime) (Manager, error) { - statusManager, err := statuses.NewStatusManager(runtime) +func NewManagerFromRuntime(runtime rt.Runtime, logger *zap.SugaredLogger) (Manager, error) { + statusManager, err := statuses.NewStatusManager(runtime, logger) if err != nil { return nil, fmt.Errorf("failed to create status manager: %w", err) } @@ -105,6 +107,7 @@ func NewManagerFromRuntime(runtime rt.Runtime) (Manager, error) { return &defaultManager{ runtime: runtime, statuses: statusManager, + logger: logger, }, nil } @@ -139,7 +142,7 @@ func (d *defaultManager) StopWorkloads(ctx context.Context, names []string) (*er if err != nil { if errors.Is(err, rt.ErrWorkloadNotFound) { // Log but don't fail the entire operation for not found containers - logger.Warnf("Warning: Failed to stop workload %s: %v", name, err) + d.logger.Warnf("Warning: Failed to stop workload %s: %v", name, err) continue } return nil, fmt.Errorf("failed to find workload %s: %v", name, err) @@ -148,13 +151,13 @@ func (d *defaultManager) StopWorkloads(ctx context.Context, names []string) (*er running := container.IsRunning() if !running { // Log but don't fail the entire operation for not running containers - logger.Warnf("Warning: Failed to stop workload %s: %v", name, ErrWorkloadNotRunning) + d.logger.Warnf("Warning: Failed to stop workload %s: %v", name, ErrWorkloadNotRunning) continue } // Transition workload to `stopping` state. if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopping, ""); err != nil { - logger.Warnf("Failed to set workload %s status to stopping: %v", name, err) + d.logger.Warnf("Failed to set workload %s status to stopping: %v", name, err) } containers = append(containers, &container) } @@ -169,28 +172,28 @@ func (d *defaultManager) RunWorkload(ctx context.Context, runConfig *runner.RunC return fmt.Errorf("failed to create workload status: %v", err) } - mcpRunner := runner.NewRunner(runConfig, d.statuses) + mcpRunner := runner.NewRunner(runConfig, d.statuses, d.logger) err := mcpRunner.Run(ctx) if err != nil { // If the run failed, we should set the status to error. if statusErr := d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusError, err.Error()); statusErr != nil { - logger.Warnf("Failed to set workload %s status to error: %v", runConfig.BaseName, statusErr) + d.logger.Warnf("Failed to set workload %s status to error: %v", runConfig.BaseName, statusErr) } } return err } -func validateSecretParameters(ctx context.Context, runConfig *runner.RunConfig) error { +func validateSecretParameters(ctx context.Context, runConfig *runner.RunConfig, logger *zap.SugaredLogger) error { // If there are run secrets, validate them if len(runConfig.Secrets) > 0 { - cfg := config.GetConfig() + cfg := config.GetConfig(logger) providerType, err := cfg.Secrets.GetProviderType() if err != nil { return fmt.Errorf("error determining secrets provider type: %w", err) } - secretManager, err := secrets.CreateSecretProvider(providerType) + secretManager, err := secrets.CreateSecretProvider(providerType, logger) if err != nil { return fmt.Errorf("error instantiating secret manager: %w", err) } @@ -205,7 +208,7 @@ func validateSecretParameters(ctx context.Context, runConfig *runner.RunConfig) func (d *defaultManager) RunWorkloadDetached(ctx context.Context, runConfig *runner.RunConfig) error { // before running, validate the parameters for the workload - err := validateSecretParameters(ctx, runConfig) + err := validateSecretParameters(ctx, runConfig, d.logger) if err != nil { return fmt.Errorf("failed to validate workload parameters: %w", err) } @@ -224,10 +227,10 @@ func (d *defaultManager) RunWorkloadDetached(ctx context.Context, runConfig *run // #nosec G304 - This is safe as baseName is generated by the application logFile, err := os.OpenFile(logFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) if err != nil { - logger.Warnf("Warning: Failed to create log file: %v", err) + d.logger.Warnf("Warning: Failed to create log file: %v", err) } else { defer logFile.Close() - logger.Infof("Logging to: %s", logFilePath) + d.logger.Infof("Logging to: %s", logFilePath) } // Use the restart command to start the detached process @@ -245,8 +248,8 @@ func (d *defaultManager) RunWorkloadDetached(ctx context.Context, runConfig *run // NOTE: This breaks the abstraction slightly since this is only relevant for the CLI, but there // are checks inside `GetSecretsPassword` to ensure this does not get called in a detached process. // This will be addressed in a future re-think of the secrets manager interface. - if needSecretsPassword(runConfig.Secrets) { - password, err := secrets.GetSecretsPassword("") + if needSecretsPassword(runConfig.Secrets, d.logger) { + password, err := secrets.GetSecretsPassword("", d.logger) if err != nil { return fmt.Errorf("failed to get secrets password: %v", err) } @@ -280,11 +283,11 @@ func (d *defaultManager) RunWorkloadDetached(ctx context.Context, runConfig *run // Write the PID to a file so the stop command can kill the process if err := process.WritePIDFile(runConfig.BaseName, detachedCmd.Process.Pid); err != nil { - logger.Warnf("Warning: Failed to write PID file: %v", err) + d.logger.Warnf("Warning: Failed to write PID file: %v", err) } - logger.Infof("MCP server is running in the background (PID: %d)", detachedCmd.Process.Pid) - logger.Infof("Use 'thv stop %s' to stop the server", runConfig.ContainerName) + d.logger.Infof("MCP server is running in the background (PID: %d)", detachedCmd.Process.Pid) + d.logger.Infof("Use 'thv stop %s' to stop the server", runConfig.ContainerName) return nil } @@ -320,7 +323,7 @@ func (d *defaultManager) deleteWorkload(ctx context.Context, name string) error // Set status to removing if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusRemoving, ""); err != nil { - logger.Warnf("Failed to set workload %s status to removing: %v", name, err) + d.logger.Warnf("Failed to set workload %s status to removing: %v", name, err) } containerLabels := container.Labels @@ -341,7 +344,7 @@ func (d *defaultManager) deleteWorkload(ctx context.Context, name string) error // Remove the workload status from the status store if err := d.statuses.DeleteWorkloadStatus(ctx, name); err != nil { - logger.Warnf("failed to delete workload status for %s: %v", name, err) + d.logger.Warnf("failed to delete workload status for %s: %v", name, err) } return nil @@ -353,11 +356,11 @@ func (d *defaultManager) getWorkloadContainer(childCtx, ctx context.Context, nam if err != nil { if errors.Is(err, rt.ErrWorkloadNotFound) { // Log but don't fail the entire operation for not found containers - logger.Warnf("Warning: Failed to delete workload %s: %v", name, err) + d.logger.Warnf("Warning: Failed to delete workload %s: %v", name, err) return nil, nil } if statusErr := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusError, err.Error()); statusErr != nil { - logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) + d.logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) } return nil, fmt.Errorf("failed to find workload %s: %v", name, err) } @@ -365,19 +368,19 @@ func (d *defaultManager) getWorkloadContainer(childCtx, ctx context.Context, nam } // stopProxyIfNeeded stops the proxy process if the workload has a base name -func (*defaultManager) stopProxyIfNeeded(name, baseName string) { - logger.Infof("Removing proxy process for %s...", name) +func (d *defaultManager) stopProxyIfNeeded(name, baseName string) { + d.logger.Infof("Removing proxy process for %s...", name) if baseName != "" { - proxy.StopProcess(baseName) + proxy.StopProcess(baseName, d.logger) } } // removeContainer removes the container from the runtime func (d *defaultManager) removeContainer(childCtx, ctx context.Context, name string) error { - logger.Infof("Removing container %s...", name) + d.logger.Infof("Removing container %s...", name) if err := d.runtime.RemoveWorkload(childCtx, name); err != nil { if statusErr := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusError, err.Error()); statusErr != nil { - logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) + d.logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) } return fmt.Errorf("failed to remove container: %v", err) } @@ -392,23 +395,23 @@ func (d *defaultManager) cleanupWorkloadResources(childCtx context.Context, name // Clean up temporary permission profile if err := d.cleanupTempPermissionProfile(childCtx, baseName); err != nil { - logger.Warnf("Warning: Failed to cleanup temporary permission profile: %v", err) + d.logger.Warnf("Warning: Failed to cleanup temporary permission profile: %v", err) } // Delete the saved state - if err := state.DeleteSavedRunConfig(childCtx, baseName); err != nil { - logger.Warnf("Warning: Failed to delete saved state: %v", err) + if err := state.DeleteSavedRunConfig(childCtx, baseName, d.logger); err != nil { + d.logger.Warnf("Warning: Failed to delete saved state: %v", err) } else { - logger.Infof("Saved state for %s removed", baseName) + d.logger.Infof("Saved state for %s removed", baseName) } - logger.Infof("Container %s removed", name) + d.logger.Infof("Container %s removed", name) // Remove client configurations - if err := removeClientConfigurations(name); err != nil { - logger.Warnf("Warning: Failed to remove client configurations: %v", err) + if err := removeClientConfigurations(name, d.logger); err != nil { + d.logger.Warnf("Warning: Failed to remove client configurations: %v", err) } else { - logger.Infof("Client configurations for %s removed", name) + d.logger.Infof("Client configurations for %s removed", name) } } @@ -457,8 +460,8 @@ func (d *defaultManager) RestartWorkloads(ctx context.Context, names []string, f container, err := d.runtime.GetWorkloadInfo(childCtx, name) if err != nil { if errors.Is(err, rt.ErrWorkloadNotFound) { - logger.Warnf("Warning: Failed to find container: %v", err) - logger.Warnf("Trying to find state with name %s directly...", name) + d.logger.Warnf("Warning: Failed to find container: %v", err) + d.logger.Warnf("Trying to find state with name %s directly...", name) // Try to use the provided name as the base name containerBaseName = name @@ -473,10 +476,10 @@ func (d *defaultManager) RestartWorkloads(ctx context.Context, names []string, f } // Check if the proxy process is running - proxyRunning := proxy.IsRunning(containerBaseName) + proxyRunning := proxy.IsRunning(containerBaseName, d.logger) if running && proxyRunning { - logger.Infof("Container %s and proxy are already running", name) + d.logger.Infof("Container %s and proxy are already running", name) return nil } @@ -493,20 +496,20 @@ func (d *defaultManager) RestartWorkloads(ctx context.Context, names []string, f // At this point we're sure that the workload exists but is not running. // Transition workload to `starting` state. if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStarting, ""); err != nil { - logger.Warnf("Failed to set workload %s status to starting: %v", name, err) + d.logger.Warnf("Failed to set workload %s status to starting: %v", name, err) } - logger.Infof("Loaded configuration from state for %s", containerBaseName) + d.logger.Infof("Loaded configuration from state for %s", containerBaseName) // Run the tooling server inside a detached process. - logger.Infof("Starting tooling server %s...", name) + d.logger.Infof("Starting tooling server %s...", name) // If the container is running but the proxy is not, stop the container first if running { // && !proxyRunning was previously here but is implied by previous if statement. - logger.Infof("Container %s is running but proxy is not. Stopping container...", name) + d.logger.Infof("Container %s is running but proxy is not. Stopping container...", name) if err = d.runtime.StopWorkload(childCtx, name); err != nil { return fmt.Errorf("failed to stop container %s: %v", name, err) } - logger.Infof("Container %s stopped", name) + d.logger.Infof("Container %s stopped", name) } if foreground { @@ -521,7 +524,7 @@ func (d *defaultManager) RestartWorkloads(ctx context.Context, names []string, f // TODO: Move to dedicated config management interface. // updateClientConfigurations updates client configuration files with the MCP server URL -func removeClientConfigurations(containerName string) error { +func removeClientConfigurations(containerName string, logger *zap.SugaredLogger) error { // Get the workload's group by loading its run config runConfig, err := runner.LoadState(context.Background(), containerName) var group string @@ -532,7 +535,7 @@ func removeClientConfigurations(containerName string) error { group = runConfig.Group } - clientManager, err := client.NewManager(context.Background()) + clientManager, err := client.NewManager(context.Background(), logger) if err != nil { logger.Warnf("Warning: Failed to create client manager for %s, skipping client config removal: %v", containerName, err) return nil @@ -553,33 +556,33 @@ func (d *defaultManager) loadRunnerFromState(ctx context.Context, baseName strin runConfig.Deployer = d.runtime // Create a new runner with the loaded configuration - return runner.NewRunner(runConfig, d.statuses), nil + return runner.NewRunner(runConfig, d.statuses, d.logger), nil } -func needSecretsPassword(secretOptions []string) bool { +func needSecretsPassword(secretOptions []string, logger *zap.SugaredLogger) bool { // If the user did not ask for any secrets, then don't attempt to instantiate // the secrets manager. if len(secretOptions) == 0 { return false } // Ignore err - if the flag is not set, it's not needed. - providerType, _ := config.GetConfig().Secrets.GetProviderType() + providerType, _ := config.GetConfig(logger).Secrets.GetProviderType() return providerType == secrets.EncryptedType } // cleanupTempPermissionProfile cleans up temporary permission profile files for a given base name -func (*defaultManager) cleanupTempPermissionProfile(ctx context.Context, baseName string) error { +func (d *defaultManager) cleanupTempPermissionProfile(ctx context.Context, baseName string) error { // Try to load the saved configuration to get the permission profile path runConfig, err := runner.LoadState(ctx, baseName) if err != nil { // If we can't load the state, there's nothing to clean up - logger.Debugf("Could not load state for %s, skipping permission profile cleanup: %v", baseName, err) + d.logger.Debugf("Could not load state for %s, skipping permission profile cleanup: %v", baseName, err) return nil } // Clean up the temporary permission profile if it exists if runConfig.PermissionProfileNameOrPath != "" { - if err := runner.CleanupTempPermissionProfile(runConfig.PermissionProfileNameOrPath); err != nil { + if err := runner.CleanupTempPermissionProfile(runConfig.PermissionProfileNameOrPath, d.logger); err != nil { return fmt.Errorf("failed to cleanup temporary permission profile: %v", err) } } @@ -598,27 +601,27 @@ func (d *defaultManager) stopWorkloads(ctx context.Context, workloads []*rt.Cont name := labels.GetContainerBaseName(workload.Labels) // Stop the proxy process - proxy.StopProcess(name) + proxy.StopProcess(name, d.logger) - logger.Infof("Stopping containers for %s...", name) + d.logger.Infof("Stopping containers for %s...", name) // Stop the container if err := d.runtime.StopWorkload(childCtx, workload.Name); err != nil { if statusErr := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusError, err.Error()); statusErr != nil { - logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) + d.logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) } return fmt.Errorf("failed to stop container: %w", err) } - if err := removeClientConfigurations(name); err != nil { - logger.Warnf("Warning: Failed to remove client configurations: %v", err) + if err := removeClientConfigurations(name, d.logger); err != nil { + d.logger.Warnf("Warning: Failed to remove client configurations: %v", err) } else { - logger.Infof("Client configurations for %s removed", name) + d.logger.Infof("Client configurations for %s removed", name) } if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopped, ""); err != nil { - logger.Warnf("Failed to set workload %s status to stopped: %v", name, err) + d.logger.Warnf("Failed to set workload %s status to stopped: %v", name, err) } - logger.Infof("Successfully stopped %s...", name) + d.logger.Infof("Successfully stopped %s...", name) return nil }) } @@ -627,7 +630,7 @@ func (d *defaultManager) stopWorkloads(ctx context.Context, workloads []*rt.Cont } // MoveToDefaultGroup moves the specified workloads to the default group by updating their runconfig. -func (*defaultManager) MoveToDefaultGroup(ctx context.Context, workloadNames []string, groupName string) error { +func (d *defaultManager) MoveToDefaultGroup(ctx context.Context, workloadNames []string, groupName string) error { for _, workloadName := range workloadNames { // Validate workload name if err := types.ValidateWorkloadName(workloadName); err != nil { @@ -642,7 +645,7 @@ func (*defaultManager) MoveToDefaultGroup(ctx context.Context, workloadNames []s // Check if the workload is actually in the specified group if runnerConfig.Group != groupName { - logger.Debugf("Workload %s is not in group %s (current group: %s), skipping", + d.logger.Debugf("Workload %s is not in group %s (current group: %s), skipping", workloadName, groupName, runnerConfig.Group) continue } @@ -651,11 +654,11 @@ func (*defaultManager) MoveToDefaultGroup(ctx context.Context, workloadNames []s runnerConfig.Group = groups.DefaultGroup // Save the updated configuration - if err = runnerConfig.SaveState(ctx); err != nil { + if err = runnerConfig.SaveState(ctx, d.logger); err != nil { return fmt.Errorf("failed to save updated configuration for workload %s: %w", workloadName, err) } - logger.Infof("Moved workload %s to default group", workloadName) + d.logger.Infof("Moved workload %s to default group", workloadName) } return nil diff --git a/pkg/workloads/statuses/file_status.go b/pkg/workloads/statuses/file_status.go index 69d75ccf6..6d1903227 100644 --- a/pkg/workloads/statuses/file_status.go +++ b/pkg/workloads/statuses/file_status.go @@ -11,11 +11,11 @@ import ( "github.com/adrg/xdg" "github.com/gofrs/flock" + "go.uber.org/zap" rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/core" "github.com/stacklok/toolhive/pkg/labels" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/transport/proxy" "github.com/stacklok/toolhive/pkg/workloads/types" ) @@ -31,7 +31,7 @@ const ( // NewFileStatusManager creates a new file-based StatusManager. // Status files will be stored in the XDG data directory under "statuses/". -func NewFileStatusManager(runtime rt.Runtime) (StatusManager, error) { +func NewFileStatusManager(runtime rt.Runtime, logger *zap.SugaredLogger) (StatusManager, error) { // Get the base directory using XDG data directory baseDir, err := xdg.DataFile(statusesPrefix) if err != nil { @@ -46,6 +46,7 @@ func NewFileStatusManager(runtime rt.Runtime) (StatusManager, error) { return &fileStatusManager{ baseDir: baseDir, runtime: runtime, + logger: logger, }, nil } @@ -55,6 +56,7 @@ func NewFileStatusManager(runtime rt.Runtime) (StatusManager, error) { type fileStatusManager struct { baseDir string runtime rt.Runtime + logger *zap.SugaredLogger } // workloadStatusFile represents the JSON structure stored on disk @@ -140,7 +142,7 @@ func (f *fileStatusManager) ListWorkloads(ctx context.Context, listAll bool, lab for _, container := range runtimeContainers { workload, err := types.WorkloadFromContainerInfo(&container) if err != nil { - logger.Warnf("failed to convert container info for workload %s: %v", container.Name, err) + f.logger.Warnf("failed to convert container info for workload %s: %v", container.Name, err) continue } workloadMap[container.Name] = workload @@ -152,7 +154,7 @@ func (f *fileStatusManager) ListWorkloads(ctx context.Context, listAll bool, lab // Validate running workloads similar to GetWorkload validatedWorkload, err := f.validateWorkloadInList(ctx, name, fileWorkload, runtimeContainer) if err != nil { - logger.Warnf("failed to validate workload %s in list: %v", name, err) + f.logger.Warnf("failed to validate workload %s in list: %v", name, err) // Fall back to basic merge without validation runtimeWorkload := workloadMap[name] runtimeWorkload.Status = fileWorkload.Status @@ -231,12 +233,12 @@ func (f *fileStatusManager) SetWorkloadStatus( return fmt.Errorf("failed to write updated status for workload %s: %w", workloadName, err) } - logger.Debugf("workload %s set to status %s (context: %s)", workloadName, status, contextMsg) + f.logger.Debugf("workload %s set to status %s (context: %s)", workloadName, status, contextMsg) return nil }) if err != nil { - logger.Errorf("error updating workload %s status: %v", workloadName, err) + f.logger.Errorf("error updating workload %s status: %v", workloadName, err) } return err } @@ -250,7 +252,7 @@ func (f *fileStatusManager) DeleteWorkloadStatus(ctx context.Context, workloadNa } // Remove lock file (best effort) - done by withFileLock after this function returns - logger.Debugf("workload %s status deleted", workloadName) + f.logger.Debugf("workload %s status deleted", workloadName) return nil }) } @@ -287,11 +289,11 @@ func (f *fileStatusManager) withFileLock(ctx context.Context, workloadName strin fileLock := flock.New(lockFilePath) defer func() { if err := fileLock.Unlock(); err != nil { - logger.Warnf("failed to unlock file %s: %v", lockFilePath, err) + f.logger.Warnf("failed to unlock file %s: %v", lockFilePath, err) } // Attempt to remove lock file (best effort) if err := os.Remove(lockFilePath); err != nil && !os.IsNotExist(err) { - logger.Warnf("failed to remove lock file for workload %s: %v", workloadName, err) + f.logger.Warnf("failed to remove lock file for workload %s: %v", workloadName, err) } }() @@ -320,7 +322,7 @@ func (f *fileStatusManager) withFileReadLock(ctx context.Context, workloadName s fileLock := flock.New(lockFilePath) defer func() { if err := fileLock.Unlock(); err != nil { - logger.Warnf("failed to unlock file %s: %v", lockFilePath, err) + f.logger.Warnf("failed to unlock file %s: %v", lockFilePath, err) } }() @@ -400,7 +402,7 @@ func (f *fileStatusManager) getWorkloadsFromFiles() (map[string]core.Workload, e // Read the status file statusFile, err := f.readStatusFile(file) if err != nil { - logger.Warnf("failed to read status file %s: %v", file, err) + f.logger.Warnf("failed to read status file %s: %v", file, err) continue } @@ -449,7 +451,7 @@ func (f *fileStatusManager) handleRuntimeMismatch( ) (core.Workload, error) { contextMsg := fmt.Sprintf("workload status mismatch: file indicates running, but runtime shows %s", containerInfo.State) if err := f.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusUnhealthy, contextMsg); err != nil { - logger.Warnf("failed to update workload %s status to unhealthy: %v", workloadName, err) + f.logger.Warnf("failed to update workload %s status to unhealthy: %v", workloadName, err) } // Convert to workload and return unhealthy status @@ -475,7 +477,7 @@ func (f *fileStatusManager) checkProxyHealth( return core.Workload{}, false // No proxy check needed } - proxyRunning := proxy.IsRunning(baseName) + proxyRunning := proxy.IsRunning(baseName, f.logger) if proxyRunning { return core.Workload{}, false // Proxy is healthy } @@ -484,13 +486,13 @@ func (f *fileStatusManager) checkProxyHealth( contextMsg := fmt.Sprintf("proxy process not running: workload shows running but proxy process for %s is not active", baseName) if err := f.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusUnhealthy, contextMsg); err != nil { - logger.Warnf("failed to update workload %s status to unhealthy: %v", workloadName, err) + f.logger.Warnf("failed to update workload %s status to unhealthy: %v", workloadName, err) } // Convert to workload and return unhealthy status runtimeResult, err := types.WorkloadFromContainerInfo(&containerInfo) if err != nil { - logger.Warnf("failed to convert container info for unhealthy workload %s: %v", workloadName, err) + f.logger.Warnf("failed to convert container info for unhealthy workload %s: %v", workloadName, err) return core.Workload{}, false // Return false to avoid double error handling } diff --git a/pkg/workloads/statuses/file_status_test.go b/pkg/workloads/statuses/file_status_test.go index e69b67fb5..335794d4b 100644 --- a/pkg/workloads/statuses/file_status_test.go +++ b/pkg/workloads/statuses/file_status_test.go @@ -17,19 +17,17 @@ import ( rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/container/runtime/mocks" "github.com/stacklok/toolhive/pkg/core" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" ) -func init() { - // Initialize logger for all tests - logger.Initialize() -} - func TestFileStatusManager_SetWorkloadStatus_Create(t *testing.T) { t.Parallel() // Create temporary directory for tests tempDir := t.TempDir() - manager := &fileStatusManager{baseDir: tempDir} + manager := &fileStatusManager{ + baseDir: tempDir, + logger: log.NewLogger(), + } ctx := context.Background() // Test creating a new workload status @@ -57,7 +55,10 @@ func TestFileStatusManager_SetWorkloadStatus_Create(t *testing.T) { func TestFileStatusManager_SetWorkloadStatus_Update(t *testing.T) { t.Parallel() tempDir := t.TempDir() - manager := &fileStatusManager{baseDir: tempDir} + manager := &fileStatusManager{ + baseDir: tempDir, + logger: log.NewLogger(), + } ctx := context.Background() // Create workload first time @@ -80,6 +81,7 @@ func TestFileStatusManager_GetWorkload(t *testing.T) { manager := &fileStatusManager{ baseDir: tempDir, runtime: mockRuntime, + logger: log.NewLogger(), } ctx := context.Background() @@ -106,6 +108,7 @@ func TestFileStatusManager_GetWorkload_NotFound(t *testing.T) { manager := &fileStatusManager{ baseDir: tempDir, runtime: mockRuntime, + logger: log.NewLogger(), } ctx := context.Background() @@ -129,6 +132,7 @@ func TestFileStatusManager_GetWorkload_RuntimeFallback(t *testing.T) { manager := &fileStatusManager{ baseDir: tempDir, runtime: mockRuntime, + logger: log.NewLogger(), } ctx := context.Background() @@ -168,6 +172,7 @@ func TestFileStatusManager_GetWorkload_FileAndRuntimeCombination(t *testing.T) { manager := &fileStatusManager{ baseDir: tempDir, runtime: mockRuntime, + logger: log.NewLogger(), } ctx := context.Background() @@ -210,7 +215,10 @@ func TestFileStatusManager_GetWorkload_FileAndRuntimeCombination(t *testing.T) { func TestFileStatusManager_SetWorkloadStatus(t *testing.T) { t.Parallel() tempDir := t.TempDir() - manager := &fileStatusManager{baseDir: tempDir} + manager := &fileStatusManager{ + baseDir: tempDir, + logger: log.NewLogger(), + } ctx := context.Background() // Create a workload status @@ -244,7 +252,10 @@ func TestFileStatusManager_SetWorkloadStatus(t *testing.T) { func TestFileStatusManager_SetWorkloadStatus_NotFound(t *testing.T) { t.Parallel() tempDir := t.TempDir() - manager := &fileStatusManager{baseDir: tempDir} + manager := &fileStatusManager{ + baseDir: tempDir, + logger: log.NewLogger(), + } ctx := context.Background() // Try to set status for non-existent workload - creates file since no runtime check @@ -272,7 +283,10 @@ func TestFileStatusManager_SetWorkloadStatus_NotFound(t *testing.T) { func TestFileStatusManager_DeleteWorkloadStatus(t *testing.T) { t.Parallel() tempDir := t.TempDir() - manager := &fileStatusManager{baseDir: tempDir} + manager := &fileStatusManager{ + baseDir: tempDir, + logger: log.NewLogger(), + } ctx := context.Background() // Create a workload status @@ -293,7 +307,10 @@ func TestFileStatusManager_DeleteWorkloadStatus(t *testing.T) { func TestFileStatusManager_DeleteWorkloadStatus_NotFound(t *testing.T) { t.Parallel() tempDir := t.TempDir() - manager := &fileStatusManager{baseDir: tempDir} + manager := &fileStatusManager{ + baseDir: tempDir, + logger: log.NewLogger(), + } ctx := context.Background() // Try to delete non-existent workload - should not error @@ -312,6 +329,7 @@ func TestFileStatusManager_ConcurrentAccess(t *testing.T) { manager := &fileStatusManager{ baseDir: tempDir, runtime: mockRuntime, + logger: log.NewLogger(), } ctx := context.Background() @@ -352,6 +370,7 @@ func TestFileStatusManager_FullLifecycle(t *testing.T) { manager := &fileStatusManager{ baseDir: tempDir, runtime: mockRuntime, + logger: log.NewLogger(), } ctx := context.Background() @@ -600,6 +619,7 @@ func TestFileStatusManager_ListWorkloads(t *testing.T) { manager := &fileStatusManager{ baseDir: tempDir, runtime: mockRuntime, + logger: log.NewLogger(), } // Setup test data @@ -635,6 +655,7 @@ func TestFileStatusManager_GetWorkload_UnhealthyDetection(t *testing.T) { manager := &fileStatusManager{ baseDir: tempDir, runtime: mockRuntime, + logger: log.NewLogger(), } ctx := context.Background() @@ -704,6 +725,7 @@ func TestFileStatusManager_GetWorkload_HealthyRunningWorkload(t *testing.T) { manager := &fileStatusManager{ baseDir: tempDir, runtime: mockRuntime, + logger: log.NewLogger(), } ctx := context.Background() @@ -752,6 +774,7 @@ func TestFileStatusManager_GetWorkload_ProxyNotRunning(t *testing.T) { manager := &fileStatusManager{ baseDir: tempDir, runtime: mockRuntime, + logger: log.NewLogger(), } ctx := context.Background() @@ -826,6 +849,7 @@ func TestFileStatusManager_GetWorkload_HealthyWithProxy(t *testing.T) { manager := &fileStatusManager{ baseDir: tempDir, runtime: mockRuntime, + logger: log.NewLogger(), } ctx := context.Background() @@ -873,6 +897,7 @@ func TestFileStatusManager_ListWorkloads_WithValidation(t *testing.T) { manager := &fileStatusManager{ baseDir: tempDir, runtime: mockRuntime, + logger: log.NewLogger(), } ctx := context.Background() diff --git a/pkg/workloads/statuses/status.go b/pkg/workloads/statuses/status.go index 92d10366e..6a6482578 100644 --- a/pkg/workloads/statuses/status.go +++ b/pkg/workloads/statuses/status.go @@ -5,9 +5,10 @@ import ( "context" "fmt" + "go.uber.org/zap" + rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/core" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/workloads/types" ) @@ -28,20 +29,21 @@ type StatusManager interface { } // NewStatusManagerFromRuntime creates a new instance of StatusManager from an existing runtime. -func NewStatusManagerFromRuntime(runtime rt.Runtime) StatusManager { +func NewStatusManagerFromRuntime(runtime rt.Runtime, logger *zap.SugaredLogger) StatusManager { return &runtimeStatusManager{ runtime: runtime, + logger: logger, } } // NewStatusManager creates a new status manager instance using the appropriate implementation // based on the runtime environment. If running in Kubernetes, it returns the runtime-based // implementation. Otherwise, it returns the file-based implementation. -func NewStatusManager(runtime rt.Runtime) (StatusManager, error) { +func NewStatusManager(runtime rt.Runtime, logger *zap.SugaredLogger) (StatusManager, error) { if rt.IsKubernetesRuntime() { - return NewStatusManagerFromRuntime(runtime), nil + return NewStatusManagerFromRuntime(runtime, logger), nil } - return NewFileStatusManager(runtime) + return NewFileStatusManager(runtime, logger) } // runtimeStatusManager is an implementation of StatusManager that uses the state @@ -49,6 +51,7 @@ func NewStatusManager(runtime rt.Runtime) (StatusManager, error) { // ToolHive at the time of writing. type runtimeStatusManager struct { runtime rt.Runtime + logger *zap.SugaredLogger } func (r *runtimeStatusManager) GetWorkload(ctx context.Context, workloadName string) (core.Workload, error) { @@ -97,14 +100,14 @@ func (r *runtimeStatusManager) ListWorkloads(ctx context.Context, listAll bool, return workloads, nil } -func (*runtimeStatusManager) SetWorkloadStatus( +func (r *runtimeStatusManager) SetWorkloadStatus( _ context.Context, workloadName string, status rt.WorkloadStatus, contextMsg string, ) error { // TODO: This will need to handle concurrent updates. - logger.Debugf("workload %s set to status %s (context: %s)", workloadName, status, contextMsg) + r.logger.Debugf("workload %s set to status %s (context: %s)", workloadName, status, contextMsg) return nil } diff --git a/pkg/workloads/statuses/status_test.go b/pkg/workloads/statuses/status_test.go index 6f3897b27..85ee2cd7a 100644 --- a/pkg/workloads/statuses/status_test.go +++ b/pkg/workloads/statuses/status_test.go @@ -13,23 +13,21 @@ import ( rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/container/runtime/mocks" "github.com/stacklok/toolhive/pkg/core" - "github.com/stacklok/toolhive/pkg/logger" + log "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/workloads/types" ) const testWorkloadName = "test-workload" -func init() { - logger.Initialize() -} - //nolint:paralleltest // Cannot use t.Parallel() with t.Setenv() in Go 1.24+ func TestNewStatusManagerFromRuntime(t *testing.T) { + logger := log.NewLogger() + ctrl := gomock.NewController(t) defer ctrl.Finish() mockRuntime := mocks.NewMockRuntime(ctrl) - manager := NewStatusManagerFromRuntime(mockRuntime) + manager := NewStatusManagerFromRuntime(mockRuntime, logger) assert.NotNil(t, manager) assert.IsType(t, &runtimeStatusManager{}, manager) @@ -45,7 +43,10 @@ func TestRuntimeStatusManager_CreateWorkloadStatus(t *testing.T) { defer ctrl.Finish() mockRuntime := mocks.NewMockRuntime(ctrl) - manager := &runtimeStatusManager{runtime: mockRuntime} + manager := &runtimeStatusManager{ + runtime: mockRuntime, + logger: log.NewLogger(), + } ctx := context.Background() @@ -266,7 +267,10 @@ func TestRuntimeStatusManager_SetWorkloadStatus(t *testing.T) { defer ctrl.Finish() mockRuntime := mocks.NewMockRuntime(ctrl) - manager := &runtimeStatusManager{runtime: mockRuntime} + manager := &runtimeStatusManager{ + runtime: mockRuntime, + logger: log.NewLogger(), + } ctx := context.Background() status := rt.WorkloadStatusRunning @@ -420,6 +424,8 @@ func TestMatchesLabelFilters(t *testing.T) { func TestNewStatusManager(t *testing.T) { t.Parallel() + logger := log.NewLogger() + ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) @@ -465,7 +471,7 @@ func TestNewStatusManager(t *testing.T) { os.Unsetenv("KUBERNETES_SERVICE_HOST") } - manager, err := NewStatusManager(mockRuntime) + manager, err := NewStatusManager(mockRuntime, logger) assert.NoError(t, err) assert.NotNil(t, manager) diff --git a/test/e2e/client_test.go b/test/e2e/client_test.go index 484155007..8aface994 100644 --- a/test/e2e/client_test.go +++ b/test/e2e/client_test.go @@ -10,6 +10,7 @@ import ( . "github.com/onsi/gomega" "github.com/stacklok/toolhive/pkg/config" + log "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/test/e2e" ) @@ -81,10 +82,11 @@ var _ = Describe("Client Management", func() { Describe("client list-registered command", func() { BeforeEach(func() { // Pre-populate temporary config with multiple registered clients in non-alphabetical order + logger := log.NewLogger() testClients := []string{"vscode", "cursor", "roo-code", "cline", "claude-code"} err := config.UpdateConfigAtPath(tempConfigPath, func(c *config.Config) { c.Clients.RegisteredClients = testClients - }) + }, logger) Expect(err).ToNot(HaveOccurred()) }) diff --git a/test/e2e/proxy_oauth_test.go b/test/e2e/proxy_oauth_test.go index 6872fad9e..ef9405de4 100644 --- a/test/e2e/proxy_oauth_test.go +++ b/test/e2e/proxy_oauth_test.go @@ -17,6 +17,7 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + log "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/networking" "github.com/stacklok/toolhive/test/e2e" ) @@ -38,6 +39,7 @@ var _ = Describe("Proxy OAuth Authentication E2E", Serial, func() { clientID = "test-client" clientSecret = "test-secret" mockOIDCBaseURL string + logger = log.NewLogger() ) BeforeEach(func() { @@ -52,10 +54,10 @@ var _ = Describe("Proxy OAuth Authentication E2E", Serial, func() { proxyServerName = generateUniqueOIDCServerName("proxy-oauth-test") // Find available ports for our mock servers using networking utilities - mockOIDCPort, err = networking.FindOrUsePort(0) + mockOIDCPort, err = networking.FindOrUsePort(0, logger) Expect(err).ToNot(HaveOccurred()) - proxyPort, err = networking.FindOrUsePort(0) + proxyPort, err = networking.FindOrUsePort(0, logger) Expect(err).ToNot(HaveOccurred()) mockOIDCBaseURL = fmt.Sprintf("http://localhost:%d", mockOIDCPort)