diff --git a/api/v1alpha1/modelvalidation_types.go b/api/v1alpha1/modelvalidation_types.go index 47a97831..e2fe0eba 100644 --- a/api/v1alpha1/modelvalidation_types.go +++ b/api/v1alpha1/modelvalidation_types.go @@ -17,6 +17,9 @@ limitations under the License. package v1alpha1 import ( + "crypto/sha256" + "fmt" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -69,15 +72,54 @@ type ModelValidationSpec struct { Config ValidationConfig `json:"config"` } +// PodTrackingInfo contains information about a tracked pod +type PodTrackingInfo struct { + // Name is the name of the pod + Name string `json:"name"` + // UID is the unique identifier of the pod + UID string `json:"uid"` + // InjectedAt is when the pod was injected + InjectedAt metav1.Time `json:"injectedAt"` +} + // ModelValidationStatus defines the observed state of ModelValidation type ModelValidationStatus struct { // INSERT ADDITIONAL STATUS FIELD - define observed state of cluster // Important: Run "make" to regenerate code after modifying this file Conditions []metav1.Condition `json:"conditions,omitempty"` + + // InjectedPodCount is the number of pods that have been injected with validation + InjectedPodCount int32 `json:"injectedPodCount"` + + // UninjectedPodCount is the number of pods that have the label but were not injected + UninjectedPodCount int32 `json:"uninjectedPodCount"` + + // OrphanedPodCount is the number of injected pods that reference this CR but are inconsistent + OrphanedPodCount int32 `json:"orphanedPodCount"` + + // AuthMethod indicates which authentication method is being used + AuthMethod string `json:"authMethod,omitempty"` + + // InjectedPods contains detailed information about injected pods + InjectedPods []PodTrackingInfo `json:"injectedPods,omitempty"` + + // UninjectedPods contains detailed information about pods that should have been injected but weren't + UninjectedPods []PodTrackingInfo `json:"uninjectedPods,omitempty"` + + // OrphanedPods contains detailed information about pods that are injected but inconsistent + OrphanedPods []PodTrackingInfo `json:"orphanedPods,omitempty"` + + // LastUpdated is the timestamp of the last status update + LastUpdated metav1.Time `json:"lastUpdated,omitempty"` } // +kubebuilder:object:root=true // +kubebuilder:subresource:status +// +kubebuilder:printcolumn:name="Auth Method",type=string,JSONPath=`.status.authMethod` +// +kubebuilder:printcolumn:name="Injected Pods",type=integer,JSONPath=`.status.injectedPodCount` +// +kubebuilder:printcolumn:name="Uninjected Pods",type=integer,JSONPath=`.status.uninjectedPodCount` +// +kubebuilder:printcolumn:name="Orphaned Pods",type=integer,JSONPath=`.status.orphanedPodCount` +// +kubebuilder:printcolumn:name="Age",type="date",JSONPath=".metadata.creationTimestamp" // ModelValidation is the Schema for the modelvalidations API type ModelValidation struct { @@ -97,6 +139,42 @@ type ModelValidationList struct { Items []ModelValidation `json:"items"` } +// GetAuthMethod returns the authentication method being used +func (mv *ModelValidation) GetAuthMethod() string { + if mv.Spec.Config.SigstoreConfig != nil { + return "sigstore" + } else if mv.Spec.Config.PkiConfig != nil { + return "pki" + } else if mv.Spec.Config.PrivateKeyConfig != nil { + return "private-key" + } + return "unknown" +} + +// GetConfigHash returns a hash of the validation configuration for drift detection +func (mv *ModelValidation) GetConfigHash() string { + return mv.Spec.Config.GetConfigHash() +} + +// GetConfigHash returns a hash of the validation configuration for drift detection +func (vc *ValidationConfig) GetConfigHash() string { + hasher := sha256.New() + + if vc.SigstoreConfig != nil { + hasher.Write([]byte("sigstore")) + hasher.Write([]byte(vc.SigstoreConfig.CertificateIdentity)) + hasher.Write([]byte(vc.SigstoreConfig.CertificateOidcIssuer)) + } else if vc.PkiConfig != nil { + hasher.Write([]byte("pki")) + hasher.Write([]byte(vc.PkiConfig.CertificateAuthority)) + } else if vc.PrivateKeyConfig != nil { + hasher.Write([]byte("privatekey")) + hasher.Write([]byte(vc.PrivateKeyConfig.KeyPath)) + } + + return fmt.Sprintf("%x", hasher.Sum(nil))[:16] // Use first 16 chars for brevity +} + func init() { SchemeBuilder.Register(&ModelValidation{}, &ModelValidationList{}) } diff --git a/api/v1alpha1/zz_generated.deepcopy.go b/api/v1alpha1/zz_generated.deepcopy.go index 47af49d6..c23c35d6 100644 --- a/api/v1alpha1/zz_generated.deepcopy.go +++ b/api/v1alpha1/zz_generated.deepcopy.go @@ -126,6 +126,28 @@ func (in *ModelValidationStatus) DeepCopyInto(out *ModelValidationStatus) { (*in)[i].DeepCopyInto(&(*out)[i]) } } + if in.InjectedPods != nil { + in, out := &in.InjectedPods, &out.InjectedPods + *out = make([]PodTrackingInfo, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.UninjectedPods != nil { + in, out := &in.UninjectedPods, &out.UninjectedPods + *out = make([]PodTrackingInfo, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.OrphanedPods != nil { + in, out := &in.OrphanedPods, &out.OrphanedPods + *out = make([]PodTrackingInfo, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + in.LastUpdated.DeepCopyInto(&out.LastUpdated) } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ModelValidationStatus. @@ -153,6 +175,22 @@ func (in *PkiConfig) DeepCopy() *PkiConfig { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *PodTrackingInfo) DeepCopyInto(out *PodTrackingInfo) { + *out = *in + in.InjectedAt.DeepCopyInto(&out.InjectedAt) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new PodTrackingInfo. +func (in *PodTrackingInfo) DeepCopy() *PodTrackingInfo { + if in == nil { + return nil + } + out := new(PodTrackingInfo) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *PrivateKeyConfig) DeepCopyInto(out *PrivateKeyConfig) { *out = *in diff --git a/cmd/main.go b/cmd/main.go index ccf56a20..242c3016 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -22,8 +22,11 @@ import ( "flag" "os" "path/filepath" + "time" "github.com/sigstore/model-validation-operator/internal/constants" + "github.com/sigstore/model-validation-operator/internal/controller" + "github.com/sigstore/model-validation-operator/internal/tracker" "github.com/sigstore/model-validation-operator/internal/utils" "github.com/sigstore/model-validation-operator/internal/webhooks" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" @@ -47,6 +50,16 @@ import ( // +kubebuilder:scaffold:imports ) +const ( + // Default configuration values for the status tracker + defaultDebounceDuration = 500 * time.Millisecond + defaultRetryBaseDelay = 100 * time.Millisecond + defaultRetryMaxDelay = 16 * time.Second + defaultRateLimitQPS = 10.0 + defaultRateLimitBurst = 100 + defaultStatusUpdateTimeout = 30 * time.Second +) + var ( scheme = runtime.NewScheme() setupLog = ctrl.Log.WithName("setup") @@ -69,6 +82,14 @@ func main() { var secureMetrics bool var enableHTTP2 bool var tlsOpts []func(*tls.Config) + + // Status tracker configuration + var debounceDuration time.Duration + var retryBaseDelay time.Duration + var retryMaxDelay time.Duration + var rateLimitQPS float64 + var rateLimitBurst int + var statusUpdateTimeout time.Duration flag.StringVar(&metricsAddr, "metrics-bind-address", "0", "The address the metrics endpoint binds to. "+ "Use :8443 for HTTPS or :8080 for HTTP, or leave as 0 to disable the metrics service.") flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.") @@ -91,6 +112,20 @@ func main() { "MODEL_TRANSPARENCY_CLI_IMAGE", constants.ModelTransparencyCliImage, "Model transparency CLI image to be used.") + + // Status tracker configuration flags + flag.DurationVar(&debounceDuration, "debounce-duration", defaultDebounceDuration, + "Time to wait for more changes before updating status") + flag.DurationVar(&retryBaseDelay, "retry-base-delay", defaultRetryBaseDelay, + "Base delay for exponential backoff retries") + flag.DurationVar(&retryMaxDelay, "retry-max-delay", defaultRetryMaxDelay, + "Maximum delay for exponential backoff retries") + flag.Float64Var(&rateLimitQPS, "rate-limit-qps", defaultRateLimitQPS, + "Overall rate limit for status updates (queries per second)") + flag.IntVar(&rateLimitBurst, "rate-limit-burst", defaultRateLimitBurst, + "Burst capacity for overall rate limit") + flag.DurationVar(&statusUpdateTimeout, "status-update-timeout", defaultStatusUpdateTimeout, + "Timeout for status update operations") opts := zap.Options{ Development: true, } @@ -246,6 +281,36 @@ func main() { Handler: interceptor, }) + statusTracker := tracker.NewStatusTracker(mgr.GetClient(), tracker.StatusTrackerConfig{ + DebounceDuration: debounceDuration, + RetryBaseDelay: retryBaseDelay, + RetryMaxDelay: retryMaxDelay, + RateLimitQPS: rateLimitQPS, + RateLimitBurst: rateLimitBurst, + StatusUpdateTimeout: statusUpdateTimeout, + }) + defer statusTracker.Stop() + + podReconciler := &controller.PodReconciler{ + Client: mgr.GetClient(), + Scheme: mgr.GetScheme(), + Tracker: statusTracker, + } + if err := podReconciler.SetupWithManager(mgr); err != nil { + setupLog.Error(err, "unable to create pod controller") + os.Exit(1) + } + + mvReconciler := &controller.ModelValidationReconciler{ + Client: mgr.GetClient(), + Scheme: mgr.GetScheme(), + Tracker: statusTracker, + } + if err := mvReconciler.SetupWithManager(mgr); err != nil { + setupLog.Error(err, "unable to create ModelValidation controller") + os.Exit(1) + } + setupLog.Info("starting manager") if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil { setupLog.Error(err, "problem running manager") diff --git a/config/crd/bases/ml.sigstore.dev_modelvalidations.yaml b/config/crd/bases/ml.sigstore.dev_modelvalidations.yaml index 8996ff4c..bad41bb2 100644 --- a/config/crd/bases/ml.sigstore.dev_modelvalidations.yaml +++ b/config/crd/bases/ml.sigstore.dev_modelvalidations.yaml @@ -14,7 +14,23 @@ spec: singular: modelvalidation scope: Namespaced versions: - - name: v1alpha1 + - additionalPrinterColumns: + - jsonPath: .status.authMethod + name: Auth Method + type: string + - jsonPath: .status.injectedPodCount + name: Injected Pods + type: integer + - jsonPath: .status.uninjectedPodCount + name: Uninjected Pods + type: integer + - jsonPath: .status.orphanedPodCount + name: Orphaned Pods + type: integer + - jsonPath: .metadata.creationTimestamp + name: Age + type: date + name: v1alpha1 schema: openAPIV3Schema: description: ModelValidation is the Schema for the modelvalidations API @@ -89,6 +105,10 @@ spec: status: description: ModelValidationStatus defines the observed state of ModelValidation properties: + authMethod: + description: AuthMethod indicates which authentication method is being + used + type: string conditions: description: |- INSERT ADDITIONAL STATUS FIELD - define observed state of cluster @@ -148,6 +168,98 @@ spec: - type type: object type: array + injectedPodCount: + description: InjectedPodCount is the number of pods that have been + injected with validation + format: int32 + type: integer + injectedPods: + description: InjectedPods contains detailed information about injected + pods + items: + description: PodTrackingInfo contains information about a tracked + pod + properties: + injectedAt: + description: InjectedAt is when the pod was injected + format: date-time + type: string + name: + description: Name is the name of the pod + type: string + uid: + description: UID is the unique identifier of the pod + type: string + required: + - injectedAt + - name + - uid + type: object + type: array + lastUpdated: + description: LastUpdated is the timestamp of the last status update + format: date-time + type: string + orphanedPodCount: + description: OrphanedPodCount is the number of injected pods that + reference this CR but are inconsistent + format: int32 + type: integer + orphanedPods: + description: OrphanedPods contains detailed information about pods + that are injected but inconsistent + items: + description: PodTrackingInfo contains information about a tracked + pod + properties: + injectedAt: + description: InjectedAt is when the pod was injected + format: date-time + type: string + name: + description: Name is the name of the pod + type: string + uid: + description: UID is the unique identifier of the pod + type: string + required: + - injectedAt + - name + - uid + type: object + type: array + uninjectedPodCount: + description: UninjectedPodCount is the number of pods that have the + label but were not injected + format: int32 + type: integer + uninjectedPods: + description: UninjectedPods contains detailed information about pods + that should have been injected but weren't + items: + description: PodTrackingInfo contains information about a tracked + pod + properties: + injectedAt: + description: InjectedAt is when the pod was injected + format: date-time + type: string + name: + description: Name is the name of the pod + type: string + uid: + description: UID is the unique identifier of the pod + type: string + required: + - injectedAt + - name + - uid + type: object + type: array + required: + - injectedPodCount + - orphanedPodCount + - uninjectedPodCount type: object type: object served: true diff --git a/config/rbac/role.yaml b/config/rbac/role.yaml index eee3ef0e..b32aa176 100644 --- a/config/rbac/role.yaml +++ b/config/rbac/role.yaml @@ -12,6 +12,21 @@ rules: - get - list - watch +- apiGroups: + - "" + resources: + - pods + verbs: + - get + - list + - update + - watch +- apiGroups: + - "" + resources: + - pods/finalizers + verbs: + - update - apiGroups: - ml.sigstore.dev resources: @@ -25,6 +40,4 @@ rules: resources: - modelvalidations/status verbs: - - get - - patch - update diff --git a/go.mod b/go.mod index 9ad93c20..af625f26 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,10 @@ require ( github.com/go-logr/logr v1.4.2 github.com/onsi/ginkgo/v2 v2.22.0 github.com/onsi/gomega v1.36.1 + github.com/prometheus/client_golang v1.19.1 + github.com/prometheus/client_model v0.6.1 + github.com/stretchr/testify v1.9.0 + golang.org/x/time v0.7.0 k8s.io/api v0.32.1 k8s.io/apimachinery v0.32.1 k8s.io/client-go v0.32.1 @@ -53,8 +57,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/prometheus/client_golang v1.19.1 // indirect - github.com/prometheus/client_model v0.6.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/common v0.55.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect github.com/spf13/cobra v1.8.1 // indirect @@ -78,7 +81,6 @@ require ( golang.org/x/sys v0.26.0 // indirect golang.org/x/term v0.25.0 // indirect golang.org/x/text v0.19.0 // indirect - golang.org/x/time v0.7.0 // indirect golang.org/x/tools v0.26.0 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7 // indirect diff --git a/internal/constants/images.go b/internal/constants/images.go index ded4d049..6901dbff 100644 --- a/internal/constants/images.go +++ b/internal/constants/images.go @@ -1,6 +1,11 @@ // Package constants provides shared constants used throughout the model validation operator package constants +const ( + // ModelValidationInitContainerName is the name of the init container injected for model validation + ModelValidationInitContainerName = "model-validation" +) + var ( // ModelTransparencyCliImage is the default image for the model transparency CLI // used as an init container to validate model signatures diff --git a/internal/constants/labels.go b/internal/constants/labels.go index 1e12ff1c..3e96f546 100644 --- a/internal/constants/labels.go +++ b/internal/constants/labels.go @@ -10,6 +10,18 @@ const ( // IgnoreNamespaceLabel is the label used to ignore a namespace for model validation IgnoreNamespaceLabel = ModelValidationDomain + "/ignore" - // ModelValidationInitContainerName is the name of the init container injected for model validation - ModelValidationInitContainerName = "model-validation" + // ModelValidationFinalizer is the finalizer used to track model validation pods + ModelValidationFinalizer = ModelValidationDomain + "/finalizer" + + // InjectedAnnotationKey is the annotation key used to track injected pods + InjectedAnnotationKey = ModelValidationDomain + "/injected-at" + + // AuthMethodAnnotationKey is the annotation key used to track the auth method used during injection + AuthMethodAnnotationKey = ModelValidationDomain + "/auth-method" + + // ConfigHashAnnotationKey is the annotation key used to track the configuration hash during injection + ConfigHashAnnotationKey = ModelValidationDomain + "/config-hash" + + // IgnoreNamespaceValue is the value for the ignore namespace label + IgnoreNamespaceValue = "true" ) diff --git a/internal/controller/modelvalidation_controller.go b/internal/controller/modelvalidation_controller.go new file mode 100644 index 00000000..337cf154 --- /dev/null +++ b/internal/controller/modelvalidation_controller.go @@ -0,0 +1,59 @@ +// Package controller provides controllers for managing ModelValidation resources +package controller + +import ( + "context" + + "github.com/sigstore/model-validation-operator/api/v1alpha1" + "github.com/sigstore/model-validation-operator/internal/tracker" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/reconcile" +) + +// ModelValidationReconciler reconciles ModelValidation objects +type ModelValidationReconciler struct { + client.Client + Scheme *runtime.Scheme + Tracker tracker.StatusTracker +} + +// +kubebuilder:rbac:groups=ml.sigstore.dev,resources=modelvalidations,verbs=get;list;watch +// +kubebuilder:rbac:groups=ml.sigstore.dev,resources=modelvalidations/status,verbs=update + +// Reconcile handles ModelValidation events to track creation/updates/deletion +func (r *ModelValidationReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + logger := log.FromContext(ctx) + + mv := &v1alpha1.ModelValidation{} + if err := r.Get(ctx, req.NamespacedName, mv); err != nil { + if errors.IsNotFound(err) { + logger.Info("ModelValidation deleted, removing from tracking", "modelvalidation", req.NamespacedName) + r.Tracker.RemoveModelValidation(req.NamespacedName) + return reconcile.Result{}, nil + } + logger.Error(err, "Failed to get ModelValidation") + return reconcile.Result{}, err + } + + if !mv.DeletionTimestamp.IsZero() { + logger.Info("ModelValidation being deleted, removing from tracking", "modelvalidation", req.NamespacedName) + r.Tracker.RemoveModelValidation(req.NamespacedName) + return reconcile.Result{}, nil + } + + logger.Info("ModelValidation created/updated, adding to tracking", "modelvalidation", req.NamespacedName) + r.Tracker.AddModelValidation(ctx, mv) + + return reconcile.Result{}, nil +} + +// SetupWithManager sets up the controller with the Manager +func (r *ModelValidationReconciler) SetupWithManager(mgr ctrl.Manager) error { + return ctrl.NewControllerManagedBy(mgr). + For(&v1alpha1.ModelValidation{}). + Complete(r) +} diff --git a/internal/controller/modelvalidation_controller_test.go b/internal/controller/modelvalidation_controller_test.go new file mode 100644 index 00000000..6abfab7a --- /dev/null +++ b/internal/controller/modelvalidation_controller_test.go @@ -0,0 +1,121 @@ +package controller + +import ( + "context" + + . "github.com/onsi/ginkgo/v2" //nolint:revive + . "github.com/onsi/gomega" //nolint:revive + + "github.com/sigstore/model-validation-operator/internal/testutil" + "github.com/sigstore/model-validation-operator/internal/tracker" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" +) + +var _ = Describe("ModelValidationReconciler", func() { + var ( + ctx context.Context + reconciler *ModelValidationReconciler + mockTracker *tracker.MockStatusTracker + ) + + BeforeEach(func() { + ctx = context.Background() + mockTracker = tracker.NewMockStatusTracker() + }) + + Context("when reconciling ModelValidation resources", func() { + It("should call tracker for each reconcile", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + + fakeClient := testutil.SetupFakeClientWithObjects(mv) + + reconciler = &ModelValidationReconciler{ + Client: fakeClient, + Scheme: runtime.NewScheme(), + Tracker: mockTracker, + } + + namespace := mv.Namespace + name := mv.Name + req := testutil.CreateReconcileRequest(namespace, name) + + result, err := reconciler.Reconcile(ctx, req) + Expect(err).NotTo(HaveOccurred()) + Expect(result.Requeue).To(BeFalse()) + + addCalls := mockTracker.GetAddModelValidationCalls() + Expect(addCalls).To(HaveLen(1)) + Expect(addCalls[0].Name).To(Equal(name)) + Expect(addCalls[0].Namespace).To(Equal(namespace)) + + // Second reconcile should also call AddModelValidation + // Controller is not idempotent - it calls tracker each time + result, err = reconciler.Reconcile(ctx, req) + Expect(err).NotTo(HaveOccurred()) + Expect(result.Requeue).To(BeFalse()) + + addCalls = mockTracker.GetAddModelValidationCalls() + Expect(addCalls).To(HaveLen(2)) + Expect(addCalls[1].Name).To(Equal(name)) + Expect(addCalls[1].Namespace).To(Equal(namespace)) + }) + + It("should remove ModelValidation with DeletionTimestamp from tracking", func() { + now := metav1.Now() + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv-deleting", + Namespace: "default", + DeletionTimestamp: &now, + Finalizers: []string{"test-finalizer"}, + }) + + fakeClient := testutil.SetupFakeClientWithObjects(mv) + + reconciler = &ModelValidationReconciler{ + Client: fakeClient, + Scheme: runtime.NewScheme(), + Tracker: mockTracker, + } + + namespace := mv.Namespace + name := mv.Name + req := testutil.CreateReconcileRequest(namespace, name) + + result, err := reconciler.Reconcile(ctx, req) + Expect(err).NotTo(HaveOccurred()) + Expect(result.Requeue).To(BeFalse()) + + removeCalls := mockTracker.GetRemoveModelValidationCalls() + Expect(removeCalls).To(HaveLen(1)) + Expect(removeCalls[0].Namespace).To(Equal(namespace)) + Expect(removeCalls[0].Name).To(Equal(name)) + }) + + It("should remove deleted ModelValidation from tracking", func() { + fakeClient := testutil.SetupFakeClientWithObjects() + + reconciler = &ModelValidationReconciler{ + Client: fakeClient, + Scheme: runtime.NewScheme(), + Tracker: mockTracker, + } + + namespace := "default" + name := "deleted-mv" + req := testutil.CreateReconcileRequest(namespace, name) + + result, err := reconciler.Reconcile(ctx, req) + Expect(err).NotTo(HaveOccurred()) + Expect(result.Requeue).To(BeFalse()) + + removeCalls := mockTracker.GetRemoveModelValidationCalls() + Expect(removeCalls).To(HaveLen(1)) + Expect(removeCalls[0].Namespace).To(Equal(namespace)) + Expect(removeCalls[0].Name).To(Equal(name)) + }) + }) +}) diff --git a/internal/controller/pod_controller.go b/internal/controller/pod_controller.go new file mode 100644 index 00000000..aedc11ac --- /dev/null +++ b/internal/controller/pod_controller.go @@ -0,0 +1,90 @@ +// Package controller provides controllers for managing ModelValidation resources +package controller + +import ( + "context" + + "github.com/sigstore/model-validation-operator/internal/constants" + "github.com/sigstore/model-validation-operator/internal/tracker" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/reconcile" +) + +// PodReconciler reconciles Pod objects to track injected pods +type PodReconciler struct { + client.Client + Scheme *runtime.Scheme + Tracker tracker.StatusTracker +} + +// +kubebuilder:rbac:groups="",resources=pods,verbs=get;list;watch;update +// +kubebuilder:rbac:groups="",resources=pods/finalizers,verbs=update +// +kubebuilder:rbac:groups=ml.sigstore.dev,resources=modelvalidations,verbs=get;list +// +kubebuilder:rbac:groups=ml.sigstore.dev,resources=modelvalidations/status,verbs=update + +// Reconcile handles pod events to update ModelValidation status +func (r *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + logger := log.FromContext(ctx) + + pod := &corev1.Pod{} + if err := r.Get(ctx, req.NamespacedName, pod); err != nil { + if errors.IsNotFound(err) { + logger.V(1).Info("Pod deleted, removing from tracking", "pod", req.NamespacedName) + if err := r.Tracker.RemovePodByName(ctx, req.NamespacedName); err != nil { + logger.Error(err, "Failed to remove deleted pod from tracking", "pod", req.NamespacedName) + return reconcile.Result{}, err + } + return reconcile.Result{}, nil + } + return reconcile.Result{}, err + } + + if !pod.DeletionTimestamp.IsZero() { + logger.Info("Handling pod deletion", "pod", req.NamespacedName) + + if err := r.Tracker.RemovePodEvent(ctx, pod.UID); err != nil { + logger.Error(err, "Failed to remove pod from tracking") + return reconcile.Result{}, err + } + + if controllerutil.ContainsFinalizer(pod, constants.ModelValidationFinalizer) { + controllerutil.RemoveFinalizer(pod, constants.ModelValidationFinalizer) + if err := r.Update(ctx, pod); err != nil { + logger.Error(err, "Failed to remove finalizer from pod") + return reconcile.Result{}, err + } + } + + logger.Info("Successfully handled pod deletion", "pod", req.NamespacedName) + return reconcile.Result{}, nil + } + + modelValidationName, ok := pod.Labels[constants.ModelValidationLabel] + if !ok || modelValidationName == "" { + // Try to remove the pod in case it was previously tracked but label was removed + if err := r.Tracker.RemovePodByName(ctx, req.NamespacedName); err != nil { + logger.Error(err, "Failed to remove pod without label from tracking", "pod", req.NamespacedName) + return reconcile.Result{}, err + } + return reconcile.Result{}, nil + } + if err := r.Tracker.ProcessPodEvent(ctx, pod); err != nil { + logger.Error(err, "Failed to process pod event") + return reconcile.Result{}, err + } + + return reconcile.Result{}, nil +} + +// SetupWithManager sets up the controller with the Manager +func (r *PodReconciler) SetupWithManager(mgr ctrl.Manager) error { + return ctrl.NewControllerManagedBy(mgr). + For(&corev1.Pod{}). + Complete(r) +} diff --git a/internal/controller/pod_controller_test.go b/internal/controller/pod_controller_test.go new file mode 100644 index 00000000..2ac5e58a --- /dev/null +++ b/internal/controller/pod_controller_test.go @@ -0,0 +1,187 @@ +package controller + +import ( + "context" + + . "github.com/onsi/ginkgo/v2" //nolint:revive + . "github.com/onsi/gomega" //nolint:revive + + "github.com/sigstore/model-validation-operator/internal/constants" + "github.com/sigstore/model-validation-operator/internal/testutil" + "github.com/sigstore/model-validation-operator/internal/tracker" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + ctrl "sigs.k8s.io/controller-runtime" +) + +var _ = Describe("PodReconciler", func() { + var ( + ctx context.Context + reconciler *PodReconciler + mockTracker *tracker.MockStatusTracker + ) + + BeforeEach(func() { + ctx = context.Background() + mockTracker = tracker.NewMockStatusTracker() + }) + + Context("when reconciling Pod resources", func() { + It("should try to remove pods without ModelValidation label if previously tracked", func() { + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "test-pod", + Namespace: "default", + }) + + fakeClient := testutil.SetupFakeClientWithObjects(pod) + + reconciler = &PodReconciler{ + Client: fakeClient, + Scheme: runtime.NewScheme(), + Tracker: mockTracker, + } + + req := testutil.CreateReconcileRequest(pod.Namespace, pod.Name) + + result, err := reconciler.Reconcile(ctx, req) + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal(ctrl.Result{})) + + processEvents := mockTracker.GetProcessPodEventCalls() + Expect(processEvents).To(BeEmpty()) + + // Should try to remove the pod in case it was previously tracked + removeByNameCalls := mockTracker.GetRemovePodByNameCalls() + Expect(removeByNameCalls).To(HaveLen(1)) + Expect(removeByNameCalls[0].Name).To(Equal("test-pod")) + Expect(removeByNameCalls[0].Namespace).To(Equal("default")) + }) + + It("should process pods with finalizer but not being deleted", func() { + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "test-pod", + Namespace: "default", + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + }) + + fakeClient := testutil.SetupFakeClientWithObjects(pod) + + reconciler = &PodReconciler{ + Client: fakeClient, + Scheme: runtime.NewScheme(), + Tracker: mockTracker, + } + + req := testutil.CreateReconcileRequest(pod.Namespace, pod.Name) + + result, err := reconciler.Reconcile(ctx, req) + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal(ctrl.Result{})) + + processEvents := mockTracker.GetProcessPodEventCalls() + Expect(processEvents).To(HaveLen(1)) + Expect(processEvents[0].Pod.Name).To(Equal(pod.Name)) + Expect(processEvents[0].Pod.Namespace).To(Equal(pod.Namespace)) + }) + + It("should handle pod deletion by removing finalizer", func() { + now := metav1.Now() + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "test-pod", + Namespace: "default", + UID: types.UID("test-uid"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + DeletionTimestamp: &now, + }) + + fakeClient := testutil.SetupFakeClientWithObjects(pod) + + reconciler = &PodReconciler{ + Client: fakeClient, + Scheme: runtime.NewScheme(), + Tracker: mockTracker, + } + + req := testutil.CreateReconcileRequest(pod.Namespace, pod.Name) + + result, err := reconciler.Reconcile(ctx, req) + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal(ctrl.Result{})) + + removeEvents := mockTracker.GetRemovePodEventCalls() + Expect(removeEvents).To(HaveLen(1)) + Expect(removeEvents[0]).To(Equal(pod.UID)) + + Eventually(func() []string { + updatedPod := &corev1.Pod{} + err := fakeClient.Get(ctx, req.NamespacedName, updatedPod) + if err != nil { + return []string{} + } + return updatedPod.Finalizers + }, "2s", "100ms").ShouldNot(ContainElement(constants.ModelValidationFinalizer)) + }) + + It("should handle pod with multiple finalizers being deleted", func() { + now := metav1.Now() + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "test-pod", + Namespace: "default", + UID: types.UID("test-uid"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{"other-finalizer", constants.ModelValidationFinalizer, "another-finalizer"}, + DeletionTimestamp: &now, + }) + + fakeClient := testutil.SetupFakeClientWithObjects(pod) + + reconciler = &PodReconciler{ + Client: fakeClient, + Scheme: runtime.NewScheme(), + Tracker: mockTracker, + } + + req := testutil.CreateReconcileRequest(pod.Namespace, pod.Name) + + result, err := reconciler.Reconcile(ctx, req) + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal(ctrl.Result{})) + + removeEvents := mockTracker.GetRemovePodEventCalls() + Expect(removeEvents).To(HaveLen(1)) + Expect(removeEvents[0]).To(Equal(pod.UID)) + + updatedPod := &corev1.Pod{} + err = fakeClient.Get(ctx, req.NamespacedName, updatedPod) + Expect(err).NotTo(HaveOccurred()) + Expect(updatedPod.Finalizers).NotTo(ContainElement(constants.ModelValidationFinalizer)) + Expect(updatedPod.Finalizers).To(ContainElement("other-finalizer")) + Expect(updatedPod.Finalizers).To(ContainElement("another-finalizer")) + }) + + It("should handle reconciling deleted pods gracefully", func() { + fakeClient := testutil.SetupFakeClientWithObjects() + + reconciler = &PodReconciler{ + Client: fakeClient, + Scheme: runtime.NewScheme(), + Tracker: mockTracker, + } + + req := testutil.CreateReconcileRequest("default", "deleted-pod") + + result, err := reconciler.Reconcile(ctx, req) + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal(ctrl.Result{})) + + removeByNameCalls := mockTracker.GetRemovePodByNameCalls() + Expect(removeByNameCalls).To(HaveLen(1)) + Expect(removeByNameCalls[0].Name).To(Equal("deleted-pod")) + Expect(removeByNameCalls[0].Namespace).To(Equal("default")) + }) + }) +}) diff --git a/internal/controller/suite_test.go b/internal/controller/suite_test.go new file mode 100644 index 00000000..6481b57e --- /dev/null +++ b/internal/controller/suite_test.go @@ -0,0 +1,13 @@ +package controller + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" //nolint:revive + . "github.com/onsi/gomega" //nolint:revive +) + +func TestControllers(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Controller Suite") +} diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go new file mode 100644 index 00000000..54c3c91c --- /dev/null +++ b/internal/metrics/metrics.go @@ -0,0 +1,171 @@ +// Package metrics provides Prometheus metrics for the model validation operator. +package metrics + +import ( + "github.com/prometheus/client_golang/prometheus" + "sigs.k8s.io/controller-runtime/pkg/metrics" +) + +const ( + // Metric label names + labelNamespace = "namespace" + labelModelValidation = "model_validation" + labelPodState = "pod_state" + labelStatusUpdateResult = "result" + labelDriftType = "drift_type" + + // PodStateInjected represents pods with model validation finalizers + PodStateInjected = "injected" + // PodStateUninjected represents pods without model validation finalizers + PodStateUninjected = "uninjected" + // PodStateOrphaned represents pods with configuration drift + PodStateOrphaned = "orphaned" + + // StatusUpdateSuccess indicates a successful status update + StatusUpdateSuccess = "success" + // StatusUpdateFailure indicates a failed status update + StatusUpdateFailure = "failure" +) + +var ( + // ModelValidationPodCounts tracks the current number of pods in each state per ModelValidation + ModelValidationPodCounts = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "model_validation_operator", + Name: "modelvalidation_pod_count", + Help: "Current number of pods tracked per ModelValidation by state", + }, + []string{labelNamespace, labelModelValidation, labelPodState}, + ) + + // PodStateTransitionsTotal tracks pod state transitions + PodStateTransitionsTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "model_validation_operator", + Name: "pod_state_transitions_total", + Help: "Total number of pod state transitions", + }, + []string{labelNamespace, labelModelValidation, "from_state", "to_state"}, + ) + + // StatusUpdatesTotal tracks ModelValidation status updates + StatusUpdatesTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "model_validation_operator", + Name: "status_updates_total", + Help: "Total number of ModelValidation status updates", + }, + []string{labelNamespace, labelModelValidation, labelStatusUpdateResult}, + ) + + // ConfigurationDriftEventsTotal tracks configuration drift events + ConfigurationDriftEventsTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "model_validation_operator", + Name: "configuration_drift_events_total", + Help: "Total number of configuration drift events detected", + }, + []string{labelNamespace, labelModelValidation, labelDriftType}, + ) + + // ModelValidationCRsTotal tracks total number of ModelValidation CRs per namespace + // Does not include authMethod for namespace-level tracking + ModelValidationCRsTotal = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "model_validation_operator", + Name: "modelvalidation_crs_total", + Help: "Total number of ModelValidation CRs being tracked per namespace", + }, + []string{labelNamespace}, + ) + + // StatusUpdateDuration tracks the duration of status update operations + StatusUpdateDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "model_validation_operator", + Name: "status_update_duration_seconds", + Help: "Duration of ModelValidation status update operations", + Buckets: prometheus.DefBuckets, + }, + []string{labelNamespace, labelModelValidation, labelStatusUpdateResult}, + ) + + // QueueSize tracks the current size of the status update queue + QueueSize = prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: "model_validation_operator", + Name: "status_update_queue_size", + Help: "Current size of the status update queue", + }, + ) + + // RetryAttemptsTotal tracks retry attempts for status updates + RetryAttemptsTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "model_validation_operator", + Name: "status_update_retry_attempts_total", + Help: "Total number of status update retry attempts", + }, + []string{labelNamespace, labelModelValidation}, + ) +) + +func init() { + metrics.Registry.MustRegister( + ModelValidationPodCounts, + PodStateTransitionsTotal, + StatusUpdatesTotal, + ConfigurationDriftEventsTotal, + ModelValidationCRsTotal, + StatusUpdateDuration, + QueueSize, + RetryAttemptsTotal, + ) +} + +// RecordPodCount records the current pod count for a ModelValidation +func RecordPodCount(namespace, modelValidation, podState string, count float64) { + ModelValidationPodCounts.WithLabelValues(namespace, modelValidation, podState).Set(count) +} + +// RecordPodStateTransition records a pod state transition +func RecordPodStateTransition(namespace, modelValidation, fromState, toState string) { + PodStateTransitionsTotal.WithLabelValues(namespace, modelValidation, fromState, toState).Inc() +} + +// RecordStatusUpdate records a status update result +func RecordStatusUpdate(namespace, modelValidation, result string) { + StatusUpdatesTotal.WithLabelValues(namespace, modelValidation, result).Inc() +} + +// RecordConfigurationDrift records a configuration drift event +func RecordConfigurationDrift(namespace, modelValidation, driftType string) { + ConfigurationDriftEventsTotal.WithLabelValues(namespace, modelValidation, driftType).Inc() +} + +// RecordModelValidationCR records the current number of ModelValidation CRs per namespace +func RecordModelValidationCR(namespace string, count float64) { + ModelValidationCRsTotal.WithLabelValues(namespace).Set(count) +} + +// RecordStatusUpdateDuration records the duration of a status update +func RecordStatusUpdateDuration(namespace, modelValidation, result string, duration float64) { + StatusUpdateDuration.WithLabelValues(namespace, modelValidation, result).Observe(duration) +} + +// SetQueueSize sets the current queue size +func SetQueueSize(size float64) { + QueueSize.Set(size) +} + +// RecordRetryAttempt records a retry attempt +func RecordRetryAttempt(namespace, modelValidation string) { + RetryAttemptsTotal.WithLabelValues(namespace, modelValidation).Inc() +} + +// RecordMultiplePodStateTransitions records multiple identical pod state transitions +func RecordMultiplePodStateTransitions(namespace, modelValidation, fromState, toState string, count int) { + if count > 0 { + PodStateTransitionsTotal.WithLabelValues(namespace, modelValidation, fromState, toState).Add(float64(count)) + } +} diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go new file mode 100644 index 00000000..11ab7d12 --- /dev/null +++ b/internal/metrics/metrics_test.go @@ -0,0 +1,144 @@ +package metrics + +import ( + "testing" + + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" + "github.com/stretchr/testify/assert" + "sigs.k8s.io/controller-runtime/pkg/metrics" +) + +const ( + testNamespace = "test-namespace" + testModelValidation = "test-mv" +) + +func TestMetricsDefinition(t *testing.T) { + // Test that all metric variables are defined + assert.NotNil(t, ModelValidationPodCounts) + assert.NotNil(t, PodStateTransitionsTotal) + assert.NotNil(t, StatusUpdatesTotal) + assert.NotNil(t, ConfigurationDriftEventsTotal) + assert.NotNil(t, ModelValidationCRsTotal) + assert.NotNil(t, StatusUpdateDuration) + assert.NotNil(t, QueueSize) + assert.NotNil(t, RetryAttemptsTotal) +} + +// Test helper to verify gauge metrics +func verifyGaugeMetric(t *testing.T, gauge *prometheus.GaugeVec, labels []string, expectedValue float64) { + metric := gauge.WithLabelValues(labels...) + metricDto := &dto.Metric{} + err := metric.Write(metricDto) + assert.NoError(t, err) + assert.Equal(t, expectedValue, metricDto.GetGauge().GetValue()) +} + +// Test helper to verify counter metrics +func verifyCounterIncrement(t *testing.T, counter *prometheus.CounterVec, labels []string) float64 { + metric := counter.WithLabelValues(labels...) + metricDto := &dto.Metric{} + err := metric.Write(metricDto) + assert.NoError(t, err) + return metricDto.GetCounter().GetValue() +} + +func TestRecordPodCount(t *testing.T) { + podState := PodStateInjected + count := float64(5) + + RecordPodCount(testNamespace, testModelValidation, podState, count) + + verifyGaugeMetric(t, ModelValidationPodCounts, []string{testNamespace, testModelValidation, podState}, count) +} + +func TestRecordPodStateTransition(t *testing.T) { + fromState := PodStateUninjected + toState := PodStateInjected + + labels := []string{testNamespace, testModelValidation, fromState, toState} + initialValue := verifyCounterIncrement(t, PodStateTransitionsTotal, labels) + RecordPodStateTransition(testNamespace, testModelValidation, fromState, toState) + finalValue := verifyCounterIncrement(t, PodStateTransitionsTotal, labels) + + assert.Equal(t, initialValue+1, finalValue) +} + +func TestRecordStatusUpdate(t *testing.T) { + result := StatusUpdateSuccess + + labels := []string{testNamespace, testModelValidation, result} + initialValue := verifyCounterIncrement(t, StatusUpdatesTotal, labels) + RecordStatusUpdate(testNamespace, testModelValidation, result) + finalValue := verifyCounterIncrement(t, StatusUpdatesTotal, labels) + + assert.Equal(t, initialValue+1, finalValue) +} + +func TestRecordConfigurationDrift(t *testing.T) { + driftType := "config_hash" + + labels := []string{testNamespace, testModelValidation, driftType} + initialValue := verifyCounterIncrement(t, ConfigurationDriftEventsTotal, labels) + RecordConfigurationDrift(testNamespace, testModelValidation, driftType) + finalValue := verifyCounterIncrement(t, ConfigurationDriftEventsTotal, labels) + + assert.Equal(t, initialValue+1, finalValue) +} + +func TestRecordModelValidationCR(t *testing.T) { + count := float64(3) + + RecordModelValidationCR(testNamespace, count) + + verifyGaugeMetric(t, ModelValidationCRsTotal, []string{testNamespace}, count) +} + +func TestSetQueueSize(t *testing.T) { + size := float64(10) + + SetQueueSize(size) + + metricDto := &dto.Metric{} + err := QueueSize.Write(metricDto) + assert.NoError(t, err) + assert.Equal(t, size, metricDto.GetGauge().GetValue()) +} + +func TestRecordStatusUpdateDuration(t *testing.T) { + result := StatusUpdateSuccess + duration := 0.5 // 500ms + + RecordStatusUpdateDuration(testNamespace, testModelValidation, result, duration) + + // Verify the histogram was recorded by checking the metric family + metricFamilies, err := metrics.Registry.Gather() + assert.NoError(t, err) + + var found bool + for _, mf := range metricFamilies { + if mf.GetName() == "model_validation_operator_status_update_duration_seconds" { + for _, metric := range mf.GetMetric() { + if metric.GetHistogram().GetSampleCount() > 0 { + found = true + break + } + } + } + } + assert.True(t, found, "Expected histogram metric to be recorded") +} + +func TestRecordMultiplePodStateTransitions(t *testing.T) { + fromState := PodStateUninjected + toState := PodStateInjected + count := 3 + + labels := []string{testNamespace, testModelValidation, fromState, toState} + initialValue := verifyCounterIncrement(t, PodStateTransitionsTotal, labels) + RecordMultiplePodStateTransitions(testNamespace, testModelValidation, fromState, toState, count) + finalValue := verifyCounterIncrement(t, PodStateTransitionsTotal, labels) + + assert.Equal(t, initialValue+float64(count), finalValue) +} diff --git a/internal/testutil/ginkgo_utils.go b/internal/testutil/ginkgo_utils.go new file mode 100644 index 00000000..16eb4543 --- /dev/null +++ b/internal/testutil/ginkgo_utils.go @@ -0,0 +1,41 @@ +// Package testutil provides Ginkgo-specific test utilities +package testutil + +import ( + "context" + + "github.com/sigstore/model-validation-operator/api/v1alpha1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/onsi/gomega" +) + +// ExpectModelValidationTracking verifies a ModelValidation's tracking state +func ExpectModelValidationTracking( + statusTracker ModelValidationTracker, + namespacedName types.NamespacedName, + shouldBeTracked bool, +) { + isTracked := statusTracker.IsModelValidationTracked(namespacedName) + gomega.Expect(isTracked).To(gomega.Equal(shouldBeTracked), + "Expected ModelValidation %s tracking state to be %t", namespacedName, shouldBeTracked) +} + +// GetModelValidationFromClientExpected retrieves a ModelValidation from the client with Ginkgo expectations +func GetModelValidationFromClientExpected( + ctx context.Context, + client client.Client, + namespacedName types.NamespacedName, +) *v1alpha1.ModelValidation { + mv, err := GetModelValidationFromClient(ctx, client, namespacedName) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + return mv +} + +// ExpectAttemptCount waits for and verifies expected attempt counts in retry testing +func ExpectAttemptCount(fakeClient *FailingClient, expectedAttempts int) { + gomega.Eventually(func() int { + return fakeClient.GetAttemptCount() + }, "2s", "10ms").Should(gomega.Equal(expectedAttempts)) +} diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go new file mode 100644 index 00000000..ccc5eeb4 --- /dev/null +++ b/internal/testutil/testutil.go @@ -0,0 +1,263 @@ +// Package testutil provides shared test utilities for the model validation operator +package testutil + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/sigstore/model-validation-operator/api/v1alpha1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/controller-runtime/pkg/reconcile" +) + +// DefaultNamespace is the default namespace used in tests +const DefaultNamespace = "default" + +// TestModelValidationOptions holds configuration for creating test ModelValidation resources +type TestModelValidationOptions struct { + Name string + Namespace string + DeletionTimestamp *metav1.Time + Finalizers []string + ConfigType string + CertificateCA string + CertIdentity string + CertOidcIssuer string +} + +// TestPodOptions holds configuration for creating test Pod resources +type TestPodOptions struct { + Name string + Namespace string + UID types.UID + Labels map[string]string + Annotations map[string]string + Finalizers []string + DeletionTimestamp *metav1.Time +} + +// CreateTestModelValidation creates a test ModelValidation resource with the given options +func CreateTestModelValidation(opts TestModelValidationOptions) *v1alpha1.ModelValidation { + if opts.Name == "" { + opts.Name = fmt.Sprintf("test-mv-%d", time.Now().UnixNano()) + } + if opts.Namespace == "" { + opts.Namespace = DefaultNamespace + } + if opts.ConfigType == "" { + opts.ConfigType = "sigstore" + } + + mv := &v1alpha1.ModelValidation{ + ObjectMeta: metav1.ObjectMeta{ + Name: opts.Name, + Namespace: opts.Namespace, + DeletionTimestamp: opts.DeletionTimestamp, + Finalizers: opts.Finalizers, + }, + Spec: v1alpha1.ModelValidationSpec{ + Model: v1alpha1.Model{ + Path: "test-model", + SignaturePath: "test-signature", + }, + }, + } + + // Configure auth method + switch opts.ConfigType { + case "pki": + certCA := opts.CertificateCA + if certCA == "" { + certCA = "test-ca" + } + mv.Spec.Config = v1alpha1.ValidationConfig{ + PkiConfig: &v1alpha1.PkiConfig{ + CertificateAuthority: certCA, + }, + } + case "sigstore": + fallthrough + default: + certIdentity := opts.CertIdentity + if certIdentity == "" { + certIdentity = "test@example.com" + } + certOidcIssuer := opts.CertOidcIssuer + if certOidcIssuer == "" { + certOidcIssuer = "https://accounts.google.com" + } + mv.Spec.Config = v1alpha1.ValidationConfig{ + SigstoreConfig: &v1alpha1.SigstoreConfig{ + CertificateIdentity: certIdentity, + CertificateOidcIssuer: certOidcIssuer, + }, + } + } + + return mv +} + +// CreateTestPod creates a test Pod resource with the given options +func CreateTestPod(opts TestPodOptions) *corev1.Pod { + if opts.Name == "" { + opts.Name = fmt.Sprintf("test-pod-%d", time.Now().UnixNano()) + } + if opts.Namespace == "" { + opts.Namespace = DefaultNamespace + } + if opts.UID == types.UID("") { + opts.UID = types.UID(fmt.Sprintf("uid-%d", time.Now().UnixNano())) + } + + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: opts.Name, + Namespace: opts.Namespace, + UID: opts.UID, + Labels: opts.Labels, + Annotations: opts.Annotations, + Finalizers: opts.Finalizers, + DeletionTimestamp: opts.DeletionTimestamp, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "test-container", + Image: "test-image", + }, + }, + }, + } + + return pod +} + +// CreateTestNamespacedName creates a NamespacedName for testing +func CreateTestNamespacedName(name, namespace string) types.NamespacedName { + if name == "" { + name = fmt.Sprintf("test-name-%d", time.Now().UnixNano()) + } + if namespace == "" { + namespace = DefaultNamespace + } + return types.NamespacedName{Name: name, Namespace: namespace} +} + +// SetupFakeClientWithObjects creates a fake Kubernetes client with the given objects +func SetupFakeClientWithObjects(objects ...client.Object) client.Client { + scheme := runtime.NewScheme() + if err := corev1.AddToScheme(scheme); err != nil { + panic(err) // This should not happen in tests + } + if err := v1alpha1.AddToScheme(scheme); err != nil { + panic(err) // This should not happen in tests + } + + builder := fake.NewClientBuilder().WithScheme(scheme) + if len(objects) > 0 { + builder = builder.WithObjects(objects...) + // Add status subresource for ModelValidation objects + for _, obj := range objects { + if _, ok := obj.(*v1alpha1.ModelValidation); ok { + builder = builder.WithStatusSubresource(obj) + } + } + } + + return builder.Build() +} + +// CreateReconcileRequest creates a reconcile request for testing +func CreateReconcileRequest(namespace, name string) reconcile.Request { + return reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: name}} +} + +// ModelValidationTracker interface for tracking functionality +type ModelValidationTracker interface { + IsModelValidationTracked(namespacedName types.NamespacedName) bool +} + +// GetModelValidationFromClient retrieves a ModelValidation from the client +func GetModelValidationFromClient( + ctx context.Context, + client client.Client, + namespacedName types.NamespacedName, +) (*v1alpha1.ModelValidation, error) { + mv := &v1alpha1.ModelValidation{} + err := client.Get(ctx, namespacedName, mv) + return mv, err +} + +// FailingClient is a mock client that fails status updates for testing retry behavior +type FailingClient struct { + client.Client + FailureCount int + MaxFailures int + AttemptCount int + mu sync.Mutex +} + +// Status returns a failing sub-resource writer for testing +func (f *FailingClient) Status() client.SubResourceWriter { + return &FailingSubResourceWriter{ + SubResourceWriter: f.Client.Status(), + parent: f, + } +} + +// FailingSubResourceWriter wraps a SubResourceWriter to simulate failures +type FailingSubResourceWriter struct { + client.SubResourceWriter + parent *FailingClient +} + +// Update simulates failures for the first MaxFailures attempts, then succeeds +func (f *FailingSubResourceWriter) Update( + ctx context.Context, obj client.Object, opts ...client.SubResourceUpdateOption, +) error { + f.parent.mu.Lock() + defer f.parent.mu.Unlock() + + f.parent.AttemptCount++ + if f.parent.FailureCount < f.parent.MaxFailures { + f.parent.FailureCount++ + return errors.New("simulated status update failure") + } + + // After max failures, delegate to real client + return f.SubResourceWriter.Update(ctx, obj, opts...) +} + +// GetAttemptCount returns the current attempt count (thread-safe) +func (f *FailingClient) GetAttemptCount() int { + f.mu.Lock() + defer f.mu.Unlock() + return f.AttemptCount +} + +// GetFailureCount returns the current failure count (thread-safe) +func (f *FailingClient) GetFailureCount() int { + f.mu.Lock() + defer f.mu.Unlock() + return f.FailureCount +} + +// CreateFailingClientWithObjects creates a FailingClient with the given objects +// that fails the first maxFailures attempts +func CreateFailingClientWithObjects(maxFailures int, objects ...client.Object) *FailingClient { + fakeClient := SetupFakeClientWithObjects(objects...) + return &FailingClient{ + Client: fakeClient, + MaxFailures: maxFailures, + FailureCount: 0, + AttemptCount: 0, + } +} diff --git a/internal/tracker/debounced_queue.go b/internal/tracker/debounced_queue.go new file mode 100644 index 00000000..653d816a --- /dev/null +++ b/internal/tracker/debounced_queue.go @@ -0,0 +1,219 @@ +package tracker + +import ( + "sync" + "time" + + "github.com/sigstore/model-validation-operator/internal/metrics" + "golang.org/x/time/rate" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/util/workqueue" +) + +// DebouncedQueue provides a queue with built-in debouncing functionality +// It encapsulates both debouncing logic and workqueue implementation +type DebouncedQueue interface { + // Add adds an item to the queue with debouncing + // If the same item is already pending debounce, it resets the timer + Add(item types.NamespacedName) + + // Get gets the next item for processing (blocks if queue is empty) + Get() (types.NamespacedName, bool) + + // Done marks an item as done processing + Done(item types.NamespacedName) + + // AddWithRetry adds an item to the queue with rate limiting for retries + AddWithRetry(item types.NamespacedName) + + // ForgetRetries forgets retry tracking for an item + ForgetRetries(item types.NamespacedName) + + // GetRetryCount returns the number of retries for an item + GetRetryCount(item types.NamespacedName) int + + // Len returns the number of items currently in the queue (not pending debounce) + Len() int + + // WaitForUpdates waits for all pending debounced updates to complete + WaitForUpdates() + + // WaitForCompletion waits for all pending debounced updates to complete + // and for the queue to be fully drained (all items processed) + WaitForCompletion() + + // ShutDown shuts down the queue and stops all timers + ShutDown() +} + +// DebouncedQueueImpl implements DebouncedQueue using workqueue and timers +type DebouncedQueueImpl struct { + duration time.Duration + queue workqueue.TypedRateLimitingInterface[types.NamespacedName] + debounceTimers map[types.NamespacedName]*time.Timer + debounceWg sync.WaitGroup + mu sync.Mutex + stopCh chan struct{} + // Protects metric updates to ensure consistency with queue state + metricMu sync.Mutex +} + +// DebouncedQueueConfig holds configuration for the debounced queue +type DebouncedQueueConfig struct { + DebounceDuration time.Duration + RetryBaseDelay time.Duration + RetryMaxDelay time.Duration + RateLimitQPS float64 + RateLimitBurst int +} + +// NewDebouncedQueue creates a new debounced queue with the specified configuration +func NewDebouncedQueue(config DebouncedQueueConfig) DebouncedQueue { + // Create custom rate limiter with configurable retry parameters + rateLimiter := workqueue.NewTypedMaxOfRateLimiter( + workqueue.NewTypedItemExponentialFailureRateLimiter[types.NamespacedName]( + config.RetryBaseDelay, config.RetryMaxDelay), + &workqueue.TypedBucketRateLimiter[types.NamespacedName]{ + Limiter: rate.NewLimiter(rate.Limit(config.RateLimitQPS), config.RateLimitBurst)}, + ) + + workQueue := workqueue.NewTypedRateLimitingQueue(rateLimiter) + + return &DebouncedQueueImpl{ + duration: config.DebounceDuration, + queue: workQueue, + debounceTimers: make(map[types.NamespacedName]*time.Timer), + stopCh: make(chan struct{}), + } +} + +// updateQueueSizeMetric updates the queue size metric +// Uses a mutex to ensure metric consistency with queue state +func (dq *DebouncedQueueImpl) updateQueueSizeMetric() { + dq.metricMu.Lock() + defer dq.metricMu.Unlock() + + // Skip metric updates if shutdown has been initiated + if dq.isShutDown() { + return + } + + metrics.SetQueueSize(float64(dq.queue.Len())) +} + +// isShutDown checks if the queue has been shut down (non-blocking) +func (dq *DebouncedQueueImpl) isShutDown() bool { + select { + case <-dq.stopCh: + return true + default: + return false + } +} + +// Add adds an item to the queue with debouncing +// If the queue has been shut down, this method will return without adding the item +func (dq *DebouncedQueueImpl) Add(item types.NamespacedName) { + dq.mu.Lock() + defer dq.mu.Unlock() + if dq.isShutDown() { + return + } + if timer, ok := dq.debounceTimers[item]; ok { + timer.Reset(dq.duration) + } else { + dq.debounceWg.Add(1) + dq.debounceTimers[item] = time.AfterFunc(dq.duration, func() { + defer dq.debounceWg.Done() + if dq.isShutDown() { + return + } + dq.queue.Add(item) + dq.updateQueueSizeMetric() + dq.mu.Lock() + delete(dq.debounceTimers, item) + dq.mu.Unlock() + }) + } +} + +// Get gets the next item for processing +func (dq *DebouncedQueueImpl) Get() (types.NamespacedName, bool) { + return dq.queue.Get() +} + +// Done marks an item as done processing +func (dq *DebouncedQueueImpl) Done(item types.NamespacedName) { + dq.queue.Done(item) + dq.updateQueueSizeMetric() +} + +// AddWithRetry adds an item to the queue with rate limiting for retries +// If the queue has been shut down, this method will return without adding the item +func (dq *DebouncedQueueImpl) AddWithRetry(item types.NamespacedName) { + if dq.isShutDown() { + return + } + dq.queue.AddRateLimited(item) + dq.updateQueueSizeMetric() +} + +// ForgetRetries forgets retry tracking for an item +func (dq *DebouncedQueueImpl) ForgetRetries(item types.NamespacedName) { + dq.queue.Forget(item) +} + +// GetRetryCount returns the number of retries for an item +func (dq *DebouncedQueueImpl) GetRetryCount(item types.NamespacedName) int { + return dq.queue.NumRequeues(item) +} + +// Len returns the number of items currently in the queue +func (dq *DebouncedQueueImpl) Len() int { + return dq.queue.Len() +} + +// WaitForUpdates waits for all pending debounced updates to complete +func (dq *DebouncedQueueImpl) WaitForUpdates() { + dq.debounceWg.Wait() +} + +// WaitForCompletion waits for all pending debounced updates to complete +// and for the queue to be fully drained (all items processed) +func (dq *DebouncedQueueImpl) WaitForCompletion() { + dq.WaitForUpdates() + + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if dq.queue.Len() == 0 { + return + } + case <-dq.stopCh: + return + } + } +} + +// ShutDown shuts down the queue and stops all timers +func (dq *DebouncedQueueImpl) ShutDown() { + dq.mu.Lock() + defer dq.mu.Unlock() + + for _, timer := range dq.debounceTimers { + timer.Stop() + } + dq.debounceTimers = make(map[types.NamespacedName]*time.Timer) + + if !dq.isShutDown() { + close(dq.stopCh) + } + + dq.queue.ShutDown() + + // Update metrics to reflect shutdown + metrics.SetQueueSize(0) +} diff --git a/internal/tracker/debounced_queue_test.go b/internal/tracker/debounced_queue_test.go new file mode 100644 index 00000000..d97152e9 --- /dev/null +++ b/internal/tracker/debounced_queue_test.go @@ -0,0 +1,334 @@ +package tracker + +import ( + "time" + + . "github.com/onsi/ginkgo/v2" //nolint:revive + . "github.com/onsi/gomega" //nolint:revive + + "k8s.io/apimachinery/pkg/types" +) + +var _ = Describe("DebouncedQueue", func() { + var ( + debouncedQueue DebouncedQueue + testKey1 types.NamespacedName + testKey2 types.NamespacedName + duration time.Duration + ) + + BeforeEach(func() { + testKey1 = types.NamespacedName{Name: "test1", Namespace: "default"} + testKey2 = types.NamespacedName{Name: "test2", Namespace: "default"} + duration = 50 * time.Millisecond + debouncedQueue = NewDebouncedQueue(DebouncedQueueConfig{ + DebounceDuration: duration, + RetryBaseDelay: 100 * time.Millisecond, + RetryMaxDelay: 1000 * time.Millisecond, + RateLimitQPS: 10, + RateLimitBurst: 100, + }) + }) + + AfterEach(func() { + debouncedQueue.ShutDown() + }) + + Context("when debouncing single updates", func() { + It("should add update to queue after debounce duration", func() { + // Initially queue should be empty + Expect(debouncedQueue.Len()).To(Equal(0)) + + // Trigger debounce + debouncedQueue.Add(testKey1) + + // Should still be empty immediately + Expect(debouncedQueue.Len()).To(Equal(0)) + + // Wait for debounce duration plus buffer + Eventually(func() int { + return debouncedQueue.Len() + }, duration*3, 10*time.Millisecond).Should(Equal(1)) + + // Verify the correct key was added + item, shutdown := debouncedQueue.Get() + Expect(shutdown).To(BeFalse()) + Expect(item).To(Equal(testKey1)) + debouncedQueue.Done(item) + }) + + It("should reset timer on subsequent updates", func() { + // First update + debouncedQueue.Add(testKey1) + + // Wait half the debounce duration + time.Sleep(duration / 2) + + // Queue should still be empty + Expect(debouncedQueue.Len()).To(Equal(0)) + + // Second update should reset the timer + debouncedQueue.Add(testKey1) + + // Wait another half duration (original timer would have fired by now) + time.Sleep(duration / 2) + + // Queue should still be empty because timer was reset + Expect(debouncedQueue.Len()).To(Equal(0)) + + // Wait for full duration from second update + Eventually(func() int { + return debouncedQueue.Len() + }, duration*2, 10*time.Millisecond).Should(Equal(1)) + }) + }) + + Context("when debouncing multiple keys", func() { + It("should handle multiple keys independently", func() { + // Trigger debounce for both keys + debouncedQueue.Add(testKey1) + debouncedQueue.Add(testKey2) + + // Both should be processed independently + Eventually(func() int { + return debouncedQueue.Len() + }, duration*3, 10*time.Millisecond).Should(Equal(2)) + + // Verify both keys are in queue + receivedKeys := make(map[types.NamespacedName]bool) + + item1, shutdown := debouncedQueue.Get() + Expect(shutdown).To(BeFalse()) + receivedKeys[item1] = true + debouncedQueue.Done(item1) + + item2, shutdown := debouncedQueue.Get() + Expect(shutdown).To(BeFalse()) + receivedKeys[item2] = true + debouncedQueue.Done(item2) + + Expect(receivedKeys).To(HaveKey(testKey1)) + Expect(receivedKeys).To(HaveKey(testKey2)) + }) + + It("should reset timers independently for different keys", func() { + // Start timer for key1 + debouncedQueue.Add(testKey1) + + // Wait and start timer for key2 + time.Sleep(duration / 2) + debouncedQueue.Add(testKey2) + + // Key1 should fire first (started earlier) + Eventually(func() int { + return debouncedQueue.Len() + }, duration*2, 10*time.Millisecond).Should(BeNumerically(">=", 1)) + + // Get first item (should be key1) + item, shutdown := debouncedQueue.Get() + Expect(shutdown).To(BeFalse()) + Expect(item).To(Equal(testKey1)) + debouncedQueue.Done(item) + + // Key2 should fire shortly after + Eventually(func() int { + return debouncedQueue.Len() + }, duration, 10*time.Millisecond).Should(Equal(1)) + + item, shutdown = debouncedQueue.Get() + Expect(shutdown).To(BeFalse()) + Expect(item).To(Equal(testKey2)) + debouncedQueue.Done(item) + }) + }) + + Context("when waiting for updates", func() { + It("should wait for all pending updates to complete", func() { + // Trigger multiple updates + debouncedQueue.Add(testKey1) + debouncedQueue.Add(testKey2) + + // WaitForUpdates should block until all timers complete + done := make(chan bool) + go func() { + debouncedQueue.WaitForUpdates() + done <- true + }() + + // Should not complete immediately + Consistently(done, duration/2).ShouldNot(Receive()) + + // Should complete after debounce duration + Eventually(done, duration*3).Should(Receive()) + + // Queue should have both items + Expect(debouncedQueue.Len()).To(Equal(2)) + }) + + It("should handle WaitForUpdates with no pending updates", func() { + // WaitForUpdates should return immediately when no updates pending + done := make(chan bool) + go func() { + debouncedQueue.WaitForUpdates() + done <- true + }() + + // Should complete immediately + Eventually(done, 100*time.Millisecond).Should(Receive()) + }) + + It("should wait for queue to drain with WaitForCompletion", func() { + // Trigger multiple updates + debouncedQueue.Add(testKey1) + debouncedQueue.Add(testKey2) + + // WaitForCompletion should block until all timers complete and queue is drained + done := make(chan bool) + go func() { + debouncedQueue.WaitForCompletion() + done <- true + }() + + // Should not complete immediately + Consistently(done, duration/2).ShouldNot(Receive()) + + // Should still not complete after just debounce duration (queue not drained) + time.Sleep(duration * 2) + Consistently(done, 50*time.Millisecond).ShouldNot(Receive()) + + // Drain the queue + for debouncedQueue.Len() > 0 { + item, shutdown := debouncedQueue.Get() + if shutdown { + break + } + debouncedQueue.Done(item) + } + + // Now it should complete + Eventually(done, 100*time.Millisecond).Should(Receive()) + }) + + It("should stop waiting when queue is shut down", func() { + // Trigger updates but don't process them + debouncedQueue.Add(testKey1) + debouncedQueue.Add(testKey2) + + // Wait for items to be in queue + Eventually(func() int { + return debouncedQueue.Len() + }, duration*3, 10*time.Millisecond).Should(Equal(2)) + + // Start WaitForCompletion which should block + done := make(chan bool) + go func() { + debouncedQueue.WaitForCompletion() + done <- true + }() + + // Should not complete immediately (queue not drained) + Consistently(done, 50*time.Millisecond).ShouldNot(Receive()) + + // Shut down the queue - this should cause WaitForCompletion to return + debouncedQueue.ShutDown() + + // Should complete quickly now + Eventually(done, 100*time.Millisecond).Should(Receive()) + }) + }) + + Context("when shutting down", func() { + It("should stop all pending timers", func() { + // Trigger updates + debouncedQueue.Add(testKey1) + debouncedQueue.Add(testKey2) + + // Shut down before timers fire + debouncedQueue.ShutDown() + + // Wait longer than debounce duration + time.Sleep(duration * 2) + + // Queue should remain empty (timers were stopped) + Expect(debouncedQueue.Len()).To(Equal(0)) + }) + }) + + Context("when handling rapid updates", func() { + It("should coalesce rapid updates into single queue item", func() { + // Rapid fire updates for same key + for i := 0; i < 10; i++ { + debouncedQueue.Add(testKey1) + time.Sleep(duration / 20) // Much shorter than debounce duration + } + + // Should result in only one queued item + Eventually(func() int { + return debouncedQueue.Len() + }, duration*3, 10*time.Millisecond).Should(Equal(1)) + + // Verify it's the correct key + item, shutdown := debouncedQueue.Get() + Expect(shutdown).To(BeFalse()) + Expect(item).To(Equal(testKey1)) + debouncedQueue.Done(item) + + // No more items should be queued + Consistently(func() int { + return debouncedQueue.Len() + }, duration).Should(Equal(0)) + }) + + It("should handle interleaved updates for different keys", func() { + // Interleave updates for different keys + for i := 0; i < 5; i++ { + debouncedQueue.Add(testKey1) + debouncedQueue.Add(testKey2) + time.Sleep(duration / 10) + } + + // Should result in one item for each key + Eventually(func() int { + return debouncedQueue.Len() + }, duration*3, 10*time.Millisecond).Should(Equal(2)) + }) + }) + + Context("when handling retries", func() { + It("should support retry functionality", func() { + // Add item normally first + debouncedQueue.Add(testKey1) + + // Wait for debounced item to be available + Eventually(func() int { + return debouncedQueue.Len() + }, duration*3, 10*time.Millisecond).Should(Equal(1)) + + // Get and process item + item, shutdown := debouncedQueue.Get() + Expect(shutdown).To(BeFalse()) + Expect(item).To(Equal(testKey1)) + + // Check retry count (should be 0 initially) + retries := debouncedQueue.GetRetryCount(testKey1) + Expect(retries).To(Equal(0)) + + // Add back with retry (simulating failure) + debouncedQueue.AddWithRetry(testKey1) + debouncedQueue.Done(item) + + // Get item again + item, shutdown = debouncedQueue.Get() + Expect(shutdown).To(BeFalse()) + Expect(item).To(Equal(testKey1)) + + // Retry count should be incremented + retries = debouncedQueue.GetRetryCount(testKey1) + Expect(retries).To(BeNumerically(">", 0)) + + // Forget retries + debouncedQueue.ForgetRetries(testKey1) + debouncedQueue.Done(item) + }) + }) +}) diff --git a/internal/tracker/mock.go b/internal/tracker/mock.go new file mode 100644 index 00000000..f22b1798 --- /dev/null +++ b/internal/tracker/mock.go @@ -0,0 +1,185 @@ +package tracker + +import ( + "context" + "sync" + + "github.com/sigstore/model-validation-operator/api/v1alpha1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" +) + +// MockStatusTracker is a mock implementation of StatusTracker for testing +type MockStatusTracker struct { + mu sync.RWMutex + + // Track method calls for verification + AddModelValidationCalls []*v1alpha1.ModelValidation + RemoveModelValidationCalls []types.NamespacedName + ProcessPodEventCalls []ProcessPodEventCall + RemovePodEventCalls []types.UID + RemovePodByNameCalls []types.NamespacedName + StopCalls int + + // Optional error responses for testing error scenarios + ProcessPodEventError error + RemovePodEventError error + RemovePodByNameError error + + // Track ModelValidations that are considered "tracked" + TrackedModelValidations map[types.NamespacedName]bool +} + +// ProcessPodEventCall captures the parameters of a ProcessPodEvent call +type ProcessPodEventCall struct { + Ctx context.Context + Pod *corev1.Pod +} + +// NewMockStatusTracker creates a new mock status tracker +func NewMockStatusTracker() *MockStatusTracker { + return &MockStatusTracker{ + TrackedModelValidations: make(map[types.NamespacedName]bool), + } +} + +// AddModelValidation implements StatusTracker interface +func (m *MockStatusTracker) AddModelValidation(_ context.Context, mv *v1alpha1.ModelValidation) { + m.mu.Lock() + defer m.mu.Unlock() + + m.AddModelValidationCalls = append(m.AddModelValidationCalls, mv) + mvKey := types.NamespacedName{Name: mv.Name, Namespace: mv.Namespace} + m.TrackedModelValidations[mvKey] = true +} + +// RemoveModelValidation implements StatusTracker interface +func (m *MockStatusTracker) RemoveModelValidation(mvKey types.NamespacedName) { + m.mu.Lock() + defer m.mu.Unlock() + + m.RemoveModelValidationCalls = append(m.RemoveModelValidationCalls, mvKey) + delete(m.TrackedModelValidations, mvKey) +} + +// ProcessPodEvent implements StatusTracker interface +func (m *MockStatusTracker) ProcessPodEvent(ctx context.Context, pod *corev1.Pod) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.ProcessPodEventCalls = append(m.ProcessPodEventCalls, ProcessPodEventCall{ + Ctx: ctx, + Pod: pod, + }) + + return m.ProcessPodEventError +} + +// RemovePodEvent implements StatusTracker interface +func (m *MockStatusTracker) RemovePodEvent(_ context.Context, podUID types.UID) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.RemovePodEventCalls = append(m.RemovePodEventCalls, podUID) + return m.RemovePodEventError +} + +// RemovePodByName implements StatusTracker interface +func (m *MockStatusTracker) RemovePodByName(_ context.Context, podName types.NamespacedName) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.RemovePodByNameCalls = append(m.RemovePodByNameCalls, podName) + return m.RemovePodByNameError +} + +// Stop implements StatusTracker interface +func (m *MockStatusTracker) Stop() { + m.mu.Lock() + defer m.mu.Unlock() + + m.StopCalls++ +} + +// Reset clears all recorded calls +func (m *MockStatusTracker) Reset() { + m.mu.Lock() + defer m.mu.Unlock() + + m.AddModelValidationCalls = nil + m.RemoveModelValidationCalls = nil + m.ProcessPodEventCalls = nil + m.RemovePodEventCalls = nil + m.RemovePodByNameCalls = nil + m.StopCalls = 0 + m.ProcessPodEventError = nil + m.RemovePodEventError = nil + m.RemovePodByNameError = nil + m.TrackedModelValidations = make(map[types.NamespacedName]bool) +} + +// GetAddModelValidationCalls returns a copy of the AddModelValidation calls +func (m *MockStatusTracker) GetAddModelValidationCalls() []*v1alpha1.ModelValidation { + m.mu.RLock() + defer m.mu.RUnlock() + + calls := make([]*v1alpha1.ModelValidation, len(m.AddModelValidationCalls)) + copy(calls, m.AddModelValidationCalls) + return calls +} + +// GetRemoveModelValidationCalls returns a copy of the RemoveModelValidation calls +func (m *MockStatusTracker) GetRemoveModelValidationCalls() []types.NamespacedName { + m.mu.RLock() + defer m.mu.RUnlock() + + calls := make([]types.NamespacedName, len(m.RemoveModelValidationCalls)) + copy(calls, m.RemoveModelValidationCalls) + return calls +} + +// GetProcessPodEventCalls returns a copy of the ProcessPodEvent calls +func (m *MockStatusTracker) GetProcessPodEventCalls() []ProcessPodEventCall { + m.mu.RLock() + defer m.mu.RUnlock() + + calls := make([]ProcessPodEventCall, len(m.ProcessPodEventCalls)) + copy(calls, m.ProcessPodEventCalls) + return calls +} + +// GetRemovePodEventCalls returns a copy of the RemovePodEvent calls +func (m *MockStatusTracker) GetRemovePodEventCalls() []types.UID { + m.mu.RLock() + defer m.mu.RUnlock() + + calls := make([]types.UID, len(m.RemovePodEventCalls)) + copy(calls, m.RemovePodEventCalls) + return calls +} + +// GetRemovePodByNameCalls returns a copy of the RemovePodByName calls +func (m *MockStatusTracker) GetRemovePodByNameCalls() []types.NamespacedName { + m.mu.RLock() + defer m.mu.RUnlock() + + calls := make([]types.NamespacedName, len(m.RemovePodByNameCalls)) + copy(calls, m.RemovePodByNameCalls) + return calls +} + +// GetStopCalls returns the number of Stop calls +func (m *MockStatusTracker) GetStopCalls() int { + m.mu.RLock() + defer m.mu.RUnlock() + + return m.StopCalls +} + +// IsModelValidationTracked returns whether a ModelValidation is being tracked +func (m *MockStatusTracker) IsModelValidationTracked(mvKey types.NamespacedName) bool { + m.mu.RLock() + defer m.mu.RUnlock() + + return m.TrackedModelValidations[mvKey] +} diff --git a/internal/tracker/model_validation_info.go b/internal/tracker/model_validation_info.go new file mode 100644 index 00000000..2177cf9e --- /dev/null +++ b/internal/tracker/model_validation_info.go @@ -0,0 +1,161 @@ +// Package tracker provides status tracking functionality for ModelValidation resources +package tracker + +import ( + "github.com/sigstore/model-validation-operator/internal/metrics" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" +) + +// PodInfo stores information about a tracked pod +type PodInfo struct { + Name string + Namespace string + UID types.UID + Timestamp metav1.Time + ConfigHash string + AuthMethod string +} + +// ModelValidationInfo consolidates all tracking information for a ModelValidation resource +type ModelValidationInfo struct { + // Name of the ModelValidation CR + Name string + // Current configuration hash for drift detection + ConfigHash string + // Current authentication method + AuthMethod string + // Observed generation for detecting spec changes + ObservedGeneration int64 + // Pods with finalizer and matching configuration + InjectedPods map[types.UID]*PodInfo + // Pods with label but no finalizer + UninjectedPods map[types.UID]*PodInfo + // Pods with finalizer but configuration drift + OrphanedPods map[types.UID]*PodInfo +} + +// NewModelValidationInfo creates a new ModelValidationInfo with initialized maps +func NewModelValidationInfo(name, configHash, authMethod string, observedGeneration int64) *ModelValidationInfo { + return &ModelValidationInfo{ + Name: name, + ConfigHash: configHash, + AuthMethod: authMethod, + ObservedGeneration: observedGeneration, + InjectedPods: make(map[types.UID]*PodInfo), + UninjectedPods: make(map[types.UID]*PodInfo), + OrphanedPods: make(map[types.UID]*PodInfo), + } +} + +// getPodState returns the current state of a pod, or empty string if not found +func (mvi *ModelValidationInfo) getPodState(podUID types.UID) string { + if _, exists := mvi.InjectedPods[podUID]; exists { + return metrics.PodStateInjected + } + if _, exists := mvi.OrphanedPods[podUID]; exists { + return metrics.PodStateOrphaned + } + if _, exists := mvi.UninjectedPods[podUID]; exists { + return metrics.PodStateUninjected + } + return "" +} + +// movePodToState removes pod from current state, assigns to new state, and records transition +func (mvi *ModelValidationInfo) movePodToState(podInfo *PodInfo, newState string, targetMap map[types.UID]*PodInfo) { + prevState := mvi.getPodState(podInfo.UID) + mvi.RemovePod(podInfo.UID) + + targetMap[podInfo.UID] = podInfo + + if prevState != "" && prevState != newState { + metrics.RecordPodStateTransition(podInfo.Namespace, mvi.Name, prevState, newState) + } +} + +// AddInjectedPod adds a pod with finalizer, automatically determining if it's injected or orphaned +func (mvi *ModelValidationInfo) AddInjectedPod(podInfo *PodInfo) { + // Determine if pod is orphaned based on configuration drift + isOrphaned := (podInfo.ConfigHash != "" && podInfo.ConfigHash != mvi.ConfigHash) || + (podInfo.AuthMethod != "" && podInfo.AuthMethod != mvi.AuthMethod) + + if isOrphaned { + mvi.movePodToState(podInfo, metrics.PodStateOrphaned, mvi.OrphanedPods) + } else { + mvi.movePodToState(podInfo, metrics.PodStateInjected, mvi.InjectedPods) + } +} + +// AddUninjectedPod adds a pod to the uninjected category +func (mvi *ModelValidationInfo) AddUninjectedPod(podInfo *PodInfo) { + mvi.movePodToState(podInfo, metrics.PodStateUninjected, mvi.UninjectedPods) +} + +// RemovePod removes a pod from all categories +func (mvi *ModelValidationInfo) RemovePod(podUID types.UID) bool { + // A pod should only exist in one of these maps, so we can stop after first match + // Check in order of likelihood: injected, orphaned, uninjected + if _, exists := mvi.InjectedPods[podUID]; exists { + delete(mvi.InjectedPods, podUID) + return true + } + if _, exists := mvi.OrphanedPods[podUID]; exists { + delete(mvi.OrphanedPods, podUID) + return true + } + if _, exists := mvi.UninjectedPods[podUID]; exists { + delete(mvi.UninjectedPods, podUID) + return true + } + return false +} + +// UpdateConfig updates the configuration information and returns any drifted pods +// This safely handles configuration changes by detecting pods that are now orphaned +func (mvi *ModelValidationInfo) UpdateConfig(configHash, authMethod string, observedGeneration int64) []*PodInfo { + // If config hasn't changed, no drift possible + if mvi.ConfigHash == configHash && mvi.AuthMethod == authMethod { + // still update observed generation for tracking + mvi.ObservedGeneration = observedGeneration + return nil + } + + // Get drifted pods before updating config using new parameters + var driftedPods []*PodInfo + for _, podInfo := range mvi.InjectedPods { + if (podInfo.ConfigHash != "" && podInfo.ConfigHash != configHash) || + (podInfo.AuthMethod != "" && podInfo.AuthMethod != authMethod) { + driftedPods = append(driftedPods, podInfo) + } + } + + // Update the config and observed generation + mvi.ConfigHash = configHash + mvi.AuthMethod = authMethod + mvi.ObservedGeneration = observedGeneration + + // Move drifted pods to orphaned + for _, podInfo := range driftedPods { + delete(mvi.InjectedPods, podInfo.UID) + mvi.OrphanedPods[podInfo.UID] = podInfo + } + + return driftedPods +} + +// GetAllPods returns all pods from all categories for cleanup +func (mvi *ModelValidationInfo) GetAllPods() []*PodInfo { + totalPods := len(mvi.InjectedPods) + len(mvi.UninjectedPods) + len(mvi.OrphanedPods) + allPods := make([]*PodInfo, 0, totalPods) + for _, podInfo := range mvi.InjectedPods { + allPods = append(allPods, podInfo) + } + for _, podInfo := range mvi.UninjectedPods { + allPods = append(allPods, podInfo) + } + for _, podInfo := range mvi.OrphanedPods { + allPods = append(allPods, podInfo) + } + return allPods +} diff --git a/internal/tracker/pod_mapping.go b/internal/tracker/pod_mapping.go new file mode 100644 index 00000000..dd7ed58a --- /dev/null +++ b/internal/tracker/pod_mapping.go @@ -0,0 +1,84 @@ +// Package tracker provides status tracking functionality for ModelValidation resources +package tracker + +import ( + "k8s.io/apimachinery/pkg/types" +) + +// PodMapping manages bidirectional mapping between pod names and UIDs +type PodMapping struct { + nameToUID map[types.NamespacedName]types.UID + uidToName map[types.UID]types.NamespacedName +} + +// NewPodMapping creates a new bidirectional pod mapping +func NewPodMapping() *PodMapping { + return &PodMapping{ + nameToUID: make(map[types.NamespacedName]types.UID), + uidToName: make(map[types.UID]types.NamespacedName), + } +} + +// AddPod adds a pod to the mapping +func (pm *PodMapping) AddPod(name types.NamespacedName, uid types.UID) { + // Remove any existing mappings for this name or UID + pm.removePod(name, uid) + + pm.nameToUID[name] = uid + pm.uidToName[uid] = name +} + +// GetUIDByName returns the UID for a given pod name +func (pm *PodMapping) GetUIDByName(name types.NamespacedName) (types.UID, bool) { + uid, exists := pm.nameToUID[name] + return uid, exists +} + +// GetNameByUID returns the name for a given pod UID +func (pm *PodMapping) GetNameByUID(uid types.UID) (types.NamespacedName, bool) { + name, exists := pm.uidToName[uid] + return name, exists +} + +// RemovePodsByName removes multiple pod mappings by name +func (pm *PodMapping) RemovePodsByName(names ...types.NamespacedName) { + if len(names) == 0 { + return + } + + for _, name := range names { + if uid, exists := pm.nameToUID[name]; exists { + pm.removePod(name, uid) + } + } +} + +// RemovePodByUID removes a pod mapping by UID +func (pm *PodMapping) RemovePodByUID(uid types.UID) bool { + name, exists := pm.uidToName[uid] + if exists { + pm.removePod(name, uid) + } + return exists +} + +// removePod removes mappings +func (pm *PodMapping) removePod(name types.NamespacedName, uid types.UID) { + delete(pm.nameToUID, name) + delete(pm.uidToName, uid) +} + +// GetAllPodNames returns all tracked pod names +func (pm *PodMapping) GetAllPodNames() []types.NamespacedName { + names := make([]types.NamespacedName, 0, len(pm.nameToUID)) + for name := range pm.nameToUID { + names = append(names, name) + } + return names +} + +// Clear removes all mappings +func (pm *PodMapping) Clear() { + pm.nameToUID = make(map[types.NamespacedName]types.UID) + pm.uidToName = make(map[types.UID]types.NamespacedName) +} diff --git a/internal/tracker/status_tracker.go b/internal/tracker/status_tracker.go new file mode 100644 index 00000000..5ee31c0a --- /dev/null +++ b/internal/tracker/status_tracker.go @@ -0,0 +1,568 @@ +// Package tracker provides status tracking functionality for ModelValidation resources +package tracker + +import ( + "context" + "fmt" + "slices" + "sort" + "sync" + "time" + + "github.com/sigstore/model-validation-operator/api/v1alpha1" + "github.com/sigstore/model-validation-operator/internal/constants" + "github.com/sigstore/model-validation-operator/internal/metrics" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + "sigs.k8s.io/controller-runtime/pkg/log" +) + +// StatusTrackerImpl tracks injected pods and namespaces for ModelValidation resources +type StatusTrackerImpl struct { + client client.Client + mu sync.RWMutex + // Consolidated tracking information for each ModelValidation + modelValidations map[types.NamespacedName]*ModelValidationInfo + // Count of ModelValidation resources per namespace + mvNamespaces map[string]int + // Pod name/UID bidirectional mapping + podMapping *PodMapping + + // Configuration + statusUpdateTimeout time.Duration + + // Async update components + debouncedQueue DebouncedQueue + stopCh chan struct{} + wg sync.WaitGroup +} + +// StatusTrackerConfig holds configuration options for the status tracker +type StatusTrackerConfig struct { + DebounceDuration time.Duration + RetryBaseDelay time.Duration + RetryMaxDelay time.Duration + RateLimitQPS float64 + RateLimitBurst int + StatusUpdateTimeout time.Duration +} + +// NewStatusTracker creates a new status tracker with explicit configuration +func NewStatusTracker(client client.Client, config StatusTrackerConfig) StatusTracker { + debouncedQueue := NewDebouncedQueue(DebouncedQueueConfig{ + DebounceDuration: config.DebounceDuration, + RetryBaseDelay: config.RetryBaseDelay, + RetryMaxDelay: config.RetryMaxDelay, + RateLimitQPS: config.RateLimitQPS, + RateLimitBurst: config.RateLimitBurst, + }) + + st := &StatusTrackerImpl{ + client: client, + modelValidations: make(map[types.NamespacedName]*ModelValidationInfo), + mvNamespaces: make(map[string]int), + podMapping: NewPodMapping(), + statusUpdateTimeout: config.StatusUpdateTimeout, + debouncedQueue: debouncedQueue, + stopCh: make(chan struct{}), + } + + st.wg.Add(1) + go st.asyncUpdateWorker() + + return st +} + +// collectPodTrackingInfo is a helper to convert PodInfo to PodTrackingInfo +func collectPodTrackingInfo(pods map[types.UID]*PodInfo, result *[]v1alpha1.PodTrackingInfo) { + if pods == nil { + return + } + for _, podInfo := range pods { + *result = append(*result, v1alpha1.PodTrackingInfo{ + Name: podInfo.Name, + UID: string(podInfo.UID), + InjectedAt: podInfo.Timestamp, + }) + } +} + +// collectPodsForCleanup is a helper to collect pod names for cleanup +func collectPodsForCleanup(allPods []*PodInfo, result *[]types.NamespacedName) { + if allPods == nil { + return + } + for _, podInfo := range allPods { + *result = append(*result, types.NamespacedName{ + Name: podInfo.Name, + Namespace: podInfo.Namespace, + }) + } +} + +// Stop stops the async update worker +func (st *StatusTrackerImpl) Stop() { + // Wait for pending updates before stopping + st.debouncedQueue.WaitForUpdates() + + // Stop components + close(st.stopCh) + st.debouncedQueue.ShutDown() + st.wg.Wait() +} + +// asyncUpdateWorker processes status updates asynchronously +func (st *StatusTrackerImpl) asyncUpdateWorker() { + defer st.wg.Done() + + for { + select { + case <-st.stopCh: + return + default: + st.processNextUpdate() + } + } +} + +// processNextUpdate processes the next update from the queue +func (st *StatusTrackerImpl) processNextUpdate() { + mvKey, shutdown := st.debouncedQueue.Get() + if shutdown { + return + } + defer st.debouncedQueue.Done(mvKey) + + ctx, cancel := context.WithTimeout(context.Background(), st.statusUpdateTimeout) + defer cancel() + + if err := st.doUpdateStatus(ctx, mvKey); err != nil { + logger := log.FromContext(ctx) + logger.V(1).Info("Status update failed, will retry", + "modelvalidation", mvKey, + "attempts", st.debouncedQueue.GetRetryCount(mvKey)+1, + "error", err) + metrics.RecordRetryAttempt(mvKey.Namespace, mvKey.Name) + st.debouncedQueue.AddWithRetry(mvKey) + } else { + st.debouncedQueue.ForgetRetries(mvKey) + } +} + +// WaitForUpdates waits for all pending status updates to complete +// This is useful in tests to ensure async operations finish before assertions +func (st *StatusTrackerImpl) WaitForUpdates() { + st.debouncedQueue.WaitForCompletion() +} + +// doUpdateStatus updates the ModelValidation status with current metrics +// This is called asynchronously from the update worker +func (st *StatusTrackerImpl) doUpdateStatus(ctx context.Context, mvKey types.NamespacedName) error { + logger := log.FromContext(ctx) + startTime := time.Now() + + mv := &v1alpha1.ModelValidation{} + if err := st.client.Get(ctx, mvKey, mv); err != nil { + if errors.IsNotFound(err) { + st.mu.Lock() + delete(st.modelValidations, mvKey) + st.mu.Unlock() + recordStatusUpdateResult(mvKey.Namespace, mvKey.Name, metrics.StatusUpdateFailure, time.Since(startTime)) + return nil + } + recordStatusUpdateResult(mvKey.Namespace, mvKey.Name, metrics.StatusUpdateFailure, time.Since(startTime)) + return err + } + + trackedPods := []v1alpha1.PodTrackingInfo{} + uninjectedPods := []v1alpha1.PodTrackingInfo{} + orphanedPods := []v1alpha1.PodTrackingInfo{} + var trackedConfigHash, trackedAuthMethod string + + st.mu.RLock() + mvInfo := st.modelValidations[mvKey] + if mvInfo != nil { + collectPodTrackingInfo(mvInfo.InjectedPods, &trackedPods) + collectPodTrackingInfo(mvInfo.UninjectedPods, &uninjectedPods) + collectPodTrackingInfo(mvInfo.OrphanedPods, &orphanedPods) + trackedConfigHash = mvInfo.ConfigHash + trackedAuthMethod = mvInfo.AuthMethod + } + st.mu.RUnlock() + if mvInfo == nil { + return nil + } + sort.Slice(trackedPods, func(i, j int) bool { + return trackedPods[i].Name < trackedPods[j].Name + }) + sort.Slice(uninjectedPods, func(i, j int) bool { + return uninjectedPods[i].Name < uninjectedPods[j].Name + }) + sort.Slice(orphanedPods, func(i, j int) bool { + return orphanedPods[i].Name < orphanedPods[j].Name + }) + + // Check if ModelValidation parameters have changed since tracking began + currentConfigHash := mv.GetConfigHash() + currentAuthMethod := mv.GetAuthMethod() + + // Check for configuration drift between tracked and current state + if trackedConfigHash != currentConfigHash || trackedAuthMethod != currentAuthMethod { + logger.Error(fmt.Errorf("configuration drift detected"), "ModelValidation config drift during status update", + "modelvalidation", mvKey, + "oldConfigHash", trackedConfigHash, + "newConfigHash", currentConfigHash, + "oldAuthMethod", trackedAuthMethod, + "newAuthMethod", currentAuthMethod) + + if trackedConfigHash != currentConfigHash { + metrics.RecordConfigurationDrift(mvKey.Namespace, mvKey.Name, "config_hash") + } + if trackedAuthMethod != currentAuthMethod { + metrics.RecordConfigurationDrift(mvKey.Namespace, mvKey.Name, "auth_method") + } + + recordStatusUpdateResult(mvKey.Namespace, mvKey.Name, metrics.StatusUpdateFailure, time.Since(startTime)) + return fmt.Errorf( + "configuration drift detected for ModelValidation %s: hash changed from %s to %s, auth method changed from %s to %s", + mvKey, trackedConfigHash, currentConfigHash, trackedAuthMethod, currentAuthMethod) + } + + newStatus := v1alpha1.ModelValidationStatus{ + Conditions: mv.Status.Conditions, + InjectedPodCount: int32(len(trackedPods)), + UninjectedPodCount: int32(len(uninjectedPods)), + OrphanedPodCount: int32(len(orphanedPods)), + AuthMethod: trackedAuthMethod, + InjectedPods: trackedPods, + UninjectedPods: uninjectedPods, + OrphanedPods: orphanedPods, + LastUpdated: metav1.Now(), + } + + if statusEqual(mv.Status, newStatus) { + logger.V(2).Info("Status unchanged, skipping update", "modelvalidation", mvKey) + recordStatusUpdateResult(mvKey.Namespace, mvKey.Name, metrics.StatusUpdateSuccess, time.Since(startTime)) + return nil + } + + mv.Status = newStatus + if err := st.client.Status().Update(ctx, mv); err != nil { + logger.Error(err, "Failed to update ModelValidation status", "modelvalidation", mvKey) + recordStatusUpdateResult(mvKey.Namespace, mvKey.Name, metrics.StatusUpdateFailure, time.Since(startTime)) + return err + } + + recordStatusUpdateResult(mvKey.Namespace, mvKey.Name, metrics.StatusUpdateSuccess, time.Since(startTime)) + + recordPodCounts(mvKey.Namespace, mvKey.Name, len(trackedPods), len(uninjectedPods), len(orphanedPods)) + + logger.Info("Updated ModelValidation status", + "modelvalidation", mvKey, + "injectedPods", mv.Status.InjectedPodCount, + "uninjectedPods", mv.Status.UninjectedPodCount, + "orphanedPods", mv.Status.OrphanedPodCount, + "authMethod", mv.Status.AuthMethod) + + return nil +} + +// comparePodSlices compares two slices of PodTrackingInfo for equality +func comparePodSlices(a, b []v1alpha1.PodTrackingInfo) bool { + return slices.EqualFunc(a, b, func(x, y v1alpha1.PodTrackingInfo) bool { + return x.Name == y.Name && x.UID == y.UID + }) +} + +// recordStatusUpdateResult records both the status update result and duration metrics +func recordStatusUpdateResult(namespace, modelValidation, result string, duration time.Duration) { + metrics.RecordStatusUpdate(namespace, modelValidation, result) + metrics.RecordStatusUpdateDuration(namespace, modelValidation, result, duration.Seconds()) +} + +// recordPodCounts records current pod counts for all states +func recordPodCounts(namespace, modelValidation string, injected, uninjected, orphaned int) { + metrics.RecordPodCount(namespace, modelValidation, metrics.PodStateInjected, float64(injected)) + metrics.RecordPodCount(namespace, modelValidation, metrics.PodStateUninjected, float64(uninjected)) + metrics.RecordPodCount(namespace, modelValidation, metrics.PodStateOrphaned, float64(orphaned)) +} + +// statusEqual compares two ModelValidationStatus objects for equality +// ignoring LastUpdated timestamp. +func statusEqual(a, b v1alpha1.ModelValidationStatus) bool { + if a.InjectedPodCount != b.InjectedPodCount || + a.UninjectedPodCount != b.UninjectedPodCount || + a.OrphanedPodCount != b.OrphanedPodCount || + a.AuthMethod != b.AuthMethod { + return false + } + + return comparePodSlices(a.InjectedPods, b.InjectedPods) && + comparePodSlices(a.UninjectedPods, b.UninjectedPods) && + comparePodSlices(a.OrphanedPods, b.OrphanedPods) +} + +// AddModelValidation adds a ModelValidation to tracking using the provided ModelValidation +func (st *StatusTrackerImpl) AddModelValidation(ctx context.Context, mv *v1alpha1.ModelValidation) { + mvKey := types.NamespacedName{Name: mv.Name, Namespace: mv.Namespace} + + st.mu.Lock() + mvi, alreadyTracking := st.modelValidations[mvKey] + if !alreadyTracking { + mvInfo := NewModelValidationInfo(mv.Name, mv.GetConfigHash(), mv.GetAuthMethod(), mv.Generation) + st.modelValidations[mvKey] = mvInfo + st.mvNamespaces[mvKey.Namespace]++ + + metrics.RecordModelValidationCR(mvKey.Namespace, float64(st.mvNamespaces[mvKey.Namespace])) + } else if mv.Generation != mvi.ObservedGeneration { + // Update existing tracking with current config and handle drift + driftedPods := st.modelValidations[mvKey].UpdateConfig(mv.GetConfigHash(), mv.GetAuthMethod(), mv.Generation) + + if len(driftedPods) > 0 { + logger := log.FromContext(ctx) + logger.Info("Detected configuration drift, moved pods to orphaned status", + "modelvalidation", mvKey, "driftedPods", len(driftedPods)) + + metrics.RecordMultiplePodStateTransitions(mvKey.Namespace, mvKey.Name, + metrics.PodStateInjected, metrics.PodStateOrphaned, len(driftedPods)) + } + } + st.mu.Unlock() + + if !alreadyTracking { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), st.statusUpdateTimeout) + defer cancel() + + if err := st.seedExistingPods(ctx, mvKey); err != nil { + logger := log.FromContext(ctx) + logger.Error(err, "Failed to seed existing pods", "modelvalidation", mvKey) + } + }() + } + + st.debouncedQueue.Add(mvKey) +} + +// RemoveModelValidation removes a ModelValidation from tracking +func (st *StatusTrackerImpl) RemoveModelValidation(mvKey types.NamespacedName) { + st.mu.Lock() + defer st.mu.Unlock() + + // Only decrement namespace counter if ModelValidation was actually tracked + mvInfo, wasTracked := st.modelValidations[mvKey] + if wasTracked { + delete(st.modelValidations, mvKey) + st.mvNamespaces[mvKey.Namespace]-- + + if st.mvNamespaces[mvKey.Namespace] <= 0 { + delete(st.mvNamespaces, mvKey.Namespace) + metrics.RecordModelValidationCR(mvKey.Namespace, 0) + } else { + metrics.RecordModelValidationCR(mvKey.Namespace, float64(st.mvNamespaces[mvKey.Namespace])) + } + + allPods := mvInfo.GetAllPods() + var podsToCleanup []types.NamespacedName + collectPodsForCleanup(allPods, &podsToCleanup) + + st.podMapping.RemovePodsByName(podsToCleanup...) + + // Reset pod count metrics for removed ModelValidation + metrics.RecordPodCount(mvKey.Namespace, mvKey.Name, metrics.PodStateInjected, 0) + metrics.RecordPodCount(mvKey.Namespace, mvKey.Name, metrics.PodStateUninjected, 0) + metrics.RecordPodCount(mvKey.Namespace, mvKey.Name, metrics.PodStateOrphaned, 0) + } +} + +// IsModelValidationTracked checks if a ModelValidation is being tracked +func (st *StatusTrackerImpl) IsModelValidationTracked(mvKey types.NamespacedName) bool { + st.mu.RLock() + defer st.mu.RUnlock() + + return st.modelValidations[mvKey] != nil +} + +// GetObservedGeneration returns the observed generation for a tracked ModelValidation +// Returns the generation and a boolean indicating whether the ModelValidation is tracked +func (st *StatusTrackerImpl) GetObservedGeneration(mvKey types.NamespacedName) (int64, bool) { + st.mu.RLock() + defer st.mu.RUnlock() + + if mvInfo := st.modelValidations[mvKey]; mvInfo != nil { + return mvInfo.ObservedGeneration, true + } + return 0, false +} + +// seedExistingPods processes existing pods using the provided ModelValidation +func (st *StatusTrackerImpl) seedExistingPods(ctx context.Context, mvKey types.NamespacedName) error { + logger := log.FromContext(ctx) + + // List all pods in the namespace with the ModelValidation label + podList := &corev1.PodList{} + listOpts := []client.ListOption{ + client.InNamespace(mvKey.Namespace), + client.MatchingLabels{constants.ModelValidationLabel: mvKey.Name}, + } + + if err := st.client.List(ctx, podList, listOpts...); err != nil { + logger.Error(err, "Failed to list existing pods for seeding", "modelvalidation", mvKey) + return err + } + + logger.Info("Seeding existing pods", "modelvalidation", mvKey, "podCount", len(podList.Items)) + + if err := st.processSeedPods(ctx, podList.Items, mvKey); err != nil { + logger.Error(err, "Failed to process some pods during seeding", "modelvalidation", mvKey) + } + + return nil +} + +// processSeedPods processes seed pods using the provided ModelValidation +func (st *StatusTrackerImpl) processSeedPods(ctx context.Context, pods []corev1.Pod, mvKey types.NamespacedName) error { + logger := log.FromContext(ctx) + var errors []error + + for i := range pods { + pod := &pods[i] + if err := st.processSeedPodEvent(pod, mvKey); err != nil { + logger.Error(err, "Failed to process existing pod during batch seeding", "pod", pod.Name, "modelvalidation", mvKey) + errors = append(errors, err) + } + } + + // Return an aggregate error if any individual pod processing failed + if len(errors) > 0 { + return fmt.Errorf("failed to process %d out of %d pods", len(errors), len(pods)) + } + + return nil +} + +// processSeedPodEvent processes a pod event with a provided ModelValidation +func (st *StatusTrackerImpl) processSeedPodEvent(pod *corev1.Pod, mvKey types.NamespacedName) error { + modelValidationName, hasLabel := pod.Labels[constants.ModelValidationLabel] + if !hasLabel || modelValidationName == "" { + return nil + } + + st.mu.Lock() + defer st.mu.Unlock() + + mvInfo := st.modelValidations[mvKey] + if mvInfo != nil { + st.processPodEventCommon(pod, mvKey, mvInfo) + } + + return nil +} + +// processPodEventCommon contains the common logic for processing pod events +func (st *StatusTrackerImpl) processPodEventCommon( + pod *corev1.Pod, + mvKey types.NamespacedName, + mvInfo *ModelValidationInfo, +) { + hasOurFinalizer := controllerutil.ContainsFinalizer(pod, constants.ModelValidationFinalizer) + + podInfo := &PodInfo{ + Name: pod.Name, + Namespace: pod.Namespace, + UID: pod.UID, + Timestamp: pod.CreationTimestamp, + ConfigHash: pod.Annotations[constants.ConfigHashAnnotationKey], + AuthMethod: pod.Annotations[constants.AuthMethodAnnotationKey], + } + + podNamespacedName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace} + st.podMapping.AddPod(podNamespacedName, pod.UID) + + if hasOurFinalizer { + mvInfo.AddInjectedPod(podInfo) + } else { + mvInfo.AddUninjectedPod(podInfo) + } + + st.debouncedQueue.Add(mvKey) +} + +// ProcessPodEvent processes a pod event from controllers +// This determines how to categorize the pod based on its state and configuration consistency +func (st *StatusTrackerImpl) ProcessPodEvent(_ context.Context, pod *corev1.Pod) error { + if pod == nil { + return fmt.Errorf("pod cannot be nil") + } + + modelValidationName, hasLabel := pod.Labels[constants.ModelValidationLabel] + if !hasLabel || modelValidationName == "" { + return nil + } + + mvKey := types.NamespacedName{ + Name: modelValidationName, + Namespace: pod.Namespace, + } + + st.mu.Lock() + defer st.mu.Unlock() + + mvInfo := st.modelValidations[mvKey] + if mvInfo != nil { + st.processPodEventCommon(pod, mvKey, mvInfo) + } + return nil +} + +// RemovePodEvent removes a pod from tracking when it's deleted +func (st *StatusTrackerImpl) RemovePodEvent(_ context.Context, podUID types.UID) error { + st.mu.Lock() + defer st.mu.Unlock() + + updatedMVs := st.removePodFromTrackingMapsUnsafe(podUID) + + for mvKey := range updatedMVs { + st.debouncedQueue.Add(mvKey) + } + + return nil +} + +// RemovePodByName removes a pod from tracking by its NamespacedName +// This is useful when we know a pod was deleted but don't have its UID +func (st *StatusTrackerImpl) RemovePodByName(_ context.Context, podName types.NamespacedName) error { + st.mu.Lock() + defer st.mu.Unlock() + + podUID, exists := st.podMapping.GetUIDByName(podName) + if !exists { + return nil + } + + updatedMVs := st.removePodFromTrackingMapsUnsafe(podUID) + st.podMapping.RemovePodsByName(podName) + + for mvKey := range updatedMVs { + st.debouncedQueue.Add(mvKey) + } + + return nil +} + +// removePodFromTrackingMapsUnsafe removes a pod from tracking maps and returns updated MVs +func (st *StatusTrackerImpl) removePodFromTrackingMapsUnsafe(podUID types.UID) map[types.NamespacedName]bool { + updatedMVs := make(map[types.NamespacedName]bool) + + for mvKey, mvInfo := range st.modelValidations { + if mvInfo.RemovePod(podUID) { + updatedMVs[mvKey] = true + } + } + + return updatedMVs +} diff --git a/internal/tracker/status_tracker_intf.go b/internal/tracker/status_tracker_intf.go new file mode 100644 index 00000000..b74ec498 --- /dev/null +++ b/internal/tracker/status_tracker_intf.go @@ -0,0 +1,24 @@ +package tracker + +import ( + "context" + + "github.com/sigstore/model-validation-operator/api/v1alpha1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" +) + +// StatusTracker defines the interface for tracking ModelValidation and Pod events +type StatusTracker interface { + // ModelValidation tracking methods + AddModelValidation(ctx context.Context, mv *v1alpha1.ModelValidation) + RemoveModelValidation(mvKey types.NamespacedName) + + // Pod tracking methods + ProcessPodEvent(ctx context.Context, pod *corev1.Pod) error + RemovePodEvent(ctx context.Context, podUID types.UID) error + RemovePodByName(ctx context.Context, podName types.NamespacedName) error + + // Lifecycle methods + Stop() +} diff --git a/internal/tracker/status_tracker_test.go b/internal/tracker/status_tracker_test.go new file mode 100644 index 00000000..2bc8777e --- /dev/null +++ b/internal/tracker/status_tracker_test.go @@ -0,0 +1,762 @@ +package tracker + +import ( + "context" + "sync" + "time" + + . "github.com/onsi/ginkgo/v2" //nolint:revive + . "github.com/onsi/gomega" //nolint:revive + + "github.com/sigstore/model-validation-operator/api/v1alpha1" + "github.com/sigstore/model-validation-operator/internal/constants" + "github.com/sigstore/model-validation-operator/internal/testutil" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +// NewStatusTrackerForTesting creates a status tracker with fast test-friendly defaults +func NewStatusTrackerForTesting(client client.Client) StatusTracker { + return NewStatusTracker(client, StatusTrackerConfig{ + DebounceDuration: 50 * time.Millisecond, // Faster for tests + RetryBaseDelay: 10 * time.Millisecond, // Faster retries for tests + RetryMaxDelay: 1 * time.Second, // Lower max delay for tests + RateLimitQPS: 100, // Higher QPS for tests + RateLimitBurst: 1000, // Higher burst for tests + StatusUpdateTimeout: 5 * time.Second, // Shorter timeout for tests + }) +} + +// setupTestEnvironment creates a common test setup with ModelValidation and StatusTracker +func setupTestEnvironment(mv *v1alpha1.ModelValidation) (*StatusTrackerImpl, client.Client, types.NamespacedName) { + fakeClient := testutil.SetupFakeClientWithObjects(mv) + statusTracker := NewStatusTrackerForTesting(fakeClient).(*StatusTrackerImpl) + mvKey := types.NamespacedName{Name: mv.Name, Namespace: mv.Namespace} + return statusTracker, fakeClient, mvKey +} + +// expectPodCount waits for and verifies expected pod counts in ModelValidation status +func expectPodCount( + ctx context.Context, + fakeClient client.Client, + mvKey types.NamespacedName, + injected, uninjected, orphaned int32, +) { + Eventually(func() bool { + updatedMV := testutil.GetModelValidationFromClientExpected(ctx, fakeClient, mvKey) + return updatedMV.Status.InjectedPodCount == injected && + updatedMV.Status.UninjectedPodCount == uninjected && + updatedMV.Status.OrphanedPodCount == orphaned + }, "2s", "50ms").Should(BeTrue()) +} + +// getNamespaceCount returns the count of ModelValidations in the given namespace +func getNamespaceCount(tracker *StatusTrackerImpl, namespace string) int { + tracker.mu.RLock() + defer tracker.mu.RUnlock() + return tracker.mvNamespaces[namespace] +} + +// namespaceExists checks if a namespace exists in the tracker +func namespaceExists(tracker *StatusTrackerImpl, namespace string) bool { + tracker.mu.RLock() + defer tracker.mu.RUnlock() + _, exists := tracker.mvNamespaces[namespace] + return exists +} + +var _ = Describe("StatusTracker", func() { + var ( + ctx context.Context + statusTracker *StatusTrackerImpl + fakeClient client.Client + mvKey types.NamespacedName + ) + + BeforeEach(func() { + ctx = context.Background() + }) + + AfterEach(func() { + if statusTracker != nil { + statusTracker.Stop() + } + }) + + Context("when tracking pod injections", func() { + It("should track injected pods and update ModelValidation status", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "test-pod", + Namespace: "default", + UID: types.UID("test-uid"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + }) + + err := statusTracker.ProcessPodEvent(ctx, pod) + Expect(err).NotTo(HaveOccurred()) + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 1, 0, 0) + + updatedMV := testutil.GetModelValidationFromClientExpected(ctx, fakeClient, mvKey) + Expect(updatedMV.Status.InjectedPods).To(HaveLen(1)) + Expect(updatedMV.Status.InjectedPods[0].Name).To(Equal("test-pod")) + }) + + It("should categorize pods correctly based on finalizer and CR existence", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + injectedPod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "injected-pod", + Namespace: "default", + UID: types.UID("uid-1"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + }) + + uninjectedPod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "uninjected-pod", + Namespace: "default", + UID: types.UID("uid-2"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + }) + + err := statusTracker.ProcessPodEvent(ctx, injectedPod) + Expect(err).NotTo(HaveOccurred()) + + err = statusTracker.ProcessPodEvent(ctx, uninjectedPod) + Expect(err).NotTo(HaveOccurred()) + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 1, 1, 0) + + updatedMV := testutil.GetModelValidationFromClientExpected(ctx, fakeClient, mvKey) + Expect(updatedMV.Status.InjectedPods).To(HaveLen(1)) + Expect(updatedMV.Status.UninjectedPods).To(HaveLen(1)) + }) + }) + + Context("when removing pods", func() { + It("should remove pods from tracking and update status", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + pod1 := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "test-pod-1", + Namespace: "default", + UID: types.UID("test-uid-1"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + }) + pod2 := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "test-pod-2", + Namespace: "default", + UID: types.UID("test-uid-2"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + }) + + err := statusTracker.ProcessPodEvent(ctx, pod1) + Expect(err).NotTo(HaveOccurred()) + err = statusTracker.ProcessPodEvent(ctx, pod2) + Expect(err).NotTo(HaveOccurred()) + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 2, 0, 0) + + err = statusTracker.RemovePodEvent(ctx, pod1.UID) + Expect(err).NotTo(HaveOccurred()) + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 1, 0, 0) + + updatedMV := testutil.GetModelValidationFromClientExpected(ctx, fakeClient, mvKey) + Expect(updatedMV.Status.InjectedPods).To(HaveLen(1)) + Expect(updatedMV.Status.InjectedPods[0].Name).To(Equal("test-pod-2")) + }) + }) + + Context("when managing ModelValidation tracking", func() { + It("should add and remove ModelValidations from tracking", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + + Expect(statusTracker.IsModelValidationTracked(mvKey)).To(BeFalse()) + + statusTracker.AddModelValidation(ctx, mv) + Expect(statusTracker.IsModelValidationTracked(mvKey)).To(BeTrue()) + + statusTracker.RemoveModelValidation(mvKey) + Expect(statusTracker.IsModelValidationTracked(mvKey)).To(BeFalse()) + }) + + It("should handle AddModelValidation idempotently", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + mvKey = testutil.CreateTestNamespacedName("test-mv", "default") + + Expect(statusTracker.IsModelValidationTracked(mvKey)).To(BeFalse()) + + initialCount := getNamespaceCount(statusTracker, "default") + + statusTracker.AddModelValidation(ctx, mv) + Expect(statusTracker.IsModelValidationTracked(mvKey)).To(BeTrue()) + + firstCount := getNamespaceCount(statusTracker, "default") + Expect(firstCount).To(Equal(initialCount + 1)) + + statusTracker.AddModelValidation(ctx, mv) + Expect(statusTracker.IsModelValidationTracked(mvKey)).To(BeTrue()) + + secondCount := getNamespaceCount(statusTracker, "default") + Expect(secondCount).To(Equal(firstCount), "Second AddModelValidation should not increment namespace counter") + + statusTracker.AddModelValidation(ctx, mv) + thirdCount := getNamespaceCount(statusTracker, "default") + Expect(thirdCount).To(Equal(firstCount), "Third AddModelValidation should not increment namespace counter") + }) + + It("should skip AddModelValidation when generation hasn't changed", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv-generation", + Namespace: "default", + }) + statusTracker, _, mvKey = setupTestEnvironment(mv) + + // Set initial generation + mv.Generation = 1 + statusTracker.AddModelValidation(ctx, mv) + Expect(statusTracker.IsModelValidationTracked(mvKey)).To(BeTrue()) + + initialObservedGen, tracked := statusTracker.GetObservedGeneration(mvKey) + Expect(tracked).To(BeTrue(), "ModelValidation should be tracked") + Expect(initialObservedGen).To(Equal(int64(1))) + + // Call AddModelValidation again with same generation - should be skipped + statusTracker.AddModelValidation(ctx, mv) + + currentObservedGen, tracked := statusTracker.GetObservedGeneration(mvKey) + Expect(tracked).To(BeTrue(), "ModelValidation should still be tracked") + Expect(currentObservedGen).To(Equal(int64(1)), + "Generation should remain unchanged when same generation is processed") + + // Updated generation should trigger processing + mv.Generation = 2 + statusTracker.AddModelValidation(ctx, mv) + + updatedObservedGen, tracked := statusTracker.GetObservedGeneration(mvKey) + Expect(tracked).To(BeTrue(), "ModelValidation should still be tracked") + Expect(updatedObservedGen).To(Equal(int64(2)), "Generation should be updated when new generation is processed") + }) + + It("should return false for untracked ModelValidation in GetObservedGeneration", func() { + untrackedKey := types.NamespacedName{Name: "untracked-mv", Namespace: "default"} + + // Create a statusTracker but don't add any ModelValidation + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "tracked-mv", + Namespace: "default", + }) + statusTracker, _, _ = setupTestEnvironment(mv) + + _, tracked := statusTracker.GetObservedGeneration(untrackedKey) + Expect(tracked).To(BeFalse(), "Untracked ModelValidation should return false") + }) + + It("should handle RemoveModelValidation idempotently and manage namespace counts correctly", func() { + mv1 := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv-1", + Namespace: "default", + }) + mv2 := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv-2", + Namespace: "default", + }) + mv3 := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv-3", + Namespace: "other-namespace", + }) + fakeClient = testutil.SetupFakeClientWithObjects(mv1, mv2, mv3) + statusTracker = NewStatusTrackerForTesting(fakeClient).(*StatusTrackerImpl) + + mvKey1 := testutil.CreateTestNamespacedName("test-mv-1", "default") + mvKey2 := testutil.CreateTestNamespacedName("test-mv-2", "default") + mvKey3 := testutil.CreateTestNamespacedName("test-mv-3", "other-namespace") + + statusTracker.AddModelValidation(ctx, mv1) + statusTracker.AddModelValidation(ctx, mv2) + statusTracker.AddModelValidation(ctx, mv3) + + defaultCount := getNamespaceCount(statusTracker, "default") + otherCount := getNamespaceCount(statusTracker, "other-namespace") + Expect(defaultCount).To(Equal(2)) + Expect(otherCount).To(Equal(1)) + + statusTracker.RemoveModelValidation(mvKey1) + Expect(statusTracker.IsModelValidationTracked(mvKey1)).To(BeFalse()) + + defaultCountAfterFirst := getNamespaceCount(statusTracker, "default") + Expect(defaultCountAfterFirst).To(Equal(1)) + + statusTracker.RemoveModelValidation(mvKey1) + defaultCountAfterSecond := getNamespaceCount(statusTracker, "default") + Expect(defaultCountAfterSecond).To(Equal(1), "Second RemoveModelValidation should not decrement namespace counter") + + statusTracker.RemoveModelValidation(mvKey2) + defaultExists := namespaceExists(statusTracker, "default") + otherCountFinal := getNamespaceCount(statusTracker, "other-namespace") + Expect(defaultExists).To(BeFalse(), "Namespace should be deleted when count reaches zero") + Expect(otherCountFinal).To(Equal(1), "Other namespace should be unaffected") + + statusTracker.RemoveModelValidation(mvKey3) + otherExists := namespaceExists(statusTracker, "other-namespace") + Expect(otherExists).To(BeFalse(), "Other namespace should also be deleted when count reaches zero") + }) + + DescribeTable("should handle configuration drift detection", + func( + mvConfig testutil.TestModelValidationOptions, + podAuth string, + podConfigHash string, + expectedInjected, expectedOrphaned int32, + expectedPodName string, + ) { + mv := testutil.CreateTestModelValidation(mvConfig) + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + podAnnotations := map[string]string{ + constants.AuthMethodAnnotationKey: podAuth, + } + if podConfigHash != "" { + podAnnotations[constants.ConfigHashAnnotationKey] = podConfigHash + } + + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: expectedPodName, + Namespace: "default", + UID: types.UID("uid-" + expectedPodName), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + Annotations: podAnnotations, + }) + + err := statusTracker.ProcessPodEvent(ctx, pod) + Expect(err).NotTo(HaveOccurred()) + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, expectedInjected, 0, expectedOrphaned) + + updatedMV := testutil.GetModelValidationFromClientExpected(ctx, fakeClient, mvKey) + if expectedInjected > 0 { + Expect(updatedMV.Status.InjectedPods).To(HaveLen(int(expectedInjected))) + Expect(updatedMV.Status.InjectedPods[0].Name).To(Equal(expectedPodName)) + } + if expectedOrphaned > 0 { + Expect(updatedMV.Status.OrphanedPods).To(HaveLen(int(expectedOrphaned))) + Expect(updatedMV.Status.OrphanedPods[0].Name).To(Equal(expectedPodName)) + } + }, + Entry("auth method mismatch - sigstore MV with pki pod", + testutil.TestModelValidationOptions{Name: "test-mv", Namespace: "default", ConfigType: "sigstore"}, + "pki", "", int32(0), int32(1), "orphaned-pod"), + Entry("matching pki configuration", + testutil.TestModelValidationOptions{Name: "test-mv", Namespace: "default", ConfigType: "pki"}, + "pki", "", int32(1), int32(0), "matching-pod"), + Entry("sigstore config hash drift", + testutil.TestModelValidationOptions{ + Name: "test-mv", Namespace: "default", ConfigType: "sigstore", + CertIdentity: "user@example.com", CertOidcIssuer: "https://accounts.google.com", + }, + "sigstore", func() string { + oldMV := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", Namespace: "default", ConfigType: "sigstore", + CertIdentity: "different@example.com", CertOidcIssuer: "https://accounts.google.com", + }) + return oldMV.GetConfigHash() + }(), int32(0), int32(1), "drift-pod"), + Entry("pki config hash drift", + testutil.TestModelValidationOptions{ + Name: "test-mv", Namespace: "default", ConfigType: "pki", CertificateCA: "/path/to/ca.crt", + }, + "pki", func() string { + oldMV := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", Namespace: "default", ConfigType: "pki", CertificateCA: "/different/path/ca.crt", + }) + return oldMV.GetConfigHash() + }(), int32(0), int32(1), "pki-drift-pod"), + ) + + It("should handle pod deletion by name", func() { + // Create ModelValidation with Sigstore config + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + ConfigType: "sigstore", + }) + + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "deleted-pod", + Namespace: "default", + UID: types.UID("uid-deleted"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + }) + + err := statusTracker.ProcessPodEvent(ctx, pod) + Expect(err).NotTo(HaveOccurred()) + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 1, 0, 0) + + // Now simulate pod deletion by removing it by name + podName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace} + err = statusTracker.RemovePodByName(ctx, podName) + Expect(err).NotTo(HaveOccurred()) + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 0, 0, 0) + + updatedMV := testutil.GetModelValidationFromClientExpected(ctx, fakeClient, mvKey) + Expect(updatedMV.Status.InjectedPods).To(BeEmpty()) + }) + + It("should re-evaluate existing pods when ModelValidation is updated", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + ConfigType: "sigstore", + }) + + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "test-pod", + Namespace: "default", + UID: types.UID("test-uid"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + Annotations: map[string]string{ + constants.AuthMethodAnnotationKey: "sigstore", + }, + }) + + fakeClient = testutil.SetupFakeClientWithObjects(mv, pod) + statusTracker = NewStatusTrackerForTesting(fakeClient).(*StatusTrackerImpl) + + mvKey := types.NamespacedName{Name: mv.Name, Namespace: mv.Namespace} + + statusTracker.AddModelValidation(ctx, mv) + + err := statusTracker.ProcessPodEvent(ctx, pod) + Expect(err).NotTo(HaveOccurred()) + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 1, 0, 0) + + // This should trigger re-evaluation of existing pods + statusTracker.AddModelValidation(ctx, mv) + + // The pod should still be tracked (no configuration drift in this case) + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 1, 0, 0) + + updatedMV := testutil.GetModelValidationFromClientExpected(ctx, fakeClient, mvKey) + Expect(updatedMV.Status.InjectedPods).To(HaveLen(1)) + Expect(updatedMV.Status.InjectedPods[0].Name).To(Equal("test-pod")) + }) + }) + + Context("when handling error conditions and edge cases", func() { + It("should handle pod without required labels", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "unlabeled-pod", + Namespace: "default", + UID: types.UID("unlabeled-uid"), + Labels: map[string]string{}, // No ModelValidation label + }) + + err := statusTracker.ProcessPodEvent(ctx, pod) + Expect(err).NotTo(HaveOccurred()) // Should not error, just ignore the pod + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 0, 0, 0) // No pods should be tracked + }) + + It("should handle pod with unknown ModelValidation", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "unknown-mv-pod", + Namespace: "default", + UID: types.UID("unknown-uid"), + Labels: map[string]string{constants.ModelValidationLabel: "unknown-mv"}, + }) + + err := statusTracker.ProcessPodEvent(ctx, pod) + Expect(err).NotTo(HaveOccurred()) // Should not error, just ignore the pod + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 0, 0, 0) // No pods should be tracked + }) + + It("should handle removal of non-existent pod by UID", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + statusTracker, _, _ = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + err := statusTracker.RemovePodEvent(ctx, types.UID("non-existent-uid")) + Expect(err).NotTo(HaveOccurred()) // Should not error for non-existent pod + }) + + It("should handle removal of non-existent pod by name", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + statusTracker, _, _ = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + podName := types.NamespacedName{Name: "non-existent-pod", Namespace: "default"} + err := statusTracker.RemovePodByName(ctx, podName) + Expect(err).NotTo(HaveOccurred()) // Should not error for non-existent pod + }) + + It("should handle pod with invalid annotations", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + ConfigType: "sigstore", + }) + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "invalid-pod", + Namespace: "default", + UID: types.UID("invalid-uid"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + Annotations: map[string]string{ + constants.AuthMethodAnnotationKey: "invalid-auth-method", + constants.ConfigHashAnnotationKey: "invalid-hash", + }, + }) + + err := statusTracker.ProcessPodEvent(ctx, pod) + Expect(err).NotTo(HaveOccurred()) // Should handle gracefully + + statusTracker.WaitForUpdates() + // Should be classified as orphaned due to invalid auth method + expectPodCount(ctx, fakeClient, mvKey, 0, 0, 1) + }) + + It("should handle empty UID gracefully", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + statusTracker, _, _ = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + err := statusTracker.RemovePodEvent(ctx, types.UID("")) + Expect(err).NotTo(HaveOccurred()) // Should handle empty UID gracefully + }) + }) +}) + +// NewStatusTrackerForRetryTesting creates a status tracker with fast test-friendly config for retry testing +func NewStatusTrackerForRetryTesting(client client.Client) StatusTracker { + return NewStatusTracker(client, StatusTrackerConfig{ + DebounceDuration: 20 * time.Millisecond, // Very fast for testing + RetryBaseDelay: 10 * time.Millisecond, // Fast retries + RetryMaxDelay: 100 * time.Millisecond, // Low max delay + RateLimitQPS: 1000, // High QPS to avoid interference + RateLimitBurst: 10000, // High burst to avoid interference + StatusUpdateTimeout: 5 * time.Second, + }) +} + +var _ = Describe("StatusTracker retry and debounce integration", func() { + var ( + ctx context.Context + cancel context.CancelFunc + fakeClient *testutil.FailingClient + statusTracker *StatusTrackerImpl + mvKey types.NamespacedName + mv *v1alpha1.ModelValidation + ) + + BeforeEach(func() { + ctx, cancel = context.WithCancel(context.Background()) + mv = testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + // Setup failing environment with 2 failures before success + fakeClient = testutil.CreateFailingClientWithObjects(2, mv) + statusTracker = NewStatusTrackerForRetryTesting(fakeClient).(*StatusTrackerImpl) + mvKey = types.NamespacedName{Name: mv.Name, Namespace: mv.Namespace} + }) + + AfterEach(func() { + if statusTracker != nil { + statusTracker.Stop() + } + if cancel != nil { + cancel() + } + }) + + It("should retry failed status updates with exponential backoff while respecting debouncing", func() { + statusTracker.AddModelValidation(ctx, mv) + + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "test-pod-1", + Namespace: mv.Namespace, + Labels: map[string]string{ + constants.ModelValidationLabel: mv.Name, + }, + Annotations: map[string]string{ + constants.ConfigHashAnnotationKey: mv.GetConfigHash(), + constants.AuthMethodAnnotationKey: mv.GetAuthMethod(), + }, + Finalizers: []string{constants.ModelValidationFinalizer}, + }) + err := statusTracker.ProcessPodEvent(ctx, pod) + Expect(err).NotTo(HaveOccurred()) + + // Trigger multiple rapid updates to test debouncing + for i := 0; i < 5; i++ { + statusTracker.debouncedQueue.Add(mvKey) + time.Sleep(5 * time.Millisecond) // Shorter than debounce duration + } + + // Wait for debouncing and retries to complete (2 failures + 1 success) + testutil.ExpectAttemptCount(fakeClient, 3) + + // Verify that despite multiple debounce calls, we only had one series of retries + totalAttempts := fakeClient.GetAttemptCount() + actualFailures := fakeClient.GetFailureCount() + + Expect(totalAttempts).To(Equal(3), "Should have exactly 3 attempts: 2 failures + 1 success") + Expect(actualFailures).To(Equal(2), "Should have exactly 2 failures before success") + + expectPodCount(ctx, fakeClient, mvKey, 1, 0, 0) + }) + + It("should handle concurrent debounced updates with retries correctly", func() { + statusTracker.AddModelValidation(ctx, mv) + + // Trigger concurrent updates from multiple goroutines + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + statusTracker.debouncedQueue.Add(mvKey) + }() + } + wg.Wait() + + // Wait for processing to complete (2 failures + 1 success) + testutil.ExpectAttemptCount(fakeClient, 3) + + totalAttempts := fakeClient.GetAttemptCount() + + Expect(totalAttempts).To(Equal(3), "Concurrent updates should be debounced into single retry sequence") + }) +}) + +var _ = Describe("StatusTracker utility functions", func() { + Context("when comparing ModelValidation statuses", func() { + It("should correctly identify equal statuses", func() { + status1 := v1alpha1.ModelValidationStatus{ + InjectedPodCount: 2, + AuthMethod: "sigstore", + InjectedPods: []v1alpha1.PodTrackingInfo{ + {Name: "pod1", UID: "uid1"}, + {Name: "pod2", UID: "uid2"}, + }, + } + status2 := v1alpha1.ModelValidationStatus{ + InjectedPodCount: 2, + AuthMethod: "sigstore", + InjectedPods: []v1alpha1.PodTrackingInfo{ + {Name: "pod1", UID: "uid1"}, + {Name: "pod2", UID: "uid2"}, + }, + } + + Expect(statusEqual(status1, status2)).To(BeTrue()) + }) + + It("should correctly identify different statuses", func() { + status1 := v1alpha1.ModelValidationStatus{ + InjectedPodCount: 2, + } + status2 := v1alpha1.ModelValidationStatus{ + InjectedPodCount: 3, + } + + Expect(statusEqual(status1, status2)).To(BeFalse()) + }) + + It("should ignore LastUpdated differences", func() { + status1 := v1alpha1.ModelValidationStatus{ + InjectedPodCount: 1, + LastUpdated: metav1.Time{Time: time.Now()}, + } + status2 := v1alpha1.ModelValidationStatus{ + InjectedPodCount: 1, + LastUpdated: metav1.Time{Time: time.Now().Add(1 * time.Hour)}, + } + + Expect(statusEqual(status1, status2)).To(BeTrue()) + }) + }) + +}) diff --git a/internal/tracker/suite_test.go b/internal/tracker/suite_test.go new file mode 100644 index 00000000..1af9b747 --- /dev/null +++ b/internal/tracker/suite_test.go @@ -0,0 +1,13 @@ +package tracker + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" //nolint:revive + . "github.com/onsi/gomega" //nolint:revive +) + +func TestTracker(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Tracker Suite") +} diff --git a/internal/webhooks/pod_webhook.go b/internal/webhooks/pod_webhook.go index 5edc1426..247ace9f 100644 --- a/internal/webhooks/pod_webhook.go +++ b/internal/webhooks/pod_webhook.go @@ -5,10 +5,12 @@ import ( "encoding/json" "fmt" "net/http" + "time" "github.com/sigstore/model-validation-operator/internal/constants" corev1 "k8s.io/api/core/v1" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/webhook" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" @@ -29,7 +31,6 @@ func NewPodInterceptor(c client.Client, decoder admission.Decoder) webhook.Admis // +kubebuilder:webhook:path=/mutate-v1-pod,mutating=true,failurePolicy=fail,groups="",resources=pods,sideEffects=None,verbs=create;update,versions=v1,name=pods.validation.ml.sigstore.dev,admissionReviewVersions=v1 // +kubebuilder:rbac:groups=ml.sigstore.dev,resources=modelvalidations,verbs=get;list;watch -// +kubebuilder:rbac:groups=ml.sigstore.dev,resources=modelvalidations/status,verbs=get;update;patch // +kubebuilder:rbac:groups="",resources=namespaces,verbs=get;list;watch // podInterceptor extends pods with Model Validation Init-Container if annotation is specified. @@ -55,7 +56,7 @@ func (p *podInterceptor) Handle(ctx context.Context, req admission.Request) admi logger.Error(err, "failed to get namespace") return admission.Errored(http.StatusInternalServerError, err) } - if ns.Labels[constants.IgnoreNamespaceLabel] == "true" { + if ns.Labels[constants.IgnoreNamespaceLabel] == constants.IgnoreNamespaceValue { logger.Info("Namespace has ignore label, skipping", "namespace", req.Namespace) return admission.Allowed("namespace ignored") } @@ -70,8 +71,8 @@ func (p *podInterceptor) Handle(ctx context.Context, req admission.Request) admi logger.Info("Search associated Model Validation CR", "pod", pod.Name, "namespace", pod.Namespace, "modelValidationName", modelValidationName) - rhmv := &v1alpha1.ModelValidation{} - err := p.client.Get(ctx, client.ObjectKey{Name: modelValidationName, Namespace: pod.Namespace}, rhmv) + mv := &v1alpha1.ModelValidation{} + err := p.client.Get(ctx, client.ObjectKey{Name: modelValidationName, Namespace: pod.Namespace}, mv) if err != nil { msg := fmt.Sprintf("failed to get the ModelValidation CR %s/%s", pod.Namespace, modelValidationName) logger.Error(err, msg) @@ -85,10 +86,19 @@ func (p *podInterceptor) Handle(ctx context.Context, req admission.Request) admi } args := []string{"verify"} - args = append(args, validationConfigToArgs(logger, rhmv.Spec.Config, rhmv.Spec.Model.SignaturePath)...) - args = append(args, rhmv.Spec.Model.Path) + args = append(args, validationConfigToArgs(logger, mv.Spec.Config, mv.Spec.Model.SignaturePath)...) + args = append(args, mv.Spec.Model.Path) pp := pod.DeepCopy() + + controllerutil.AddFinalizer(pp, constants.ModelValidationFinalizer) + if pp.Annotations == nil { + pp.Annotations = make(map[string]string) + } + pp.Annotations[constants.InjectedAnnotationKey] = time.Now().Format(time.RFC3339) + pp.Annotations[constants.AuthMethodAnnotationKey] = mv.GetAuthMethod() + pp.Annotations[constants.ConfigHashAnnotationKey] = mv.GetConfigHash() + vm := []corev1.VolumeMount{} for _, c := range pod.Spec.Containers { vm = append(vm, c.VolumeMounts...) diff --git a/internal/webhooks/pod_webhook_test.go b/internal/webhooks/pod_webhook_test.go index 33a7096a..28de2841 100644 --- a/internal/webhooks/pod_webhook_test.go +++ b/internal/webhooks/pod_webhook_test.go @@ -2,11 +2,14 @@ package webhooks import ( "context" + "fmt" + "time" . "github.com/onsi/ginkgo/v2" //nolint:revive . "github.com/onsi/gomega" //nolint:revive "github.com/sigstore/model-validation-operator/api/v1alpha1" "github.com/sigstore/model-validation-operator/internal/constants" + "github.com/sigstore/model-validation-operator/internal/testutil" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" @@ -14,74 +17,64 @@ import ( var _ = Describe("Pod webhook", func() { Context("Pod webhook test", func() { - - const ( - Name = "test" - Namespace = "default" - ) + Name := "test" + var Namespace string ctx := context.Background() - namespace := &corev1.Namespace{ - ObjectMeta: metav1.ObjectMeta{ - Name: Name, - Namespace: Namespace, - }, - } - - typeNamespaceName := types.NamespacedName{Name: Name, Namespace: Namespace} + var typeNamespaceName types.NamespacedName BeforeEach(func() { + Namespace = fmt.Sprintf("test-ns-%d", time.Now().UnixNano()) + typeNamespaceName = testutil.CreateTestNamespacedName(Name, Namespace) + By("Creating the Namespace to perform the tests") + namespace := &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: Namespace, + }, + } err := k8sClient.Create(ctx, namespace) Expect(err).To(Not(HaveOccurred())) By("Create ModelValidation resource") - err = k8sClient.Create(ctx, &v1alpha1.ModelValidation{ - ObjectMeta: metav1.ObjectMeta{ - Name: Name, - Namespace: Namespace, - }, - Spec: v1alpha1.ModelValidationSpec{ - Model: v1alpha1.Model{ - Path: "test", - SignaturePath: "test", - }, - Config: v1alpha1.ValidationConfig{ - SigstoreConfig: nil, - PkiConfig: nil, - PrivateKeyConfig: nil, - }, - }, + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: Name, + Namespace: Namespace, + ConfigType: "sigstore", + CertIdentity: "test@example.com", + CertOidcIssuer: "https://accounts.google.com", }) + err = k8sClient.Create(ctx, mv) Expect(err).To(Not(HaveOccurred())) + + statusTracker.AddModelValidation(ctx, mv) }) AfterEach(func() { // TODO(user): Attention if you improve this code by adding other context test you MUST // be aware of the current delete namespace limitations. // More info: https://book.kubebuilder.io/reference/envtest.html#testing-considerations - By("Deleting the Namespace to perform the tests") - _ = k8sClient.Delete(ctx, namespace) - }) - It("Should create sidecar container", func() { - By("create labeled pod") - instance := &corev1.Pod{ + By("Deleting the ModelValidation resource") + _ = k8sClient.Delete(ctx, &v1alpha1.ModelValidation{ ObjectMeta: metav1.ObjectMeta{ Name: Name, Namespace: Namespace, - Labels: map[string]string{constants.ModelValidationLabel: Name}, }, - Spec: corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "test", - Image: "test", - }, - }, - }, - } + }) + + By("Deleting the Namespace to perform the tests") + _ = k8sClient.Delete(ctx, &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: Namespace}}) + }) + + It("Should create sidecar container and add finalizer", func() { + By("create labeled pod") + instance := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: Name, + Namespace: Namespace, + Labels: map[string]string{constants.ModelValidationLabel: Name}, + }) err := k8sClient.Create(ctx, instance) Expect(err).To(Not(HaveOccurred())) @@ -102,6 +95,49 @@ var _ = Describe("Pod webhook", func() { func(containers []corev1.Container) string { return containers[0].Image }, Equal(constants.ModelTransparencyCliImage)), )) + + By("Checking that finalizer was added") + Expect(found.Finalizers).To(ContainElement(constants.ModelValidationFinalizer)) + }) + + It("Should track pod in ModelValidation status", func() { + By("create labeled pod") + instance := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "tracked-pod", + Namespace: Namespace, + Labels: map[string]string{constants.ModelValidationLabel: Name}, + }) + err := k8sClient.Create(ctx, instance) + Expect(err).To(Not(HaveOccurred())) + + By("Waiting for pod to be injected") + found := &corev1.Pod{} + Eventually(func() []corev1.Container { + _ = k8sClient.Get(ctx, types.NamespacedName{Name: "tracked-pod", Namespace: Namespace}, found) + return found.Spec.InitContainers + }).Should(HaveLen(1)) + + err = statusTracker.ProcessPodEvent(ctx, found) + Expect(err).To(Not(HaveOccurred())) + + By("Checking ModelValidation status was updated") + mv := &v1alpha1.ModelValidation{} + Eventually(func() int32 { + _ = k8sClient.Get(ctx, typeNamespaceName, mv) + return mv.Status.InjectedPodCount + }, 5*time.Second).Should(BeNumerically(">", 0)) + + Expect(mv.Status.AuthMethod).To(Equal("sigstore")) // Sigstore auth method configured in test + Expect(mv.Status.InjectedPods).ToNot(BeEmpty()) + + foundTrackedPod := false + for _, tp := range mv.Status.InjectedPods { + if tp.Name == "tracked-pod" { + foundTrackedPod = true + break + } + } + Expect(foundTrackedPod).To(BeTrue(), "Pod should be tracked in status") }) }) }) diff --git a/internal/webhooks/suite_test.go b/internal/webhooks/suite_test.go index 8a6cd296..c9d583f5 100644 --- a/internal/webhooks/suite_test.go +++ b/internal/webhooks/suite_test.go @@ -9,8 +9,10 @@ import ( "runtime" "strings" "testing" + "time" "github.com/sigstore/model-validation-operator/api/v1alpha1" + "github.com/sigstore/model-validation-operator/internal/tracker" "k8s.io/klog/v2" "k8s.io/klog/v2/test" "sigs.k8s.io/controller-runtime/pkg/webhook" @@ -30,11 +32,12 @@ import ( // http://onsi.github.io/ginkgo/ to learn more about Ginkgo. var ( - cfg *rest.Config - k8sClient client.Client // You'll be using this client in your tests. - testEnv *envtest.Environment - ctx context.Context - cancel context.CancelFunc + cfg *rest.Config + k8sClient client.Client // You'll be using this client in your tests. + testEnv *envtest.Environment + ctx context.Context + cancel context.CancelFunc + statusTracker tracker.StatusTracker ) // findBinaryAssetsDirectory locates the kubernetes binaries directory @@ -133,6 +136,14 @@ var _ = BeforeSuite(func() { // Create a decoder for your webhook decoder := admission.NewDecoder(scheme.Scheme) + statusTracker = tracker.NewStatusTracker(mgr.GetClient(), tracker.StatusTrackerConfig{ + DebounceDuration: 50 * time.Millisecond, // Faster for tests + RetryBaseDelay: 10 * time.Millisecond, // Faster retries for tests + RetryMaxDelay: 1 * time.Second, // Lower max delay for tests + RateLimitQPS: 100, // Higher QPS for tests + RateLimitBurst: 1000, // Higher burst for tests + StatusUpdateTimeout: 5 * time.Second, // Shorter timeout for tests + }) podWebhookHandler := NewPodInterceptor(mgr.GetClient(), decoder) mgr.GetWebhookServer().Register("/mutate-v1-pod", &admission.Webhook{ Handler: podWebhookHandler, @@ -143,11 +154,22 @@ var _ = BeforeSuite(func() { err = mgr.Start(ctx) Expect(err).ToNot(HaveOccurred(), "failed to run manager") }() + + // Wait for webhook server to be ready by checking if it's serving + Eventually(func() error { + return mgr.GetWebhookServer().StartedChecker()(nil) + }, 10*time.Second, 100*time.Millisecond).Should(Succeed()) }) var _ = AfterSuite(func() { - cancel() By("tearing down the test environment") + + statusTracker.Stop() + cancel() + + // Give manager time to shutdown gracefully + time.Sleep(100 * time.Millisecond) + err := testEnv.Stop() Expect(err).NotTo(HaveOccurred()) }) diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index bba0ec3c..cf08d063 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -101,6 +101,15 @@ var _ = Describe("Manager", Ordered, func() { cmd := exec.Command("kubectl", "delete", "pod", curlMetricsPodName, "-n", operatorNamespace) _, _ = utils.Run(cmd) + By("cleaning up test resources before removing operator") + // Delete all pods first to trigger proper finalizer cleanup + cmd = exec.Command("kubectl", "delete", "pods", "--all", "-n", webhookTestNamespace, "--timeout=30s") + _, _ = utils.Run(cmd) + + // Then delete ModelValidation CR + cmd = exec.Command("kubectl", "delete", "modelvalidations", "--all", "-n", webhookTestNamespace, "--timeout=30s") + _, _ = utils.Run(cmd) + By("undeploying the controller-manager") cmd = exec.Command("make", "undeploy") _, _ = utils.Run(cmd)