diff --git a/api/v1alpha1/temporalworker_webhook.go b/api/v1alpha1/temporalworker_webhook.go index 5037eda0..d0bb028c 100644 --- a/api/v1alpha1/temporalworker_webhook.go +++ b/api/v1alpha1/temporalworker_webhook.go @@ -22,6 +22,7 @@ const ( defaultScaledownDelay = 1 * time.Hour defaultDeleteDelay = 24 * time.Hour maxTemporalWorkerDeploymentNameLen = 63 + ConnectionSpecHashAnnotation = "temporal.io/connection-spec-hash" ) func (r *TemporalWorkerDeployment) SetupWebhookWithManager(mgr ctrl.Manager) error { diff --git a/internal/controller/genplan.go b/internal/controller/genplan.go index 394927e7..46eb680a 100644 --- a/internal/controller/genplan.go +++ b/internal/controller/genplan.go @@ -94,6 +94,7 @@ func (r *TemporalWorkerDeploymentReconciler) generatePlan( l, k8sState, plannerConfig, + connection, ) if err != nil { return nil, fmt.Errorf("error generating plan: %w", err) diff --git a/internal/controller/worker_controller.go b/internal/controller/worker_controller.go index a3e849ae..4a822a40 100644 --- a/internal/controller/worker_controller.go +++ b/internal/controller/worker_controller.go @@ -17,7 +17,9 @@ import ( ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller" + "sigs.k8s.io/controller-runtime/pkg/handler" "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/reconcile" temporaliov1alpha1 "github.com/temporalio/temporal-worker-controller/api/v1alpha1" "github.com/temporalio/temporal-worker-controller/internal/controller/clientpool" @@ -74,21 +76,6 @@ func (r *TemporalWorkerDeploymentReconciler) Reconcile(ctx context.Context, req return ctrl.Result{}, client.IgnoreNotFound(err) } - // TODO(jlegrone): Set defaults via webhook rather than manually - if err := workerDeploy.Default(ctx, &workerDeploy); err != nil { - l.Error(err, "TemporalWorkerDeployment defaulter failed") - return ctrl.Result{}, err - } - - // TODO(carlydf): Handle warnings once we have some, handle ValidateUpdate once it is different from ValidateCreate - if _, err := workerDeploy.ValidateCreate(ctx, &workerDeploy); err != nil { - l.Error(err, "invalid TemporalWorkerDeployment") - return ctrl.Result{ - Requeue: true, - RequeueAfter: 5 * time.Minute, // user needs time to fix this, if it changes, it will be re-queued immediately - }, nil - } - // Verify that a connection is configured if workerDeploy.Spec.WorkerOptions.TemporalConnection == "" { err := fmt.Errorf("TemporalConnection must be set") @@ -106,6 +93,23 @@ func (r *TemporalWorkerDeploymentReconciler) Reconcile(ctx context.Context, req return ctrl.Result{}, err } + // TODO (Shivam): Do we validate TemporalConnection here as well? + + // TODO(jlegrone): Set defaults via webhook rather than manually + if err := workerDeploy.Default(ctx, &workerDeploy); err != nil { + l.Error(err, "TemporalWorkerDeployment defaulter failed") + return ctrl.Result{}, err + } + + // TODO(carlydf): Handle warnings once we have some, handle ValidateUpdate once it is different from ValidateCreate + if _, err := workerDeploy.ValidateCreate(ctx, &workerDeploy); err != nil { + l.Error(err, "invalid TemporalWorkerDeployment") + return ctrl.Result{ + Requeue: true, + RequeueAfter: 5 * time.Minute, // user needs time to fix this, if it changes, it will be re-queued immediately + }, nil + } + // Get or update temporal client for connection temporalClient, ok := r.TemporalClientPool.GetSDKClient(clientpool.ClientPoolKey{ HostPort: temporalConnection.Spec.HostPort, @@ -195,8 +199,34 @@ func (r *TemporalWorkerDeploymentReconciler) SetupWithManager(mgr ctrl.Manager) return ctrl.NewControllerManagedBy(mgr). For(&temporaliov1alpha1.TemporalWorkerDeployment{}). Owns(&appsv1.Deployment{}). + Watches(&temporaliov1alpha1.TemporalConnection{}, handler.EnqueueRequestsFromMapFunc(r.findTWDsUsingConnection)). WithOptions(controller.Options{ MaxConcurrentReconciles: 100, }). Complete(r) } + +func (r *TemporalWorkerDeploymentReconciler) findTWDsUsingConnection(ctx context.Context, tc client.Object) []reconcile.Request { + var requests []reconcile.Request + + // Find all TWDs in same namespace that reference this TC + var workers temporaliov1alpha1.TemporalWorkerDeploymentList + if err := r.List(ctx, &workers, client.InNamespace(tc.GetNamespace())); err != nil { + return requests + } + + // Filter to ones using this connection + for _, worker := range workers.Items { + if worker.Spec.WorkerOptions.TemporalConnection == tc.GetName() { + // Add the TWD object as a reconcile request + requests = append(requests, reconcile.Request{ + NamespacedName: types.NamespacedName{ + Name: worker.Name, + Namespace: worker.Namespace, + }, + }) + } + } + + return requests +} diff --git a/internal/k8s/deployments.go b/internal/k8s/deployments.go index c286da5b..0d92d010 100644 --- a/internal/k8s/deployments.go +++ b/internal/k8s/deployments.go @@ -6,15 +6,18 @@ package k8s import ( "context" + "crypto/sha256" + "encoding/hex" "fmt" + "regexp" + "sort" + "strings" + "github.com/distribution/reference" appsv1 "k8s.io/api/apps/v1" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "regexp" "sigs.k8s.io/controller-runtime/pkg/client" - "sort" - "strings" temporaliov1alpha1 "github.com/temporalio/temporal-worker-controller/api/v1alpha1" "github.com/temporalio/temporal-worker-controller/internal/controller/k8s.io/utils" @@ -254,6 +257,13 @@ func NewDeploymentWithOwnerRef( }) } + // Build pod annotations + podAnnotations := make(map[string]string) + for k, v := range spec.Template.Annotations { + podAnnotations[k] = v + } + podAnnotations[temporaliov1alpha1.ConnectionSpecHashAnnotation] = ComputeConnectionSpecHash(connection) + blockOwnerDeletion := true return &appsv1.Deployment{ @@ -282,7 +292,7 @@ func NewDeploymentWithOwnerRef( Template: v1.PodTemplateSpec{ ObjectMeta: metav1.ObjectMeta{ Labels: podLabels, - Annotations: spec.Template.Annotations, + Annotations: podAnnotations, }, Spec: *podSpec, }, @@ -290,3 +300,18 @@ func NewDeploymentWithOwnerRef( }, } } + +func ComputeConnectionSpecHash(connection temporaliov1alpha1.TemporalConnectionSpec) string { + // should not happen + if connection.MutualTLSSecret == "" || connection.HostPort == "" { + return "" + } + + hasher := sha256.New() + + // Hash connection spec fields in deterministic order + hasher.Write([]byte(connection.HostPort)) + hasher.Write([]byte(connection.MutualTLSSecret)) + + return hex.EncodeToString(hasher.Sum(nil)) +} diff --git a/internal/planner/planner.go b/internal/planner/planner.go index 044eaec5..538b09de 100644 --- a/internal/planner/planner.go +++ b/internal/planner/planner.go @@ -87,6 +87,7 @@ func GeneratePlan( l logr.Logger, k8sState *k8s.DeploymentState, config *Config, + connection temporaliov1alpha1.TemporalConnectionSpec, ) (*Plan, error) { plan := &Plan{ ScaleDeployments: make(map[*v1.ObjectReference]uint32), @@ -95,7 +96,7 @@ func GeneratePlan( // Add delete/scale operations based on version status plan.DeleteDeployments = getDeleteDeployments(k8sState, config) plan.ScaleDeployments = getScaleDeployments(k8sState, config) - plan.ShouldCreateDeployment = shouldCreateDeployment(k8sState, config) + plan.ShouldCreateDeployment = shouldCreateOrUpdateDeployment(k8sState, config, connection) // Determine if we need to start any test workflows plan.TestWorkflows = getTestWorkflows(config) @@ -218,9 +219,10 @@ func getScaleDeployments( } // shouldCreateDeployment determines if a new deployment needs to be created -func shouldCreateDeployment( +func shouldCreateOrUpdateDeployment( k8sState *k8s.DeploymentState, config *Config, + connection temporaliov1alpha1.TemporalConnectionSpec, ) bool { if config.Status.TargetVersion == nil { return true @@ -232,8 +234,10 @@ func shouldCreateDeployment( // If the target version already has a deployment, we don't need to create another one if config.Status.TargetVersion.VersionID == config.TargetVersionID { - if _, exists := k8sState.Deployments[config.TargetVersionID]; exists { - return false + if d, exists := k8sState.Deployments[config.TargetVersionID]; exists { + // If the deployment already exists, we need to check if the secret hash has changed + connectionSpecHash := k8s.ComputeConnectionSpecHash(connection) + return connectionSpecHash != d.Spec.Template.Annotations[temporaliov1alpha1.ConnectionSpecHashAnnotation] } } diff --git a/internal/planner/planner_test.go b/internal/planner/planner_test.go index 2f46ed0b..2c62c788 100644 --- a/internal/planner/planner_test.go +++ b/internal/planner/planner_test.go @@ -32,6 +32,7 @@ func TestGeneratePlan(t *testing.T) { expectConfig bool expectConfigSetCurrent *bool // pointer to distinguish between false and not set expectConfigRampPercent *float32 + connectionSpec temporaliov1alpha1.TemporalConnectionSpec }{ { name: "empty state creates new deployment", @@ -48,16 +49,17 @@ func TestGeneratePlan(t *testing.T) { Replicas: 1, ConflictToken: []byte{}, }, - expectCreate: true, + expectCreate: true, + connectionSpec: createDefaultConnectionSpec(), }, { name: "drained version gets deleted", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.123": createDeploymentWithReplicas(0), + "test/namespace.123": createDeployment(0, createDefaultConnectionSpec()), }, DeploymentsByTime: []*appsv1.Deployment{ - createDeploymentWithReplicas(0), + createDeployment(0, createDefaultConnectionSpec()), }, DeploymentRefs: map[string]*v1.ObjectReference{ "test/namespace.123": {Name: "test-123"}, @@ -84,17 +86,18 @@ func TestGeneratePlan(t *testing.T) { Replicas: 1, ConflictToken: []byte{}, }, - expectDelete: 1, - expectCreate: true, + expectDelete: 1, + expectCreate: true, + connectionSpec: createDefaultConnectionSpec(), }, { name: "deployment needs to be scaled", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.123": createDeploymentWithReplicas(1), + "test/namespace.123": createDeployment(1, createDefaultConnectionSpec()), }, DeploymentsByTime: []*appsv1.Deployment{ - createDeploymentWithReplicas(1), + createDeployment(1, createDefaultConnectionSpec()), }, DeploymentRefs: map[string]*v1.ObjectReference{ "test/namespace.123": {Name: "test-123"}, @@ -123,19 +126,20 @@ func TestGeneratePlan(t *testing.T) { Replicas: 2, ConflictToken: []byte{}, }, - expectScale: 2, - expectCreate: false, + expectScale: 2, + expectCreate: false, + connectionSpec: createDefaultConnectionSpec(), }, { name: "rollback scenario - target equals current but deprecated version is ramping", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.123": createDeploymentWithReplicas(3), - "test/namespace.456": createDeploymentWithReplicas(3), + "test/namespace.123": createDeployment(3, createDefaultConnectionSpec()), + "test/namespace.456": createDeployment(3, createDefaultConnectionSpec()), }, DeploymentsByTime: []*appsv1.Deployment{ - createDeploymentWithReplicas(3), - createDeploymentWithReplicas(3), + createDeployment(3, createDefaultConnectionSpec()), + createDeployment(3, createDefaultConnectionSpec()), }, DeploymentRefs: map[string]*v1.ObjectReference{ "test/namespace.123": {Name: "test-123"}, @@ -189,12 +193,13 @@ func TestGeneratePlan(t *testing.T) { expectConfig: true, // Should generate config to reset ramp expectConfigSetCurrent: func() *bool { b := false; return &b }(), // Should NOT set current (already current) expectConfigRampPercent: func() *float32 { f := float32(0); return &f }(), // Should reset ramp to 0 + connectionSpec: createDefaultConnectionSpec(), }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - plan, err := GeneratePlan(logr.Discard(), tc.k8sState, tc.config) + plan, err := GeneratePlan(logr.Discard(), tc.k8sState, tc.config, tc.connectionSpec) require.NoError(t, err) assert.Equal(t, tc.expectDelete, len(plan.DeleteDeployments), "unexpected number of deletions") @@ -227,7 +232,7 @@ func TestGetDeleteDeployments(t *testing.T) { name: "drained version should be deleted", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.123": createDeploymentWithReplicas(0), + "test/namespace.123": createDeployment(0, createDefaultConnectionSpec()), }, }, config: &Config{ @@ -257,7 +262,7 @@ func TestGetDeleteDeployments(t *testing.T) { name: "not yet drained long enough", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.123": createDeploymentWithReplicas(0), + "test/namespace.123": createDeployment(0, createDefaultConnectionSpec()), }, }, config: &Config{ @@ -293,7 +298,7 @@ func TestGetDeleteDeployments(t *testing.T) { name: "not registered version should be deleted", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.123": createDeploymentWithReplicas(1), + "test/namespace.123": createDeployment(1, createDefaultConnectionSpec()), }, }, config: &Config{ @@ -320,7 +325,7 @@ func TestGetDeleteDeployments(t *testing.T) { name: "delete target deployment when version ID has changed", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.b": createDeploymentWithReplicas(3), + "test/namespace.b": createDeployment(3, createDefaultConnectionSpec()), }, }, config: &Config{ @@ -362,7 +367,7 @@ func TestGetScaleDeployments(t *testing.T) { name: "default version needs scaling", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.123": createDeploymentWithReplicas(1), + "test/namespace.123": createDeployment(1, createDefaultConnectionSpec()), }, }, config: &Config{ @@ -387,7 +392,7 @@ func TestGetScaleDeployments(t *testing.T) { name: "drained version needs scaling down", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.123": createDeploymentWithReplicas(1), + "test/namespace.123": createDeployment(1, createDefaultConnectionSpec()), }, }, config: &Config{ @@ -417,7 +422,7 @@ func TestGetScaleDeployments(t *testing.T) { name: "inactive version needs scaling up", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.a": createDeploymentWithReplicas(0), + "test/namespace.a": createDeployment(0, createDefaultConnectionSpec()), }, DeploymentRefs: map[string]*v1.ObjectReference{ "test/namespace.a": {Name: "test-a"}, @@ -447,7 +452,7 @@ func TestGetScaleDeployments(t *testing.T) { name: "ramping version needs scaling up", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.b": createDeploymentWithReplicas(0), + "test/namespace.b": createDeployment(0, createDefaultConnectionSpec()), }, DeploymentRefs: map[string]*v1.ObjectReference{ "test/namespace.b": {Name: "test-b"}, @@ -475,7 +480,7 @@ func TestGetScaleDeployments(t *testing.T) { name: "current version needs scaling up", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.a": createDeploymentWithReplicas(0), + "test/namespace.a": createDeployment(0, createDefaultConnectionSpec()), }, DeploymentRefs: map[string]*v1.ObjectReference{ "test/namespace.a": {Name: "test-a"}, @@ -505,7 +510,7 @@ func TestGetScaleDeployments(t *testing.T) { name: "don't scale down drained deployment before delay", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.b": createDeploymentWithReplicas(3), + "test/namespace.b": createDeployment(3, createDefaultConnectionSpec()), }, DeploymentRefs: map[string]*v1.ObjectReference{ "test/namespace.b": {Name: "test-b"}, @@ -559,10 +564,11 @@ func TestGetScaleDeployments(t *testing.T) { func TestShouldCreateDeployment(t *testing.T) { testCases := []struct { - name string - k8sState *k8s.DeploymentState - config *Config - expectCreates bool + name string + k8sState *k8s.DeploymentState + config *Config + expectCreates bool + connectionSpec temporaliov1alpha1.TemporalConnectionSpec }{ { name: "no target version should create", @@ -579,13 +585,14 @@ func TestShouldCreateDeployment(t *testing.T) { Replicas: 1, ConflictToken: []byte{}, }, - expectCreates: true, + expectCreates: true, + connectionSpec: createDefaultConnectionSpec(), }, { - name: "existing deployment should not create", + name: "existing deployment, with a non-outdated connection spec, should not result in the creation of a new deployment", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.123": createDeploymentWithReplicas(1), + "test/namespace.123": createDeployment(1, createDefaultConnectionSpec()), }, }, config: &Config{ @@ -604,7 +611,34 @@ func TestShouldCreateDeployment(t *testing.T) { Replicas: 1, ConflictToken: []byte{}, }, - expectCreates: false, + expectCreates: false, + connectionSpec: createDefaultConnectionSpec(), + }, + { + name: "existing deployment, with an outdated connection spec, should result in the creation of a new deployment", + k8sState: &k8s.DeploymentState{ + Deployments: map[string]*appsv1.Deployment{ + "test/namespace.123": createDeployment(1, createOutdatedConnectionSpec()), + }, + }, + config: &Config{ + TargetVersionID: "test/namespace.123", + Status: &temporaliov1alpha1.TemporalWorkerDeploymentStatus{ + TargetVersion: &temporaliov1alpha1.TargetWorkerDeploymentVersion{ + BaseWorkerDeploymentVersion: temporaliov1alpha1.BaseWorkerDeploymentVersion{ + VersionID: "test/namespace.123", + Status: temporaliov1alpha1.VersionStatusInactive, + Deployment: &v1.ObjectReference{Name: "test-123"}, + }, + }, + }, + Spec: &temporaliov1alpha1.TemporalWorkerDeploymentSpec{}, + RolloutStrategy: temporaliov1alpha1.RolloutStrategy{}, + Replicas: 1, + ConflictToken: []byte{}, + }, + expectCreates: true, + connectionSpec: createDefaultConnectionSpec(), }, { name: "target version without deployment should create", @@ -627,13 +661,14 @@ func TestShouldCreateDeployment(t *testing.T) { Replicas: 1, ConflictToken: []byte{}, }, - expectCreates: true, + expectCreates: true, + connectionSpec: createDefaultConnectionSpec(), }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - creates := shouldCreateDeployment(tc.k8sState, tc.config) + creates := shouldCreateOrUpdateDeployment(tc.k8sState, tc.config, tc.connectionSpec) assert.Equal(t, tc.expectCreates, creates, "unexpected create decision") }) } @@ -1504,10 +1539,10 @@ func TestComplexVersionStateScenarios(t *testing.T) { name: "multiple deprecated versions in different states", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.a": createDeploymentWithReplicas(3), - "test/namespace.b": createDeploymentWithReplicas(3), - "test/namespace.c": createDeploymentWithReplicas(1), - "test/namespace.d": createDeploymentWithReplicas(0), + "test/namespace.a": createDeployment(3, createDefaultConnectionSpec()), + "test/namespace.b": createDeployment(3, createDefaultConnectionSpec()), + "test/namespace.c": createDeployment(1, createDefaultConnectionSpec()), + "test/namespace.d": createDeployment(0, createDefaultConnectionSpec()), }, }, config: &Config{ @@ -1568,7 +1603,7 @@ func TestComplexVersionStateScenarios(t *testing.T) { name: "draining version not scaled down before delay", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.a": createDeploymentWithReplicas(3), + "test/namespace.a": createDeployment(3, createDefaultConnectionSpec()), }, }, config: &Config{ @@ -1603,7 +1638,7 @@ func TestComplexVersionStateScenarios(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - plan, err := GeneratePlan(logr.Discard(), tc.k8sState, tc.config) + plan, err := GeneratePlan(logr.Discard(), tc.k8sState, tc.config, temporaliov1alpha1.TemporalConnectionSpec{}) // TODO (Shivam): Come back to this require.NoError(t, err) assert.Equal(t, tc.expectDeletes, len(plan.DeleteDeployments), "unexpected number of deletes") @@ -1678,13 +1713,20 @@ func TestGetTestWorkflowID(t *testing.T) { } // Helper function to create a deployment with specified replicas -func createDeploymentWithReplicas(replicas int32) *appsv1.Deployment { +func createDeployment(replicas int32, connectionSpec temporaliov1alpha1.TemporalConnectionSpec) *appsv1.Deployment { return &appsv1.Deployment{ ObjectMeta: metav1.ObjectMeta{ Name: "test-deployment", }, Spec: appsv1.DeploymentSpec{ Replicas: &replicas, + Template: v1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + temporaliov1alpha1.ConnectionSpecHashAnnotation: k8s.ComputeConnectionSpecHash(connectionSpec), + }, + }, + }, }, } } @@ -1703,3 +1745,27 @@ func rolloutStep(ramp float32, d time.Duration) temporaliov1alpha1.RolloutStep { PauseDuration: metav1Duration(d), } } + +// Helper function to create a default TemporalConnectionSpec used for testing purposes +func createDefaultConnectionSpec() temporaliov1alpha1.TemporalConnectionSpec { + return temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: defaultHostPort(), + MutualTLSSecret: defaultMutualTLSSecret(), + } +} + +// Helper function to create a TemporalConnectionSpec with an outdated secret +func createOutdatedConnectionSpec() temporaliov1alpha1.TemporalConnectionSpec { + return temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: defaultHostPort(), + MutualTLSSecret: "outdated-secret", + } +} + +func defaultHostPort() string { + return "default-host:7233" +} + +func defaultMutualTLSSecret() string { + return "default-secret" +}