diff --git a/internal/controller/clientpool/clientpool.go b/internal/controller/clientpool/clientpool.go index 56f2149a..a74873ef 100644 --- a/internal/controller/clientpool/clientpool.go +++ b/internal/controller/clientpool/clientpool.go @@ -22,8 +22,9 @@ import ( ) type ClientPoolKey struct { - HostPort string - Namespace string + HostPort string + Namespace string + MutualTLSSecret string // Include secret name in key to invalidate cache when the secret name changes } type ClientInfo struct { @@ -142,8 +143,9 @@ func (cp *ClientPool) UpsertClient(ctx context.Context, opts NewClientOptions) ( defer cp.mux.Unlock() key := ClientPoolKey{ - HostPort: opts.Spec.HostPort, - Namespace: opts.TemporalNamespace, + HostPort: opts.Spec.HostPort, + Namespace: opts.TemporalNamespace, + MutualTLSSecret: opts.Spec.MutualTLSSecret, } cp.clients[key] = ClientInfo{ client: c, diff --git a/internal/controller/execplan.go b/internal/controller/execplan.go index 044abff1..84d9da03 100644 --- a/internal/controller/execplan.go +++ b/internal/controller/execplan.go @@ -54,6 +54,15 @@ func (r *TemporalWorkerDeploymentReconciler) executePlan(ctx context.Context, l } } + // Update deployments + for _, d := range p.UpdateDeployments { + l.Info("updating deployment", "deployment", d.Name, "namespace", d.Namespace) + if err := r.Update(ctx, d); err != nil { + l.Error(err, "unable to update deployment", "deployment", d) + return fmt.Errorf("unable to update deployment: %w", err) + } + } + // Get deployment handler deploymentHandler := temporalClient.WorkerDeploymentClient().GetHandle(p.WorkerDeploymentName) diff --git a/internal/controller/genplan.go b/internal/controller/genplan.go index d6586ebc..35ecca22 100644 --- a/internal/controller/genplan.go +++ b/internal/controller/genplan.go @@ -27,6 +27,7 @@ type plan struct { DeleteDeployments []*appsv1.Deployment CreateDeployment *appsv1.Deployment ScaleDeployments map[*corev1.ObjectReference]uint32 + UpdateDeployments []*appsv1.Deployment // Register new versions as current or with ramp UpdateVersionConfig *planner.VersionConfig @@ -90,6 +91,7 @@ func (r *TemporalWorkerDeploymentReconciler) generatePlan( &w.Status, &w.Spec, temporalState, + connection, plannerConfig, ) if err != nil { @@ -99,6 +101,7 @@ func (r *TemporalWorkerDeploymentReconciler) generatePlan( // Convert planner result to controller plan plan.DeleteDeployments = planResult.DeleteDeployments plan.ScaleDeployments = planResult.ScaleDeployments + plan.UpdateDeployments = planResult.UpdateDeployments // Convert version config plan.UpdateVersionConfig = planResult.VersionConfig diff --git a/internal/controller/worker_controller.go b/internal/controller/worker_controller.go index 54228145..cddae68a 100644 --- a/internal/controller/worker_controller.go +++ b/internal/controller/worker_controller.go @@ -21,7 +21,9 @@ import ( ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller" + "sigs.k8s.io/controller-runtime/pkg/handler" "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/reconcile" ) var ( @@ -58,7 +60,6 @@ type TemporalWorkerDeploymentReconciler struct { // // For more details, check Reconcile and its Result here: // - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.15.0/pkg/reconcile -// TODO(carlydf): Add watching of temporal connection custom resource (may have issue) func (r *TemporalWorkerDeploymentReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { // TODO(Shivam): Monitor if the time taken for a successful reconciliation loop is closing in on 5 minutes. If so, we // may need to increase the timeout value. @@ -112,8 +113,9 @@ func (r *TemporalWorkerDeploymentReconciler) Reconcile(ctx context.Context, req // Get or update temporal client for connection temporalClient, ok := r.TemporalClientPool.GetSDKClient(clientpool.ClientPoolKey{ - HostPort: temporalConnection.Spec.HostPort, - Namespace: workerDeploy.Spec.WorkerOptions.TemporalNamespace, + HostPort: temporalConnection.Spec.HostPort, + Namespace: workerDeploy.Spec.WorkerOptions.TemporalNamespace, + MutualTLSSecret: temporalConnection.Spec.MutualTLSSecret, }, temporalConnection.Spec.MutualTLSSecret != "") if !ok { c, err := r.TemporalClientPool.UpsertClient(ctx, clientpool.NewClientOptions{ @@ -212,9 +214,35 @@ func (r *TemporalWorkerDeploymentReconciler) SetupWithManager(mgr ctrl.Manager) return ctrl.NewControllerManagedBy(mgr). For(&temporaliov1alpha1.TemporalWorkerDeployment{}). Owns(&appsv1.Deployment{}). + Watches(&temporaliov1alpha1.TemporalConnection{}, handler.EnqueueRequestsFromMapFunc(r.findTWDsUsingConnection)). WithOptions(controller.Options{ MaxConcurrentReconciles: 100, RecoverPanic: &recoverPanic, }). Complete(r) } + +func (r *TemporalWorkerDeploymentReconciler) findTWDsUsingConnection(ctx context.Context, tc client.Object) []reconcile.Request { + var requests []reconcile.Request + + // Find all TWDs in same namespace that reference this TC + var twds temporaliov1alpha1.TemporalWorkerDeploymentList + if err := r.List(ctx, &twds, client.InNamespace(tc.GetNamespace())); err != nil { + return requests + } + + // Filter to ones using this connection + for _, twd := range twds.Items { + if twd.Spec.WorkerOptions.TemporalConnection == tc.GetName() { + // Enqueue a reconcile request for this TWD + requests = append(requests, reconcile.Request{ + NamespacedName: types.NamespacedName{ + Name: twd.Name, + Namespace: twd.Namespace, + }, + }) + } + } + + return requests +} diff --git a/internal/demo/README.md b/internal/demo/README.md index 13c8a5f1..a3ff7100 100644 --- a/internal/demo/README.md +++ b/internal/demo/README.md @@ -79,6 +79,9 @@ minikube kubectl -- get pods -n temporal-worker-controller -w # Describe the controller pod's status minikube kubectl -- describe pod -n temporal-worker-controller +# Output the controller pod's logs +minikube kubectl -- logs -n temporal-system -f pod/ + # View TemporalWorkerDeployment status kubectl get twd ``` diff --git a/internal/demo/helloworld/temporal_worker_deployment.yaml b/internal/demo/helloworld/temporal_worker_deployment.yaml index 7de468af..3fdb199f 100644 --- a/internal/demo/helloworld/temporal_worker_deployment.yaml +++ b/internal/demo/helloworld/temporal_worker_deployment.yaml @@ -18,11 +18,11 @@ spec: steps: # Increase traffic from 1% to 10% over 15 seconds - rampPercentage: 1 - pauseDuration: 5s + pauseDuration: 30s - rampPercentage: 5 - pauseDuration: 5s + pauseDuration: 30s - rampPercentage: 10 - pauseDuration: 5s + pauseDuration: 30s # Increase traffic to 50% and wait 1 minute - rampPercentage: 50 pauseDuration: 1m diff --git a/internal/k8s/deployments.go b/internal/k8s/deployments.go index 6ab4bd4d..7367510f 100644 --- a/internal/k8s/deployments.go +++ b/internal/k8s/deployments.go @@ -6,6 +6,8 @@ package k8s import ( "context" + "crypto/sha256" + "encoding/hex" "fmt" "regexp" "sort" @@ -25,11 +27,12 @@ import ( const ( DeployOwnerKey = ".metadata.controller" // BuildIDLabel is the label that identifies the build ID for a deployment - BuildIDLabel = "temporal.io/build-id" - DeploymentNameSeparator = "/" // TODO(carlydf): change this to "." once the server accepts `.` in deployment names - VersionIDSeparator = "." // TODO(carlydf): change this to ":" - K8sResourceNameSeparator = "-" - MaxBuildIdLen = 63 + BuildIDLabel = "temporal.io/build-id" + DeploymentNameSeparator = "/" // TODO(carlydf): change this to "." once the server accepts `.` in deployment names + VersionIDSeparator = "." // TODO(carlydf): change this to ":" + K8sResourceNameSeparator = "-" + MaxBuildIdLen = 63 + ConnectionSpecHashAnnotation = "temporal.io/connection-spec-hash" ) // DeploymentState represents the Kubernetes state of all deployments for a temporal worker deployment @@ -256,6 +259,12 @@ func NewDeploymentWithOwnerRef( }) } + // Build pod annotations + podAnnotations := make(map[string]string) + for k, v := range spec.Template.Annotations { + podAnnotations[k] = v + } + podAnnotations[ConnectionSpecHashAnnotation] = ComputeConnectionSpecHash(connection) blockOwnerDeletion := true return &appsv1.Deployment{ @@ -284,7 +293,7 @@ func NewDeploymentWithOwnerRef( Template: corev1.PodTemplateSpec{ ObjectMeta: metav1.ObjectMeta{ Labels: podLabels, - Annotations: spec.Template.Annotations, + Annotations: podAnnotations, }, Spec: *podSpec, }, @@ -293,6 +302,21 @@ func NewDeploymentWithOwnerRef( } } +func ComputeConnectionSpecHash(connection temporaliov1alpha1.TemporalConnectionSpec) string { + // HostPort is required, but MutualTLSSecret can be empty for non-mTLS connections + if connection.HostPort == "" { + return "" + } + + hasher := sha256.New() + + // Hash connection spec fields in deterministic order + _, _ = hasher.Write([]byte(connection.HostPort)) + _, _ = hasher.Write([]byte(connection.MutualTLSSecret)) + + return hex.EncodeToString(hasher.Sum(nil)) +} + func NewDeploymentWithControllerRef( w *temporaliov1alpha1.TemporalWorkerDeployment, buildID string, diff --git a/internal/k8s/deployments_test.go b/internal/k8s/deployments_test.go index e4856093..24edc5b2 100644 --- a/internal/k8s/deployments_test.go +++ b/internal/k8s/deployments_test.go @@ -451,3 +451,109 @@ func TestComputeWorkerDeploymentName_Integration_WithVersionedName(t *testing.T) assert.Equal(t, expectedVersionID, versionID) assert.Equal(t, "hello-world"+k8s.DeploymentNameSeparator+"demo"+k8s.VersionIDSeparator+"v1-0-0-dd84", versionID) } + +// TestNewDeploymentWithPodAnnotations tests that every new pod created has a connection spec hash annotation +func TestNewDeploymentWithPodAnnotations(t *testing.T) { + connection := temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: "localhost:7233", + MutualTLSSecret: "my-secret", + } + + deployment := k8s.NewDeploymentWithOwnerRef( + &metav1.TypeMeta{}, + &metav1.ObjectMeta{Name: "test", Namespace: "default"}, + &temporaliov1alpha1.TemporalWorkerDeploymentSpec{}, + "test-deployment", + "build123", + connection, + ) + + expectedHash := k8s.ComputeConnectionSpecHash(connection) + actualHash := deployment.Spec.Template.Annotations[k8s.ConnectionSpecHashAnnotation] + + assert.Equal(t, expectedHash, actualHash, "Deployment should have correct connection spec hash annotation") +} + +func TestComputeConnectionSpecHash(t *testing.T) { + t.Run("generates non-empty hash for valid connection spec", func(t *testing.T) { + spec := temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: "localhost:7233", + MutualTLSSecret: "my-tls-secret", + } + + result := k8s.ComputeConnectionSpecHash(spec) + assert.NotEmpty(t, result, "Hash should not be empty for valid spec") + assert.Len(t, result, 64, "SHA256 hash should be 64 characters") // hex encoded SHA256 + }) + + t.Run("returns empty hash when hostport is empty", func(t *testing.T) { + spec := temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: "", + MutualTLSSecret: "secret", + } + + result := k8s.ComputeConnectionSpecHash(spec) + assert.Empty(t, result, "Hash should be empty when hostport is empty") + }) + + t.Run("is deterministic - same input produces same hash", func(t *testing.T) { + spec := temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: "localhost:7233", + MutualTLSSecret: "my-secret", + } + + hash1 := k8s.ComputeConnectionSpecHash(spec) + hash2 := k8s.ComputeConnectionSpecHash(spec) + + assert.Equal(t, hash1, hash2, "Same input should produce identical hashes") + }) + + t.Run("different hostports produce different hashes", func(t *testing.T) { + spec1 := temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: "localhost:7233", + MutualTLSSecret: "same-secret", + } + spec2 := temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: "different-host:7233", + MutualTLSSecret: "same-secret", + } + + hash1 := k8s.ComputeConnectionSpecHash(spec1) + hash2 := k8s.ComputeConnectionSpecHash(spec2) + + assert.NotEqual(t, hash1, hash2, "Different hostports should produce different hashes") + }) + + t.Run("different mTLS secrets produce different hashes", func(t *testing.T) { + spec1 := temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: "localhost:7233", + MutualTLSSecret: "secret1", + } + spec2 := temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: "localhost:7233", + MutualTLSSecret: "secret2", + } + + hash1 := k8s.ComputeConnectionSpecHash(spec1) + hash2 := k8s.ComputeConnectionSpecHash(spec2) + + assert.NotEqual(t, hash1, hash2, "Different mTLS secrets should produce different hashes") + }) + + t.Run("empty mTLS secret vs non-empty produce different hashes", func(t *testing.T) { + spec1 := temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: "localhost:7233", + MutualTLSSecret: "", + } + spec2 := temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: "localhost:7233", + MutualTLSSecret: "some-secret", + } + + hash1 := k8s.ComputeConnectionSpecHash(spec1) + hash2 := k8s.ComputeConnectionSpecHash(spec2) + + assert.NotEqual(t, hash1, hash2, "Empty vs non-empty mTLS secret should produce different hashes") + assert.NotEmpty(t, hash1, "Hash should still be generated even with empty mTLS secret") + }) +} diff --git a/internal/planner/planner.go b/internal/planner/planner.go index 3e7d4e8f..9daf2dac 100644 --- a/internal/planner/planner.go +++ b/internal/planner/planner.go @@ -21,6 +21,7 @@ type Plan struct { // Which actions to take DeleteDeployments []*appsv1.Deployment ScaleDeployments map[*corev1.ObjectReference]uint32 + UpdateDeployments []*appsv1.Deployment ShouldCreateDeployment bool VersionConfig *VersionConfig TestWorkflows []WorkflowConfig @@ -62,6 +63,7 @@ func GeneratePlan( status *temporaliov1alpha1.TemporalWorkerDeploymentStatus, spec *temporaliov1alpha1.TemporalWorkerDeploymentSpec, temporalState *temporal.TemporalWorkerState, + connection temporaliov1alpha1.TemporalConnectionSpec, config *Config, ) (*Plan, error) { plan := &Plan{ @@ -72,6 +74,7 @@ func GeneratePlan( plan.DeleteDeployments = getDeleteDeployments(k8sState, status, spec) plan.ScaleDeployments = getScaleDeployments(k8sState, status, spec) plan.ShouldCreateDeployment = shouldCreateDeployment(status, spec) + plan.UpdateDeployments = getUpdateDeployments(k8sState, status, connection) // Determine if we need to start any test workflows plan.TestWorkflows = getTestWorkflows(status, config) @@ -85,6 +88,88 @@ func GeneratePlan( return plan, nil } +// checkAndUpdateDeploymentConnectionSpec determines whether the Deployment for the given versionID is +// out-of-date with respect to the provided TemporalConnectionSpec. If an update is required, it mutates +// the existing Deployment in-place and returns a pointer to that Deployment. If no update is needed or +// the Deployment does not exist, it returns nil. +func checkAndUpdateDeploymentConnectionSpec( + versionID string, + k8sState *k8s.DeploymentState, + connection temporaliov1alpha1.TemporalConnectionSpec, +) *appsv1.Deployment { + existingDeployment, exists := k8sState.Deployments[versionID] + if !exists { + return nil + } + + // If the connection spec hash has changed, update the deployment + currentHash := k8s.ComputeConnectionSpecHash(connection) + if currentHash != existingDeployment.Spec.Template.Annotations[k8s.ConnectionSpecHashAnnotation] { + + // Update the deployment in-place with new connection info + updateDeploymentWithConnection(existingDeployment, connection) + return existingDeployment // Return the modified deployment + } + + return nil +} + +// updateDeploymentWithConnection updates an existing deployment with new TemporalConnectionSpec +func updateDeploymentWithConnection(deployment *appsv1.Deployment, connection temporaliov1alpha1.TemporalConnectionSpec) { + // Update the connection spec hash annotation + deployment.Spec.Template.Annotations[k8s.ConnectionSpecHashAnnotation] = k8s.ComputeConnectionSpecHash(connection) + + // Update secret volume if mTLS is enabled + if connection.MutualTLSSecret != "" { + for i, volume := range deployment.Spec.Template.Spec.Volumes { + if volume.Name == "temporal-tls" && volume.Secret != nil { + deployment.Spec.Template.Spec.Volumes[i].Secret.SecretName = connection.MutualTLSSecret + break + } + } + } + + // Update any environment variables that reference the connection + for i, container := range deployment.Spec.Template.Spec.Containers { + for j, env := range container.Env { + if env.Name == "TEMPORAL_HOST_PORT" { + deployment.Spec.Template.Spec.Containers[i].Env[j].Value = connection.HostPort + } + } + } +} + +func getUpdateDeployments( + k8sState *k8s.DeploymentState, + status *temporaliov1alpha1.TemporalWorkerDeploymentStatus, + connection temporaliov1alpha1.TemporalConnectionSpec, +) []*appsv1.Deployment { + var updateDeployments []*appsv1.Deployment + + // Check target version deployment if it has an expired connection spec hash + if status.TargetVersion.VersionID != "" { + if deployment := checkAndUpdateDeploymentConnectionSpec(status.TargetVersion.VersionID, k8sState, connection); deployment != nil { + updateDeployments = append(updateDeployments, deployment) + } + } + + // Check current version deployment if it has an expired connection spec hash + if status.CurrentVersion != nil && status.CurrentVersion.VersionID != "" { + if deployment := checkAndUpdateDeploymentConnectionSpec(status.CurrentVersion.VersionID, k8sState, connection); deployment != nil { + updateDeployments = append(updateDeployments, deployment) + } + } + + // Check deprecated versions for expired connection spec hashes + for _, version := range status.DeprecatedVersions { + if deployment := checkAndUpdateDeploymentConnectionSpec(version.VersionID, k8sState, connection); deployment != nil { + updateDeployments = append(updateDeployments, deployment) + } + } + + return updateDeployments +} + // getDeleteDeployments determines which deployments should be deleted func getDeleteDeployments( k8sState *k8s.DeploymentState, diff --git a/internal/planner/planner_test.go b/internal/planner/planner_test.go index 1d7acfd2..db645498 100644 --- a/internal/planner/planner_test.go +++ b/internal/planner/planner_test.go @@ -32,6 +32,7 @@ func TestGeneratePlan(t *testing.T) { expectDelete int expectCreate bool expectScale int + expectUpdate int expectWorkflow int expectConfig bool expectConfigSetCurrent *bool // pointer so we can test nil @@ -57,12 +58,12 @@ func TestGeneratePlan(t *testing.T) { name: "drained version gets deleted", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.123": createDeploymentWithReplicas(0), - "test/namespace.456": createDeploymentWithReplicas(1), + "test/namespace.123": createDeploymentWithDefaultConnectionSpecHash(0), + "test/namespace.456": createDeploymentWithDefaultConnectionSpecHash(1), }, DeploymentsByTime: []*appsv1.Deployment{ - createDeploymentWithReplicas(0), - createDeploymentWithReplicas(1), + createDeploymentWithDefaultConnectionSpecHash(0), + createDeploymentWithDefaultConnectionSpecHash(1), }, }, status: &temporaliov1alpha1.TemporalWorkerDeploymentStatus{ @@ -110,10 +111,10 @@ func TestGeneratePlan(t *testing.T) { name: "deployment needs to be scaled", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.123": createDeploymentWithReplicas(1), + "test/namespace.123": createDeploymentWithDefaultConnectionSpecHash(1), }, DeploymentsByTime: []*appsv1.Deployment{ - createDeploymentWithReplicas(1), + createDeploymentWithDefaultConnectionSpecHash(1), }, DeploymentRefs: map[string]*corev1.ObjectReference{ "test/namespace.123": {Name: "test-123"}, @@ -148,12 +149,12 @@ func TestGeneratePlan(t *testing.T) { name: "rollback scenario - target equals current but deprecated version is ramping", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.123": createDeploymentWithReplicas(3), - "test/namespace.456": createDeploymentWithReplicas(3), + "test/namespace.123": createDeploymentWithDefaultConnectionSpecHash(3), + "test/namespace.456": createDeploymentWithDefaultConnectionSpecHash(3), }, DeploymentsByTime: []*appsv1.Deployment{ - createDeploymentWithReplicas(3), - createDeploymentWithReplicas(3), + createDeploymentWithDefaultConnectionSpecHash(3), + createDeploymentWithDefaultConnectionSpecHash(3), }, DeploymentRefs: map[string]*corev1.ObjectReference{ "test/namespace.123": {Name: "test-123"}, @@ -235,6 +236,115 @@ func TestGeneratePlan(t *testing.T) { }, expectCreate: false, }, + { + name: "update deployment when target version, with an existing deployment, has an expired connection spec hash", + k8sState: &k8s.DeploymentState{ + Deployments: map[string]*appsv1.Deployment{ + "test/namespace.123": createDeploymentWithExpiredConnectionSpecHash(1), + "test/namespace.456": createDeploymentWithDefaultConnectionSpecHash(1), + }, + }, + status: &temporaliov1alpha1.TemporalWorkerDeploymentStatus{ + TargetVersion: temporaliov1alpha1.TargetWorkerDeploymentVersion{ + BaseWorkerDeploymentVersion: temporaliov1alpha1.BaseWorkerDeploymentVersion{ + VersionID: "test/namespace.123", + Status: temporaliov1alpha1.VersionStatusRamping, + Deployment: &corev1.ObjectReference{Name: "test-123"}, + }, + }, + CurrentVersion: &temporaliov1alpha1.CurrentWorkerDeploymentVersion{ + BaseWorkerDeploymentVersion: temporaliov1alpha1.BaseWorkerDeploymentVersion{ + VersionID: "test/namespace.456", + Status: temporaliov1alpha1.VersionStatusCurrent, + Deployment: &corev1.ObjectReference{Name: "test-456"}, + }, + }, + }, + spec: &temporaliov1alpha1.TemporalWorkerDeploymentSpec{ + Replicas: func() *int32 { r := int32(1); return &r }(), + }, + state: &temporal.TemporalWorkerState{}, + config: &Config{ + RolloutStrategy: temporaliov1alpha1.RolloutStrategy{}, + }, + expectUpdate: 1, + }, + { + name: "update deployment when a deprecated version has an expired connection spec hash", + k8sState: &k8s.DeploymentState{ + Deployments: map[string]*appsv1.Deployment{ + "test/namespace.123": createDeploymentWithDefaultConnectionSpecHash(1), + "test/namespace.456": createDeploymentWithDefaultConnectionSpecHash(1), + "test/namespace.789": createDeploymentWithExpiredConnectionSpecHash(1), + }, + }, + status: &temporaliov1alpha1.TemporalWorkerDeploymentStatus{ + TargetVersion: temporaliov1alpha1.TargetWorkerDeploymentVersion{ + BaseWorkerDeploymentVersion: temporaliov1alpha1.BaseWorkerDeploymentVersion{ + VersionID: "test/namespace.123", + Status: temporaliov1alpha1.VersionStatusRamping, + Deployment: &corev1.ObjectReference{Name: "test-123"}, + }, + }, + CurrentVersion: &temporaliov1alpha1.CurrentWorkerDeploymentVersion{ + BaseWorkerDeploymentVersion: temporaliov1alpha1.BaseWorkerDeploymentVersion{ + VersionID: "test/namespace.456", + Status: temporaliov1alpha1.VersionStatusCurrent, + Deployment: &corev1.ObjectReference{Name: "test-456"}, + }, + }, + DeprecatedVersions: []*temporaliov1alpha1.DeprecatedWorkerDeploymentVersion{ + { + BaseWorkerDeploymentVersion: temporaliov1alpha1.BaseWorkerDeploymentVersion{ + VersionID: "test/namespace.789", + Status: temporaliov1alpha1.VersionStatusDraining, + Deployment: &corev1.ObjectReference{Name: "test-789"}, + }, + }, + }, + }, + spec: &temporaliov1alpha1.TemporalWorkerDeploymentSpec{ + Replicas: func() *int32 { r := int32(1); return &r }(), + }, + state: &temporal.TemporalWorkerState{}, + config: &Config{ + RolloutStrategy: temporaliov1alpha1.RolloutStrategy{}, + }, + expectUpdate: 1, + }, + { + name: "update deployment when current version has an expired connection spec hash", + k8sState: &k8s.DeploymentState{ + Deployments: map[string]*appsv1.Deployment{ + "test/namespace.123": createDeploymentWithDefaultConnectionSpecHash(1), + "test/namespace.456": createDeploymentWithExpiredConnectionSpecHash(1), + }, + }, + status: &temporaliov1alpha1.TemporalWorkerDeploymentStatus{ + TargetVersion: temporaliov1alpha1.TargetWorkerDeploymentVersion{ + BaseWorkerDeploymentVersion: temporaliov1alpha1.BaseWorkerDeploymentVersion{ + VersionID: "test/namespace.123", + Status: temporaliov1alpha1.VersionStatusRamping, + Deployment: &corev1.ObjectReference{Name: "test-123"}, + }, + }, + CurrentVersion: &temporaliov1alpha1.CurrentWorkerDeploymentVersion{ + BaseWorkerDeploymentVersion: temporaliov1alpha1.BaseWorkerDeploymentVersion{ + VersionID: "test/namespace.456", + Status: temporaliov1alpha1.VersionStatusCurrent, + Deployment: &corev1.ObjectReference{Name: "test-456"}, + }, + }, + }, + spec: &temporaliov1alpha1.TemporalWorkerDeploymentSpec{ + Replicas: func() *int32 { r := int32(1); return &r }(), + }, + state: &temporal.TemporalWorkerState{}, + config: &Config{ + RolloutStrategy: temporaliov1alpha1.RolloutStrategy{}, + }, + expectUpdate: 1, + }, } for _, tc := range testCases { @@ -242,12 +352,13 @@ func TestGeneratePlan(t *testing.T) { if tc.status == nil { tc.status = &temporaliov1alpha1.TemporalWorkerDeploymentStatus{} } - plan, err := GeneratePlan(logr.Discard(), tc.k8sState, tc.status, tc.spec, tc.state, tc.config) + plan, err := GeneratePlan(logr.Discard(), tc.k8sState, tc.status, tc.spec, tc.state, createDefaultConnectionSpec(), tc.config) require.NoError(t, err) assert.Equal(t, tc.expectDelete, len(plan.DeleteDeployments), "unexpected number of deletions") assert.Equal(t, tc.expectScale, len(plan.ScaleDeployments), "unexpected number of scales") assert.Equal(t, tc.expectCreate, plan.ShouldCreateDeployment, "unexpected create flag") + assert.Equal(t, tc.expectUpdate, len(plan.UpdateDeployments), "unexpected number of updates") assert.Equal(t, tc.expectWorkflow, len(plan.TestWorkflows), "unexpected number of test workflows") assert.Equal(t, tc.expectConfig, plan.VersionConfig != nil, "unexpected version config presence") @@ -277,7 +388,7 @@ func TestGetDeleteDeployments(t *testing.T) { name: "drained version should be deleted", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.123": createDeploymentWithReplicas(0), + "test/namespace.123": createDeploymentWithDefaultConnectionSpecHash(0), }, }, status: &temporaliov1alpha1.TemporalWorkerDeploymentStatus{ @@ -309,7 +420,7 @@ func TestGetDeleteDeployments(t *testing.T) { name: "not yet drained long enough", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.456": createDeploymentWithReplicas(0), + "test/namespace.456": createDeploymentWithDefaultConnectionSpecHash(0), }, }, status: &temporaliov1alpha1.TemporalWorkerDeploymentStatus{ @@ -338,8 +449,8 @@ func TestGetDeleteDeployments(t *testing.T) { name: "not registered version should be deleted", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.123": createDeploymentWithReplicas(1), - "test/namespace.456": createDeploymentWithReplicas(1), + "test/namespace.123": createDeploymentWithDefaultConnectionSpecHash(1), + "test/namespace.456": createDeploymentWithDefaultConnectionSpecHash(1), }, }, status: &temporaliov1alpha1.TemporalWorkerDeploymentStatus{ @@ -390,7 +501,7 @@ func TestGetScaleDeployments(t *testing.T) { name: "current version needs scaling", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.123": createDeploymentWithReplicas(1), + "test/namespace.123": createDeploymentWithDefaultConnectionSpecHash(1), }, }, status: &temporaliov1alpha1.TemporalWorkerDeploymentStatus{ @@ -418,8 +529,8 @@ func TestGetScaleDeployments(t *testing.T) { name: "drained version needs scaling down", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.123": createDeploymentWithReplicas(1), - "test/namespace.456": createDeploymentWithReplicas(2), + "test/namespace.123": createDeploymentWithDefaultConnectionSpecHash(1), + "test/namespace.456": createDeploymentWithDefaultConnectionSpecHash(2), }, }, status: &temporaliov1alpha1.TemporalWorkerDeploymentStatus{ @@ -455,7 +566,7 @@ func TestGetScaleDeployments(t *testing.T) { name: "inactive version needs scaling up", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.a": createDeploymentWithReplicas(0), + "test/namespace.a": createDeploymentWithDefaultConnectionSpecHash(0), }, DeploymentRefs: map[string]*corev1.ObjectReference{ "test/namespace.a": {Name: "test-a"}, @@ -488,7 +599,7 @@ func TestGetScaleDeployments(t *testing.T) { name: "ramping version needs scaling up", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.b": createDeploymentWithReplicas(0), + "test/namespace.b": createDeploymentWithDefaultConnectionSpecHash(0), }, DeploymentRefs: map[string]*corev1.ObjectReference{ "test/namespace.b": {Name: "test-b"}, @@ -515,7 +626,7 @@ func TestGetScaleDeployments(t *testing.T) { name: "target version needs scaling up", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.a": createDeploymentWithReplicas(0), + "test/namespace.a": createDeploymentWithDefaultConnectionSpecHash(0), }, DeploymentRefs: map[string]*corev1.ObjectReference{ "test/namespace.a": {Name: "test-a"}, @@ -548,7 +659,7 @@ func TestGetScaleDeployments(t *testing.T) { name: "don't scale down drained deployment before delay", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.b": createDeploymentWithReplicas(3), + "test/namespace.b": createDeploymentWithDefaultConnectionSpecHash(3), }, DeploymentRefs: map[string]*corev1.ObjectReference{ "test/namespace.b": {Name: "test-b"}, @@ -1543,11 +1654,11 @@ func TestComplexVersionStateScenarios(t *testing.T) { name: "multiple deprecated versions in different states", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.a": createDeploymentWithReplicas(5), - "test/namespace.b": createDeploymentWithReplicas(3), - "test/namespace.c": createDeploymentWithReplicas(3), - "test/namespace.d": createDeploymentWithReplicas(1), - "test/namespace.e": createDeploymentWithReplicas(0), + "test/namespace.a": createDeploymentWithDefaultConnectionSpecHash(5), + "test/namespace.b": createDeploymentWithDefaultConnectionSpecHash(3), + "test/namespace.c": createDeploymentWithDefaultConnectionSpecHash(3), + "test/namespace.d": createDeploymentWithDefaultConnectionSpecHash(1), + "test/namespace.e": createDeploymentWithDefaultConnectionSpecHash(0), }, }, config: &Config{ @@ -1613,7 +1724,7 @@ func TestComplexVersionStateScenarios(t *testing.T) { name: "draining version not scaled down before delay", k8sState: &k8s.DeploymentState{ Deployments: map[string]*appsv1.Deployment{ - "test/namespace.a": createDeploymentWithReplicas(3), + "test/namespace.a": createDeploymentWithDefaultConnectionSpecHash(3), }, }, config: &Config{ @@ -1654,7 +1765,7 @@ func TestComplexVersionStateScenarios(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - plan, err := GeneratePlan(logr.Discard(), tc.k8sState, tc.status, tc.spec, tc.state, tc.config) + plan, err := GeneratePlan(logr.Discard(), tc.k8sState, tc.status, tc.spec, tc.state, createDefaultConnectionSpec(), tc.config) require.NoError(t, err) assert.Equal(t, tc.expectDeletes, len(plan.DeleteDeployments), "unexpected number of deletes") @@ -1718,14 +1829,219 @@ func TestGetTestWorkflowID(t *testing.T) { } } -// Helper function to create a deployment with specified replicas -func createDeploymentWithReplicas(replicas int32) *appsv1.Deployment { +func TestCheckAndUpdateDeploymentConnectionSpec(t *testing.T) { + tests := []struct { + name string + versionID string + existingDeployment *appsv1.Deployment + newConnection temporaliov1alpha1.TemporalConnectionSpec + expectUpdate bool + expectSecretName string + expectHostPortEnv string + expectConnectionHash string + }{ + { + name: "non-existing deployment does not result in an update", + versionID: "non-existent-version", + existingDeployment: nil, + newConnection: temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: "new-host:7233", + MutualTLSSecret: "new-secret", + }, + expectUpdate: false, + }, + { + name: "same connection spec hash does not update the existing deployment", + versionID: "test-version", + existingDeployment: createTestDeploymentWithConnection( + "test-version", + temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: defaultHostPort(), + MutualTLSSecret: defaultMutualTLSSecret(), + }, + ), + newConnection: temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: defaultHostPort(), + MutualTLSSecret: defaultMutualTLSSecret(), + }, + expectUpdate: false, + }, + { + name: "different secret name triggers update", + versionID: "test-version", + existingDeployment: createTestDeploymentWithConnection( + "test-version", + temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: defaultHostPort(), + MutualTLSSecret: defaultMutualTLSSecret(), + }, + ), + newConnection: temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: defaultHostPort(), + MutualTLSSecret: "new-secret", + }, + expectUpdate: true, + expectSecretName: "new-secret", + expectHostPortEnv: defaultHostPort(), + expectConnectionHash: k8s.ComputeConnectionSpecHash(temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: defaultHostPort(), + MutualTLSSecret: "new-secret", + }), + }, + { + name: "different host port triggers update", + versionID: "test-version", + existingDeployment: createTestDeploymentWithConnection( + "test-version", + temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: defaultHostPort(), + MutualTLSSecret: defaultMutualTLSSecret(), + }, + ), + newConnection: temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: "new-host:7233", + MutualTLSSecret: defaultMutualTLSSecret(), + }, + expectUpdate: true, + expectSecretName: defaultMutualTLSSecret(), + expectHostPortEnv: "new-host:7233", + expectConnectionHash: k8s.ComputeConnectionSpecHash(temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: "new-host:7233", + MutualTLSSecret: defaultMutualTLSSecret(), + }), + }, + { + name: "both hostport and secret change triggers update", + versionID: "test-version", + existingDeployment: createTestDeploymentWithConnection( + "test-version", + temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: defaultHostPort(), + MutualTLSSecret: defaultMutualTLSSecret(), + }, + ), + newConnection: temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: "new-host:7233", + MutualTLSSecret: "new-secret", + }, + expectUpdate: true, + expectSecretName: "new-secret", + expectHostPortEnv: "new-host:7233", + expectConnectionHash: k8s.ComputeConnectionSpecHash(temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: "new-host:7233", + MutualTLSSecret: "new-secret", + }), + }, + { + name: "empty mutual tls secret updates correctly", + versionID: "test-version", + existingDeployment: createTestDeploymentWithConnection( + "test-version", + temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: defaultHostPort(), + MutualTLSSecret: defaultMutualTLSSecret(), + }, + ), + newConnection: temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: defaultHostPort(), + MutualTLSSecret: "", + }, + expectUpdate: true, + expectSecretName: "", // Should not update secret volume when empty + expectHostPortEnv: defaultHostPort(), + expectConnectionHash: k8s.ComputeConnectionSpecHash(temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: defaultHostPort(), + MutualTLSSecret: "", + }), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k8sState := &k8s.DeploymentState{ + Deployments: map[string]*appsv1.Deployment{}, + } + + if tt.existingDeployment != nil { + k8sState.Deployments[tt.versionID] = tt.existingDeployment + } + + result := checkAndUpdateDeploymentConnectionSpec(tt.versionID, k8sState, tt.newConnection) + + if !tt.expectUpdate { + assert.Nil(t, result, "Expected no update, but got deployment") + return + } + + require.NotNil(t, result, "Expected deployment update, but got nil") + + // Check that the connection hash annotation was updated + actualHash := result.Spec.Template.Annotations[k8s.ConnectionSpecHashAnnotation] + assert.Equal(t, tt.expectConnectionHash, actualHash, "Connection spec hash should be updated") + + // Check secret volume update (only if mTLS secret is not empty) + if tt.newConnection.MutualTLSSecret != "" { + found := false + for _, volume := range result.Spec.Template.Spec.Volumes { + if volume.Name == "temporal-tls" && volume.Secret != nil { + assert.Equal(t, tt.expectSecretName, volume.Secret.SecretName, "Secret name should be updated") + found = true + break + } + } + assert.True(t, found, "Should find temporal-tls volume with updated secret") + } + + // Check environment variable update + found := false + for _, container := range result.Spec.Template.Spec.Containers { + for _, env := range container.Env { + if env.Name == "TEMPORAL_HOST_PORT" { + assert.Equal(t, tt.expectHostPortEnv, env.Value, "TEMPORAL_HOST_PORT should be updated") + found = true + break + } + } + } + assert.True(t, found, "Should find TEMPORAL_HOST_PORT environment variable") + }) + } +} + +// Helper function to create a deployment with the specified replicas and the default connection spec hash +func createDeploymentWithDefaultConnectionSpecHash(replicas int32) *appsv1.Deployment { return &appsv1.Deployment{ ObjectMeta: metav1.ObjectMeta{ Name: "test-deployment", }, Spec: appsv1.DeploymentSpec{ Replicas: &replicas, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + k8s.ConnectionSpecHashAnnotation: k8s.ComputeConnectionSpecHash(createDefaultConnectionSpec()), + }, + }, + }, + }, + } +} + +// Helper function to create a deployment with the specified replicas and with a non-default connection spec hash +func createDeploymentWithExpiredConnectionSpecHash(replicas int32) *appsv1.Deployment { + return &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-deployment", + }, + Spec: appsv1.DeploymentSpec{ + Replicas: &replicas, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + k8s.ConnectionSpecHashAnnotation: k8s.ComputeConnectionSpecHash(createOutdatedConnectionSpec()), + }, + }, + }, }, } } @@ -1744,3 +2060,58 @@ func rolloutStep(ramp float32, d time.Duration) temporaliov1alpha1.RolloutStep { PauseDuration: metav1Duration(d), } } + +// Helper function to create a default TemporalConnectionSpec used for testing purposes +func createDefaultConnectionSpec() temporaliov1alpha1.TemporalConnectionSpec { + return temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: defaultHostPort(), + MutualTLSSecret: defaultMutualTLSSecret(), + } +} + +// Helper function to create a TemporalConnectionSpec with an outdated secret +func createOutdatedConnectionSpec() temporaliov1alpha1.TemporalConnectionSpec { + return temporaliov1alpha1.TemporalConnectionSpec{ + HostPort: defaultHostPort(), + MutualTLSSecret: "outdated-secret", + } +} + +func defaultHostPort() string { + return "default-host:7233" +} + +func defaultMutualTLSSecret() string { + return "default-secret" +} + +// createDefaultWorkerSpec creates a default TemporalWorkerDeploymentSpec for testing +func createDefaultWorkerSpec() *temporaliov1alpha1.TemporalWorkerDeploymentSpec { + return &temporaliov1alpha1.TemporalWorkerDeploymentSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "worker", + Image: "test-image:latest", + }, + }, + }, + }, + WorkerOptions: temporaliov1alpha1.WorkerOptions{ + TemporalNamespace: "test-namespace", + }, + } +} + +// createTestDeploymentWithConnection creates a test deployment with the specified connection spec +func createTestDeploymentWithConnection(versionID string, connection temporaliov1alpha1.TemporalConnectionSpec) *appsv1.Deployment { + return k8s.NewDeploymentWithOwnerRef( + &metav1.TypeMeta{}, + &metav1.ObjectMeta{Name: "test-worker", Namespace: "default"}, + createDefaultWorkerSpec(), + versionID, + "build123", + connection, + ) +}