diff --git a/.github/workflows/test-e2e.yml b/.github/workflows/test-e2e.yml index 174b4211..be087e47 100644 --- a/.github/workflows/test-e2e.yml +++ b/.github/workflows/test-e2e.yml @@ -48,4 +48,4 @@ jobs: - name: Running Test e2e run: | go mod tidy - make test-e2e + make test-e2e-ci diff --git a/.gitignore b/.gitignore index aac1b67d..bff59b62 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,10 @@ cover.out manifests/*.yaml # Don't include generated bundles (will be built in CI/CD at some point) bundle/ + +# Generated test keys +testdata/docker/test_*.pub +testdata/docker/test_*.priv + +# Generated model signature +testdata/tensorflow_saved_model/model.sig diff --git a/Makefile b/Makefile index 46f203d6..3dc99d47 100644 --- a/Makefile +++ b/Makefile @@ -31,6 +31,9 @@ BUNDLE_METADATA_OPTS ?= $(BUNDLE_CHANNELS) $(BUNDLE_DEFAULT_CHANNEL) # ghcr.io/sigstore/model-validation-operator-bundle:$VERSION and ghcr.io/sigstore/model-validation-operator-catalog:$VERSION. IMAGE_TAG_BASE ?= ghcr.io/sigstore/model-validation-operator +# IMG defines the image:tag used for the operator. +IMG ?= $(IMAGE_TAG_BASE):v$(VERSION) + # BUNDLE_IMG defines the image:tag used for the bundle. # You can use it as an arg. (E.g make bundle-build BUNDLE_IMG=/:) BUNDLE_IMG ?= $(IMAGE_TAG_BASE)-bundle:v$(VERSION) @@ -52,8 +55,6 @@ endif # Set the Operator SDK version to use. By default, what is installed on the system is used. # This is useful for CI or a project to utilize a specific version of the operator-sdk toolkit. OPERATOR_SDK_VERSION ?= v1.41.1 -# Image URL to use all building/pushing image targets -IMG ?= controller:latest # Get the currently used golang install path (in GOPATH/bin, unless GOBIN is set) ifeq (,$(shell go env GOBIN)) @@ -118,14 +119,6 @@ vet: ## Run go vet against code. test: manifests generate fmt vet setup-envtest ## Run tests. KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test $$(go list ./... | grep -v /e2e) -coverprofile cover.out -# TODO(user): To use a different vendor for e2e tests, modify the setup under 'tests/e2e'. -# The default setup assumes Kind is pre-installed and builds/loads the Manager Docker image locally. -# CertManager is installed by default; skip with: -# - CERT_MANAGER_INSTALL_SKIP=true -.PHONY: test-e2e -test-e2e: manifests generate fmt vet ## Run the e2e tests. Expected an isolated environment using Kind. - go test ./test/e2e/ -v -ginkgo.v - .PHONY: lint lint: golangci-lint ## Run golangci-lint linter $(GOLANGCI_LINT) run @@ -419,3 +412,158 @@ generate-manifests: manifests ## Generate manifests for all environments using g echo "Generating manifests for $$env environment..."; \ ./scripts/generate-manifests.sh $$env manifests; \ done + +##@ E2E Test Infrastructure + +# E2E Test Variables +E2E_OPERATOR_NAMESPACE ?= model-validation-operator-system +E2E_TEST_NAMESPACE ?= e2e-webhook-test-ns +E2E_TEST_MODEL ?= model-validation-test-model:latest +MODEL_TRANSPARENCY_IMG ?= ghcr.io/sigstore/model-transparency-cli:v1.0.1 +CERTMANAGER_VERSION ?= v1.18.2 +CERT_MANAGER_YAML ?= https://github.com/cert-manager/cert-manager/releases/download/$(CERTMANAGER_VERSION)/cert-manager.yaml +KIND_CLUSTER ?= kind + +# Build and sign test model +.PHONY: e2e-generate-test-keys +e2e-generate-test-keys: + @echo "Generating ECDSA P-256 test keys for model signing..." + @if [ ! -f testdata/docker/test_private_key.priv ]; then \ + echo "Generating private key..."; \ + openssl ecparam -name prime256v1 -genkey -noout -out testdata/docker/test_private_key.priv; \ + fi + @if [ ! -f testdata/docker/test_public_key.pub ]; then \ + echo "Generating public key..."; \ + openssl ec -in testdata/docker/test_private_key.priv -pubout -out testdata/docker/test_public_key.pub; \ + fi + @if [ ! -f testdata/docker/test_invalid_private_key.priv ]; then \ + echo "Generating invalid private key for failure tests..."; \ + openssl ecparam -name prime256v1 -genkey -noout -out testdata/docker/test_invalid_private_key.priv; \ + fi + @if [ ! -f testdata/docker/test_invalid_public_key.pub ]; then \ + echo "Generating invalid public key for failure tests..."; \ + openssl ec -in testdata/docker/test_invalid_private_key.priv -pubout -out testdata/docker/test_invalid_public_key.pub; \ + fi + +.PHONY: e2e-sign-test-model +e2e-sign-test-model: e2e-generate-test-keys + @echo "Signing test model with private key..." + @# Remove public key from model directory before signing to avoid including it in signature + @rm -f testdata/tensorflow_saved_model/test_public_key.pub + $(CONTAINER_TOOL) run --rm \ + -v $(PWD)/testdata/tensorflow_saved_model:/model \ + -v $(PWD)/testdata/docker/test_private_key.priv:/test_private_key.priv \ + --entrypoint="" \ + ghcr.io/sigstore/model-transparency-cli:v1.0.1 \ + /usr/local/bin/model_signing sign key /model \ + --private_key /test_private_key.priv \ + --signature /model/model.sig + +.PHONY: e2e-build-test-model +e2e-build-test-model: e2e-sign-test-model + @echo "Building test model image..." + cd testdata && $(CONTAINER_TOOL) build --no-cache -t $(E2E_TEST_MODEL) -f docker/test-model.Dockerfile . + +# install and uninstall cert-manager for tests + +.PHONY: e2e-install-certmanager +e2e-install-certmanager: + @echo "Installing cert-manager..." + $(KUBECTL) apply -f $(CERT_MANAGER_YAML) + @echo "Waiting for cert-manager to be ready..." + $(KUBECTL) wait --for=condition=Available deployment -n cert-manager --all --timeout=120s + +.PHONY: e2e-uninstall-certmanager +e2e-uninstall-certmanager: ## Uninstall cert-manager + @echo "Uninstalling cert-manager..." + -$(KUBECTL) delete -f $(CERT_MANAGER_YAML) + +# Load test images into the kind cluster + +.PHONY: e2e-build-image +e2e-build-image: + $(CONTAINER_TOOL) build -t $(IMG) -f $(CONTAINER_FILE) . + +.PHONY: e2e-load-images +e2e-load-images: e2e-build-image e2e-build-test-model + @echo "Pulling model-transparency-cli image..." + $(CONTAINER_TOOL) pull $(MODEL_TRANSPARENCY_IMG) + @echo "Loading manager image into Kind cluster..." + $(KIND) load docker-image -n $(KIND_CLUSTER) $(IMG) + @echo "Loading model-transparency-cli image into Kind cluster..." + $(KIND) load docker-image -n $(KIND_CLUSTER) $(MODEL_TRANSPARENCY_IMG) + @echo "Loading test model image into Kind cluster..." + $(KIND) load docker-image -n $(KIND_CLUSTER) $(E2E_TEST_MODEL) + +# Setup test environment (namespaces, local models on kind cluster, operator) + +.PHONY: e2e-setup-namespaces +e2e-setup-namespaces: + @echo "Creating operator namespace..." + $(KUBECTL) create ns $(E2E_OPERATOR_NAMESPACE) || true + @echo "Labeling operator namespace with restricted security policy..." + $(KUBECTL) label --overwrite ns $(E2E_OPERATOR_NAMESPACE) pod-security.kubernetes.io/enforce=restricted + @echo "Labeling operator namespace to be ignored by webhook..." + $(KUBECTL) label --overwrite ns $(E2E_OPERATOR_NAMESPACE) validation.ml.sigstore.dev/ignore=true + @echo "Creating test namespace..." + $(KUBECTL) create ns $(E2E_TEST_NAMESPACE) || true + +.PHONY: e2e-setup-model-data +e2e-setup-model-data: e2e-load-images e2e-setup-namespaces + @echo "Cleaning up any existing model data DaemonSet..." + -$(KUBECTL) delete daemonset model-data-setup -n $(E2E_TEST_NAMESPACE) 2>/dev/null || true + @echo "Waiting for cleanup to complete..." + @sleep 5 + @echo "Deploying model data setup DaemonSet..." + $(KUBECTL) apply -f test/e2e/testdata/model-data-daemonset.yaml + @echo "Waiting for model data to be available on all nodes..." + $(KUBECTL) rollout status daemonset/model-data-setup -n $(E2E_TEST_NAMESPACE) --timeout=120s + +.PHONY: e2e-deploy-operator +e2e-deploy-operator: e2e-setup-namespaces deploy + @echo "E2E operator deployment complete" + +.PHONY: e2e-wait-operator +e2e-wait-operator: ## Wait for operator pod to be ready + @echo "Waiting for controller pod to be ready..." + $(KUBECTL) wait --for=condition=Ready pod -l control-plane=controller-manager -n $(E2E_OPERATOR_NAMESPACE) --timeout=120s + +# test environment setup and teardown - certmanager, operator and test model for testing + +.PHONY: e2e-setup +e2e-setup: e2e-install-certmanager e2e-setup-model-data e2e-deploy-operator e2e-wait-operator ## Complete e2e test setup + @echo "E2E test environment setup complete" + +.PHONY: e2e-cleanup-resources +e2e-cleanup-resources: ## Clean up test resources before removing operator + @echo "Cleaning up test resources..." + -$(KUBECTL) delete pods --all -n $(E2E_TEST_NAMESPACE) --timeout=30s + -$(KUBECTL) delete modelvalidations --all -n $(E2E_TEST_NAMESPACE) --timeout=30s + -$(KUBECTL) delete daemonset model-data-setup -n $(E2E_TEST_NAMESPACE) --timeout=30s + +.PHONY: e2e-teardown +e2e-teardown: e2e-cleanup-resources undeploy e2e-uninstall-certmanager + @echo "Tearing down e2e test environment..." + -$(KUBECTL) delete ns $(E2E_OPERATOR_NAMESPACE) --timeout=60s + -$(KUBECTL) delete ns $(E2E_TEST_NAMESPACE) --timeout=60s + +# run e2e tests + +.PHONY: test-e2e +test-e2e: manifests generate fmt vet ## Run the e2e tests, no setup and teardown. Expects the operator to be deployed. + @echo "Running e2e tests (assumes infrastructure is already set up)..." + go test ./test/e2e/ -v -ginkgo.v + +.PHONY: test-e2e-full +test-e2e-full: manifests generate fmt vet e2e-setup ## Run e2e tests with setup and teardown + @echo "Running e2e tests with full infrastructure setup..." + go test ./test/e2e/ -v -ginkgo.v; \ + TEST_RESULT=$$?; \ + $(MAKE) e2e-teardown; \ + exit $$TEST_RESULT + +.PHONY: test-e2e-ci +test-e2e-ci: manifests generate fmt vet e2e-setup ## Run the e2e tests, with setup. No teardown as the CI workflow will throw away kind + @echo "Running e2e tests with infrastructure setup for CI..." + go test ./test/e2e/ -v -ginkgo.v + diff --git a/api/v1alpha1/modelvalidation_types.go b/api/v1alpha1/modelvalidation_types.go index 47a97831..9ab28695 100644 --- a/api/v1alpha1/modelvalidation_types.go +++ b/api/v1alpha1/modelvalidation_types.go @@ -17,6 +17,9 @@ limitations under the License. package v1alpha1 import ( + "crypto/sha256" + "fmt" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -43,19 +46,19 @@ type PkiConfig struct { CertificateAuthority string `json:"certificateAuthority,omitempty"` } -// PrivateKeyConfig defines the private key verification configuration -// for validating model signatures using a local private key -type PrivateKeyConfig struct { - // Path to the private key. +// PublicKeyConfig defines the public key verification configuration +// for validating model signatures using a local public key +type PublicKeyConfig struct { + // Path to the public key. KeyPath string `json:"keyPath,omitempty"` } // ValidationConfig defines the various methods available for validating model signatures. // At least one validation method must be specified. type ValidationConfig struct { - SigstoreConfig *SigstoreConfig `json:"sigstoreConfig,omitempty"` - PkiConfig *PkiConfig `json:"pkiConfig,omitempty"` - PrivateKeyConfig *PrivateKeyConfig `json:"privateKeyConfig,omitempty"` + SigstoreConfig *SigstoreConfig `json:"sigstoreConfig,omitempty"` + PkiConfig *PkiConfig `json:"pkiConfig,omitempty"` + PublicKeyConfig *PublicKeyConfig `json:"publicKeyConfig,omitempty"` } // ModelValidationSpec defines the desired state of ModelValidation @@ -69,15 +72,54 @@ type ModelValidationSpec struct { Config ValidationConfig `json:"config"` } +// PodTrackingInfo contains information about a tracked pod +type PodTrackingInfo struct { + // Name is the name of the pod + Name string `json:"name"` + // UID is the unique identifier of the pod + UID string `json:"uid"` + // InjectedAt is when the pod was injected + InjectedAt metav1.Time `json:"injectedAt"` +} + // ModelValidationStatus defines the observed state of ModelValidation type ModelValidationStatus struct { // INSERT ADDITIONAL STATUS FIELD - define observed state of cluster // Important: Run "make" to regenerate code after modifying this file Conditions []metav1.Condition `json:"conditions,omitempty"` + + // InjectedPodCount is the number of pods that have been injected with validation + InjectedPodCount int32 `json:"injectedPodCount"` + + // UninjectedPodCount is the number of pods that have the label but were not injected + UninjectedPodCount int32 `json:"uninjectedPodCount"` + + // OrphanedPodCount is the number of injected pods that reference this CR but are inconsistent + OrphanedPodCount int32 `json:"orphanedPodCount"` + + // AuthMethod indicates which authentication method is being used + AuthMethod string `json:"authMethod,omitempty"` + + // InjectedPods contains detailed information about injected pods + InjectedPods []PodTrackingInfo `json:"injectedPods,omitempty"` + + // UninjectedPods contains detailed information about pods that should have been injected but weren't + UninjectedPods []PodTrackingInfo `json:"uninjectedPods,omitempty"` + + // OrphanedPods contains detailed information about pods that are injected but inconsistent + OrphanedPods []PodTrackingInfo `json:"orphanedPods,omitempty"` + + // LastUpdated is the timestamp of the last status update + LastUpdated metav1.Time `json:"lastUpdated,omitempty"` } // +kubebuilder:object:root=true // +kubebuilder:subresource:status +// +kubebuilder:printcolumn:name="Auth Method",type=string,JSONPath=`.status.authMethod` +// +kubebuilder:printcolumn:name="Injected Pods",type=integer,JSONPath=`.status.injectedPodCount` +// +kubebuilder:printcolumn:name="Uninjected Pods",type=integer,JSONPath=`.status.uninjectedPodCount` +// +kubebuilder:printcolumn:name="Orphaned Pods",type=integer,JSONPath=`.status.orphanedPodCount` +// +kubebuilder:printcolumn:name="Age",type="date",JSONPath=".metadata.creationTimestamp" // ModelValidation is the Schema for the modelvalidations API type ModelValidation struct { @@ -97,6 +139,42 @@ type ModelValidationList struct { Items []ModelValidation `json:"items"` } +// GetAuthMethod returns the authentication method being used +func (mv *ModelValidation) GetAuthMethod() string { + if mv.Spec.Config.SigstoreConfig != nil { + return "sigstore" + } else if mv.Spec.Config.PkiConfig != nil { + return "pki" + } else if mv.Spec.Config.PublicKeyConfig != nil { + return "public-key" + } + return "unknown" +} + +// GetConfigHash returns a hash of the validation configuration for drift detection +func (mv *ModelValidation) GetConfigHash() string { + return mv.Spec.Config.GetConfigHash() +} + +// GetConfigHash returns a hash of the validation configuration for drift detection +func (vc *ValidationConfig) GetConfigHash() string { + hasher := sha256.New() + + if vc.SigstoreConfig != nil { + hasher.Write([]byte("sigstore")) + hasher.Write([]byte(vc.SigstoreConfig.CertificateIdentity)) + hasher.Write([]byte(vc.SigstoreConfig.CertificateOidcIssuer)) + } else if vc.PkiConfig != nil { + hasher.Write([]byte("pki")) + hasher.Write([]byte(vc.PkiConfig.CertificateAuthority)) + } else if vc.PublicKeyConfig != nil { + hasher.Write([]byte("publickey")) + hasher.Write([]byte(vc.PublicKeyConfig.KeyPath)) + } + + return fmt.Sprintf("%x", hasher.Sum(nil))[:16] // Use first 16 chars for brevity +} + func init() { SchemeBuilder.Register(&ModelValidation{}, &ModelValidationList{}) } diff --git a/api/v1alpha1/zz_generated.deepcopy.go b/api/v1alpha1/zz_generated.deepcopy.go index 47af49d6..9e17feb3 100644 --- a/api/v1alpha1/zz_generated.deepcopy.go +++ b/api/v1alpha1/zz_generated.deepcopy.go @@ -126,6 +126,28 @@ func (in *ModelValidationStatus) DeepCopyInto(out *ModelValidationStatus) { (*in)[i].DeepCopyInto(&(*out)[i]) } } + if in.InjectedPods != nil { + in, out := &in.InjectedPods, &out.InjectedPods + *out = make([]PodTrackingInfo, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.UninjectedPods != nil { + in, out := &in.UninjectedPods, &out.UninjectedPods + *out = make([]PodTrackingInfo, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.OrphanedPods != nil { + in, out := &in.OrphanedPods, &out.OrphanedPods + *out = make([]PodTrackingInfo, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + in.LastUpdated.DeepCopyInto(&out.LastUpdated) } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ModelValidationStatus. @@ -154,16 +176,32 @@ func (in *PkiConfig) DeepCopy() *PkiConfig { } // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. -func (in *PrivateKeyConfig) DeepCopyInto(out *PrivateKeyConfig) { +func (in *PodTrackingInfo) DeepCopyInto(out *PodTrackingInfo) { + *out = *in + in.InjectedAt.DeepCopyInto(&out.InjectedAt) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new PodTrackingInfo. +func (in *PodTrackingInfo) DeepCopy() *PodTrackingInfo { + if in == nil { + return nil + } + out := new(PodTrackingInfo) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *PublicKeyConfig) DeepCopyInto(out *PublicKeyConfig) { *out = *in } -// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new PrivateKeyConfig. -func (in *PrivateKeyConfig) DeepCopy() *PrivateKeyConfig { +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new PublicKeyConfig. +func (in *PublicKeyConfig) DeepCopy() *PublicKeyConfig { if in == nil { return nil } - out := new(PrivateKeyConfig) + out := new(PublicKeyConfig) in.DeepCopyInto(out) return out } @@ -196,9 +234,9 @@ func (in *ValidationConfig) DeepCopyInto(out *ValidationConfig) { *out = new(PkiConfig) **out = **in } - if in.PrivateKeyConfig != nil { - in, out := &in.PrivateKeyConfig, &out.PrivateKeyConfig - *out = new(PrivateKeyConfig) + if in.PublicKeyConfig != nil { + in, out := &in.PublicKeyConfig, &out.PublicKeyConfig + *out = new(PublicKeyConfig) **out = **in } } diff --git a/cmd/main.go b/cmd/main.go index ccf56a20..242c3016 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -22,8 +22,11 @@ import ( "flag" "os" "path/filepath" + "time" "github.com/sigstore/model-validation-operator/internal/constants" + "github.com/sigstore/model-validation-operator/internal/controller" + "github.com/sigstore/model-validation-operator/internal/tracker" "github.com/sigstore/model-validation-operator/internal/utils" "github.com/sigstore/model-validation-operator/internal/webhooks" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" @@ -47,6 +50,16 @@ import ( // +kubebuilder:scaffold:imports ) +const ( + // Default configuration values for the status tracker + defaultDebounceDuration = 500 * time.Millisecond + defaultRetryBaseDelay = 100 * time.Millisecond + defaultRetryMaxDelay = 16 * time.Second + defaultRateLimitQPS = 10.0 + defaultRateLimitBurst = 100 + defaultStatusUpdateTimeout = 30 * time.Second +) + var ( scheme = runtime.NewScheme() setupLog = ctrl.Log.WithName("setup") @@ -69,6 +82,14 @@ func main() { var secureMetrics bool var enableHTTP2 bool var tlsOpts []func(*tls.Config) + + // Status tracker configuration + var debounceDuration time.Duration + var retryBaseDelay time.Duration + var retryMaxDelay time.Duration + var rateLimitQPS float64 + var rateLimitBurst int + var statusUpdateTimeout time.Duration flag.StringVar(&metricsAddr, "metrics-bind-address", "0", "The address the metrics endpoint binds to. "+ "Use :8443 for HTTPS or :8080 for HTTP, or leave as 0 to disable the metrics service.") flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.") @@ -91,6 +112,20 @@ func main() { "MODEL_TRANSPARENCY_CLI_IMAGE", constants.ModelTransparencyCliImage, "Model transparency CLI image to be used.") + + // Status tracker configuration flags + flag.DurationVar(&debounceDuration, "debounce-duration", defaultDebounceDuration, + "Time to wait for more changes before updating status") + flag.DurationVar(&retryBaseDelay, "retry-base-delay", defaultRetryBaseDelay, + "Base delay for exponential backoff retries") + flag.DurationVar(&retryMaxDelay, "retry-max-delay", defaultRetryMaxDelay, + "Maximum delay for exponential backoff retries") + flag.Float64Var(&rateLimitQPS, "rate-limit-qps", defaultRateLimitQPS, + "Overall rate limit for status updates (queries per second)") + flag.IntVar(&rateLimitBurst, "rate-limit-burst", defaultRateLimitBurst, + "Burst capacity for overall rate limit") + flag.DurationVar(&statusUpdateTimeout, "status-update-timeout", defaultStatusUpdateTimeout, + "Timeout for status update operations") opts := zap.Options{ Development: true, } @@ -246,6 +281,36 @@ func main() { Handler: interceptor, }) + statusTracker := tracker.NewStatusTracker(mgr.GetClient(), tracker.StatusTrackerConfig{ + DebounceDuration: debounceDuration, + RetryBaseDelay: retryBaseDelay, + RetryMaxDelay: retryMaxDelay, + RateLimitQPS: rateLimitQPS, + RateLimitBurst: rateLimitBurst, + StatusUpdateTimeout: statusUpdateTimeout, + }) + defer statusTracker.Stop() + + podReconciler := &controller.PodReconciler{ + Client: mgr.GetClient(), + Scheme: mgr.GetScheme(), + Tracker: statusTracker, + } + if err := podReconciler.SetupWithManager(mgr); err != nil { + setupLog.Error(err, "unable to create pod controller") + os.Exit(1) + } + + mvReconciler := &controller.ModelValidationReconciler{ + Client: mgr.GetClient(), + Scheme: mgr.GetScheme(), + Tracker: statusTracker, + } + if err := mvReconciler.SetupWithManager(mgr); err != nil { + setupLog.Error(err, "unable to create ModelValidation controller") + os.Exit(1) + } + setupLog.Info("starting manager") if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil { setupLog.Error(err, "problem running manager") diff --git a/config/crd/bases/ml.sigstore.dev_modelvalidations.yaml b/config/crd/bases/ml.sigstore.dev_modelvalidations.yaml index 8996ff4c..4c5b4d19 100644 --- a/config/crd/bases/ml.sigstore.dev_modelvalidations.yaml +++ b/config/crd/bases/ml.sigstore.dev_modelvalidations.yaml @@ -14,7 +14,23 @@ spec: singular: modelvalidation scope: Namespaced versions: - - name: v1alpha1 + - additionalPrinterColumns: + - jsonPath: .status.authMethod + name: Auth Method + type: string + - jsonPath: .status.injectedPodCount + name: Injected Pods + type: integer + - jsonPath: .status.uninjectedPodCount + name: Uninjected Pods + type: integer + - jsonPath: .status.orphanedPodCount + name: Orphaned Pods + type: integer + - jsonPath: .metadata.creationTimestamp + name: Age + type: date + name: v1alpha1 schema: openAPIV3Schema: description: ModelValidation is the Schema for the modelvalidations API @@ -51,13 +67,13 @@ spec: description: Path to the certificate authority for PKI. type: string type: object - privateKeyConfig: + publicKeyConfig: description: |- - PrivateKeyConfig defines the private key verification configuration - for validating model signatures using a local private key + PublicKeyConfig defines the public key verification configuration + for validating model signatures using a local public key properties: keyPath: - description: Path to the private key. + description: Path to the public key. type: string type: object sigstoreConfig: @@ -89,6 +105,10 @@ spec: status: description: ModelValidationStatus defines the observed state of ModelValidation properties: + authMethod: + description: AuthMethod indicates which authentication method is being + used + type: string conditions: description: |- INSERT ADDITIONAL STATUS FIELD - define observed state of cluster @@ -148,6 +168,98 @@ spec: - type type: object type: array + injectedPodCount: + description: InjectedPodCount is the number of pods that have been + injected with validation + format: int32 + type: integer + injectedPods: + description: InjectedPods contains detailed information about injected + pods + items: + description: PodTrackingInfo contains information about a tracked + pod + properties: + injectedAt: + description: InjectedAt is when the pod was injected + format: date-time + type: string + name: + description: Name is the name of the pod + type: string + uid: + description: UID is the unique identifier of the pod + type: string + required: + - injectedAt + - name + - uid + type: object + type: array + lastUpdated: + description: LastUpdated is the timestamp of the last status update + format: date-time + type: string + orphanedPodCount: + description: OrphanedPodCount is the number of injected pods that + reference this CR but are inconsistent + format: int32 + type: integer + orphanedPods: + description: OrphanedPods contains detailed information about pods + that are injected but inconsistent + items: + description: PodTrackingInfo contains information about a tracked + pod + properties: + injectedAt: + description: InjectedAt is when the pod was injected + format: date-time + type: string + name: + description: Name is the name of the pod + type: string + uid: + description: UID is the unique identifier of the pod + type: string + required: + - injectedAt + - name + - uid + type: object + type: array + uninjectedPodCount: + description: UninjectedPodCount is the number of pods that have the + label but were not injected + format: int32 + type: integer + uninjectedPods: + description: UninjectedPods contains detailed information about pods + that should have been injected but weren't + items: + description: PodTrackingInfo contains information about a tracked + pod + properties: + injectedAt: + description: InjectedAt is when the pod was injected + format: date-time + type: string + name: + description: Name is the name of the pod + type: string + uid: + description: UID is the unique identifier of the pod + type: string + required: + - injectedAt + - name + - uid + type: object + type: array + required: + - injectedPodCount + - orphanedPodCount + - uninjectedPodCount type: object type: object served: true diff --git a/config/rbac/role.yaml b/config/rbac/role.yaml index eee3ef0e..b32aa176 100644 --- a/config/rbac/role.yaml +++ b/config/rbac/role.yaml @@ -12,6 +12,21 @@ rules: - get - list - watch +- apiGroups: + - "" + resources: + - pods + verbs: + - get + - list + - update + - watch +- apiGroups: + - "" + resources: + - pods/finalizers + verbs: + - update - apiGroups: - ml.sigstore.dev resources: @@ -25,6 +40,4 @@ rules: resources: - modelvalidations/status verbs: - - get - - patch - update diff --git a/examples/unsigned.yaml b/examples/unsigned.yaml index bd1ea72b..1781c3da 100644 --- a/examples/unsigned.yaml +++ b/examples/unsigned.yaml @@ -7,7 +7,7 @@ spec: config: # pkiConfig: # certificateAuthority: /path/to/ca.crt - privateKeyConfig: + publicKeyConfig: keyPath: /root/pub.key # sigstoreConfig: # certificateIdentity: "https://sigstore.example.com/certificate" diff --git a/examples/verify.yaml b/examples/verify.yaml index 39343ccb..fd6408f5 100644 --- a/examples/verify.yaml +++ b/examples/verify.yaml @@ -9,7 +9,7 @@ spec: config: # pkiConfig: # certificateAuthority: /path/to/ca.crt - # privateKeyConfig: + # publicKeyConfig: # keyPath: /root/pub.key sigstoreConfig: certificateIdentity: "https://github.com/sigstore/model-validation-operator/.github/workflows/sign-model.yaml@refs/tags/v0.0.2" diff --git a/go.mod b/go.mod index 9ad93c20..dead5372 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,11 @@ require ( github.com/go-logr/logr v1.4.2 github.com/onsi/ginkgo/v2 v2.22.0 github.com/onsi/gomega v1.36.1 + github.com/prometheus/client_golang v1.19.1 + github.com/prometheus/client_model v0.6.1 + github.com/prometheus/common v0.55.0 + github.com/stretchr/testify v1.9.0 + golang.org/x/time v0.7.0 k8s.io/api v0.32.1 k8s.io/apimachinery v0.32.1 k8s.io/client-go v0.32.1 @@ -53,9 +58,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/prometheus/client_golang v1.19.1 // indirect - github.com/prometheus/client_model v0.6.1 // indirect - github.com/prometheus/common v0.55.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/procfs v0.15.1 // indirect github.com/spf13/cobra v1.8.1 // indirect github.com/spf13/pflag v1.0.5 // indirect @@ -78,7 +81,6 @@ require ( golang.org/x/sys v0.26.0 // indirect golang.org/x/term v0.25.0 // indirect golang.org/x/text v0.19.0 // indirect - golang.org/x/time v0.7.0 // indirect golang.org/x/tools v0.26.0 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7 // indirect diff --git a/internal/constants/images.go b/internal/constants/images.go index ded4d049..6901dbff 100644 --- a/internal/constants/images.go +++ b/internal/constants/images.go @@ -1,6 +1,11 @@ // Package constants provides shared constants used throughout the model validation operator package constants +const ( + // ModelValidationInitContainerName is the name of the init container injected for model validation + ModelValidationInitContainerName = "model-validation" +) + var ( // ModelTransparencyCliImage is the default image for the model transparency CLI // used as an init container to validate model signatures diff --git a/internal/constants/labels.go b/internal/constants/labels.go index 1e12ff1c..3e96f546 100644 --- a/internal/constants/labels.go +++ b/internal/constants/labels.go @@ -10,6 +10,18 @@ const ( // IgnoreNamespaceLabel is the label used to ignore a namespace for model validation IgnoreNamespaceLabel = ModelValidationDomain + "/ignore" - // ModelValidationInitContainerName is the name of the init container injected for model validation - ModelValidationInitContainerName = "model-validation" + // ModelValidationFinalizer is the finalizer used to track model validation pods + ModelValidationFinalizer = ModelValidationDomain + "/finalizer" + + // InjectedAnnotationKey is the annotation key used to track injected pods + InjectedAnnotationKey = ModelValidationDomain + "/injected-at" + + // AuthMethodAnnotationKey is the annotation key used to track the auth method used during injection + AuthMethodAnnotationKey = ModelValidationDomain + "/auth-method" + + // ConfigHashAnnotationKey is the annotation key used to track the configuration hash during injection + ConfigHashAnnotationKey = ModelValidationDomain + "/config-hash" + + // IgnoreNamespaceValue is the value for the ignore namespace label + IgnoreNamespaceValue = "true" ) diff --git a/internal/controller/modelvalidation_controller.go b/internal/controller/modelvalidation_controller.go new file mode 100644 index 00000000..337cf154 --- /dev/null +++ b/internal/controller/modelvalidation_controller.go @@ -0,0 +1,59 @@ +// Package controller provides controllers for managing ModelValidation resources +package controller + +import ( + "context" + + "github.com/sigstore/model-validation-operator/api/v1alpha1" + "github.com/sigstore/model-validation-operator/internal/tracker" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/reconcile" +) + +// ModelValidationReconciler reconciles ModelValidation objects +type ModelValidationReconciler struct { + client.Client + Scheme *runtime.Scheme + Tracker tracker.StatusTracker +} + +// +kubebuilder:rbac:groups=ml.sigstore.dev,resources=modelvalidations,verbs=get;list;watch +// +kubebuilder:rbac:groups=ml.sigstore.dev,resources=modelvalidations/status,verbs=update + +// Reconcile handles ModelValidation events to track creation/updates/deletion +func (r *ModelValidationReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + logger := log.FromContext(ctx) + + mv := &v1alpha1.ModelValidation{} + if err := r.Get(ctx, req.NamespacedName, mv); err != nil { + if errors.IsNotFound(err) { + logger.Info("ModelValidation deleted, removing from tracking", "modelvalidation", req.NamespacedName) + r.Tracker.RemoveModelValidation(req.NamespacedName) + return reconcile.Result{}, nil + } + logger.Error(err, "Failed to get ModelValidation") + return reconcile.Result{}, err + } + + if !mv.DeletionTimestamp.IsZero() { + logger.Info("ModelValidation being deleted, removing from tracking", "modelvalidation", req.NamespacedName) + r.Tracker.RemoveModelValidation(req.NamespacedName) + return reconcile.Result{}, nil + } + + logger.Info("ModelValidation created/updated, adding to tracking", "modelvalidation", req.NamespacedName) + r.Tracker.AddModelValidation(ctx, mv) + + return reconcile.Result{}, nil +} + +// SetupWithManager sets up the controller with the Manager +func (r *ModelValidationReconciler) SetupWithManager(mgr ctrl.Manager) error { + return ctrl.NewControllerManagedBy(mgr). + For(&v1alpha1.ModelValidation{}). + Complete(r) +} diff --git a/internal/controller/modelvalidation_controller_test.go b/internal/controller/modelvalidation_controller_test.go new file mode 100644 index 00000000..6abfab7a --- /dev/null +++ b/internal/controller/modelvalidation_controller_test.go @@ -0,0 +1,121 @@ +package controller + +import ( + "context" + + . "github.com/onsi/ginkgo/v2" //nolint:revive + . "github.com/onsi/gomega" //nolint:revive + + "github.com/sigstore/model-validation-operator/internal/testutil" + "github.com/sigstore/model-validation-operator/internal/tracker" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" +) + +var _ = Describe("ModelValidationReconciler", func() { + var ( + ctx context.Context + reconciler *ModelValidationReconciler + mockTracker *tracker.MockStatusTracker + ) + + BeforeEach(func() { + ctx = context.Background() + mockTracker = tracker.NewMockStatusTracker() + }) + + Context("when reconciling ModelValidation resources", func() { + It("should call tracker for each reconcile", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + + fakeClient := testutil.SetupFakeClientWithObjects(mv) + + reconciler = &ModelValidationReconciler{ + Client: fakeClient, + Scheme: runtime.NewScheme(), + Tracker: mockTracker, + } + + namespace := mv.Namespace + name := mv.Name + req := testutil.CreateReconcileRequest(namespace, name) + + result, err := reconciler.Reconcile(ctx, req) + Expect(err).NotTo(HaveOccurred()) + Expect(result.Requeue).To(BeFalse()) + + addCalls := mockTracker.GetAddModelValidationCalls() + Expect(addCalls).To(HaveLen(1)) + Expect(addCalls[0].Name).To(Equal(name)) + Expect(addCalls[0].Namespace).To(Equal(namespace)) + + // Second reconcile should also call AddModelValidation + // Controller is not idempotent - it calls tracker each time + result, err = reconciler.Reconcile(ctx, req) + Expect(err).NotTo(HaveOccurred()) + Expect(result.Requeue).To(BeFalse()) + + addCalls = mockTracker.GetAddModelValidationCalls() + Expect(addCalls).To(HaveLen(2)) + Expect(addCalls[1].Name).To(Equal(name)) + Expect(addCalls[1].Namespace).To(Equal(namespace)) + }) + + It("should remove ModelValidation with DeletionTimestamp from tracking", func() { + now := metav1.Now() + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv-deleting", + Namespace: "default", + DeletionTimestamp: &now, + Finalizers: []string{"test-finalizer"}, + }) + + fakeClient := testutil.SetupFakeClientWithObjects(mv) + + reconciler = &ModelValidationReconciler{ + Client: fakeClient, + Scheme: runtime.NewScheme(), + Tracker: mockTracker, + } + + namespace := mv.Namespace + name := mv.Name + req := testutil.CreateReconcileRequest(namespace, name) + + result, err := reconciler.Reconcile(ctx, req) + Expect(err).NotTo(HaveOccurred()) + Expect(result.Requeue).To(BeFalse()) + + removeCalls := mockTracker.GetRemoveModelValidationCalls() + Expect(removeCalls).To(HaveLen(1)) + Expect(removeCalls[0].Namespace).To(Equal(namespace)) + Expect(removeCalls[0].Name).To(Equal(name)) + }) + + It("should remove deleted ModelValidation from tracking", func() { + fakeClient := testutil.SetupFakeClientWithObjects() + + reconciler = &ModelValidationReconciler{ + Client: fakeClient, + Scheme: runtime.NewScheme(), + Tracker: mockTracker, + } + + namespace := "default" + name := "deleted-mv" + req := testutil.CreateReconcileRequest(namespace, name) + + result, err := reconciler.Reconcile(ctx, req) + Expect(err).NotTo(HaveOccurred()) + Expect(result.Requeue).To(BeFalse()) + + removeCalls := mockTracker.GetRemoveModelValidationCalls() + Expect(removeCalls).To(HaveLen(1)) + Expect(removeCalls[0].Namespace).To(Equal(namespace)) + Expect(removeCalls[0].Name).To(Equal(name)) + }) + }) +}) diff --git a/internal/controller/pod_controller.go b/internal/controller/pod_controller.go new file mode 100644 index 00000000..aedc11ac --- /dev/null +++ b/internal/controller/pod_controller.go @@ -0,0 +1,90 @@ +// Package controller provides controllers for managing ModelValidation resources +package controller + +import ( + "context" + + "github.com/sigstore/model-validation-operator/internal/constants" + "github.com/sigstore/model-validation-operator/internal/tracker" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/reconcile" +) + +// PodReconciler reconciles Pod objects to track injected pods +type PodReconciler struct { + client.Client + Scheme *runtime.Scheme + Tracker tracker.StatusTracker +} + +// +kubebuilder:rbac:groups="",resources=pods,verbs=get;list;watch;update +// +kubebuilder:rbac:groups="",resources=pods/finalizers,verbs=update +// +kubebuilder:rbac:groups=ml.sigstore.dev,resources=modelvalidations,verbs=get;list +// +kubebuilder:rbac:groups=ml.sigstore.dev,resources=modelvalidations/status,verbs=update + +// Reconcile handles pod events to update ModelValidation status +func (r *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + logger := log.FromContext(ctx) + + pod := &corev1.Pod{} + if err := r.Get(ctx, req.NamespacedName, pod); err != nil { + if errors.IsNotFound(err) { + logger.V(1).Info("Pod deleted, removing from tracking", "pod", req.NamespacedName) + if err := r.Tracker.RemovePodByName(ctx, req.NamespacedName); err != nil { + logger.Error(err, "Failed to remove deleted pod from tracking", "pod", req.NamespacedName) + return reconcile.Result{}, err + } + return reconcile.Result{}, nil + } + return reconcile.Result{}, err + } + + if !pod.DeletionTimestamp.IsZero() { + logger.Info("Handling pod deletion", "pod", req.NamespacedName) + + if err := r.Tracker.RemovePodEvent(ctx, pod.UID); err != nil { + logger.Error(err, "Failed to remove pod from tracking") + return reconcile.Result{}, err + } + + if controllerutil.ContainsFinalizer(pod, constants.ModelValidationFinalizer) { + controllerutil.RemoveFinalizer(pod, constants.ModelValidationFinalizer) + if err := r.Update(ctx, pod); err != nil { + logger.Error(err, "Failed to remove finalizer from pod") + return reconcile.Result{}, err + } + } + + logger.Info("Successfully handled pod deletion", "pod", req.NamespacedName) + return reconcile.Result{}, nil + } + + modelValidationName, ok := pod.Labels[constants.ModelValidationLabel] + if !ok || modelValidationName == "" { + // Try to remove the pod in case it was previously tracked but label was removed + if err := r.Tracker.RemovePodByName(ctx, req.NamespacedName); err != nil { + logger.Error(err, "Failed to remove pod without label from tracking", "pod", req.NamespacedName) + return reconcile.Result{}, err + } + return reconcile.Result{}, nil + } + if err := r.Tracker.ProcessPodEvent(ctx, pod); err != nil { + logger.Error(err, "Failed to process pod event") + return reconcile.Result{}, err + } + + return reconcile.Result{}, nil +} + +// SetupWithManager sets up the controller with the Manager +func (r *PodReconciler) SetupWithManager(mgr ctrl.Manager) error { + return ctrl.NewControllerManagedBy(mgr). + For(&corev1.Pod{}). + Complete(r) +} diff --git a/internal/controller/pod_controller_test.go b/internal/controller/pod_controller_test.go new file mode 100644 index 00000000..06a39f92 --- /dev/null +++ b/internal/controller/pod_controller_test.go @@ -0,0 +1,187 @@ +package controller + +import ( + "context" + + . "github.com/onsi/ginkgo/v2" //nolint:revive + . "github.com/onsi/gomega" //nolint:revive + + "github.com/sigstore/model-validation-operator/internal/constants" + "github.com/sigstore/model-validation-operator/internal/testutil" + "github.com/sigstore/model-validation-operator/internal/tracker" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + ctrl "sigs.k8s.io/controller-runtime" +) + +var _ = Describe("PodReconciler", func() { + var ( + ctx context.Context + reconciler *PodReconciler + mockTracker *tracker.MockStatusTracker + ) + + BeforeEach(func() { + ctx = context.Background() + mockTracker = tracker.NewMockStatusTracker() + }) + + Context("when reconciling Pod resources", func() { + It("should try to remove pods without ModelValidation label if previously tracked", func() { + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "test-pod", + Namespace: "default", + }) + + fakeClient := testutil.SetupFakeClientWithObjects(pod) + + reconciler = &PodReconciler{ + Client: fakeClient, + Scheme: runtime.NewScheme(), + Tracker: mockTracker, + } + + req := testutil.CreateReconcileRequest(pod.Namespace, pod.Name) + + result, err := reconciler.Reconcile(ctx, req) + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal(ctrl.Result{})) + + processEvents := mockTracker.GetProcessPodEventCalls() + Expect(processEvents).To(BeEmpty()) + + // Should try to remove the pod in case it was previously tracked + removeByNameCalls := mockTracker.GetRemovePodByNameCalls() + Expect(removeByNameCalls).To(HaveLen(1)) + Expect(removeByNameCalls[0].Name).To(Equal("test-pod")) + Expect(removeByNameCalls[0].Namespace).To(Equal("default")) + }) + + It("should process pods with finalizer but not being deleted", func() { + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "test-pod", + Namespace: "default", + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + }) + + fakeClient := testutil.SetupFakeClientWithObjects(pod) + + reconciler = &PodReconciler{ + Client: fakeClient, + Scheme: runtime.NewScheme(), + Tracker: mockTracker, + } + + req := testutil.CreateReconcileRequest(pod.Namespace, pod.Name) + + result, err := reconciler.Reconcile(ctx, req) + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal(ctrl.Result{})) + + processEvents := mockTracker.GetProcessPodEventCalls() + Expect(processEvents).To(HaveLen(1)) + Expect(processEvents[0].Pod.Name).To(Equal(pod.Name)) + Expect(processEvents[0].Pod.Namespace).To(Equal(pod.Namespace)) + }) + + It("should handle pod deletion by removing finalizer", func() { + now := metav1.Now() + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "test-pod", + Namespace: "default", + UID: types.UID("test-uid"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + DeletionTimestamp: &now, + }) + + fakeClient := testutil.SetupFakeClientWithObjects(pod) + + reconciler = &PodReconciler{ + Client: fakeClient, + Scheme: runtime.NewScheme(), + Tracker: mockTracker, + } + + req := testutil.CreateReconcileRequest(pod.Namespace, pod.Name) + + result, err := reconciler.Reconcile(ctx, req) + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal(ctrl.Result{})) + + removeEvents := mockTracker.GetRemovePodEventCalls() + Expect(removeEvents).To(HaveLen(1)) + Expect(removeEvents[0]).To(Equal(pod.UID)) + + Eventually(ctx, func(ctx context.Context) []string { + updatedPod := &corev1.Pod{} + err := fakeClient.Get(ctx, req.NamespacedName, updatedPod) + if err != nil { + return []string{} + } + return updatedPod.Finalizers + }, "2s", "100ms").ShouldNot(ContainElement(constants.ModelValidationFinalizer)) + }) + + It("should handle pod with multiple finalizers being deleted", func() { + now := metav1.Now() + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "test-pod", + Namespace: "default", + UID: types.UID("test-uid"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{"other-finalizer", constants.ModelValidationFinalizer, "another-finalizer"}, + DeletionTimestamp: &now, + }) + + fakeClient := testutil.SetupFakeClientWithObjects(pod) + + reconciler = &PodReconciler{ + Client: fakeClient, + Scheme: runtime.NewScheme(), + Tracker: mockTracker, + } + + req := testutil.CreateReconcileRequest(pod.Namespace, pod.Name) + + result, err := reconciler.Reconcile(ctx, req) + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal(ctrl.Result{})) + + removeEvents := mockTracker.GetRemovePodEventCalls() + Expect(removeEvents).To(HaveLen(1)) + Expect(removeEvents[0]).To(Equal(pod.UID)) + + updatedPod := &corev1.Pod{} + err = fakeClient.Get(ctx, req.NamespacedName, updatedPod) + Expect(err).NotTo(HaveOccurred()) + Expect(updatedPod.Finalizers).NotTo(ContainElement(constants.ModelValidationFinalizer)) + Expect(updatedPod.Finalizers).To(ContainElement("other-finalizer")) + Expect(updatedPod.Finalizers).To(ContainElement("another-finalizer")) + }) + + It("should handle reconciling deleted pods gracefully", func() { + fakeClient := testutil.SetupFakeClientWithObjects() + + reconciler = &PodReconciler{ + Client: fakeClient, + Scheme: runtime.NewScheme(), + Tracker: mockTracker, + } + + req := testutil.CreateReconcileRequest("default", "deleted-pod") + + result, err := reconciler.Reconcile(ctx, req) + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal(ctrl.Result{})) + + removeByNameCalls := mockTracker.GetRemovePodByNameCalls() + Expect(removeByNameCalls).To(HaveLen(1)) + Expect(removeByNameCalls[0].Name).To(Equal("deleted-pod")) + Expect(removeByNameCalls[0].Namespace).To(Equal("default")) + }) + }) +}) diff --git a/internal/controller/suite_test.go b/internal/controller/suite_test.go new file mode 100644 index 00000000..6481b57e --- /dev/null +++ b/internal/controller/suite_test.go @@ -0,0 +1,13 @@ +package controller + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" //nolint:revive + . "github.com/onsi/gomega" //nolint:revive +) + +func TestControllers(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Controller Suite") +} diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go new file mode 100644 index 00000000..54c3c91c --- /dev/null +++ b/internal/metrics/metrics.go @@ -0,0 +1,171 @@ +// Package metrics provides Prometheus metrics for the model validation operator. +package metrics + +import ( + "github.com/prometheus/client_golang/prometheus" + "sigs.k8s.io/controller-runtime/pkg/metrics" +) + +const ( + // Metric label names + labelNamespace = "namespace" + labelModelValidation = "model_validation" + labelPodState = "pod_state" + labelStatusUpdateResult = "result" + labelDriftType = "drift_type" + + // PodStateInjected represents pods with model validation finalizers + PodStateInjected = "injected" + // PodStateUninjected represents pods without model validation finalizers + PodStateUninjected = "uninjected" + // PodStateOrphaned represents pods with configuration drift + PodStateOrphaned = "orphaned" + + // StatusUpdateSuccess indicates a successful status update + StatusUpdateSuccess = "success" + // StatusUpdateFailure indicates a failed status update + StatusUpdateFailure = "failure" +) + +var ( + // ModelValidationPodCounts tracks the current number of pods in each state per ModelValidation + ModelValidationPodCounts = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "model_validation_operator", + Name: "modelvalidation_pod_count", + Help: "Current number of pods tracked per ModelValidation by state", + }, + []string{labelNamespace, labelModelValidation, labelPodState}, + ) + + // PodStateTransitionsTotal tracks pod state transitions + PodStateTransitionsTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "model_validation_operator", + Name: "pod_state_transitions_total", + Help: "Total number of pod state transitions", + }, + []string{labelNamespace, labelModelValidation, "from_state", "to_state"}, + ) + + // StatusUpdatesTotal tracks ModelValidation status updates + StatusUpdatesTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "model_validation_operator", + Name: "status_updates_total", + Help: "Total number of ModelValidation status updates", + }, + []string{labelNamespace, labelModelValidation, labelStatusUpdateResult}, + ) + + // ConfigurationDriftEventsTotal tracks configuration drift events + ConfigurationDriftEventsTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "model_validation_operator", + Name: "configuration_drift_events_total", + Help: "Total number of configuration drift events detected", + }, + []string{labelNamespace, labelModelValidation, labelDriftType}, + ) + + // ModelValidationCRsTotal tracks total number of ModelValidation CRs per namespace + // Does not include authMethod for namespace-level tracking + ModelValidationCRsTotal = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "model_validation_operator", + Name: "modelvalidation_crs_total", + Help: "Total number of ModelValidation CRs being tracked per namespace", + }, + []string{labelNamespace}, + ) + + // StatusUpdateDuration tracks the duration of status update operations + StatusUpdateDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "model_validation_operator", + Name: "status_update_duration_seconds", + Help: "Duration of ModelValidation status update operations", + Buckets: prometheus.DefBuckets, + }, + []string{labelNamespace, labelModelValidation, labelStatusUpdateResult}, + ) + + // QueueSize tracks the current size of the status update queue + QueueSize = prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: "model_validation_operator", + Name: "status_update_queue_size", + Help: "Current size of the status update queue", + }, + ) + + // RetryAttemptsTotal tracks retry attempts for status updates + RetryAttemptsTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "model_validation_operator", + Name: "status_update_retry_attempts_total", + Help: "Total number of status update retry attempts", + }, + []string{labelNamespace, labelModelValidation}, + ) +) + +func init() { + metrics.Registry.MustRegister( + ModelValidationPodCounts, + PodStateTransitionsTotal, + StatusUpdatesTotal, + ConfigurationDriftEventsTotal, + ModelValidationCRsTotal, + StatusUpdateDuration, + QueueSize, + RetryAttemptsTotal, + ) +} + +// RecordPodCount records the current pod count for a ModelValidation +func RecordPodCount(namespace, modelValidation, podState string, count float64) { + ModelValidationPodCounts.WithLabelValues(namespace, modelValidation, podState).Set(count) +} + +// RecordPodStateTransition records a pod state transition +func RecordPodStateTransition(namespace, modelValidation, fromState, toState string) { + PodStateTransitionsTotal.WithLabelValues(namespace, modelValidation, fromState, toState).Inc() +} + +// RecordStatusUpdate records a status update result +func RecordStatusUpdate(namespace, modelValidation, result string) { + StatusUpdatesTotal.WithLabelValues(namespace, modelValidation, result).Inc() +} + +// RecordConfigurationDrift records a configuration drift event +func RecordConfigurationDrift(namespace, modelValidation, driftType string) { + ConfigurationDriftEventsTotal.WithLabelValues(namespace, modelValidation, driftType).Inc() +} + +// RecordModelValidationCR records the current number of ModelValidation CRs per namespace +func RecordModelValidationCR(namespace string, count float64) { + ModelValidationCRsTotal.WithLabelValues(namespace).Set(count) +} + +// RecordStatusUpdateDuration records the duration of a status update +func RecordStatusUpdateDuration(namespace, modelValidation, result string, duration float64) { + StatusUpdateDuration.WithLabelValues(namespace, modelValidation, result).Observe(duration) +} + +// SetQueueSize sets the current queue size +func SetQueueSize(size float64) { + QueueSize.Set(size) +} + +// RecordRetryAttempt records a retry attempt +func RecordRetryAttempt(namespace, modelValidation string) { + RetryAttemptsTotal.WithLabelValues(namespace, modelValidation).Inc() +} + +// RecordMultiplePodStateTransitions records multiple identical pod state transitions +func RecordMultiplePodStateTransitions(namespace, modelValidation, fromState, toState string, count int) { + if count > 0 { + PodStateTransitionsTotal.WithLabelValues(namespace, modelValidation, fromState, toState).Add(float64(count)) + } +} diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go new file mode 100644 index 00000000..11ab7d12 --- /dev/null +++ b/internal/metrics/metrics_test.go @@ -0,0 +1,144 @@ +package metrics + +import ( + "testing" + + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" + "github.com/stretchr/testify/assert" + "sigs.k8s.io/controller-runtime/pkg/metrics" +) + +const ( + testNamespace = "test-namespace" + testModelValidation = "test-mv" +) + +func TestMetricsDefinition(t *testing.T) { + // Test that all metric variables are defined + assert.NotNil(t, ModelValidationPodCounts) + assert.NotNil(t, PodStateTransitionsTotal) + assert.NotNil(t, StatusUpdatesTotal) + assert.NotNil(t, ConfigurationDriftEventsTotal) + assert.NotNil(t, ModelValidationCRsTotal) + assert.NotNil(t, StatusUpdateDuration) + assert.NotNil(t, QueueSize) + assert.NotNil(t, RetryAttemptsTotal) +} + +// Test helper to verify gauge metrics +func verifyGaugeMetric(t *testing.T, gauge *prometheus.GaugeVec, labels []string, expectedValue float64) { + metric := gauge.WithLabelValues(labels...) + metricDto := &dto.Metric{} + err := metric.Write(metricDto) + assert.NoError(t, err) + assert.Equal(t, expectedValue, metricDto.GetGauge().GetValue()) +} + +// Test helper to verify counter metrics +func verifyCounterIncrement(t *testing.T, counter *prometheus.CounterVec, labels []string) float64 { + metric := counter.WithLabelValues(labels...) + metricDto := &dto.Metric{} + err := metric.Write(metricDto) + assert.NoError(t, err) + return metricDto.GetCounter().GetValue() +} + +func TestRecordPodCount(t *testing.T) { + podState := PodStateInjected + count := float64(5) + + RecordPodCount(testNamespace, testModelValidation, podState, count) + + verifyGaugeMetric(t, ModelValidationPodCounts, []string{testNamespace, testModelValidation, podState}, count) +} + +func TestRecordPodStateTransition(t *testing.T) { + fromState := PodStateUninjected + toState := PodStateInjected + + labels := []string{testNamespace, testModelValidation, fromState, toState} + initialValue := verifyCounterIncrement(t, PodStateTransitionsTotal, labels) + RecordPodStateTransition(testNamespace, testModelValidation, fromState, toState) + finalValue := verifyCounterIncrement(t, PodStateTransitionsTotal, labels) + + assert.Equal(t, initialValue+1, finalValue) +} + +func TestRecordStatusUpdate(t *testing.T) { + result := StatusUpdateSuccess + + labels := []string{testNamespace, testModelValidation, result} + initialValue := verifyCounterIncrement(t, StatusUpdatesTotal, labels) + RecordStatusUpdate(testNamespace, testModelValidation, result) + finalValue := verifyCounterIncrement(t, StatusUpdatesTotal, labels) + + assert.Equal(t, initialValue+1, finalValue) +} + +func TestRecordConfigurationDrift(t *testing.T) { + driftType := "config_hash" + + labels := []string{testNamespace, testModelValidation, driftType} + initialValue := verifyCounterIncrement(t, ConfigurationDriftEventsTotal, labels) + RecordConfigurationDrift(testNamespace, testModelValidation, driftType) + finalValue := verifyCounterIncrement(t, ConfigurationDriftEventsTotal, labels) + + assert.Equal(t, initialValue+1, finalValue) +} + +func TestRecordModelValidationCR(t *testing.T) { + count := float64(3) + + RecordModelValidationCR(testNamespace, count) + + verifyGaugeMetric(t, ModelValidationCRsTotal, []string{testNamespace}, count) +} + +func TestSetQueueSize(t *testing.T) { + size := float64(10) + + SetQueueSize(size) + + metricDto := &dto.Metric{} + err := QueueSize.Write(metricDto) + assert.NoError(t, err) + assert.Equal(t, size, metricDto.GetGauge().GetValue()) +} + +func TestRecordStatusUpdateDuration(t *testing.T) { + result := StatusUpdateSuccess + duration := 0.5 // 500ms + + RecordStatusUpdateDuration(testNamespace, testModelValidation, result, duration) + + // Verify the histogram was recorded by checking the metric family + metricFamilies, err := metrics.Registry.Gather() + assert.NoError(t, err) + + var found bool + for _, mf := range metricFamilies { + if mf.GetName() == "model_validation_operator_status_update_duration_seconds" { + for _, metric := range mf.GetMetric() { + if metric.GetHistogram().GetSampleCount() > 0 { + found = true + break + } + } + } + } + assert.True(t, found, "Expected histogram metric to be recorded") +} + +func TestRecordMultiplePodStateTransitions(t *testing.T) { + fromState := PodStateUninjected + toState := PodStateInjected + count := 3 + + labels := []string{testNamespace, testModelValidation, fromState, toState} + initialValue := verifyCounterIncrement(t, PodStateTransitionsTotal, labels) + RecordMultiplePodStateTransitions(testNamespace, testModelValidation, fromState, toState, count) + finalValue := verifyCounterIncrement(t, PodStateTransitionsTotal, labels) + + assert.Equal(t, initialValue+float64(count), finalValue) +} diff --git a/internal/testutil/ginkgo_utils.go b/internal/testutil/ginkgo_utils.go new file mode 100644 index 00000000..16eb4543 --- /dev/null +++ b/internal/testutil/ginkgo_utils.go @@ -0,0 +1,41 @@ +// Package testutil provides Ginkgo-specific test utilities +package testutil + +import ( + "context" + + "github.com/sigstore/model-validation-operator/api/v1alpha1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/onsi/gomega" +) + +// ExpectModelValidationTracking verifies a ModelValidation's tracking state +func ExpectModelValidationTracking( + statusTracker ModelValidationTracker, + namespacedName types.NamespacedName, + shouldBeTracked bool, +) { + isTracked := statusTracker.IsModelValidationTracked(namespacedName) + gomega.Expect(isTracked).To(gomega.Equal(shouldBeTracked), + "Expected ModelValidation %s tracking state to be %t", namespacedName, shouldBeTracked) +} + +// GetModelValidationFromClientExpected retrieves a ModelValidation from the client with Ginkgo expectations +func GetModelValidationFromClientExpected( + ctx context.Context, + client client.Client, + namespacedName types.NamespacedName, +) *v1alpha1.ModelValidation { + mv, err := GetModelValidationFromClient(ctx, client, namespacedName) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + return mv +} + +// ExpectAttemptCount waits for and verifies expected attempt counts in retry testing +func ExpectAttemptCount(fakeClient *FailingClient, expectedAttempts int) { + gomega.Eventually(func() int { + return fakeClient.GetAttemptCount() + }, "2s", "10ms").Should(gomega.Equal(expectedAttempts)) +} diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go new file mode 100644 index 00000000..ccc5eeb4 --- /dev/null +++ b/internal/testutil/testutil.go @@ -0,0 +1,263 @@ +// Package testutil provides shared test utilities for the model validation operator +package testutil + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/sigstore/model-validation-operator/api/v1alpha1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/controller-runtime/pkg/reconcile" +) + +// DefaultNamespace is the default namespace used in tests +const DefaultNamespace = "default" + +// TestModelValidationOptions holds configuration for creating test ModelValidation resources +type TestModelValidationOptions struct { + Name string + Namespace string + DeletionTimestamp *metav1.Time + Finalizers []string + ConfigType string + CertificateCA string + CertIdentity string + CertOidcIssuer string +} + +// TestPodOptions holds configuration for creating test Pod resources +type TestPodOptions struct { + Name string + Namespace string + UID types.UID + Labels map[string]string + Annotations map[string]string + Finalizers []string + DeletionTimestamp *metav1.Time +} + +// CreateTestModelValidation creates a test ModelValidation resource with the given options +func CreateTestModelValidation(opts TestModelValidationOptions) *v1alpha1.ModelValidation { + if opts.Name == "" { + opts.Name = fmt.Sprintf("test-mv-%d", time.Now().UnixNano()) + } + if opts.Namespace == "" { + opts.Namespace = DefaultNamespace + } + if opts.ConfigType == "" { + opts.ConfigType = "sigstore" + } + + mv := &v1alpha1.ModelValidation{ + ObjectMeta: metav1.ObjectMeta{ + Name: opts.Name, + Namespace: opts.Namespace, + DeletionTimestamp: opts.DeletionTimestamp, + Finalizers: opts.Finalizers, + }, + Spec: v1alpha1.ModelValidationSpec{ + Model: v1alpha1.Model{ + Path: "test-model", + SignaturePath: "test-signature", + }, + }, + } + + // Configure auth method + switch opts.ConfigType { + case "pki": + certCA := opts.CertificateCA + if certCA == "" { + certCA = "test-ca" + } + mv.Spec.Config = v1alpha1.ValidationConfig{ + PkiConfig: &v1alpha1.PkiConfig{ + CertificateAuthority: certCA, + }, + } + case "sigstore": + fallthrough + default: + certIdentity := opts.CertIdentity + if certIdentity == "" { + certIdentity = "test@example.com" + } + certOidcIssuer := opts.CertOidcIssuer + if certOidcIssuer == "" { + certOidcIssuer = "https://accounts.google.com" + } + mv.Spec.Config = v1alpha1.ValidationConfig{ + SigstoreConfig: &v1alpha1.SigstoreConfig{ + CertificateIdentity: certIdentity, + CertificateOidcIssuer: certOidcIssuer, + }, + } + } + + return mv +} + +// CreateTestPod creates a test Pod resource with the given options +func CreateTestPod(opts TestPodOptions) *corev1.Pod { + if opts.Name == "" { + opts.Name = fmt.Sprintf("test-pod-%d", time.Now().UnixNano()) + } + if opts.Namespace == "" { + opts.Namespace = DefaultNamespace + } + if opts.UID == types.UID("") { + opts.UID = types.UID(fmt.Sprintf("uid-%d", time.Now().UnixNano())) + } + + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: opts.Name, + Namespace: opts.Namespace, + UID: opts.UID, + Labels: opts.Labels, + Annotations: opts.Annotations, + Finalizers: opts.Finalizers, + DeletionTimestamp: opts.DeletionTimestamp, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "test-container", + Image: "test-image", + }, + }, + }, + } + + return pod +} + +// CreateTestNamespacedName creates a NamespacedName for testing +func CreateTestNamespacedName(name, namespace string) types.NamespacedName { + if name == "" { + name = fmt.Sprintf("test-name-%d", time.Now().UnixNano()) + } + if namespace == "" { + namespace = DefaultNamespace + } + return types.NamespacedName{Name: name, Namespace: namespace} +} + +// SetupFakeClientWithObjects creates a fake Kubernetes client with the given objects +func SetupFakeClientWithObjects(objects ...client.Object) client.Client { + scheme := runtime.NewScheme() + if err := corev1.AddToScheme(scheme); err != nil { + panic(err) // This should not happen in tests + } + if err := v1alpha1.AddToScheme(scheme); err != nil { + panic(err) // This should not happen in tests + } + + builder := fake.NewClientBuilder().WithScheme(scheme) + if len(objects) > 0 { + builder = builder.WithObjects(objects...) + // Add status subresource for ModelValidation objects + for _, obj := range objects { + if _, ok := obj.(*v1alpha1.ModelValidation); ok { + builder = builder.WithStatusSubresource(obj) + } + } + } + + return builder.Build() +} + +// CreateReconcileRequest creates a reconcile request for testing +func CreateReconcileRequest(namespace, name string) reconcile.Request { + return reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: name}} +} + +// ModelValidationTracker interface for tracking functionality +type ModelValidationTracker interface { + IsModelValidationTracked(namespacedName types.NamespacedName) bool +} + +// GetModelValidationFromClient retrieves a ModelValidation from the client +func GetModelValidationFromClient( + ctx context.Context, + client client.Client, + namespacedName types.NamespacedName, +) (*v1alpha1.ModelValidation, error) { + mv := &v1alpha1.ModelValidation{} + err := client.Get(ctx, namespacedName, mv) + return mv, err +} + +// FailingClient is a mock client that fails status updates for testing retry behavior +type FailingClient struct { + client.Client + FailureCount int + MaxFailures int + AttemptCount int + mu sync.Mutex +} + +// Status returns a failing sub-resource writer for testing +func (f *FailingClient) Status() client.SubResourceWriter { + return &FailingSubResourceWriter{ + SubResourceWriter: f.Client.Status(), + parent: f, + } +} + +// FailingSubResourceWriter wraps a SubResourceWriter to simulate failures +type FailingSubResourceWriter struct { + client.SubResourceWriter + parent *FailingClient +} + +// Update simulates failures for the first MaxFailures attempts, then succeeds +func (f *FailingSubResourceWriter) Update( + ctx context.Context, obj client.Object, opts ...client.SubResourceUpdateOption, +) error { + f.parent.mu.Lock() + defer f.parent.mu.Unlock() + + f.parent.AttemptCount++ + if f.parent.FailureCount < f.parent.MaxFailures { + f.parent.FailureCount++ + return errors.New("simulated status update failure") + } + + // After max failures, delegate to real client + return f.SubResourceWriter.Update(ctx, obj, opts...) +} + +// GetAttemptCount returns the current attempt count (thread-safe) +func (f *FailingClient) GetAttemptCount() int { + f.mu.Lock() + defer f.mu.Unlock() + return f.AttemptCount +} + +// GetFailureCount returns the current failure count (thread-safe) +func (f *FailingClient) GetFailureCount() int { + f.mu.Lock() + defer f.mu.Unlock() + return f.FailureCount +} + +// CreateFailingClientWithObjects creates a FailingClient with the given objects +// that fails the first maxFailures attempts +func CreateFailingClientWithObjects(maxFailures int, objects ...client.Object) *FailingClient { + fakeClient := SetupFakeClientWithObjects(objects...) + return &FailingClient{ + Client: fakeClient, + MaxFailures: maxFailures, + FailureCount: 0, + AttemptCount: 0, + } +} diff --git a/internal/tracker/debounced_queue.go b/internal/tracker/debounced_queue.go new file mode 100644 index 00000000..653d816a --- /dev/null +++ b/internal/tracker/debounced_queue.go @@ -0,0 +1,219 @@ +package tracker + +import ( + "sync" + "time" + + "github.com/sigstore/model-validation-operator/internal/metrics" + "golang.org/x/time/rate" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/util/workqueue" +) + +// DebouncedQueue provides a queue with built-in debouncing functionality +// It encapsulates both debouncing logic and workqueue implementation +type DebouncedQueue interface { + // Add adds an item to the queue with debouncing + // If the same item is already pending debounce, it resets the timer + Add(item types.NamespacedName) + + // Get gets the next item for processing (blocks if queue is empty) + Get() (types.NamespacedName, bool) + + // Done marks an item as done processing + Done(item types.NamespacedName) + + // AddWithRetry adds an item to the queue with rate limiting for retries + AddWithRetry(item types.NamespacedName) + + // ForgetRetries forgets retry tracking for an item + ForgetRetries(item types.NamespacedName) + + // GetRetryCount returns the number of retries for an item + GetRetryCount(item types.NamespacedName) int + + // Len returns the number of items currently in the queue (not pending debounce) + Len() int + + // WaitForUpdates waits for all pending debounced updates to complete + WaitForUpdates() + + // WaitForCompletion waits for all pending debounced updates to complete + // and for the queue to be fully drained (all items processed) + WaitForCompletion() + + // ShutDown shuts down the queue and stops all timers + ShutDown() +} + +// DebouncedQueueImpl implements DebouncedQueue using workqueue and timers +type DebouncedQueueImpl struct { + duration time.Duration + queue workqueue.TypedRateLimitingInterface[types.NamespacedName] + debounceTimers map[types.NamespacedName]*time.Timer + debounceWg sync.WaitGroup + mu sync.Mutex + stopCh chan struct{} + // Protects metric updates to ensure consistency with queue state + metricMu sync.Mutex +} + +// DebouncedQueueConfig holds configuration for the debounced queue +type DebouncedQueueConfig struct { + DebounceDuration time.Duration + RetryBaseDelay time.Duration + RetryMaxDelay time.Duration + RateLimitQPS float64 + RateLimitBurst int +} + +// NewDebouncedQueue creates a new debounced queue with the specified configuration +func NewDebouncedQueue(config DebouncedQueueConfig) DebouncedQueue { + // Create custom rate limiter with configurable retry parameters + rateLimiter := workqueue.NewTypedMaxOfRateLimiter( + workqueue.NewTypedItemExponentialFailureRateLimiter[types.NamespacedName]( + config.RetryBaseDelay, config.RetryMaxDelay), + &workqueue.TypedBucketRateLimiter[types.NamespacedName]{ + Limiter: rate.NewLimiter(rate.Limit(config.RateLimitQPS), config.RateLimitBurst)}, + ) + + workQueue := workqueue.NewTypedRateLimitingQueue(rateLimiter) + + return &DebouncedQueueImpl{ + duration: config.DebounceDuration, + queue: workQueue, + debounceTimers: make(map[types.NamespacedName]*time.Timer), + stopCh: make(chan struct{}), + } +} + +// updateQueueSizeMetric updates the queue size metric +// Uses a mutex to ensure metric consistency with queue state +func (dq *DebouncedQueueImpl) updateQueueSizeMetric() { + dq.metricMu.Lock() + defer dq.metricMu.Unlock() + + // Skip metric updates if shutdown has been initiated + if dq.isShutDown() { + return + } + + metrics.SetQueueSize(float64(dq.queue.Len())) +} + +// isShutDown checks if the queue has been shut down (non-blocking) +func (dq *DebouncedQueueImpl) isShutDown() bool { + select { + case <-dq.stopCh: + return true + default: + return false + } +} + +// Add adds an item to the queue with debouncing +// If the queue has been shut down, this method will return without adding the item +func (dq *DebouncedQueueImpl) Add(item types.NamespacedName) { + dq.mu.Lock() + defer dq.mu.Unlock() + if dq.isShutDown() { + return + } + if timer, ok := dq.debounceTimers[item]; ok { + timer.Reset(dq.duration) + } else { + dq.debounceWg.Add(1) + dq.debounceTimers[item] = time.AfterFunc(dq.duration, func() { + defer dq.debounceWg.Done() + if dq.isShutDown() { + return + } + dq.queue.Add(item) + dq.updateQueueSizeMetric() + dq.mu.Lock() + delete(dq.debounceTimers, item) + dq.mu.Unlock() + }) + } +} + +// Get gets the next item for processing +func (dq *DebouncedQueueImpl) Get() (types.NamespacedName, bool) { + return dq.queue.Get() +} + +// Done marks an item as done processing +func (dq *DebouncedQueueImpl) Done(item types.NamespacedName) { + dq.queue.Done(item) + dq.updateQueueSizeMetric() +} + +// AddWithRetry adds an item to the queue with rate limiting for retries +// If the queue has been shut down, this method will return without adding the item +func (dq *DebouncedQueueImpl) AddWithRetry(item types.NamespacedName) { + if dq.isShutDown() { + return + } + dq.queue.AddRateLimited(item) + dq.updateQueueSizeMetric() +} + +// ForgetRetries forgets retry tracking for an item +func (dq *DebouncedQueueImpl) ForgetRetries(item types.NamespacedName) { + dq.queue.Forget(item) +} + +// GetRetryCount returns the number of retries for an item +func (dq *DebouncedQueueImpl) GetRetryCount(item types.NamespacedName) int { + return dq.queue.NumRequeues(item) +} + +// Len returns the number of items currently in the queue +func (dq *DebouncedQueueImpl) Len() int { + return dq.queue.Len() +} + +// WaitForUpdates waits for all pending debounced updates to complete +func (dq *DebouncedQueueImpl) WaitForUpdates() { + dq.debounceWg.Wait() +} + +// WaitForCompletion waits for all pending debounced updates to complete +// and for the queue to be fully drained (all items processed) +func (dq *DebouncedQueueImpl) WaitForCompletion() { + dq.WaitForUpdates() + + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if dq.queue.Len() == 0 { + return + } + case <-dq.stopCh: + return + } + } +} + +// ShutDown shuts down the queue and stops all timers +func (dq *DebouncedQueueImpl) ShutDown() { + dq.mu.Lock() + defer dq.mu.Unlock() + + for _, timer := range dq.debounceTimers { + timer.Stop() + } + dq.debounceTimers = make(map[types.NamespacedName]*time.Timer) + + if !dq.isShutDown() { + close(dq.stopCh) + } + + dq.queue.ShutDown() + + // Update metrics to reflect shutdown + metrics.SetQueueSize(0) +} diff --git a/internal/tracker/debounced_queue_test.go b/internal/tracker/debounced_queue_test.go new file mode 100644 index 00000000..d97152e9 --- /dev/null +++ b/internal/tracker/debounced_queue_test.go @@ -0,0 +1,334 @@ +package tracker + +import ( + "time" + + . "github.com/onsi/ginkgo/v2" //nolint:revive + . "github.com/onsi/gomega" //nolint:revive + + "k8s.io/apimachinery/pkg/types" +) + +var _ = Describe("DebouncedQueue", func() { + var ( + debouncedQueue DebouncedQueue + testKey1 types.NamespacedName + testKey2 types.NamespacedName + duration time.Duration + ) + + BeforeEach(func() { + testKey1 = types.NamespacedName{Name: "test1", Namespace: "default"} + testKey2 = types.NamespacedName{Name: "test2", Namespace: "default"} + duration = 50 * time.Millisecond + debouncedQueue = NewDebouncedQueue(DebouncedQueueConfig{ + DebounceDuration: duration, + RetryBaseDelay: 100 * time.Millisecond, + RetryMaxDelay: 1000 * time.Millisecond, + RateLimitQPS: 10, + RateLimitBurst: 100, + }) + }) + + AfterEach(func() { + debouncedQueue.ShutDown() + }) + + Context("when debouncing single updates", func() { + It("should add update to queue after debounce duration", func() { + // Initially queue should be empty + Expect(debouncedQueue.Len()).To(Equal(0)) + + // Trigger debounce + debouncedQueue.Add(testKey1) + + // Should still be empty immediately + Expect(debouncedQueue.Len()).To(Equal(0)) + + // Wait for debounce duration plus buffer + Eventually(func() int { + return debouncedQueue.Len() + }, duration*3, 10*time.Millisecond).Should(Equal(1)) + + // Verify the correct key was added + item, shutdown := debouncedQueue.Get() + Expect(shutdown).To(BeFalse()) + Expect(item).To(Equal(testKey1)) + debouncedQueue.Done(item) + }) + + It("should reset timer on subsequent updates", func() { + // First update + debouncedQueue.Add(testKey1) + + // Wait half the debounce duration + time.Sleep(duration / 2) + + // Queue should still be empty + Expect(debouncedQueue.Len()).To(Equal(0)) + + // Second update should reset the timer + debouncedQueue.Add(testKey1) + + // Wait another half duration (original timer would have fired by now) + time.Sleep(duration / 2) + + // Queue should still be empty because timer was reset + Expect(debouncedQueue.Len()).To(Equal(0)) + + // Wait for full duration from second update + Eventually(func() int { + return debouncedQueue.Len() + }, duration*2, 10*time.Millisecond).Should(Equal(1)) + }) + }) + + Context("when debouncing multiple keys", func() { + It("should handle multiple keys independently", func() { + // Trigger debounce for both keys + debouncedQueue.Add(testKey1) + debouncedQueue.Add(testKey2) + + // Both should be processed independently + Eventually(func() int { + return debouncedQueue.Len() + }, duration*3, 10*time.Millisecond).Should(Equal(2)) + + // Verify both keys are in queue + receivedKeys := make(map[types.NamespacedName]bool) + + item1, shutdown := debouncedQueue.Get() + Expect(shutdown).To(BeFalse()) + receivedKeys[item1] = true + debouncedQueue.Done(item1) + + item2, shutdown := debouncedQueue.Get() + Expect(shutdown).To(BeFalse()) + receivedKeys[item2] = true + debouncedQueue.Done(item2) + + Expect(receivedKeys).To(HaveKey(testKey1)) + Expect(receivedKeys).To(HaveKey(testKey2)) + }) + + It("should reset timers independently for different keys", func() { + // Start timer for key1 + debouncedQueue.Add(testKey1) + + // Wait and start timer for key2 + time.Sleep(duration / 2) + debouncedQueue.Add(testKey2) + + // Key1 should fire first (started earlier) + Eventually(func() int { + return debouncedQueue.Len() + }, duration*2, 10*time.Millisecond).Should(BeNumerically(">=", 1)) + + // Get first item (should be key1) + item, shutdown := debouncedQueue.Get() + Expect(shutdown).To(BeFalse()) + Expect(item).To(Equal(testKey1)) + debouncedQueue.Done(item) + + // Key2 should fire shortly after + Eventually(func() int { + return debouncedQueue.Len() + }, duration, 10*time.Millisecond).Should(Equal(1)) + + item, shutdown = debouncedQueue.Get() + Expect(shutdown).To(BeFalse()) + Expect(item).To(Equal(testKey2)) + debouncedQueue.Done(item) + }) + }) + + Context("when waiting for updates", func() { + It("should wait for all pending updates to complete", func() { + // Trigger multiple updates + debouncedQueue.Add(testKey1) + debouncedQueue.Add(testKey2) + + // WaitForUpdates should block until all timers complete + done := make(chan bool) + go func() { + debouncedQueue.WaitForUpdates() + done <- true + }() + + // Should not complete immediately + Consistently(done, duration/2).ShouldNot(Receive()) + + // Should complete after debounce duration + Eventually(done, duration*3).Should(Receive()) + + // Queue should have both items + Expect(debouncedQueue.Len()).To(Equal(2)) + }) + + It("should handle WaitForUpdates with no pending updates", func() { + // WaitForUpdates should return immediately when no updates pending + done := make(chan bool) + go func() { + debouncedQueue.WaitForUpdates() + done <- true + }() + + // Should complete immediately + Eventually(done, 100*time.Millisecond).Should(Receive()) + }) + + It("should wait for queue to drain with WaitForCompletion", func() { + // Trigger multiple updates + debouncedQueue.Add(testKey1) + debouncedQueue.Add(testKey2) + + // WaitForCompletion should block until all timers complete and queue is drained + done := make(chan bool) + go func() { + debouncedQueue.WaitForCompletion() + done <- true + }() + + // Should not complete immediately + Consistently(done, duration/2).ShouldNot(Receive()) + + // Should still not complete after just debounce duration (queue not drained) + time.Sleep(duration * 2) + Consistently(done, 50*time.Millisecond).ShouldNot(Receive()) + + // Drain the queue + for debouncedQueue.Len() > 0 { + item, shutdown := debouncedQueue.Get() + if shutdown { + break + } + debouncedQueue.Done(item) + } + + // Now it should complete + Eventually(done, 100*time.Millisecond).Should(Receive()) + }) + + It("should stop waiting when queue is shut down", func() { + // Trigger updates but don't process them + debouncedQueue.Add(testKey1) + debouncedQueue.Add(testKey2) + + // Wait for items to be in queue + Eventually(func() int { + return debouncedQueue.Len() + }, duration*3, 10*time.Millisecond).Should(Equal(2)) + + // Start WaitForCompletion which should block + done := make(chan bool) + go func() { + debouncedQueue.WaitForCompletion() + done <- true + }() + + // Should not complete immediately (queue not drained) + Consistently(done, 50*time.Millisecond).ShouldNot(Receive()) + + // Shut down the queue - this should cause WaitForCompletion to return + debouncedQueue.ShutDown() + + // Should complete quickly now + Eventually(done, 100*time.Millisecond).Should(Receive()) + }) + }) + + Context("when shutting down", func() { + It("should stop all pending timers", func() { + // Trigger updates + debouncedQueue.Add(testKey1) + debouncedQueue.Add(testKey2) + + // Shut down before timers fire + debouncedQueue.ShutDown() + + // Wait longer than debounce duration + time.Sleep(duration * 2) + + // Queue should remain empty (timers were stopped) + Expect(debouncedQueue.Len()).To(Equal(0)) + }) + }) + + Context("when handling rapid updates", func() { + It("should coalesce rapid updates into single queue item", func() { + // Rapid fire updates for same key + for i := 0; i < 10; i++ { + debouncedQueue.Add(testKey1) + time.Sleep(duration / 20) // Much shorter than debounce duration + } + + // Should result in only one queued item + Eventually(func() int { + return debouncedQueue.Len() + }, duration*3, 10*time.Millisecond).Should(Equal(1)) + + // Verify it's the correct key + item, shutdown := debouncedQueue.Get() + Expect(shutdown).To(BeFalse()) + Expect(item).To(Equal(testKey1)) + debouncedQueue.Done(item) + + // No more items should be queued + Consistently(func() int { + return debouncedQueue.Len() + }, duration).Should(Equal(0)) + }) + + It("should handle interleaved updates for different keys", func() { + // Interleave updates for different keys + for i := 0; i < 5; i++ { + debouncedQueue.Add(testKey1) + debouncedQueue.Add(testKey2) + time.Sleep(duration / 10) + } + + // Should result in one item for each key + Eventually(func() int { + return debouncedQueue.Len() + }, duration*3, 10*time.Millisecond).Should(Equal(2)) + }) + }) + + Context("when handling retries", func() { + It("should support retry functionality", func() { + // Add item normally first + debouncedQueue.Add(testKey1) + + // Wait for debounced item to be available + Eventually(func() int { + return debouncedQueue.Len() + }, duration*3, 10*time.Millisecond).Should(Equal(1)) + + // Get and process item + item, shutdown := debouncedQueue.Get() + Expect(shutdown).To(BeFalse()) + Expect(item).To(Equal(testKey1)) + + // Check retry count (should be 0 initially) + retries := debouncedQueue.GetRetryCount(testKey1) + Expect(retries).To(Equal(0)) + + // Add back with retry (simulating failure) + debouncedQueue.AddWithRetry(testKey1) + debouncedQueue.Done(item) + + // Get item again + item, shutdown = debouncedQueue.Get() + Expect(shutdown).To(BeFalse()) + Expect(item).To(Equal(testKey1)) + + // Retry count should be incremented + retries = debouncedQueue.GetRetryCount(testKey1) + Expect(retries).To(BeNumerically(">", 0)) + + // Forget retries + debouncedQueue.ForgetRetries(testKey1) + debouncedQueue.Done(item) + }) + }) +}) diff --git a/internal/tracker/mock.go b/internal/tracker/mock.go new file mode 100644 index 00000000..f22b1798 --- /dev/null +++ b/internal/tracker/mock.go @@ -0,0 +1,185 @@ +package tracker + +import ( + "context" + "sync" + + "github.com/sigstore/model-validation-operator/api/v1alpha1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" +) + +// MockStatusTracker is a mock implementation of StatusTracker for testing +type MockStatusTracker struct { + mu sync.RWMutex + + // Track method calls for verification + AddModelValidationCalls []*v1alpha1.ModelValidation + RemoveModelValidationCalls []types.NamespacedName + ProcessPodEventCalls []ProcessPodEventCall + RemovePodEventCalls []types.UID + RemovePodByNameCalls []types.NamespacedName + StopCalls int + + // Optional error responses for testing error scenarios + ProcessPodEventError error + RemovePodEventError error + RemovePodByNameError error + + // Track ModelValidations that are considered "tracked" + TrackedModelValidations map[types.NamespacedName]bool +} + +// ProcessPodEventCall captures the parameters of a ProcessPodEvent call +type ProcessPodEventCall struct { + Ctx context.Context + Pod *corev1.Pod +} + +// NewMockStatusTracker creates a new mock status tracker +func NewMockStatusTracker() *MockStatusTracker { + return &MockStatusTracker{ + TrackedModelValidations: make(map[types.NamespacedName]bool), + } +} + +// AddModelValidation implements StatusTracker interface +func (m *MockStatusTracker) AddModelValidation(_ context.Context, mv *v1alpha1.ModelValidation) { + m.mu.Lock() + defer m.mu.Unlock() + + m.AddModelValidationCalls = append(m.AddModelValidationCalls, mv) + mvKey := types.NamespacedName{Name: mv.Name, Namespace: mv.Namespace} + m.TrackedModelValidations[mvKey] = true +} + +// RemoveModelValidation implements StatusTracker interface +func (m *MockStatusTracker) RemoveModelValidation(mvKey types.NamespacedName) { + m.mu.Lock() + defer m.mu.Unlock() + + m.RemoveModelValidationCalls = append(m.RemoveModelValidationCalls, mvKey) + delete(m.TrackedModelValidations, mvKey) +} + +// ProcessPodEvent implements StatusTracker interface +func (m *MockStatusTracker) ProcessPodEvent(ctx context.Context, pod *corev1.Pod) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.ProcessPodEventCalls = append(m.ProcessPodEventCalls, ProcessPodEventCall{ + Ctx: ctx, + Pod: pod, + }) + + return m.ProcessPodEventError +} + +// RemovePodEvent implements StatusTracker interface +func (m *MockStatusTracker) RemovePodEvent(_ context.Context, podUID types.UID) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.RemovePodEventCalls = append(m.RemovePodEventCalls, podUID) + return m.RemovePodEventError +} + +// RemovePodByName implements StatusTracker interface +func (m *MockStatusTracker) RemovePodByName(_ context.Context, podName types.NamespacedName) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.RemovePodByNameCalls = append(m.RemovePodByNameCalls, podName) + return m.RemovePodByNameError +} + +// Stop implements StatusTracker interface +func (m *MockStatusTracker) Stop() { + m.mu.Lock() + defer m.mu.Unlock() + + m.StopCalls++ +} + +// Reset clears all recorded calls +func (m *MockStatusTracker) Reset() { + m.mu.Lock() + defer m.mu.Unlock() + + m.AddModelValidationCalls = nil + m.RemoveModelValidationCalls = nil + m.ProcessPodEventCalls = nil + m.RemovePodEventCalls = nil + m.RemovePodByNameCalls = nil + m.StopCalls = 0 + m.ProcessPodEventError = nil + m.RemovePodEventError = nil + m.RemovePodByNameError = nil + m.TrackedModelValidations = make(map[types.NamespacedName]bool) +} + +// GetAddModelValidationCalls returns a copy of the AddModelValidation calls +func (m *MockStatusTracker) GetAddModelValidationCalls() []*v1alpha1.ModelValidation { + m.mu.RLock() + defer m.mu.RUnlock() + + calls := make([]*v1alpha1.ModelValidation, len(m.AddModelValidationCalls)) + copy(calls, m.AddModelValidationCalls) + return calls +} + +// GetRemoveModelValidationCalls returns a copy of the RemoveModelValidation calls +func (m *MockStatusTracker) GetRemoveModelValidationCalls() []types.NamespacedName { + m.mu.RLock() + defer m.mu.RUnlock() + + calls := make([]types.NamespacedName, len(m.RemoveModelValidationCalls)) + copy(calls, m.RemoveModelValidationCalls) + return calls +} + +// GetProcessPodEventCalls returns a copy of the ProcessPodEvent calls +func (m *MockStatusTracker) GetProcessPodEventCalls() []ProcessPodEventCall { + m.mu.RLock() + defer m.mu.RUnlock() + + calls := make([]ProcessPodEventCall, len(m.ProcessPodEventCalls)) + copy(calls, m.ProcessPodEventCalls) + return calls +} + +// GetRemovePodEventCalls returns a copy of the RemovePodEvent calls +func (m *MockStatusTracker) GetRemovePodEventCalls() []types.UID { + m.mu.RLock() + defer m.mu.RUnlock() + + calls := make([]types.UID, len(m.RemovePodEventCalls)) + copy(calls, m.RemovePodEventCalls) + return calls +} + +// GetRemovePodByNameCalls returns a copy of the RemovePodByName calls +func (m *MockStatusTracker) GetRemovePodByNameCalls() []types.NamespacedName { + m.mu.RLock() + defer m.mu.RUnlock() + + calls := make([]types.NamespacedName, len(m.RemovePodByNameCalls)) + copy(calls, m.RemovePodByNameCalls) + return calls +} + +// GetStopCalls returns the number of Stop calls +func (m *MockStatusTracker) GetStopCalls() int { + m.mu.RLock() + defer m.mu.RUnlock() + + return m.StopCalls +} + +// IsModelValidationTracked returns whether a ModelValidation is being tracked +func (m *MockStatusTracker) IsModelValidationTracked(mvKey types.NamespacedName) bool { + m.mu.RLock() + defer m.mu.RUnlock() + + return m.TrackedModelValidations[mvKey] +} diff --git a/internal/tracker/model_validation_info.go b/internal/tracker/model_validation_info.go new file mode 100644 index 00000000..2177cf9e --- /dev/null +++ b/internal/tracker/model_validation_info.go @@ -0,0 +1,161 @@ +// Package tracker provides status tracking functionality for ModelValidation resources +package tracker + +import ( + "github.com/sigstore/model-validation-operator/internal/metrics" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" +) + +// PodInfo stores information about a tracked pod +type PodInfo struct { + Name string + Namespace string + UID types.UID + Timestamp metav1.Time + ConfigHash string + AuthMethod string +} + +// ModelValidationInfo consolidates all tracking information for a ModelValidation resource +type ModelValidationInfo struct { + // Name of the ModelValidation CR + Name string + // Current configuration hash for drift detection + ConfigHash string + // Current authentication method + AuthMethod string + // Observed generation for detecting spec changes + ObservedGeneration int64 + // Pods with finalizer and matching configuration + InjectedPods map[types.UID]*PodInfo + // Pods with label but no finalizer + UninjectedPods map[types.UID]*PodInfo + // Pods with finalizer but configuration drift + OrphanedPods map[types.UID]*PodInfo +} + +// NewModelValidationInfo creates a new ModelValidationInfo with initialized maps +func NewModelValidationInfo(name, configHash, authMethod string, observedGeneration int64) *ModelValidationInfo { + return &ModelValidationInfo{ + Name: name, + ConfigHash: configHash, + AuthMethod: authMethod, + ObservedGeneration: observedGeneration, + InjectedPods: make(map[types.UID]*PodInfo), + UninjectedPods: make(map[types.UID]*PodInfo), + OrphanedPods: make(map[types.UID]*PodInfo), + } +} + +// getPodState returns the current state of a pod, or empty string if not found +func (mvi *ModelValidationInfo) getPodState(podUID types.UID) string { + if _, exists := mvi.InjectedPods[podUID]; exists { + return metrics.PodStateInjected + } + if _, exists := mvi.OrphanedPods[podUID]; exists { + return metrics.PodStateOrphaned + } + if _, exists := mvi.UninjectedPods[podUID]; exists { + return metrics.PodStateUninjected + } + return "" +} + +// movePodToState removes pod from current state, assigns to new state, and records transition +func (mvi *ModelValidationInfo) movePodToState(podInfo *PodInfo, newState string, targetMap map[types.UID]*PodInfo) { + prevState := mvi.getPodState(podInfo.UID) + mvi.RemovePod(podInfo.UID) + + targetMap[podInfo.UID] = podInfo + + if prevState != "" && prevState != newState { + metrics.RecordPodStateTransition(podInfo.Namespace, mvi.Name, prevState, newState) + } +} + +// AddInjectedPod adds a pod with finalizer, automatically determining if it's injected or orphaned +func (mvi *ModelValidationInfo) AddInjectedPod(podInfo *PodInfo) { + // Determine if pod is orphaned based on configuration drift + isOrphaned := (podInfo.ConfigHash != "" && podInfo.ConfigHash != mvi.ConfigHash) || + (podInfo.AuthMethod != "" && podInfo.AuthMethod != mvi.AuthMethod) + + if isOrphaned { + mvi.movePodToState(podInfo, metrics.PodStateOrphaned, mvi.OrphanedPods) + } else { + mvi.movePodToState(podInfo, metrics.PodStateInjected, mvi.InjectedPods) + } +} + +// AddUninjectedPod adds a pod to the uninjected category +func (mvi *ModelValidationInfo) AddUninjectedPod(podInfo *PodInfo) { + mvi.movePodToState(podInfo, metrics.PodStateUninjected, mvi.UninjectedPods) +} + +// RemovePod removes a pod from all categories +func (mvi *ModelValidationInfo) RemovePod(podUID types.UID) bool { + // A pod should only exist in one of these maps, so we can stop after first match + // Check in order of likelihood: injected, orphaned, uninjected + if _, exists := mvi.InjectedPods[podUID]; exists { + delete(mvi.InjectedPods, podUID) + return true + } + if _, exists := mvi.OrphanedPods[podUID]; exists { + delete(mvi.OrphanedPods, podUID) + return true + } + if _, exists := mvi.UninjectedPods[podUID]; exists { + delete(mvi.UninjectedPods, podUID) + return true + } + return false +} + +// UpdateConfig updates the configuration information and returns any drifted pods +// This safely handles configuration changes by detecting pods that are now orphaned +func (mvi *ModelValidationInfo) UpdateConfig(configHash, authMethod string, observedGeneration int64) []*PodInfo { + // If config hasn't changed, no drift possible + if mvi.ConfigHash == configHash && mvi.AuthMethod == authMethod { + // still update observed generation for tracking + mvi.ObservedGeneration = observedGeneration + return nil + } + + // Get drifted pods before updating config using new parameters + var driftedPods []*PodInfo + for _, podInfo := range mvi.InjectedPods { + if (podInfo.ConfigHash != "" && podInfo.ConfigHash != configHash) || + (podInfo.AuthMethod != "" && podInfo.AuthMethod != authMethod) { + driftedPods = append(driftedPods, podInfo) + } + } + + // Update the config and observed generation + mvi.ConfigHash = configHash + mvi.AuthMethod = authMethod + mvi.ObservedGeneration = observedGeneration + + // Move drifted pods to orphaned + for _, podInfo := range driftedPods { + delete(mvi.InjectedPods, podInfo.UID) + mvi.OrphanedPods[podInfo.UID] = podInfo + } + + return driftedPods +} + +// GetAllPods returns all pods from all categories for cleanup +func (mvi *ModelValidationInfo) GetAllPods() []*PodInfo { + totalPods := len(mvi.InjectedPods) + len(mvi.UninjectedPods) + len(mvi.OrphanedPods) + allPods := make([]*PodInfo, 0, totalPods) + for _, podInfo := range mvi.InjectedPods { + allPods = append(allPods, podInfo) + } + for _, podInfo := range mvi.UninjectedPods { + allPods = append(allPods, podInfo) + } + for _, podInfo := range mvi.OrphanedPods { + allPods = append(allPods, podInfo) + } + return allPods +} diff --git a/internal/tracker/pod_mapping.go b/internal/tracker/pod_mapping.go new file mode 100644 index 00000000..dd7ed58a --- /dev/null +++ b/internal/tracker/pod_mapping.go @@ -0,0 +1,84 @@ +// Package tracker provides status tracking functionality for ModelValidation resources +package tracker + +import ( + "k8s.io/apimachinery/pkg/types" +) + +// PodMapping manages bidirectional mapping between pod names and UIDs +type PodMapping struct { + nameToUID map[types.NamespacedName]types.UID + uidToName map[types.UID]types.NamespacedName +} + +// NewPodMapping creates a new bidirectional pod mapping +func NewPodMapping() *PodMapping { + return &PodMapping{ + nameToUID: make(map[types.NamespacedName]types.UID), + uidToName: make(map[types.UID]types.NamespacedName), + } +} + +// AddPod adds a pod to the mapping +func (pm *PodMapping) AddPod(name types.NamespacedName, uid types.UID) { + // Remove any existing mappings for this name or UID + pm.removePod(name, uid) + + pm.nameToUID[name] = uid + pm.uidToName[uid] = name +} + +// GetUIDByName returns the UID for a given pod name +func (pm *PodMapping) GetUIDByName(name types.NamespacedName) (types.UID, bool) { + uid, exists := pm.nameToUID[name] + return uid, exists +} + +// GetNameByUID returns the name for a given pod UID +func (pm *PodMapping) GetNameByUID(uid types.UID) (types.NamespacedName, bool) { + name, exists := pm.uidToName[uid] + return name, exists +} + +// RemovePodsByName removes multiple pod mappings by name +func (pm *PodMapping) RemovePodsByName(names ...types.NamespacedName) { + if len(names) == 0 { + return + } + + for _, name := range names { + if uid, exists := pm.nameToUID[name]; exists { + pm.removePod(name, uid) + } + } +} + +// RemovePodByUID removes a pod mapping by UID +func (pm *PodMapping) RemovePodByUID(uid types.UID) bool { + name, exists := pm.uidToName[uid] + if exists { + pm.removePod(name, uid) + } + return exists +} + +// removePod removes mappings +func (pm *PodMapping) removePod(name types.NamespacedName, uid types.UID) { + delete(pm.nameToUID, name) + delete(pm.uidToName, uid) +} + +// GetAllPodNames returns all tracked pod names +func (pm *PodMapping) GetAllPodNames() []types.NamespacedName { + names := make([]types.NamespacedName, 0, len(pm.nameToUID)) + for name := range pm.nameToUID { + names = append(names, name) + } + return names +} + +// Clear removes all mappings +func (pm *PodMapping) Clear() { + pm.nameToUID = make(map[types.NamespacedName]types.UID) + pm.uidToName = make(map[types.UID]types.NamespacedName) +} diff --git a/internal/tracker/status_tracker.go b/internal/tracker/status_tracker.go new file mode 100644 index 00000000..5ee31c0a --- /dev/null +++ b/internal/tracker/status_tracker.go @@ -0,0 +1,568 @@ +// Package tracker provides status tracking functionality for ModelValidation resources +package tracker + +import ( + "context" + "fmt" + "slices" + "sort" + "sync" + "time" + + "github.com/sigstore/model-validation-operator/api/v1alpha1" + "github.com/sigstore/model-validation-operator/internal/constants" + "github.com/sigstore/model-validation-operator/internal/metrics" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + "sigs.k8s.io/controller-runtime/pkg/log" +) + +// StatusTrackerImpl tracks injected pods and namespaces for ModelValidation resources +type StatusTrackerImpl struct { + client client.Client + mu sync.RWMutex + // Consolidated tracking information for each ModelValidation + modelValidations map[types.NamespacedName]*ModelValidationInfo + // Count of ModelValidation resources per namespace + mvNamespaces map[string]int + // Pod name/UID bidirectional mapping + podMapping *PodMapping + + // Configuration + statusUpdateTimeout time.Duration + + // Async update components + debouncedQueue DebouncedQueue + stopCh chan struct{} + wg sync.WaitGroup +} + +// StatusTrackerConfig holds configuration options for the status tracker +type StatusTrackerConfig struct { + DebounceDuration time.Duration + RetryBaseDelay time.Duration + RetryMaxDelay time.Duration + RateLimitQPS float64 + RateLimitBurst int + StatusUpdateTimeout time.Duration +} + +// NewStatusTracker creates a new status tracker with explicit configuration +func NewStatusTracker(client client.Client, config StatusTrackerConfig) StatusTracker { + debouncedQueue := NewDebouncedQueue(DebouncedQueueConfig{ + DebounceDuration: config.DebounceDuration, + RetryBaseDelay: config.RetryBaseDelay, + RetryMaxDelay: config.RetryMaxDelay, + RateLimitQPS: config.RateLimitQPS, + RateLimitBurst: config.RateLimitBurst, + }) + + st := &StatusTrackerImpl{ + client: client, + modelValidations: make(map[types.NamespacedName]*ModelValidationInfo), + mvNamespaces: make(map[string]int), + podMapping: NewPodMapping(), + statusUpdateTimeout: config.StatusUpdateTimeout, + debouncedQueue: debouncedQueue, + stopCh: make(chan struct{}), + } + + st.wg.Add(1) + go st.asyncUpdateWorker() + + return st +} + +// collectPodTrackingInfo is a helper to convert PodInfo to PodTrackingInfo +func collectPodTrackingInfo(pods map[types.UID]*PodInfo, result *[]v1alpha1.PodTrackingInfo) { + if pods == nil { + return + } + for _, podInfo := range pods { + *result = append(*result, v1alpha1.PodTrackingInfo{ + Name: podInfo.Name, + UID: string(podInfo.UID), + InjectedAt: podInfo.Timestamp, + }) + } +} + +// collectPodsForCleanup is a helper to collect pod names for cleanup +func collectPodsForCleanup(allPods []*PodInfo, result *[]types.NamespacedName) { + if allPods == nil { + return + } + for _, podInfo := range allPods { + *result = append(*result, types.NamespacedName{ + Name: podInfo.Name, + Namespace: podInfo.Namespace, + }) + } +} + +// Stop stops the async update worker +func (st *StatusTrackerImpl) Stop() { + // Wait for pending updates before stopping + st.debouncedQueue.WaitForUpdates() + + // Stop components + close(st.stopCh) + st.debouncedQueue.ShutDown() + st.wg.Wait() +} + +// asyncUpdateWorker processes status updates asynchronously +func (st *StatusTrackerImpl) asyncUpdateWorker() { + defer st.wg.Done() + + for { + select { + case <-st.stopCh: + return + default: + st.processNextUpdate() + } + } +} + +// processNextUpdate processes the next update from the queue +func (st *StatusTrackerImpl) processNextUpdate() { + mvKey, shutdown := st.debouncedQueue.Get() + if shutdown { + return + } + defer st.debouncedQueue.Done(mvKey) + + ctx, cancel := context.WithTimeout(context.Background(), st.statusUpdateTimeout) + defer cancel() + + if err := st.doUpdateStatus(ctx, mvKey); err != nil { + logger := log.FromContext(ctx) + logger.V(1).Info("Status update failed, will retry", + "modelvalidation", mvKey, + "attempts", st.debouncedQueue.GetRetryCount(mvKey)+1, + "error", err) + metrics.RecordRetryAttempt(mvKey.Namespace, mvKey.Name) + st.debouncedQueue.AddWithRetry(mvKey) + } else { + st.debouncedQueue.ForgetRetries(mvKey) + } +} + +// WaitForUpdates waits for all pending status updates to complete +// This is useful in tests to ensure async operations finish before assertions +func (st *StatusTrackerImpl) WaitForUpdates() { + st.debouncedQueue.WaitForCompletion() +} + +// doUpdateStatus updates the ModelValidation status with current metrics +// This is called asynchronously from the update worker +func (st *StatusTrackerImpl) doUpdateStatus(ctx context.Context, mvKey types.NamespacedName) error { + logger := log.FromContext(ctx) + startTime := time.Now() + + mv := &v1alpha1.ModelValidation{} + if err := st.client.Get(ctx, mvKey, mv); err != nil { + if errors.IsNotFound(err) { + st.mu.Lock() + delete(st.modelValidations, mvKey) + st.mu.Unlock() + recordStatusUpdateResult(mvKey.Namespace, mvKey.Name, metrics.StatusUpdateFailure, time.Since(startTime)) + return nil + } + recordStatusUpdateResult(mvKey.Namespace, mvKey.Name, metrics.StatusUpdateFailure, time.Since(startTime)) + return err + } + + trackedPods := []v1alpha1.PodTrackingInfo{} + uninjectedPods := []v1alpha1.PodTrackingInfo{} + orphanedPods := []v1alpha1.PodTrackingInfo{} + var trackedConfigHash, trackedAuthMethod string + + st.mu.RLock() + mvInfo := st.modelValidations[mvKey] + if mvInfo != nil { + collectPodTrackingInfo(mvInfo.InjectedPods, &trackedPods) + collectPodTrackingInfo(mvInfo.UninjectedPods, &uninjectedPods) + collectPodTrackingInfo(mvInfo.OrphanedPods, &orphanedPods) + trackedConfigHash = mvInfo.ConfigHash + trackedAuthMethod = mvInfo.AuthMethod + } + st.mu.RUnlock() + if mvInfo == nil { + return nil + } + sort.Slice(trackedPods, func(i, j int) bool { + return trackedPods[i].Name < trackedPods[j].Name + }) + sort.Slice(uninjectedPods, func(i, j int) bool { + return uninjectedPods[i].Name < uninjectedPods[j].Name + }) + sort.Slice(orphanedPods, func(i, j int) bool { + return orphanedPods[i].Name < orphanedPods[j].Name + }) + + // Check if ModelValidation parameters have changed since tracking began + currentConfigHash := mv.GetConfigHash() + currentAuthMethod := mv.GetAuthMethod() + + // Check for configuration drift between tracked and current state + if trackedConfigHash != currentConfigHash || trackedAuthMethod != currentAuthMethod { + logger.Error(fmt.Errorf("configuration drift detected"), "ModelValidation config drift during status update", + "modelvalidation", mvKey, + "oldConfigHash", trackedConfigHash, + "newConfigHash", currentConfigHash, + "oldAuthMethod", trackedAuthMethod, + "newAuthMethod", currentAuthMethod) + + if trackedConfigHash != currentConfigHash { + metrics.RecordConfigurationDrift(mvKey.Namespace, mvKey.Name, "config_hash") + } + if trackedAuthMethod != currentAuthMethod { + metrics.RecordConfigurationDrift(mvKey.Namespace, mvKey.Name, "auth_method") + } + + recordStatusUpdateResult(mvKey.Namespace, mvKey.Name, metrics.StatusUpdateFailure, time.Since(startTime)) + return fmt.Errorf( + "configuration drift detected for ModelValidation %s: hash changed from %s to %s, auth method changed from %s to %s", + mvKey, trackedConfigHash, currentConfigHash, trackedAuthMethod, currentAuthMethod) + } + + newStatus := v1alpha1.ModelValidationStatus{ + Conditions: mv.Status.Conditions, + InjectedPodCount: int32(len(trackedPods)), + UninjectedPodCount: int32(len(uninjectedPods)), + OrphanedPodCount: int32(len(orphanedPods)), + AuthMethod: trackedAuthMethod, + InjectedPods: trackedPods, + UninjectedPods: uninjectedPods, + OrphanedPods: orphanedPods, + LastUpdated: metav1.Now(), + } + + if statusEqual(mv.Status, newStatus) { + logger.V(2).Info("Status unchanged, skipping update", "modelvalidation", mvKey) + recordStatusUpdateResult(mvKey.Namespace, mvKey.Name, metrics.StatusUpdateSuccess, time.Since(startTime)) + return nil + } + + mv.Status = newStatus + if err := st.client.Status().Update(ctx, mv); err != nil { + logger.Error(err, "Failed to update ModelValidation status", "modelvalidation", mvKey) + recordStatusUpdateResult(mvKey.Namespace, mvKey.Name, metrics.StatusUpdateFailure, time.Since(startTime)) + return err + } + + recordStatusUpdateResult(mvKey.Namespace, mvKey.Name, metrics.StatusUpdateSuccess, time.Since(startTime)) + + recordPodCounts(mvKey.Namespace, mvKey.Name, len(trackedPods), len(uninjectedPods), len(orphanedPods)) + + logger.Info("Updated ModelValidation status", + "modelvalidation", mvKey, + "injectedPods", mv.Status.InjectedPodCount, + "uninjectedPods", mv.Status.UninjectedPodCount, + "orphanedPods", mv.Status.OrphanedPodCount, + "authMethod", mv.Status.AuthMethod) + + return nil +} + +// comparePodSlices compares two slices of PodTrackingInfo for equality +func comparePodSlices(a, b []v1alpha1.PodTrackingInfo) bool { + return slices.EqualFunc(a, b, func(x, y v1alpha1.PodTrackingInfo) bool { + return x.Name == y.Name && x.UID == y.UID + }) +} + +// recordStatusUpdateResult records both the status update result and duration metrics +func recordStatusUpdateResult(namespace, modelValidation, result string, duration time.Duration) { + metrics.RecordStatusUpdate(namespace, modelValidation, result) + metrics.RecordStatusUpdateDuration(namespace, modelValidation, result, duration.Seconds()) +} + +// recordPodCounts records current pod counts for all states +func recordPodCounts(namespace, modelValidation string, injected, uninjected, orphaned int) { + metrics.RecordPodCount(namespace, modelValidation, metrics.PodStateInjected, float64(injected)) + metrics.RecordPodCount(namespace, modelValidation, metrics.PodStateUninjected, float64(uninjected)) + metrics.RecordPodCount(namespace, modelValidation, metrics.PodStateOrphaned, float64(orphaned)) +} + +// statusEqual compares two ModelValidationStatus objects for equality +// ignoring LastUpdated timestamp. +func statusEqual(a, b v1alpha1.ModelValidationStatus) bool { + if a.InjectedPodCount != b.InjectedPodCount || + a.UninjectedPodCount != b.UninjectedPodCount || + a.OrphanedPodCount != b.OrphanedPodCount || + a.AuthMethod != b.AuthMethod { + return false + } + + return comparePodSlices(a.InjectedPods, b.InjectedPods) && + comparePodSlices(a.UninjectedPods, b.UninjectedPods) && + comparePodSlices(a.OrphanedPods, b.OrphanedPods) +} + +// AddModelValidation adds a ModelValidation to tracking using the provided ModelValidation +func (st *StatusTrackerImpl) AddModelValidation(ctx context.Context, mv *v1alpha1.ModelValidation) { + mvKey := types.NamespacedName{Name: mv.Name, Namespace: mv.Namespace} + + st.mu.Lock() + mvi, alreadyTracking := st.modelValidations[mvKey] + if !alreadyTracking { + mvInfo := NewModelValidationInfo(mv.Name, mv.GetConfigHash(), mv.GetAuthMethod(), mv.Generation) + st.modelValidations[mvKey] = mvInfo + st.mvNamespaces[mvKey.Namespace]++ + + metrics.RecordModelValidationCR(mvKey.Namespace, float64(st.mvNamespaces[mvKey.Namespace])) + } else if mv.Generation != mvi.ObservedGeneration { + // Update existing tracking with current config and handle drift + driftedPods := st.modelValidations[mvKey].UpdateConfig(mv.GetConfigHash(), mv.GetAuthMethod(), mv.Generation) + + if len(driftedPods) > 0 { + logger := log.FromContext(ctx) + logger.Info("Detected configuration drift, moved pods to orphaned status", + "modelvalidation", mvKey, "driftedPods", len(driftedPods)) + + metrics.RecordMultiplePodStateTransitions(mvKey.Namespace, mvKey.Name, + metrics.PodStateInjected, metrics.PodStateOrphaned, len(driftedPods)) + } + } + st.mu.Unlock() + + if !alreadyTracking { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), st.statusUpdateTimeout) + defer cancel() + + if err := st.seedExistingPods(ctx, mvKey); err != nil { + logger := log.FromContext(ctx) + logger.Error(err, "Failed to seed existing pods", "modelvalidation", mvKey) + } + }() + } + + st.debouncedQueue.Add(mvKey) +} + +// RemoveModelValidation removes a ModelValidation from tracking +func (st *StatusTrackerImpl) RemoveModelValidation(mvKey types.NamespacedName) { + st.mu.Lock() + defer st.mu.Unlock() + + // Only decrement namespace counter if ModelValidation was actually tracked + mvInfo, wasTracked := st.modelValidations[mvKey] + if wasTracked { + delete(st.modelValidations, mvKey) + st.mvNamespaces[mvKey.Namespace]-- + + if st.mvNamespaces[mvKey.Namespace] <= 0 { + delete(st.mvNamespaces, mvKey.Namespace) + metrics.RecordModelValidationCR(mvKey.Namespace, 0) + } else { + metrics.RecordModelValidationCR(mvKey.Namespace, float64(st.mvNamespaces[mvKey.Namespace])) + } + + allPods := mvInfo.GetAllPods() + var podsToCleanup []types.NamespacedName + collectPodsForCleanup(allPods, &podsToCleanup) + + st.podMapping.RemovePodsByName(podsToCleanup...) + + // Reset pod count metrics for removed ModelValidation + metrics.RecordPodCount(mvKey.Namespace, mvKey.Name, metrics.PodStateInjected, 0) + metrics.RecordPodCount(mvKey.Namespace, mvKey.Name, metrics.PodStateUninjected, 0) + metrics.RecordPodCount(mvKey.Namespace, mvKey.Name, metrics.PodStateOrphaned, 0) + } +} + +// IsModelValidationTracked checks if a ModelValidation is being tracked +func (st *StatusTrackerImpl) IsModelValidationTracked(mvKey types.NamespacedName) bool { + st.mu.RLock() + defer st.mu.RUnlock() + + return st.modelValidations[mvKey] != nil +} + +// GetObservedGeneration returns the observed generation for a tracked ModelValidation +// Returns the generation and a boolean indicating whether the ModelValidation is tracked +func (st *StatusTrackerImpl) GetObservedGeneration(mvKey types.NamespacedName) (int64, bool) { + st.mu.RLock() + defer st.mu.RUnlock() + + if mvInfo := st.modelValidations[mvKey]; mvInfo != nil { + return mvInfo.ObservedGeneration, true + } + return 0, false +} + +// seedExistingPods processes existing pods using the provided ModelValidation +func (st *StatusTrackerImpl) seedExistingPods(ctx context.Context, mvKey types.NamespacedName) error { + logger := log.FromContext(ctx) + + // List all pods in the namespace with the ModelValidation label + podList := &corev1.PodList{} + listOpts := []client.ListOption{ + client.InNamespace(mvKey.Namespace), + client.MatchingLabels{constants.ModelValidationLabel: mvKey.Name}, + } + + if err := st.client.List(ctx, podList, listOpts...); err != nil { + logger.Error(err, "Failed to list existing pods for seeding", "modelvalidation", mvKey) + return err + } + + logger.Info("Seeding existing pods", "modelvalidation", mvKey, "podCount", len(podList.Items)) + + if err := st.processSeedPods(ctx, podList.Items, mvKey); err != nil { + logger.Error(err, "Failed to process some pods during seeding", "modelvalidation", mvKey) + } + + return nil +} + +// processSeedPods processes seed pods using the provided ModelValidation +func (st *StatusTrackerImpl) processSeedPods(ctx context.Context, pods []corev1.Pod, mvKey types.NamespacedName) error { + logger := log.FromContext(ctx) + var errors []error + + for i := range pods { + pod := &pods[i] + if err := st.processSeedPodEvent(pod, mvKey); err != nil { + logger.Error(err, "Failed to process existing pod during batch seeding", "pod", pod.Name, "modelvalidation", mvKey) + errors = append(errors, err) + } + } + + // Return an aggregate error if any individual pod processing failed + if len(errors) > 0 { + return fmt.Errorf("failed to process %d out of %d pods", len(errors), len(pods)) + } + + return nil +} + +// processSeedPodEvent processes a pod event with a provided ModelValidation +func (st *StatusTrackerImpl) processSeedPodEvent(pod *corev1.Pod, mvKey types.NamespacedName) error { + modelValidationName, hasLabel := pod.Labels[constants.ModelValidationLabel] + if !hasLabel || modelValidationName == "" { + return nil + } + + st.mu.Lock() + defer st.mu.Unlock() + + mvInfo := st.modelValidations[mvKey] + if mvInfo != nil { + st.processPodEventCommon(pod, mvKey, mvInfo) + } + + return nil +} + +// processPodEventCommon contains the common logic for processing pod events +func (st *StatusTrackerImpl) processPodEventCommon( + pod *corev1.Pod, + mvKey types.NamespacedName, + mvInfo *ModelValidationInfo, +) { + hasOurFinalizer := controllerutil.ContainsFinalizer(pod, constants.ModelValidationFinalizer) + + podInfo := &PodInfo{ + Name: pod.Name, + Namespace: pod.Namespace, + UID: pod.UID, + Timestamp: pod.CreationTimestamp, + ConfigHash: pod.Annotations[constants.ConfigHashAnnotationKey], + AuthMethod: pod.Annotations[constants.AuthMethodAnnotationKey], + } + + podNamespacedName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace} + st.podMapping.AddPod(podNamespacedName, pod.UID) + + if hasOurFinalizer { + mvInfo.AddInjectedPod(podInfo) + } else { + mvInfo.AddUninjectedPod(podInfo) + } + + st.debouncedQueue.Add(mvKey) +} + +// ProcessPodEvent processes a pod event from controllers +// This determines how to categorize the pod based on its state and configuration consistency +func (st *StatusTrackerImpl) ProcessPodEvent(_ context.Context, pod *corev1.Pod) error { + if pod == nil { + return fmt.Errorf("pod cannot be nil") + } + + modelValidationName, hasLabel := pod.Labels[constants.ModelValidationLabel] + if !hasLabel || modelValidationName == "" { + return nil + } + + mvKey := types.NamespacedName{ + Name: modelValidationName, + Namespace: pod.Namespace, + } + + st.mu.Lock() + defer st.mu.Unlock() + + mvInfo := st.modelValidations[mvKey] + if mvInfo != nil { + st.processPodEventCommon(pod, mvKey, mvInfo) + } + return nil +} + +// RemovePodEvent removes a pod from tracking when it's deleted +func (st *StatusTrackerImpl) RemovePodEvent(_ context.Context, podUID types.UID) error { + st.mu.Lock() + defer st.mu.Unlock() + + updatedMVs := st.removePodFromTrackingMapsUnsafe(podUID) + + for mvKey := range updatedMVs { + st.debouncedQueue.Add(mvKey) + } + + return nil +} + +// RemovePodByName removes a pod from tracking by its NamespacedName +// This is useful when we know a pod was deleted but don't have its UID +func (st *StatusTrackerImpl) RemovePodByName(_ context.Context, podName types.NamespacedName) error { + st.mu.Lock() + defer st.mu.Unlock() + + podUID, exists := st.podMapping.GetUIDByName(podName) + if !exists { + return nil + } + + updatedMVs := st.removePodFromTrackingMapsUnsafe(podUID) + st.podMapping.RemovePodsByName(podName) + + for mvKey := range updatedMVs { + st.debouncedQueue.Add(mvKey) + } + + return nil +} + +// removePodFromTrackingMapsUnsafe removes a pod from tracking maps and returns updated MVs +func (st *StatusTrackerImpl) removePodFromTrackingMapsUnsafe(podUID types.UID) map[types.NamespacedName]bool { + updatedMVs := make(map[types.NamespacedName]bool) + + for mvKey, mvInfo := range st.modelValidations { + if mvInfo.RemovePod(podUID) { + updatedMVs[mvKey] = true + } + } + + return updatedMVs +} diff --git a/internal/tracker/status_tracker_intf.go b/internal/tracker/status_tracker_intf.go new file mode 100644 index 00000000..b74ec498 --- /dev/null +++ b/internal/tracker/status_tracker_intf.go @@ -0,0 +1,24 @@ +package tracker + +import ( + "context" + + "github.com/sigstore/model-validation-operator/api/v1alpha1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" +) + +// StatusTracker defines the interface for tracking ModelValidation and Pod events +type StatusTracker interface { + // ModelValidation tracking methods + AddModelValidation(ctx context.Context, mv *v1alpha1.ModelValidation) + RemoveModelValidation(mvKey types.NamespacedName) + + // Pod tracking methods + ProcessPodEvent(ctx context.Context, pod *corev1.Pod) error + RemovePodEvent(ctx context.Context, podUID types.UID) error + RemovePodByName(ctx context.Context, podName types.NamespacedName) error + + // Lifecycle methods + Stop() +} diff --git a/internal/tracker/status_tracker_test.go b/internal/tracker/status_tracker_test.go new file mode 100644 index 00000000..8076d1b6 --- /dev/null +++ b/internal/tracker/status_tracker_test.go @@ -0,0 +1,762 @@ +package tracker + +import ( + "context" + "sync" + "time" + + . "github.com/onsi/ginkgo/v2" //nolint:revive + . "github.com/onsi/gomega" //nolint:revive + + "github.com/sigstore/model-validation-operator/api/v1alpha1" + "github.com/sigstore/model-validation-operator/internal/constants" + "github.com/sigstore/model-validation-operator/internal/testutil" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +// NewStatusTrackerForTesting creates a status tracker with fast test-friendly defaults +func NewStatusTrackerForTesting(client client.Client) StatusTracker { + return NewStatusTracker(client, StatusTrackerConfig{ + DebounceDuration: 50 * time.Millisecond, // Faster for tests + RetryBaseDelay: 10 * time.Millisecond, // Faster retries for tests + RetryMaxDelay: 1 * time.Second, // Lower max delay for tests + RateLimitQPS: 100, // Higher QPS for tests + RateLimitBurst: 1000, // Higher burst for tests + StatusUpdateTimeout: 5 * time.Second, // Shorter timeout for tests + }) +} + +// setupTestEnvironment creates a common test setup with ModelValidation and StatusTracker +func setupTestEnvironment(mv *v1alpha1.ModelValidation) (*StatusTrackerImpl, client.Client, types.NamespacedName) { + fakeClient := testutil.SetupFakeClientWithObjects(mv) + statusTracker := NewStatusTrackerForTesting(fakeClient).(*StatusTrackerImpl) + mvKey := types.NamespacedName{Name: mv.Name, Namespace: mv.Namespace} + return statusTracker, fakeClient, mvKey +} + +// expectPodCount waits for and verifies expected pod counts in ModelValidation status +func expectPodCount( + ctx context.Context, + fakeClient client.Client, + mvKey types.NamespacedName, + injected, uninjected, orphaned int32, +) { + Eventually(ctx, func(ctx context.Context) bool { + updatedMV := testutil.GetModelValidationFromClientExpected(ctx, fakeClient, mvKey) + return updatedMV.Status.InjectedPodCount == injected && + updatedMV.Status.UninjectedPodCount == uninjected && + updatedMV.Status.OrphanedPodCount == orphaned + }, "2s", "50ms").Should(BeTrue()) +} + +// getNamespaceCount returns the count of ModelValidations in the given namespace +func getNamespaceCount(tracker *StatusTrackerImpl, namespace string) int { + tracker.mu.RLock() + defer tracker.mu.RUnlock() + return tracker.mvNamespaces[namespace] +} + +// namespaceExists checks if a namespace exists in the tracker +func namespaceExists(tracker *StatusTrackerImpl, namespace string) bool { + tracker.mu.RLock() + defer tracker.mu.RUnlock() + _, exists := tracker.mvNamespaces[namespace] + return exists +} + +var _ = Describe("StatusTracker", func() { + var ( + ctx context.Context + statusTracker *StatusTrackerImpl + fakeClient client.Client + mvKey types.NamespacedName + ) + + BeforeEach(func() { + ctx = context.Background() + }) + + AfterEach(func() { + if statusTracker != nil { + statusTracker.Stop() + } + }) + + Context("when tracking pod injections", func() { + It("should track injected pods and update ModelValidation status", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "test-pod", + Namespace: "default", + UID: types.UID("test-uid"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + }) + + err := statusTracker.ProcessPodEvent(ctx, pod) + Expect(err).NotTo(HaveOccurred()) + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 1, 0, 0) + + updatedMV := testutil.GetModelValidationFromClientExpected(ctx, fakeClient, mvKey) + Expect(updatedMV.Status.InjectedPods).To(HaveLen(1)) + Expect(updatedMV.Status.InjectedPods[0].Name).To(Equal("test-pod")) + }) + + It("should categorize pods correctly based on finalizer and CR existence", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + injectedPod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "injected-pod", + Namespace: "default", + UID: types.UID("uid-1"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + }) + + uninjectedPod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "uninjected-pod", + Namespace: "default", + UID: types.UID("uid-2"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + }) + + err := statusTracker.ProcessPodEvent(ctx, injectedPod) + Expect(err).NotTo(HaveOccurred()) + + err = statusTracker.ProcessPodEvent(ctx, uninjectedPod) + Expect(err).NotTo(HaveOccurred()) + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 1, 1, 0) + + updatedMV := testutil.GetModelValidationFromClientExpected(ctx, fakeClient, mvKey) + Expect(updatedMV.Status.InjectedPods).To(HaveLen(1)) + Expect(updatedMV.Status.UninjectedPods).To(HaveLen(1)) + }) + }) + + Context("when removing pods", func() { + It("should remove pods from tracking and update status", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + pod1 := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "test-pod-1", + Namespace: "default", + UID: types.UID("test-uid-1"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + }) + pod2 := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "test-pod-2", + Namespace: "default", + UID: types.UID("test-uid-2"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + }) + + err := statusTracker.ProcessPodEvent(ctx, pod1) + Expect(err).NotTo(HaveOccurred()) + err = statusTracker.ProcessPodEvent(ctx, pod2) + Expect(err).NotTo(HaveOccurred()) + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 2, 0, 0) + + err = statusTracker.RemovePodEvent(ctx, pod1.UID) + Expect(err).NotTo(HaveOccurred()) + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 1, 0, 0) + + updatedMV := testutil.GetModelValidationFromClientExpected(ctx, fakeClient, mvKey) + Expect(updatedMV.Status.InjectedPods).To(HaveLen(1)) + Expect(updatedMV.Status.InjectedPods[0].Name).To(Equal("test-pod-2")) + }) + }) + + Context("when managing ModelValidation tracking", func() { + It("should add and remove ModelValidations from tracking", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + + Expect(statusTracker.IsModelValidationTracked(mvKey)).To(BeFalse()) + + statusTracker.AddModelValidation(ctx, mv) + Expect(statusTracker.IsModelValidationTracked(mvKey)).To(BeTrue()) + + statusTracker.RemoveModelValidation(mvKey) + Expect(statusTracker.IsModelValidationTracked(mvKey)).To(BeFalse()) + }) + + It("should handle AddModelValidation idempotently", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + mvKey = testutil.CreateTestNamespacedName("test-mv", "default") + + Expect(statusTracker.IsModelValidationTracked(mvKey)).To(BeFalse()) + + initialCount := getNamespaceCount(statusTracker, "default") + + statusTracker.AddModelValidation(ctx, mv) + Expect(statusTracker.IsModelValidationTracked(mvKey)).To(BeTrue()) + + firstCount := getNamespaceCount(statusTracker, "default") + Expect(firstCount).To(Equal(initialCount + 1)) + + statusTracker.AddModelValidation(ctx, mv) + Expect(statusTracker.IsModelValidationTracked(mvKey)).To(BeTrue()) + + secondCount := getNamespaceCount(statusTracker, "default") + Expect(secondCount).To(Equal(firstCount), "Second AddModelValidation should not increment namespace counter") + + statusTracker.AddModelValidation(ctx, mv) + thirdCount := getNamespaceCount(statusTracker, "default") + Expect(thirdCount).To(Equal(firstCount), "Third AddModelValidation should not increment namespace counter") + }) + + It("should skip AddModelValidation when generation hasn't changed", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv-generation", + Namespace: "default", + }) + statusTracker, _, mvKey = setupTestEnvironment(mv) + + // Set initial generation + mv.Generation = 1 + statusTracker.AddModelValidation(ctx, mv) + Expect(statusTracker.IsModelValidationTracked(mvKey)).To(BeTrue()) + + initialObservedGen, tracked := statusTracker.GetObservedGeneration(mvKey) + Expect(tracked).To(BeTrue(), "ModelValidation should be tracked") + Expect(initialObservedGen).To(Equal(int64(1))) + + // Call AddModelValidation again with same generation - should be skipped + statusTracker.AddModelValidation(ctx, mv) + + currentObservedGen, tracked := statusTracker.GetObservedGeneration(mvKey) + Expect(tracked).To(BeTrue(), "ModelValidation should still be tracked") + Expect(currentObservedGen).To(Equal(int64(1)), + "Generation should remain unchanged when same generation is processed") + + // Updated generation should trigger processing + mv.Generation = 2 + statusTracker.AddModelValidation(ctx, mv) + + updatedObservedGen, tracked := statusTracker.GetObservedGeneration(mvKey) + Expect(tracked).To(BeTrue(), "ModelValidation should still be tracked") + Expect(updatedObservedGen).To(Equal(int64(2)), "Generation should be updated when new generation is processed") + }) + + It("should return false for untracked ModelValidation in GetObservedGeneration", func() { + untrackedKey := types.NamespacedName{Name: "untracked-mv", Namespace: "default"} + + // Create a statusTracker but don't add any ModelValidation + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "tracked-mv", + Namespace: "default", + }) + statusTracker, _, _ = setupTestEnvironment(mv) + + _, tracked := statusTracker.GetObservedGeneration(untrackedKey) + Expect(tracked).To(BeFalse(), "Untracked ModelValidation should return false") + }) + + It("should handle RemoveModelValidation idempotently and manage namespace counts correctly", func() { + mv1 := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv-1", + Namespace: "default", + }) + mv2 := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv-2", + Namespace: "default", + }) + mv3 := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv-3", + Namespace: "other-namespace", + }) + fakeClient = testutil.SetupFakeClientWithObjects(mv1, mv2, mv3) + statusTracker = NewStatusTrackerForTesting(fakeClient).(*StatusTrackerImpl) + + mvKey1 := testutil.CreateTestNamespacedName("test-mv-1", "default") + mvKey2 := testutil.CreateTestNamespacedName("test-mv-2", "default") + mvKey3 := testutil.CreateTestNamespacedName("test-mv-3", "other-namespace") + + statusTracker.AddModelValidation(ctx, mv1) + statusTracker.AddModelValidation(ctx, mv2) + statusTracker.AddModelValidation(ctx, mv3) + + defaultCount := getNamespaceCount(statusTracker, "default") + otherCount := getNamespaceCount(statusTracker, "other-namespace") + Expect(defaultCount).To(Equal(2)) + Expect(otherCount).To(Equal(1)) + + statusTracker.RemoveModelValidation(mvKey1) + Expect(statusTracker.IsModelValidationTracked(mvKey1)).To(BeFalse()) + + defaultCountAfterFirst := getNamespaceCount(statusTracker, "default") + Expect(defaultCountAfterFirst).To(Equal(1)) + + statusTracker.RemoveModelValidation(mvKey1) + defaultCountAfterSecond := getNamespaceCount(statusTracker, "default") + Expect(defaultCountAfterSecond).To(Equal(1), "Second RemoveModelValidation should not decrement namespace counter") + + statusTracker.RemoveModelValidation(mvKey2) + defaultExists := namespaceExists(statusTracker, "default") + otherCountFinal := getNamespaceCount(statusTracker, "other-namespace") + Expect(defaultExists).To(BeFalse(), "Namespace should be deleted when count reaches zero") + Expect(otherCountFinal).To(Equal(1), "Other namespace should be unaffected") + + statusTracker.RemoveModelValidation(mvKey3) + otherExists := namespaceExists(statusTracker, "other-namespace") + Expect(otherExists).To(BeFalse(), "Other namespace should also be deleted when count reaches zero") + }) + + DescribeTable("should handle configuration drift detection", + func( + mvConfig testutil.TestModelValidationOptions, + podAuth string, + podConfigHash string, + expectedInjected, expectedOrphaned int32, + expectedPodName string, + ) { + mv := testutil.CreateTestModelValidation(mvConfig) + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + podAnnotations := map[string]string{ + constants.AuthMethodAnnotationKey: podAuth, + } + if podConfigHash != "" { + podAnnotations[constants.ConfigHashAnnotationKey] = podConfigHash + } + + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: expectedPodName, + Namespace: "default", + UID: types.UID("uid-" + expectedPodName), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + Annotations: podAnnotations, + }) + + err := statusTracker.ProcessPodEvent(ctx, pod) + Expect(err).NotTo(HaveOccurred()) + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, expectedInjected, 0, expectedOrphaned) + + updatedMV := testutil.GetModelValidationFromClientExpected(ctx, fakeClient, mvKey) + if expectedInjected > 0 { + Expect(updatedMV.Status.InjectedPods).To(HaveLen(int(expectedInjected))) + Expect(updatedMV.Status.InjectedPods[0].Name).To(Equal(expectedPodName)) + } + if expectedOrphaned > 0 { + Expect(updatedMV.Status.OrphanedPods).To(HaveLen(int(expectedOrphaned))) + Expect(updatedMV.Status.OrphanedPods[0].Name).To(Equal(expectedPodName)) + } + }, + Entry("auth method mismatch - sigstore MV with pki pod", + testutil.TestModelValidationOptions{Name: "test-mv", Namespace: "default", ConfigType: "sigstore"}, + "pki", "", int32(0), int32(1), "orphaned-pod"), + Entry("matching pki configuration", + testutil.TestModelValidationOptions{Name: "test-mv", Namespace: "default", ConfigType: "pki"}, + "pki", "", int32(1), int32(0), "matching-pod"), + Entry("sigstore config hash drift", + testutil.TestModelValidationOptions{ + Name: "test-mv", Namespace: "default", ConfigType: "sigstore", + CertIdentity: "user@example.com", CertOidcIssuer: "https://accounts.google.com", + }, + "sigstore", func() string { + oldMV := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", Namespace: "default", ConfigType: "sigstore", + CertIdentity: "different@example.com", CertOidcIssuer: "https://accounts.google.com", + }) + return oldMV.GetConfigHash() + }(), int32(0), int32(1), "drift-pod"), + Entry("pki config hash drift", + testutil.TestModelValidationOptions{ + Name: "test-mv", Namespace: "default", ConfigType: "pki", CertificateCA: "/path/to/ca.crt", + }, + "pki", func() string { + oldMV := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", Namespace: "default", ConfigType: "pki", CertificateCA: "/different/path/ca.crt", + }) + return oldMV.GetConfigHash() + }(), int32(0), int32(1), "pki-drift-pod"), + ) + + It("should handle pod deletion by name", func() { + // Create ModelValidation with Sigstore config + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + ConfigType: "sigstore", + }) + + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "deleted-pod", + Namespace: "default", + UID: types.UID("uid-deleted"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + }) + + err := statusTracker.ProcessPodEvent(ctx, pod) + Expect(err).NotTo(HaveOccurred()) + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 1, 0, 0) + + // Now simulate pod deletion by removing it by name + podName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace} + err = statusTracker.RemovePodByName(ctx, podName) + Expect(err).NotTo(HaveOccurred()) + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 0, 0, 0) + + updatedMV := testutil.GetModelValidationFromClientExpected(ctx, fakeClient, mvKey) + Expect(updatedMV.Status.InjectedPods).To(BeEmpty()) + }) + + It("should re-evaluate existing pods when ModelValidation is updated", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + ConfigType: "sigstore", + }) + + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "test-pod", + Namespace: "default", + UID: types.UID("test-uid"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + Annotations: map[string]string{ + constants.AuthMethodAnnotationKey: "sigstore", + }, + }) + + fakeClient = testutil.SetupFakeClientWithObjects(mv, pod) + statusTracker = NewStatusTrackerForTesting(fakeClient).(*StatusTrackerImpl) + + mvKey := types.NamespacedName{Name: mv.Name, Namespace: mv.Namespace} + + statusTracker.AddModelValidation(ctx, mv) + + err := statusTracker.ProcessPodEvent(ctx, pod) + Expect(err).NotTo(HaveOccurred()) + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 1, 0, 0) + + // This should trigger re-evaluation of existing pods + statusTracker.AddModelValidation(ctx, mv) + + // The pod should still be tracked (no configuration drift in this case) + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 1, 0, 0) + + updatedMV := testutil.GetModelValidationFromClientExpected(ctx, fakeClient, mvKey) + Expect(updatedMV.Status.InjectedPods).To(HaveLen(1)) + Expect(updatedMV.Status.InjectedPods[0].Name).To(Equal("test-pod")) + }) + }) + + Context("when handling error conditions and edge cases", func() { + It("should handle pod without required labels", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "unlabeled-pod", + Namespace: "default", + UID: types.UID("unlabeled-uid"), + Labels: map[string]string{}, // No ModelValidation label + }) + + err := statusTracker.ProcessPodEvent(ctx, pod) + Expect(err).NotTo(HaveOccurred()) // Should not error, just ignore the pod + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 0, 0, 0) // No pods should be tracked + }) + + It("should handle pod with unknown ModelValidation", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "unknown-mv-pod", + Namespace: "default", + UID: types.UID("unknown-uid"), + Labels: map[string]string{constants.ModelValidationLabel: "unknown-mv"}, + }) + + err := statusTracker.ProcessPodEvent(ctx, pod) + Expect(err).NotTo(HaveOccurred()) // Should not error, just ignore the pod + + statusTracker.WaitForUpdates() + expectPodCount(ctx, fakeClient, mvKey, 0, 0, 0) // No pods should be tracked + }) + + It("should handle removal of non-existent pod by UID", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + statusTracker, _, _ = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + err := statusTracker.RemovePodEvent(ctx, types.UID("non-existent-uid")) + Expect(err).NotTo(HaveOccurred()) // Should not error for non-existent pod + }) + + It("should handle removal of non-existent pod by name", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + statusTracker, _, _ = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + podName := types.NamespacedName{Name: "non-existent-pod", Namespace: "default"} + err := statusTracker.RemovePodByName(ctx, podName) + Expect(err).NotTo(HaveOccurred()) // Should not error for non-existent pod + }) + + It("should handle pod with invalid annotations", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + ConfigType: "sigstore", + }) + statusTracker, fakeClient, mvKey = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "invalid-pod", + Namespace: "default", + UID: types.UID("invalid-uid"), + Labels: map[string]string{constants.ModelValidationLabel: "test-mv"}, + Finalizers: []string{constants.ModelValidationFinalizer}, + Annotations: map[string]string{ + constants.AuthMethodAnnotationKey: "invalid-auth-method", + constants.ConfigHashAnnotationKey: "invalid-hash", + }, + }) + + err := statusTracker.ProcessPodEvent(ctx, pod) + Expect(err).NotTo(HaveOccurred()) // Should handle gracefully + + statusTracker.WaitForUpdates() + // Should be classified as orphaned due to invalid auth method + expectPodCount(ctx, fakeClient, mvKey, 0, 0, 1) + }) + + It("should handle empty UID gracefully", func() { + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + statusTracker, _, _ = setupTestEnvironment(mv) + statusTracker.AddModelValidation(ctx, mv) + + err := statusTracker.RemovePodEvent(ctx, types.UID("")) + Expect(err).NotTo(HaveOccurred()) // Should handle empty UID gracefully + }) + }) +}) + +// NewStatusTrackerForRetryTesting creates a status tracker with fast test-friendly config for retry testing +func NewStatusTrackerForRetryTesting(client client.Client) StatusTracker { + return NewStatusTracker(client, StatusTrackerConfig{ + DebounceDuration: 20 * time.Millisecond, // Very fast for testing + RetryBaseDelay: 10 * time.Millisecond, // Fast retries + RetryMaxDelay: 100 * time.Millisecond, // Low max delay + RateLimitQPS: 1000, // High QPS to avoid interference + RateLimitBurst: 10000, // High burst to avoid interference + StatusUpdateTimeout: 5 * time.Second, + }) +} + +var _ = Describe("StatusTracker retry and debounce integration", func() { + var ( + ctx context.Context + cancel context.CancelFunc + fakeClient *testutil.FailingClient + statusTracker *StatusTrackerImpl + mvKey types.NamespacedName + mv *v1alpha1.ModelValidation + ) + + BeforeEach(func() { + ctx, cancel = context.WithCancel(context.Background()) + mv = testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: "test-mv", + Namespace: "default", + }) + // Setup failing environment with 2 failures before success + fakeClient = testutil.CreateFailingClientWithObjects(2, mv) + statusTracker = NewStatusTrackerForRetryTesting(fakeClient).(*StatusTrackerImpl) + mvKey = types.NamespacedName{Name: mv.Name, Namespace: mv.Namespace} + }) + + AfterEach(func() { + if statusTracker != nil { + statusTracker.Stop() + } + if cancel != nil { + cancel() + } + }) + + It("should retry failed status updates with exponential backoff while respecting debouncing", func() { + statusTracker.AddModelValidation(ctx, mv) + + pod := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "test-pod-1", + Namespace: mv.Namespace, + Labels: map[string]string{ + constants.ModelValidationLabel: mv.Name, + }, + Annotations: map[string]string{ + constants.ConfigHashAnnotationKey: mv.GetConfigHash(), + constants.AuthMethodAnnotationKey: mv.GetAuthMethod(), + }, + Finalizers: []string{constants.ModelValidationFinalizer}, + }) + err := statusTracker.ProcessPodEvent(ctx, pod) + Expect(err).NotTo(HaveOccurred()) + + // Trigger multiple rapid updates to test debouncing + for i := 0; i < 5; i++ { + statusTracker.debouncedQueue.Add(mvKey) + time.Sleep(5 * time.Millisecond) // Shorter than debounce duration + } + + // Wait for debouncing and retries to complete (2 failures + 1 success) + testutil.ExpectAttemptCount(fakeClient, 3) + + // Verify that despite multiple debounce calls, we only had one series of retries + totalAttempts := fakeClient.GetAttemptCount() + actualFailures := fakeClient.GetFailureCount() + + Expect(totalAttempts).To(Equal(3), "Should have exactly 3 attempts: 2 failures + 1 success") + Expect(actualFailures).To(Equal(2), "Should have exactly 2 failures before success") + + expectPodCount(ctx, fakeClient, mvKey, 1, 0, 0) + }) + + It("should handle concurrent debounced updates with retries correctly", func() { + statusTracker.AddModelValidation(ctx, mv) + + // Trigger concurrent updates from multiple goroutines + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + statusTracker.debouncedQueue.Add(mvKey) + }() + } + wg.Wait() + + // Wait for processing to complete (2 failures + 1 success) + testutil.ExpectAttemptCount(fakeClient, 3) + + totalAttempts := fakeClient.GetAttemptCount() + + Expect(totalAttempts).To(Equal(3), "Concurrent updates should be debounced into single retry sequence") + }) +}) + +var _ = Describe("StatusTracker utility functions", func() { + Context("when comparing ModelValidation statuses", func() { + It("should correctly identify equal statuses", func() { + status1 := v1alpha1.ModelValidationStatus{ + InjectedPodCount: 2, + AuthMethod: "sigstore", + InjectedPods: []v1alpha1.PodTrackingInfo{ + {Name: "pod1", UID: "uid1"}, + {Name: "pod2", UID: "uid2"}, + }, + } + status2 := v1alpha1.ModelValidationStatus{ + InjectedPodCount: 2, + AuthMethod: "sigstore", + InjectedPods: []v1alpha1.PodTrackingInfo{ + {Name: "pod1", UID: "uid1"}, + {Name: "pod2", UID: "uid2"}, + }, + } + + Expect(statusEqual(status1, status2)).To(BeTrue()) + }) + + It("should correctly identify different statuses", func() { + status1 := v1alpha1.ModelValidationStatus{ + InjectedPodCount: 2, + } + status2 := v1alpha1.ModelValidationStatus{ + InjectedPodCount: 3, + } + + Expect(statusEqual(status1, status2)).To(BeFalse()) + }) + + It("should ignore LastUpdated differences", func() { + status1 := v1alpha1.ModelValidationStatus{ + InjectedPodCount: 1, + LastUpdated: metav1.Time{Time: time.Now()}, + } + status2 := v1alpha1.ModelValidationStatus{ + InjectedPodCount: 1, + LastUpdated: metav1.Time{Time: time.Now().Add(1 * time.Hour)}, + } + + Expect(statusEqual(status1, status2)).To(BeTrue()) + }) + }) + +}) diff --git a/internal/tracker/suite_test.go b/internal/tracker/suite_test.go new file mode 100644 index 00000000..1af9b747 --- /dev/null +++ b/internal/tracker/suite_test.go @@ -0,0 +1,13 @@ +package tracker + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" //nolint:revive + . "github.com/onsi/gomega" //nolint:revive +) + +func TestTracker(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Tracker Suite") +} diff --git a/internal/webhooks/pod_webhook.go b/internal/webhooks/pod_webhook.go index 5edc1426..b260fe34 100644 --- a/internal/webhooks/pod_webhook.go +++ b/internal/webhooks/pod_webhook.go @@ -5,10 +5,12 @@ import ( "encoding/json" "fmt" "net/http" + "time" "github.com/sigstore/model-validation-operator/internal/constants" corev1 "k8s.io/api/core/v1" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/webhook" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" @@ -29,7 +31,6 @@ func NewPodInterceptor(c client.Client, decoder admission.Decoder) webhook.Admis // +kubebuilder:webhook:path=/mutate-v1-pod,mutating=true,failurePolicy=fail,groups="",resources=pods,sideEffects=None,verbs=create;update,versions=v1,name=pods.validation.ml.sigstore.dev,admissionReviewVersions=v1 // +kubebuilder:rbac:groups=ml.sigstore.dev,resources=modelvalidations,verbs=get;list;watch -// +kubebuilder:rbac:groups=ml.sigstore.dev,resources=modelvalidations/status,verbs=get;update;patch // +kubebuilder:rbac:groups="",resources=namespaces,verbs=get;list;watch // podInterceptor extends pods with Model Validation Init-Container if annotation is specified. @@ -55,7 +56,7 @@ func (p *podInterceptor) Handle(ctx context.Context, req admission.Request) admi logger.Error(err, "failed to get namespace") return admission.Errored(http.StatusInternalServerError, err) } - if ns.Labels[constants.IgnoreNamespaceLabel] == "true" { + if ns.Labels[constants.IgnoreNamespaceLabel] == constants.IgnoreNamespaceValue { logger.Info("Namespace has ignore label, skipping", "namespace", req.Namespace) return admission.Allowed("namespace ignored") } @@ -70,8 +71,8 @@ func (p *podInterceptor) Handle(ctx context.Context, req admission.Request) admi logger.Info("Search associated Model Validation CR", "pod", pod.Name, "namespace", pod.Namespace, "modelValidationName", modelValidationName) - rhmv := &v1alpha1.ModelValidation{} - err := p.client.Get(ctx, client.ObjectKey{Name: modelValidationName, Namespace: pod.Namespace}, rhmv) + mv := &v1alpha1.ModelValidation{} + err := p.client.Get(ctx, client.ObjectKey{Name: modelValidationName, Namespace: pod.Namespace}, mv) if err != nil { msg := fmt.Sprintf("failed to get the ModelValidation CR %s/%s", pod.Namespace, modelValidationName) logger.Error(err, msg) @@ -85,10 +86,19 @@ func (p *podInterceptor) Handle(ctx context.Context, req admission.Request) admi } args := []string{"verify"} - args = append(args, validationConfigToArgs(logger, rhmv.Spec.Config, rhmv.Spec.Model.SignaturePath)...) - args = append(args, rhmv.Spec.Model.Path) + args = append(args, validationConfigToArgs(logger, mv.Spec.Config, mv.Spec.Model.SignaturePath)...) + args = append(args, mv.Spec.Model.Path) pp := pod.DeepCopy() + + controllerutil.AddFinalizer(pp, constants.ModelValidationFinalizer) + if pp.Annotations == nil { + pp.Annotations = make(map[string]string) + } + pp.Annotations[constants.InjectedAnnotationKey] = time.Now().Format(time.RFC3339) + pp.Annotations[constants.AuthMethodAnnotationKey] = mv.GetAuthMethod() + pp.Annotations[constants.ConfigHashAnnotationKey] = mv.GetConfigHash() + vm := []corev1.VolumeMount{} for _, c := range pod.Spec.Containers { vm = append(vm, c.VolumeMounts...) @@ -123,12 +133,12 @@ func validationConfigToArgs(logger logr.Logger, cfg v1alpha1.ValidationConfig, s return res } - if cfg.PrivateKeyConfig != nil { - logger.Info("found private-key config") + if cfg.PublicKeyConfig != nil { + logger.Info("found public-key config") res = append(res, "key", fmt.Sprintf("--signature=%s", signaturePath), - "--public_key", cfg.PrivateKeyConfig.KeyPath, + "--public_key", cfg.PublicKeyConfig.KeyPath, ) return res } diff --git a/internal/webhooks/pod_webhook_test.go b/internal/webhooks/pod_webhook_test.go index 33a7096a..a113ef4f 100644 --- a/internal/webhooks/pod_webhook_test.go +++ b/internal/webhooks/pod_webhook_test.go @@ -2,11 +2,14 @@ package webhooks import ( "context" + "fmt" + "time" . "github.com/onsi/ginkgo/v2" //nolint:revive . "github.com/onsi/gomega" //nolint:revive "github.com/sigstore/model-validation-operator/api/v1alpha1" "github.com/sigstore/model-validation-operator/internal/constants" + "github.com/sigstore/model-validation-operator/internal/testutil" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" @@ -14,94 +17,127 @@ import ( var _ = Describe("Pod webhook", func() { Context("Pod webhook test", func() { - - const ( - Name = "test" - Namespace = "default" - ) + Name := "test" + var Namespace string ctx := context.Background() - namespace := &corev1.Namespace{ - ObjectMeta: metav1.ObjectMeta{ - Name: Name, - Namespace: Namespace, - }, - } - - typeNamespaceName := types.NamespacedName{Name: Name, Namespace: Namespace} + var typeNamespaceName types.NamespacedName BeforeEach(func() { + Namespace = fmt.Sprintf("test-ns-%d", time.Now().UnixNano()) + typeNamespaceName = testutil.CreateTestNamespacedName(Name, Namespace) + By("Creating the Namespace to perform the tests") + namespace := &corev1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: Namespace, + }, + } err := k8sClient.Create(ctx, namespace) Expect(err).To(Not(HaveOccurred())) By("Create ModelValidation resource") - err = k8sClient.Create(ctx, &v1alpha1.ModelValidation{ - ObjectMeta: metav1.ObjectMeta{ - Name: Name, - Namespace: Namespace, - }, - Spec: v1alpha1.ModelValidationSpec{ - Model: v1alpha1.Model{ - Path: "test", - SignaturePath: "test", - }, - Config: v1alpha1.ValidationConfig{ - SigstoreConfig: nil, - PkiConfig: nil, - PrivateKeyConfig: nil, - }, - }, + mv := testutil.CreateTestModelValidation(testutil.TestModelValidationOptions{ + Name: Name, + Namespace: Namespace, + ConfigType: "sigstore", + CertIdentity: "test@example.com", + CertOidcIssuer: "https://accounts.google.com", }) + err = k8sClient.Create(ctx, mv) Expect(err).To(Not(HaveOccurred())) + + statusTracker.AddModelValidation(ctx, mv) }) AfterEach(func() { // TODO(user): Attention if you improve this code by adding other context test you MUST // be aware of the current delete namespace limitations. // More info: https://book.kubebuilder.io/reference/envtest.html#testing-considerations - By("Deleting the Namespace to perform the tests") - _ = k8sClient.Delete(ctx, namespace) - }) - It("Should create sidecar container", func() { - By("create labeled pod") - instance := &corev1.Pod{ + By("Deleting the ModelValidation resource") + _ = k8sClient.Delete(ctx, &v1alpha1.ModelValidation{ ObjectMeta: metav1.ObjectMeta{ Name: Name, Namespace: Namespace, - Labels: map[string]string{constants.ModelValidationLabel: Name}, - }, - Spec: corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "test", - Image: "test", - }, - }, }, - } + }) + + By("Deleting the Namespace to perform the tests") + _ = k8sClient.Delete(ctx, &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: Namespace}}) + }) + + It("Should create sidecar container and add finalizer", func() { + By("create labeled pod") + instance := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: Name, + Namespace: Namespace, + Labels: map[string]string{constants.ModelValidationLabel: Name}, + }) err := k8sClient.Create(ctx, instance) Expect(err).To(Not(HaveOccurred())) By("Checking that validation sidecar was created") found := &corev1.Pod{} - Eventually(func() error { + Eventually(ctx, func(ctx context.Context) error { return k8sClient.Get(ctx, typeNamespaceName, found) - }).Should(Succeed()) + }, 5*time.Second).Should(Succeed()) - Eventually( - func(_ Gomega) []corev1.Container { - Expect(k8sClient.Get(ctx, typeNamespaceName, found)).To(Succeed()) + Eventually(ctx, + func(g Gomega, ctx context.Context) []corev1.Container { + g.Expect(k8sClient.Get(ctx, typeNamespaceName, found)).To(Succeed()) return found.Spec.InitContainers - }, + }, 5*time.Second, ).Should(And( WithTransform(func(containers []corev1.Container) int { return len(containers) }, Equal(1)), WithTransform( func(containers []corev1.Container) string { return containers[0].Image }, Equal(constants.ModelTransparencyCliImage)), )) + + By("Checking that finalizer was added") + Expect(found.Finalizers).To(ContainElement(constants.ModelValidationFinalizer)) + }) + + It("Should track pod in ModelValidation status", func() { + By("create labeled pod") + instance := testutil.CreateTestPod(testutil.TestPodOptions{ + Name: "tracked-pod", + Namespace: Namespace, + Labels: map[string]string{constants.ModelValidationLabel: Name}, + }) + err := k8sClient.Create(ctx, instance) + Expect(err).To(Not(HaveOccurred())) + + By("Waiting for pod to be injected") + found := &corev1.Pod{} + Eventually(ctx, func(ctx context.Context) []corev1.Container { + _ = k8sClient.Get(ctx, types.NamespacedName{Name: "tracked-pod", Namespace: Namespace}, found) + return found.Spec.InitContainers + }, 5*time.Second).Should(HaveLen(1)) + + err = statusTracker.ProcessPodEvent(ctx, found) + Expect(err).To(Not(HaveOccurred())) + + By("Checking ModelValidation status was updated") + mv := &v1alpha1.ModelValidation{} + Eventually(ctx, func(ctx context.Context) int32 { + _ = k8sClient.Get(ctx, typeNamespaceName, mv) + return mv.Status.InjectedPodCount + }, 5*time.Second).Should(BeNumerically(">", 0)) + + Expect(mv.Status.AuthMethod).To(Equal("sigstore")) // Sigstore auth method configured in test + Expect(mv.Status.InjectedPods).ToNot(BeEmpty()) + + foundTrackedPod := false + for _, tp := range mv.Status.InjectedPods { + if tp.Name == "tracked-pod" { + foundTrackedPod = true + break + } + } + Expect(foundTrackedPod).To(BeTrue(), "Pod should be tracked in status") }) }) }) diff --git a/internal/webhooks/suite_test.go b/internal/webhooks/suite_test.go index 8a6cd296..c9d583f5 100644 --- a/internal/webhooks/suite_test.go +++ b/internal/webhooks/suite_test.go @@ -9,8 +9,10 @@ import ( "runtime" "strings" "testing" + "time" "github.com/sigstore/model-validation-operator/api/v1alpha1" + "github.com/sigstore/model-validation-operator/internal/tracker" "k8s.io/klog/v2" "k8s.io/klog/v2/test" "sigs.k8s.io/controller-runtime/pkg/webhook" @@ -30,11 +32,12 @@ import ( // http://onsi.github.io/ginkgo/ to learn more about Ginkgo. var ( - cfg *rest.Config - k8sClient client.Client // You'll be using this client in your tests. - testEnv *envtest.Environment - ctx context.Context - cancel context.CancelFunc + cfg *rest.Config + k8sClient client.Client // You'll be using this client in your tests. + testEnv *envtest.Environment + ctx context.Context + cancel context.CancelFunc + statusTracker tracker.StatusTracker ) // findBinaryAssetsDirectory locates the kubernetes binaries directory @@ -133,6 +136,14 @@ var _ = BeforeSuite(func() { // Create a decoder for your webhook decoder := admission.NewDecoder(scheme.Scheme) + statusTracker = tracker.NewStatusTracker(mgr.GetClient(), tracker.StatusTrackerConfig{ + DebounceDuration: 50 * time.Millisecond, // Faster for tests + RetryBaseDelay: 10 * time.Millisecond, // Faster retries for tests + RetryMaxDelay: 1 * time.Second, // Lower max delay for tests + RateLimitQPS: 100, // Higher QPS for tests + RateLimitBurst: 1000, // Higher burst for tests + StatusUpdateTimeout: 5 * time.Second, // Shorter timeout for tests + }) podWebhookHandler := NewPodInterceptor(mgr.GetClient(), decoder) mgr.GetWebhookServer().Register("/mutate-v1-pod", &admission.Webhook{ Handler: podWebhookHandler, @@ -143,11 +154,22 @@ var _ = BeforeSuite(func() { err = mgr.Start(ctx) Expect(err).ToNot(HaveOccurred(), "failed to run manager") }() + + // Wait for webhook server to be ready by checking if it's serving + Eventually(func() error { + return mgr.GetWebhookServer().StartedChecker()(nil) + }, 10*time.Second, 100*time.Millisecond).Should(Succeed()) }) var _ = AfterSuite(func() { - cancel() By("tearing down the test environment") + + statusTracker.Stop() + cancel() + + // Give manager time to shutdown gracefully + time.Sleep(100 * time.Millisecond) + err := testEnv.Stop() Expect(err).NotTo(HaveOccurred()) }) diff --git a/test/e2e/e2e_suite_test.go b/test/e2e/e2e_suite_test.go index 7e9eee22..cad4d8d9 100644 --- a/test/e2e/e2e_suite_test.go +++ b/test/e2e/e2e_suite_test.go @@ -14,13 +14,15 @@ See the License for the specific language governing permissions and limitations under the License. */ +// Package e2e contains end-to-end tests for the model validation operator package e2e import ( "fmt" - "os" "os/exec" + "strings" "testing" + "time" . "github.com/onsi/ginkgo/v2" //nolint:revive . "github.com/onsi/gomega" //nolint:revive @@ -29,17 +31,8 @@ import ( ) var ( - // Optional Environment Variables: - // - CERT_MANAGER_INSTALL_SKIP=true: Skips CertManager installation during test setup. - // These variables are useful if CertManager is already installed, avoiding - // re-installation and conflicts. - skipCertManagerInstall = os.Getenv("CERT_MANAGER_INSTALL_SKIP") == "true" - // isCertManagerAlreadyInstalled will be set true when CertManager CRDs be found on the cluster - isCertManagerAlreadyInstalled = false - - // projectImage is the name of the image which will be build and loaded - // with the code source changes to be tested. - projectImage = "ghcr.io/sigstore/model-validation-operator:v0.0.1" + // controllerPodName stores the name of the controller pod for debugging + controllerPodName string ) const ( @@ -57,41 +50,174 @@ const ( func TestE2E(t *testing.T) { RegisterFailHandler(Fail) _, _ = fmt.Fprintf(GinkgoWriter, "Starting model-validation-operator integration test suite\n") + + // Set test timeouts + SetDefaultEventuallyTimeout(2 * time.Minute) + SetDefaultEventuallyPollingInterval(time.Second) + RunSpecs(t, "e2e suite") } var _ = BeforeSuite(func() { - By("building the manager(Operator) image") - cmd := exec.Command("make", "docker-build", fmt.Sprintf("IMG=%s", projectImage)) - _, err := utils.Run(cmd) - ExpectWithOffset(1, err).NotTo(HaveOccurred(), "Failed to build the manager(Operator) image") - - // TODO(user): If you want to change the e2e test vendor from Kind, ensure the image is - // built and available before running the tests. Also, remove the following block. - By("loading the manager(Operator) image on Kind") - err = utils.LoadImageToKindClusterWithName(projectImage) - ExpectWithOffset(1, err).NotTo(HaveOccurred(), "Failed to load the manager(Operator) image into Kind") - - // The tests-e2e are intended to run on a temporary cluster that is created and destroyed for testing. - // To prevent errors when tests run in environments with CertManager already installed, - // we check for its presence before execution. - // Setup CertManager before the suite if not skipped and if not already installed - if !skipCertManagerInstall { - By("checking if cert manager is installed already") - isCertManagerAlreadyInstalled = utils.IsCertManagerCRDsInstalled() - if !isCertManagerAlreadyInstalled { - _, _ = fmt.Fprintf(GinkgoWriter, "Installing CertManager...\n") - Expect(utils.InstallCertManager()).To(Succeed(), "Failed to install CertManager") - } else { - _, _ = fmt.Fprintf(GinkgoWriter, "WARNING: CertManager is already installed. Skipping installation...\n") + By("verifying operator namespace exists") + Expect(utils.KubectlResourceExists("ns", operatorNamespace, "")).To(BeTrue(), "Operator namespace should exist") + + By("verifying test namespace exists") + Expect(utils.KubectlResourceExists("ns", webhookTestNamespace, "")).To(BeTrue(), "Test namespace should exist") + + By("verifying controller is running") + Eventually(func() error { + cmd := exec.Command("kubectl", "get", "pods", "-l", "control-plane=controller-manager", + "-n", operatorNamespace, "-o", "jsonpath={.items[0].metadata.name}") + output, err := utils.Run(cmd) + if err != nil { + return err + } + controllerPodName = output + + // Verify the pod is ready + phase, err := utils.KubectlGet("pod", controllerPodName, operatorNamespace, "jsonpath={.status.phase}") + if err != nil { + return err } + if phase != "Running" { + return fmt.Errorf("controller pod is not running: %s", phase) + } + return nil + }, 1*time.Minute, 5*time.Second).Should(Succeed(), "Controller pod should be running") + + By("setting up persistent metrics pod") + templateData := utils.CurlPodTemplateData{ + PodName: curlMetricsPodName, + Namespace: operatorNamespace, + ServiceAccount: utils.MetricsServiceAccountName, } + err := utils.KubectlApply(curlPodTemplate, templateData) + Expect(err).NotTo(HaveOccurred(), "Failed to create persistent metrics pod") + + err = utils.KubectlWait("pod", curlMetricsPodName, operatorNamespace, "condition=Ready", "30s") + Expect(err).NotTo(HaveOccurred(), "Failed to wait for persistent metrics pod to be ready") }) var _ = AfterSuite(func() { - // Teardown CertManager after the suite if not skipped and if it was not already installed - if !skipCertManagerInstall && !isCertManagerAlreadyInstalled { - _, _ = fmt.Fprintf(GinkgoWriter, "Uninstalling CertManager...\n") - utils.UninstallCertManager() + By("cleaning up test resources") + + // Clean up any test resources that may have been created during tests + By("cleaning up test pods and CRs") + cmd := exec.Command("kubectl", "delete", "pods", "--all", "-n", webhookTestNamespace, + "--timeout=30s", "--ignore-not-found=true") + _, _ = utils.Run(cmd) + + cmd = exec.Command("kubectl", "delete", "modelvalidations", "--all", "-n", webhookTestNamespace, + "--timeout=30s", "--ignore-not-found=true") + _, _ = utils.Run(cmd) + + By("cleaning up persistent metrics pod") + cleanupTemplateData := utils.CurlPodTemplateData{ + PodName: curlMetricsPodName, + Namespace: operatorNamespace, + ServiceAccount: utils.MetricsServiceAccountName, + } + if err := utils.KubectlDelete(curlPodTemplate, &utils.KubectlDeleteOptions{ + IgnoreNotFound: true, + Timeout: "30s", + TemplateData: cleanupTemplateData, + }); err != nil { + _, _ = fmt.Fprintf(GinkgoWriter, "Failed to cleanup metrics pod: %s\n", err) + } + + _, _ = fmt.Fprintf(GinkgoWriter, "Test cleanup complete.\n") +}) + +var _ = AfterEach(func() { + if !CurrentSpecReport().Failed() { + return + } + + By("Capturing comprehensive pod logs and status for failed test") + + // Capture all pods in test namespace with logs and descriptions + By("Fetching all pods in test namespace") + if testPodsOutput, err := utils.KubectlGet("pods", "", webhookTestNamespace, "wide"); err == nil { + _, _ = fmt.Fprintf(GinkgoWriter, "Test namespace pods:\n%s\n", testPodsOutput) + } else { + _, _ = fmt.Fprintf(GinkgoWriter, "Failed to get test namespace pods: %s\n", err) + } + + // Get all pod names in test namespace + podNamesOutput, err := utils.KubectlGet("pods", "", webhookTestNamespace, "jsonpath={.items[*].metadata.name}") + if err == nil && podNamesOutput != "" { + podNames := strings.Fields(podNamesOutput) + for _, podName := range podNames { + By(fmt.Sprintf("Capturing logs and description for pod: %s", podName)) + + // Pod description + if podDesc, err := utils.KubectlDescribe("pod", podName, webhookTestNamespace); err == nil { + _, _ = fmt.Fprintf(GinkgoWriter, "=== Pod %s description ===\n%s\n", podName, podDesc) + } else { + _, _ = fmt.Fprintf(GinkgoWriter, "Failed to describe pod %s: %s\n", podName, err) + } + + // Main container logs + if podLogs, err := utils.Run(exec.Command("kubectl", "logs", podName, "-n", + webhookTestNamespace, "--all-containers=true")); err == nil && podLogs != "" { + _, _ = fmt.Fprintf(GinkgoWriter, "=== Pod %s logs ===\n%s\n", podName, podLogs) + } else { + _, _ = fmt.Fprintf(GinkgoWriter, "No logs available for pod %s or error: %v\n", podName, err) + } + + // Previous container logs (if restarted) + if prevLogs, err := utils.Run(exec.Command("kubectl", "logs", podName, "-n", + webhookTestNamespace, "--all-containers=true", "--previous=true")); err == nil && prevLogs != "" { + _, _ = fmt.Fprintf(GinkgoWriter, "=== Pod %s previous logs ===\n%s\n", podName, prevLogs) + } + + // Init container logs specifically + if initLogs, err := utils.Run(exec.Command("kubectl", "logs", podName, "-n", + webhookTestNamespace, "-c", "model-transparency-init")); err == nil && initLogs != "" { + _, _ = fmt.Fprintf(GinkgoWriter, "=== Pod %s init container logs ===\n%s\n", podName, initLogs) + } + } + } + + // Capture controller logs + By("Fetching controller manager pod logs") + if controllerLogs, err := utils.Run(exec.Command("kubectl", "logs", controllerPodName, + "-n", operatorNamespace)); err == nil { + _, _ = fmt.Fprintf(GinkgoWriter, "=== Controller logs ===\n%s\n", controllerLogs) + } else { + _, _ = fmt.Fprintf(GinkgoWriter, "Failed to get Controller logs: %s\n", err) + } + + // Capture events from both namespaces + By("Fetching test namespace events") + if testEventsOutput, err := utils.KubectlGet("events", "", webhookTestNamespace, ""); err == nil { + _, _ = fmt.Fprintf(GinkgoWriter, "=== Test namespace events ===\n%s\n", testEventsOutput) + } else { + _, _ = fmt.Fprintf(GinkgoWriter, "Failed to get test namespace events: %s\n", err) + } + + By("Fetching operator namespace events") + if eventsOutput, err := utils.KubectlGet("events", "", operatorNamespace, ""); err == nil { + _, _ = fmt.Fprintf(GinkgoWriter, "=== Operator namespace events ===\n%s\n", eventsOutput) + } else { + _, _ = fmt.Fprintf(GinkgoWriter, "Failed to get operator namespace events: %s\n", err) + } + + // Capture ModelValidation resources for debugging + By("Fetching all ModelValidation resources") + mvOutput, err := utils.KubectlGet("modelvalidations", "", webhookTestNamespace, "wide") + if err == nil { + _, _ = fmt.Fprintf(GinkgoWriter, "=== ModelValidation resources ===\n%s\n", mvOutput) + } else { + _, _ = fmt.Fprintf(GinkgoWriter, "Failed to get ModelValidation resources: %s\n", err) + } + + // Get detailed YAML output for ModelValidation resources + mvYamlOutput, err := utils.KubectlGet("modelvalidations", "", webhookTestNamespace, "yaml") + if err == nil && mvYamlOutput != "" { + _, _ = fmt.Fprintf(GinkgoWriter, "=== ModelValidation resources (YAML) ===\n%s\n", mvYamlOutput) + } else if err != nil { + _, _ = fmt.Fprintf(GinkgoWriter, "Failed to get ModelValidation YAML: %s\n", err) } }) diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index bba0ec3c..d6453c8f 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -17,153 +17,27 @@ limitations under the License. package e2e import ( - "bytes" - "encoding/json" - "fmt" - "os" "os/exec" - "path/filepath" - "time" _ "embed" . "github.com/onsi/ginkgo/v2" //nolint:revive . "github.com/onsi/gomega" //nolint:revive + corev1 "k8s.io/api/core/v1" "github.com/sigstore/model-validation-operator/internal/constants" utils "github.com/sigstore/model-validation-operator/test/utils" ) -//go:embed testdata/modelvalidation_cr.yaml -var modelValidationCR []byte - -//go:embed testdata/test_pod.yaml -var testPod []byte - -// metricsServiceName is the name of the metrics service of the project -const metricsServiceName = "model-validation-controller-manager-metrics-service" - // metricsRoleBindingName is the name of the RBAC that will be created to allow get the metrics data const metricsRoleBindingName = "model-validation-operator-metrics-binding" -// serviceAccountName is the name of the service account used by the controller manager -const serviceAccountName = "model-validation-controller-manager" - // curlMetricsPodName is the name of the pod used to access the metrics endpoint const curlMetricsPodName = "curl-metrics" const e2eTestPodName = "e2e-test-pod" var _ = Describe("Manager", Ordered, func() { - var controllerPodName string - - // Before running the tests, set up the environment by creating the namespace, - // enforce the restricted security policy to the namespace, installing CRDs, - // and deploying the controller. - BeforeAll(func() { - By("creating manager namespace") - cmd := exec.Command("kubectl", "create", "ns", operatorNamespace) - _, err := utils.Run(cmd) - Expect(err).NotTo(HaveOccurred(), "Failed to create operator namespace") - - By("labeling the operator namespace to enforce the restricted security policy") - cmd = exec.Command("kubectl", "label", "--overwrite", "ns", operatorNamespace, - "pod-security.kubernetes.io/enforce=restricted") - _, err = utils.Run(cmd) - Expect(err).NotTo(HaveOccurred(), "Failed to label operator namespace with restricted policy") - - By("labeling the operator namespace to be ignored by the webhook") - cmd = exec.Command("kubectl", "label", "--overwrite", "ns", operatorNamespace, - fmt.Sprintf("%s=true", constants.IgnoreNamespaceLabel)) - _, err = utils.Run(cmd) - Expect(err).NotTo(HaveOccurred(), "Failed to label operator namespace for webhook ignore") - - By("creating webhook test namespace") - cmd = exec.Command("kubectl", "create", "ns", webhookTestNamespace) - _, err = utils.Run(cmd) - Expect(err).NotTo(HaveOccurred(), "Failed to create webhook test namespace") - - By("installing CRDs") - cmd = exec.Command("make", "install") - _, err = utils.Run(cmd) - Expect(err).NotTo(HaveOccurred(), "Failed to install CRDs") - - By("deploying the controller-manager") - cmd = exec.Command("make", "deploy", fmt.Sprintf("IMG=%s", projectImage)) - _, err = utils.Run(cmd) - Expect(err).NotTo(HaveOccurred(), "Failed to deploy the controller-manager") - }) - - // After all tests have been executed, clean up by undeploying the controller, uninstalling CRDs, - // and deleting the namespace. - AfterAll(func() { - By("cleaning up the curl pod for metrics") - cmd := exec.Command("kubectl", "delete", "pod", curlMetricsPodName, "-n", operatorNamespace) - _, _ = utils.Run(cmd) - - By("undeploying the controller-manager") - cmd = exec.Command("make", "undeploy") - _, _ = utils.Run(cmd) - - By("uninstalling CRDs") - cmd = exec.Command("make", "uninstall") - _, _ = utils.Run(cmd) - - By("removing manager namespace") - cmd = exec.Command("kubectl", "delete", "ns", operatorNamespace) - _, _ = utils.Run(cmd) - - By("removing webhook test namespace") - cmd = exec.Command("kubectl", "delete", "ns", webhookTestNamespace) - _, _ = utils.Run(cmd) - }) - - // After each test, check for failures and collect logs, events, - // and pod descriptions for debugging. - AfterEach(func() { - specReport := CurrentSpecReport() - if specReport.Failed() { - By("Fetching controller manager pod logs") - cmd := exec.Command("kubectl", "logs", controllerPodName, "-n", operatorNamespace) - controllerLogs, err := utils.Run(cmd) - if err == nil { - _, _ = fmt.Fprintf(GinkgoWriter, "Controller logs:\n %s", controllerLogs) - } else { - _, _ = fmt.Fprintf(GinkgoWriter, "Failed to get Controller logs: %s", err) - } - - By("Fetching Kubernetes events") - cmd = exec.Command("kubectl", "get", "events", "-n", operatorNamespace, "--sort-by=.lastTimestamp") - eventsOutput, err := utils.Run(cmd) - if err == nil { - _, _ = fmt.Fprintf(GinkgoWriter, "Kubernetes events:\n%s", eventsOutput) - } else { - _, _ = fmt.Fprintf(GinkgoWriter, "Failed to get Kubernetes events: %s", err) - } - - By("Fetching curl-metrics logs") - cmd = exec.Command("kubectl", "logs", curlMetricsPodName, "-n", operatorNamespace) - metricsOutput, err := utils.Run(cmd) - if err == nil { - _, _ = fmt.Fprintf(GinkgoWriter, "Metrics logs:\n %s", metricsOutput) - } else { - _, _ = fmt.Fprintf(GinkgoWriter, "Failed to get curl-metrics logs: %s", err) - } - - By("Fetching controller manager pod description") - cmd = exec.Command("kubectl", "describe", "pod", controllerPodName, "-n", operatorNamespace) - podDescription, err := utils.Run(cmd) - if err == nil { - fmt.Println("Pod description:\n", podDescription) - } else { - fmt.Println("Failed to describe controller pod") - } - } - }) - - SetDefaultEventuallyTimeout(2 * time.Minute) - SetDefaultEventuallyPollingInterval(time.Second) - Context("Manager", func() { It("should run successfully", func() { By("validating that the controller-manager pod is running as expected") @@ -199,27 +73,26 @@ var _ = Describe("Manager", Ordered, func() { It("should ensure the metrics endpoint is serving metrics", func() { By("creating a ClusterRoleBinding for the service account to allow access to metrics") - cmd := exec.Command("kubectl", "create", "clusterrolebinding", metricsRoleBindingName, - "--clusterrole=model-validation-metrics-reader", - fmt.Sprintf("--serviceaccount=%s:%s", operatorNamespace, serviceAccountName), - ) - _, err := utils.Run(cmd) - Expect(err).NotTo(HaveOccurred(), "Failed to create ClusterRoleBinding") + err := utils.KubectlApply(clusterRoleBindingTemplate, utils.ClusterRoleBindingTemplateData{ + Name: metricsRoleBindingName, + ServiceAccountName: utils.ServiceAccountName, + Namespace: operatorNamespace, + ClusterRoleName: "model-validation-metrics-reader", + }) + Expect(err).NotTo(HaveOccurred(), "Failed to apply ClusterRoleBinding") By("validating that the metrics service is available") - cmd = exec.Command("kubectl", "get", "service", metricsServiceName, "-n", operatorNamespace) - _, err = utils.Run(cmd) - Expect(err).NotTo(HaveOccurred(), "Metrics service should exist") + exists := utils.KubectlResourceExists("service", utils.MetricsServiceName, operatorNamespace) + Expect(exists).To(BeTrue(), "Metrics service should exist") By("getting the service account token") - token, err := serviceAccountToken() + token, err := utils.CreateServiceAccountToken(utils.ServiceAccountName, operatorNamespace) Expect(err).NotTo(HaveOccurred()) Expect(token).NotTo(BeEmpty()) By("waiting for the metrics endpoint to be ready") verifyMetricsEndpointReady := func(g Gomega) { - cmd = exec.Command("kubectl", "get", "endpoints", metricsServiceName, "-n", operatorNamespace) - output, err := utils.Run(cmd) + output, err := utils.KubectlGet("endpoints", utils.MetricsServiceName, operatorNamespace, "") g.Expect(err).NotTo(HaveOccurred()) g.Expect(output).To(ContainSubstring("8443"), "Metrics endpoint is not ready") } @@ -235,49 +108,8 @@ var _ = Describe("Manager", Ordered, func() { } Eventually(verifyMetricsServerStarted).Should(Succeed()) - By("creating the curl-metrics pod to access the metrics endpoint") - cmd = exec.Command("kubectl", "run", curlMetricsPodName, "--restart=Never", - "--namespace", operatorNamespace, - "--image=curlimages/curl:latest", - "--overrides", - fmt.Sprintf(`{ - "spec": { - "containers": [{ - "name": "curl", - "image": "curlimages/curl:latest", - "command": ["/bin/sh", "-c"], - "args": ["curl -v -k -H 'Authorization: Bearer %s' https://%s.%s.svc.cluster.local:8443/metrics"], - "securityContext": { - "allowPrivilegeEscalation": false, - "capabilities": { - "drop": ["ALL"] - }, - "runAsNonRoot": true, - "runAsUser": 1000, - "seccompProfile": { - "type": "RuntimeDefault" - } - } - }], - "serviceAccount": "%s" - } - }`, token, metricsServiceName, operatorNamespace, serviceAccountName)) - _, err = utils.Run(cmd) - Expect(err).NotTo(HaveOccurred(), "Failed to create curl-metrics pod") - - By("waiting for the curl-metrics pod to complete.") - verifyCurlUp := func(g Gomega) { - cmd := exec.Command("kubectl", "get", "pods", curlMetricsPodName, - "-o", "jsonpath={.status.phase}", - "-n", operatorNamespace) - output, err := utils.Run(cmd) - g.Expect(err).NotTo(HaveOccurred()) - g.Expect(output).To(Equal("Succeeded"), "curl pod in wrong status") - } - Eventually(verifyCurlUp, 5*time.Minute).Should(Succeed()) - - By("getting the metrics by checking curl-metrics logs") - metricsOutput := getMetricsOutput() + By("getting the metrics from the persistent curl pod") + metricsOutput := utils.GetMetricsOutput(operatorNamespace, curlMetricsPodName) Expect(metricsOutput).To(ContainSubstring( "controller_runtime_webhook_requests_total", )) @@ -285,107 +117,43 @@ var _ = Describe("Manager", Ordered, func() { It("should inject the model validation init container", func() { By("deploying a ModelValidation CR") - cmd := exec.Command("kubectl", "apply", "-f", "-") - cmd.Stdin = bytes.NewReader(modelValidationCR) - _, err := utils.Run(cmd) + err := utils.KubectlApply(modelValidationTemplate, utils.CRTemplateData{ + ModelName: "e2e-test-model", + Namespace: webhookTestNamespace, + }) Expect(err).NotTo(HaveOccurred(), "Failed to apply ModelValidation CR") By("deploying a pod with the model validation label") - cmd = exec.Command("kubectl", "apply", "-f", "-") - cmd.Stdin = bytes.NewReader(testPod) - _, err = utils.Run(cmd) + err = utils.KubectlApply(podTemplate, utils.PodTemplateData{ + PodName: "e2e-test-pod", + Namespace: webhookTestNamespace, + ModelName: "e2e-test-model", + }) Expect(err).NotTo(HaveOccurred(), "Failed to apply test pod") By("verifying the init container is injected") verifyInitContainerInjection := func(g Gomega) { - cmd := exec.Command("kubectl", "get", "pod", e2eTestPodName, "-n", webhookTestNamespace, "-o", "json") - output, err := utils.Run(cmd) - g.Expect(err).NotTo(HaveOccurred()) - - var pod struct { - Spec struct { - InitContainers []struct { - Name string `json:"name"` - } `json:"initContainers"` - } `json:"spec"` - } - err = json.Unmarshal([]byte(output), &pod) + var pod corev1.Pod + err := utils.KubectlGetJSON("pod", e2eTestPodName, webhookTestNamespace, &pod) g.Expect(err).NotTo(HaveOccurred()) g.Expect(pod.Spec.InitContainers).To(HaveLen(1)) g.Expect(pod.Spec.InitContainers[0].Name).To(Equal(constants.ModelValidationInitContainerName)) } Eventually(verifyInitContainerInjection).Should(Succeed()) }) - }) - // +kubebuilder:scaffold:e2e-webhooks-checks - - // TODO: Customize the e2e test suite with scenarios specific to your project. - // Consider applying sample/CR(s) and check their status and/or verifying - // the reconciliation by using the metrics, i.e.: - // metricsOutput := getMetricsOutput() - // Expect(metricsOutput).To(ContainSubstring( - // fmt.Sprintf(`controller_runtime_reconcile_total{controller="%s",result="success"} 1`, - // strings.ToLower(), - // )) - // }) + AfterAll(func() { + By("cleaning up ClusterRoleBinding") + _ = utils.KubectlDelete(clusterRoleBindingTemplate, &utils.KubectlDeleteOptions{ + Timeout: "30s", + IgnoreNotFound: true, + TemplateData: utils.ClusterRoleBindingTemplateData{ + Name: metricsRoleBindingName, + ServiceAccountName: utils.ServiceAccountName, + Namespace: operatorNamespace, + ClusterRoleName: "model-validation-metrics-reader", + }, + }) + }) + }) }) - -// serviceAccountToken returns a token for the specified service account in the given namespace. -// It uses the Kubernetes TokenRequest API to generate a token by directly sending a request -// and parsing the resulting token from the API response. -func serviceAccountToken() (string, error) { - const tokenRequestRawString = `{ - "apiVersion": "authentication.k8s.io/v1", - "kind": "TokenRequest" - }` - - // Temporary file to store the token request - secretName := fmt.Sprintf("%s-token-request", serviceAccountName) - tokenRequestFile := filepath.Join("/tmp", secretName) - err := os.WriteFile(tokenRequestFile, []byte(tokenRequestRawString), os.FileMode(0o644)) - if err != nil { - return "", err - } - - var out string - verifyTokenCreation := func(g Gomega) { - // Execute kubectl command to create the token - cmd := exec.Command("kubectl", "create", "--raw", fmt.Sprintf( - "/api/v1/namespaces/%s/serviceaccounts/%s/token", - operatorNamespace, - serviceAccountName, - ), "-f", tokenRequestFile) - - output, err := cmd.CombinedOutput() - g.Expect(err).NotTo(HaveOccurred()) - - // Parse the JSON output to extract the token - var token tokenRequest - err = json.Unmarshal(output, &token) - g.Expect(err).NotTo(HaveOccurred()) - - out = token.Status.Token - } - Eventually(verifyTokenCreation).Should(Succeed()) - - return out, err -} - -// getMetricsOutput retrieves and returns the logs from the curl pod used to access the metrics endpoint. -func getMetricsOutput() string { - By("getting the curl-metrics logs") - cmd := exec.Command("kubectl", "logs", curlMetricsPodName, "-n", operatorNamespace) - metricsOutput, err := utils.Run(cmd) - Expect(err).NotTo(HaveOccurred(), "Failed to retrieve logs from curl pod") - Expect(metricsOutput).To(ContainSubstring("< HTTP/1.1 200 OK")) - return metricsOutput -} - -// tokenRequest is a simplified representation of the Kubernetes TokenRequest API response, -// containing only the token field that we need to extract. -type tokenRequest struct { - Status struct { - Token string `json:"token"` - } `json:"status"` -} diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go new file mode 100644 index 00000000..7f71f65e --- /dev/null +++ b/test/e2e/integration_test.go @@ -0,0 +1,184 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package e2e + +import ( + "fmt" + "time" + + _ "embed" + + . "github.com/onsi/ginkgo/v2" //nolint:revive + . "github.com/onsi/gomega" //nolint:revive + corev1 "k8s.io/api/core/v1" + + "github.com/sigstore/model-validation-operator/api/v1alpha1" + "github.com/sigstore/model-validation-operator/internal/metrics" + utils "github.com/sigstore/model-validation-operator/test/utils" +) + +const integrationTestModelName = "integration-test-model" +const integrationTestNamespace = "e2e-webhook-test-ns" + +// defaultIntegrationPodData creates a standard integration test pod configuration +func defaultIntegrationPodData(podName, namespace, modelName string) utils.PodTemplateData { + return utils.DefaultPodData(podName, namespace, modelName, "integration") +} + +var _ = Describe("ModelValidation Integration Tests", Ordered, func() { + Context("End-to-End Status and Metrics Integration", func() { + + It("should demonstrate full lifecycle with status and metrics consistency", func() { + By("deploying a ModelValidation CR with signed model configuration") + err := utils.KubectlApply(modelValidationTemplate, utils.CRTemplateData{ + ModelName: integrationTestModelName, + Namespace: integrationTestNamespace, + }) + Expect(err).NotTo(HaveOccurred(), "Failed to apply ModelValidation CR") + + By("waiting for ModelValidation CR to be ready") + Eventually(func() error { + return utils.KubectlGetJSON("modelvalidation", integrationTestModelName, + integrationTestNamespace, &v1alpha1.ModelValidation{}) + }, 30*time.Second, 2*time.Second).Should(Succeed(), "ModelValidation CR should be available") + + By("verifying initial status is correctly set") + Eventually(func(g Gomega) { + var mv v1alpha1.ModelValidation + err := utils.KubectlGetJSON("modelvalidation", integrationTestModelName, integrationTestNamespace, &mv) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(mv.Status.AuthMethod).To(Equal("public-key")) + g.Expect(mv.Status.InjectedPodCount).To(Equal(int32(0))) + g.Expect(mv.Status.LastUpdated).NotTo(BeZero()) + }, 30*time.Second).Should(Succeed()) + + By("verifying initial metrics reflect the new CR") + Eventually(func(g Gomega) { + metricsOutput := utils.GetMetricsOutput(operatorNamespace, curlMetricsPodName) + g.Expect(metricsOutput).To(ContainSubstring("model_validation_operator_modelvalidation_crs_total")) + g.Expect(metricsOutput).To(ContainSubstring("model_validation_operator_status_updates_total")) + }, 30*time.Second).Should(Succeed()) + + By("deploying pods with realistic model volumes that will be injected") + podNames := []string{"integration-pod-1", "integration-pod-2", "integration-pod-3"} + for _, podName := range podNames { + err = utils.KubectlApply(podTemplate, + defaultIntegrationPodData(podName, integrationTestNamespace, integrationTestModelName)) + Expect(err).NotTo(HaveOccurred(), fmt.Sprintf("Failed to deploy pod %s", podName)) + } + + By("verifying all pods get the init container injected") + for _, podName := range podNames { + Eventually(func(g Gomega) { + var pod corev1.Pod + err := utils.KubectlGetJSON("pod", podName, integrationTestNamespace, &pod) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(pod.Spec.InitContainers).ToNot(BeEmpty()) + g.Expect(utils.HasValidationContainer(&pod)).To(BeTrue()) + }, 30*time.Second).Should(Succeed()) + } + + By("verifying status reflects all injected pods") + Eventually(func(g Gomega) { + var mv v1alpha1.ModelValidation + err := utils.KubectlGetJSON("modelvalidation", integrationTestModelName, integrationTestNamespace, &mv) + g.Expect(err).NotTo(HaveOccurred()) + status := mv.Status + g.Expect(status.InjectedPodCount).To(Equal(int32(3))) + g.Expect(status.UninjectedPodCount).To(Equal(int32(0))) + g.Expect(status.OrphanedPodCount).To(Equal(int32(0))) + g.Expect(status.InjectedPods).To(HaveLen(3)) + + // Verify all expected pod names are present + podNamesInStatus := make(map[string]bool) + for _, pod := range status.InjectedPods { + podNamesInStatus[pod.Name] = true + g.Expect(pod.UID).NotTo(BeEmpty()) + g.Expect(pod.InjectedAt).NotTo(BeZero()) + } + for _, expectedName := range podNames { + g.Expect(podNamesInStatus[expectedName]).To(BeTrue()) + } + }, 60*time.Second, 1*time.Second).Should(Succeed()) + + By("verifying metrics match the status counts") + Eventually(func(g Gomega) { + metricsOutput := utils.GetMetricsOutput(operatorNamespace, curlMetricsPodName) + + expectedInjectedMetric := fmt.Sprintf( + `model_validation_operator_modelvalidation_pod_count{model_validation="%s",namespace="%s",pod_state="%s"} 3`, + integrationTestModelName, integrationTestNamespace, metrics.PodStateInjected) + g.Expect(metricsOutput).To(ContainSubstring(expectedInjectedMetric)) + + expectedUninjectedMetric := fmt.Sprintf( + `model_validation_operator_modelvalidation_pod_count{model_validation="%s",`+ + `namespace="%s",pod_state="%s"} 0`, + integrationTestModelName, integrationTestNamespace, metrics.PodStateUninjected) + g.Expect(metricsOutput).To(ContainSubstring(expectedUninjectedMetric)) + + g.Expect(metricsOutput).To(ContainSubstring("model_validation_operator_status_updates_total")) + }, 45*time.Second).Should(Succeed()) + + By("simulating pod deletion and verifying status/metrics updates") + err = utils.KubectlDelete(podTemplate, &utils.KubectlDeleteOptions{ + TemplateData: defaultIntegrationPodData(podNames[0], integrationTestNamespace, integrationTestModelName), + }) + Expect(err).NotTo(HaveOccurred()) + + By("verifying status reflects pod deletion") + Eventually(func(g Gomega) { + var mv v1alpha1.ModelValidation + err := utils.KubectlGetJSON("modelvalidation", integrationTestModelName, integrationTestNamespace, &mv) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(mv.Status.InjectedPodCount).To(Equal(int32(2))) + }, 30*time.Second).Should(Succeed()) + + By("verifying metrics reflect the pod deletion") + Eventually(func(g Gomega) { + metricsOutput := utils.GetMetricsOutput(operatorNamespace, curlMetricsPodName) + expectedMetric := fmt.Sprintf( + `model_validation_operator_modelvalidation_pod_count{model_validation="%s",namespace="%s",pod_state="%s"} 2`, + integrationTestModelName, integrationTestNamespace, metrics.PodStateInjected) + g.Expect(metricsOutput).To(ContainSubstring(expectedMetric)) + }, 30*time.Second).Should(Succeed()) + + By("cleaning up all test resources") + for _, podName := range podNames[1:] { + _ = utils.KubectlDelete(podTemplate, &utils.KubectlDeleteOptions{ + IgnoreNotFound: true, + TemplateData: defaultIntegrationPodData(podName, integrationTestNamespace, integrationTestModelName), + }) + } + + _ = utils.KubectlDelete(modelValidationTemplate, &utils.KubectlDeleteOptions{ + IgnoreNotFound: true, + TemplateData: utils.CRTemplateData{ + ModelName: integrationTestModelName, + Namespace: integrationTestNamespace, + }, + }) + + By("verifying cleanup is reflected in status and metrics") + Eventually(func(g Gomega) { + var mv v1alpha1.ModelValidation + err := utils.KubectlGetJSON("modelvalidation", integrationTestModelName, integrationTestNamespace, &mv) + g.Expect(err).To(HaveOccurred()) // Should fail because resource is deleted + }, 30*time.Second).Should(Succeed()) + }) + + }) +}) diff --git a/test/e2e/metrics_test.go b/test/e2e/metrics_test.go new file mode 100644 index 00000000..961c5d47 --- /dev/null +++ b/test/e2e/metrics_test.go @@ -0,0 +1,300 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package e2e + +import ( + "fmt" + "time" + + _ "embed" + + . "github.com/onsi/ginkgo/v2" //nolint:revive + . "github.com/onsi/gomega" //nolint:revive + + "github.com/sigstore/model-validation-operator/api/v1alpha1" + "github.com/sigstore/model-validation-operator/internal/metrics" + utils "github.com/sigstore/model-validation-operator/test/utils" +) + +const metricsTestPodName = "metrics-test-pod" +const metricsTestModelName = "metrics-test-model" + +// defaultMetricsPodData creates a standard metrics test pod configuration +func defaultMetricsPodData(podName string) utils.PodTemplateData { + return utils.DefaultPodData(podName, webhookTestNamespace, metricsTestModelName, "metrics") +} + +var _ = Describe("ModelValidation Metrics", Ordered, func() { + Context("Prometheus Metrics Collection", func() { + + It("should expose pod count metrics", func() { + By("deploying a ModelValidation CR") + err := utils.KubectlApply(modelValidationTemplate, utils.CRTemplateData{ + ModelName: metricsTestModelName, + Namespace: webhookTestNamespace, + }) + Expect(err).NotTo(HaveOccurred()) + + By("waiting for ModelValidation CR to be ready") + Eventually(func() error { + return utils.KubectlGetJSON("modelvalidation", metricsTestModelName, + webhookTestNamespace, &v1alpha1.ModelValidation{}) + }, 30*time.Second, 2*time.Second).Should(Succeed(), "ModelValidation CR should be available") + + By("getting baseline metrics before pod creation") + injectedLabels := map[string]string{ + "namespace": webhookTestNamespace, + "model_validation": metricsTestModelName, + "pod_state": metrics.PodStateInjected, + } + baselineInjectedCount := utils.GetMetricValue( + "model_validation_operator_modelvalidation_pod_count", injectedLabels, + operatorNamespace, curlMetricsPodName) + + By("deploying a pod to trigger metrics updates") + err = utils.KubectlApply(podTemplate, defaultMetricsPodData(metricsTestPodName)) + Expect(err).NotTo(HaveOccurred()) + + By("verifying metrics increment after pod injection") + Eventually(func() int { + return utils.GetMetricValue("model_validation_operator_modelvalidation_pod_count", + injectedLabels, operatorNamespace, curlMetricsPodName) + }, 30*time.Second, 5*time.Second).Should(Equal(baselineInjectedCount + 1)) + + By("cleaning up test resources - deleting pod first") + _ = utils.KubectlDelete(podTemplate, &utils.KubectlDeleteOptions{ + Timeout: "30s", + IgnoreNotFound: true, + TemplateData: defaultMetricsPodData(metricsTestPodName), + }) + + By("verifying metrics decrement after pod deletion") + Eventually(func() int { + return utils.GetMetricValue("model_validation_operator_modelvalidation_pod_count", + injectedLabels, operatorNamespace, curlMetricsPodName) + }, 30*time.Second, 5*time.Second).Should(Equal(baselineInjectedCount)) + + _ = utils.KubectlDelete(modelValidationTemplate, &utils.KubectlDeleteOptions{ + Timeout: "30s", + IgnoreNotFound: true, + TemplateData: utils.CRTemplateData{ + ModelName: metricsTestModelName, + Namespace: webhookTestNamespace, + }, + }) + }) + + It("should track status update metrics", func() { + By("getting baseline status update metrics") + statusLabels := map[string]string{ + "namespace": webhookTestNamespace, + "model_validation": metricsTestModelName, + "result": "success", + } + baselineSuccessCount := utils.GetMetricValue( + "model_validation_operator_status_updates_total", statusLabels, + operatorNamespace, curlMetricsPodName) + + By("deploying a ModelValidation CR") + err := utils.KubectlApply(modelValidationTemplate, utils.CRTemplateData{ + ModelName: metricsTestModelName, + Namespace: webhookTestNamespace, + }) + Expect(err).NotTo(HaveOccurred()) + + By("verifying status update metrics increment") + Eventually(func() int { + return utils.GetMetricValue("model_validation_operator_status_updates_total", + statusLabels, operatorNamespace, curlMetricsPodName) + }, 30*time.Second, 5*time.Second).Should(BeNumerically(">", baselineSuccessCount)) + + By("cleaning up test resources") + _ = utils.KubectlDelete(modelValidationTemplate, &utils.KubectlDeleteOptions{ + Timeout: "30s", + IgnoreNotFound: true, + TemplateData: utils.CRTemplateData{ + ModelName: metricsTestModelName, + Namespace: webhookTestNamespace, + }, + }) + }) + + It("should track pod state transitions", func() { + By("deploying a ModelValidation CR") + err := utils.KubectlApply(modelValidationTemplate, utils.CRTemplateData{ + ModelName: metricsTestModelName, + Namespace: webhookTestNamespace, + }) + Expect(err).NotTo(HaveOccurred()) + + By("deploying a pod (this creates '' -> 'injected' transition, not tracked)") + err = utils.KubectlApply(podTemplate, defaultMetricsPodData(fmt.Sprintf("%s-transition", metricsTestPodName))) + Expect(err).NotTo(HaveOccurred()) + + By("getting baseline transition metrics for injected -> orphaned") + transitionLabels := map[string]string{ + "from_state": metrics.PodStateInjected, + "to_state": metrics.PodStateOrphaned, + "namespace": webhookTestNamespace, + "model_validation": metricsTestModelName, + } + baselineTransitionCount := utils.GetMetricValue( + "model_validation_operator_pod_state_transitions_total", transitionLabels, + operatorNamespace, curlMetricsPodName) + + By("updating the ModelValidation CR keyPath to trigger injected -> orphaned transition") + err = utils.KubectlApply(modelValidationTemplate, utils.CRTemplateData{ + ModelName: metricsTestModelName, + Namespace: webhookTestNamespace, + KeyPath: "/keys/different_key.pub", + }) + Expect(err).NotTo(HaveOccurred()) + + By("verifying transition metrics increment after MV update") + Eventually(func() int { + return utils.GetMetricValue("model_validation_operator_pod_state_transitions_total", + transitionLabels, operatorNamespace, curlMetricsPodName) + }, 30*time.Second, 5*time.Second).Should(Equal(baselineTransitionCount + 1)) + + By("cleaning up test resources") + podData := defaultMetricsPodData(fmt.Sprintf("%s-transition", metricsTestPodName)) + _ = utils.KubectlDelete(podTemplate, &utils.KubectlDeleteOptions{ + Timeout: "30s", + IgnoreNotFound: true, + TemplateData: podData, + }) + + _ = utils.KubectlDelete(modelValidationTemplate, &utils.KubectlDeleteOptions{ + Timeout: "30s", + IgnoreNotFound: true, + TemplateData: utils.CRTemplateData{ + ModelName: metricsTestModelName, + Namespace: webhookTestNamespace, + }, + }) + }) + + It("should track ModelValidation CR count per namespace", func() { + By("getting baseline transition metrics for namespace count") + namespaceLabels := map[string]string{ + "namespace": webhookTestNamespace, + } + namespaceCount := utils.GetMetricValue( + "model_validation_operator_modelvalidation_crs_total", namespaceLabels, + operatorNamespace, curlMetricsPodName) + By("deploying multiple ModelValidation CRs in the test namespace") + for i := 1; i <= 2; i++ { + err := utils.KubectlApply(modelValidationTemplate, utils.CRTemplateData{ + ModelName: fmt.Sprintf("metrics-test-model-%d", i), + Namespace: webhookTestNamespace, + }) + Expect(err).NotTo(HaveOccurred(), fmt.Sprintf("Failed to apply ModelValidation CR %d", i)) + } + + By("verifying CR count after ModelValidation deployments") + Eventually(func() int { + return utils.GetMetricValue("model_validation_operator_modelvalidation_crs_total", + namespaceLabels, operatorNamespace, curlMetricsPodName) + }, 30*time.Second, 5*time.Second).Should(Equal(namespaceCount + 2)) + + By("cleaning up test resources") + for i := 1; i <= 2; i++ { + _ = utils.KubectlDelete(modelValidationTemplate, &utils.KubectlDeleteOptions{ + Timeout: "30s", + IgnoreNotFound: true, + }) + } + }) + + It("should expose queue size metrics", func() { + By("checking metrics for queue information") + metricsOutput := utils.GetMetricsOutput(operatorNamespace, curlMetricsPodName) + + Expect(metricsOutput).To(ContainSubstring("model_validation_operator_status_update_queue_size")) + + Expect(metricsOutput).To(ContainSubstring("model_validation_operator_status_update_duration_seconds")) + }) + }) + + Context("Metrics Integration", func() { + It("should show consistent metrics across pod lifecycle", func() { + By("deploying a ModelValidation CR") + err := utils.KubectlApply(modelValidationTemplate, utils.CRTemplateData{ + ModelName: metricsTestModelName, + Namespace: webhookTestNamespace, + }) + Expect(err).NotTo(HaveOccurred()) + + By("getting baseline metrics") + baselineMetrics := utils.GetMetricsOutput(operatorNamespace, curlMetricsPodName) + baselineInjectedCount := utils.ExtractMetricValue(baselineMetrics, + "model_validation_operator_modelvalidation_pod_count", + map[string]string{ + "namespace": webhookTestNamespace, + "model_validation": metricsTestModelName, + "pod_state": metrics.PodStateInjected, + }) + + By("deploying a pod") + err = utils.KubectlApply(podTemplate, defaultMetricsPodData(fmt.Sprintf("%s-lifecycle", metricsTestPodName))) + Expect(err).NotTo(HaveOccurred()) + + By("verifying metrics increment after pod injection") + Eventually(func(g Gomega) { + currentMetrics := utils.GetMetricsOutput(operatorNamespace, curlMetricsPodName) + currentInjectedCount := utils.ExtractMetricValue(currentMetrics, + "model_validation_operator_modelvalidation_pod_count", + map[string]string{ + "namespace": webhookTestNamespace, + "model_validation": metricsTestModelName, + "pod_state": metrics.PodStateInjected, + }) + g.Expect(currentInjectedCount).To(Equal(baselineInjectedCount + 1)) + }, 30*time.Second, 5*time.Second).Should(Succeed()) + + By("deleting the pod") + podToDelete := fmt.Sprintf( + "apiVersion: v1\nkind: Pod\nmetadata:\n name: %s\n namespace: %s", + fmt.Sprintf("%s-lifecycle", metricsTestPodName), webhookTestNamespace) + err = utils.KubectlDelete([]byte(podToDelete), nil) + Expect(err).NotTo(HaveOccurred()) + + By("verifying metrics decrement after pod deletion") + Eventually(func(g Gomega) { + currentMetrics := utils.GetMetricsOutput(operatorNamespace, curlMetricsPodName) + currentInjectedCount := utils.ExtractMetricValue(currentMetrics, + "model_validation_operator_modelvalidation_pod_count", + map[string]string{ + "namespace": webhookTestNamespace, + "model_validation": metricsTestModelName, + "pod_state": metrics.PodStateInjected, + }) + g.Expect(currentInjectedCount).To(Equal(baselineInjectedCount)) + }, 30*time.Second, 5*time.Second).Should(Succeed()) + + By("cleaning up test resources") + _ = utils.KubectlDelete(modelValidationTemplate, &utils.KubectlDeleteOptions{ + Timeout: "30s", + IgnoreNotFound: true, + TemplateData: utils.CRTemplateData{ + ModelName: metricsTestModelName, + Namespace: webhookTestNamespace, + }, + }) + }) + }) +}) diff --git a/test/e2e/status_tracking_test.go b/test/e2e/status_tracking_test.go new file mode 100644 index 00000000..9d17afc4 --- /dev/null +++ b/test/e2e/status_tracking_test.go @@ -0,0 +1,178 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package e2e + +import ( + "fmt" + "time" + + _ "embed" + + . "github.com/onsi/ginkgo/v2" //nolint:revive + . "github.com/onsi/gomega" //nolint:revive + corev1 "k8s.io/api/core/v1" + + "github.com/sigstore/model-validation-operator/api/v1alpha1" + "github.com/sigstore/model-validation-operator/internal/constants" + utils "github.com/sigstore/model-validation-operator/test/utils" +) + +const statusTestPodName = "status-test-pod" +const statusTestModelName = "status-test-model" + +// defaultStatusPodData creates a standard status test pod configuration +func defaultStatusPodData(podName string) utils.PodTemplateData { + return utils.DefaultPodData(podName, webhookTestNamespace, statusTestModelName, "status") +} + +var _ = Describe("ModelValidation Status Tracking", Ordered, func() { + Context("Status Field Updates", func() { + It("should track injected pod count correctly", func() { + By("deploying a ModelValidation CR with signed model") + err := utils.KubectlApply(modelValidationTemplate, utils.CRTemplateData{ + ModelName: statusTestModelName, + Namespace: webhookTestNamespace, + }) + Expect(err).NotTo(HaveOccurred(), "Failed to apply signed ModelValidation CR") + + By("verifying initial status shows zero counts") + verifyInitialStatus := func(g Gomega) { + var mv v1alpha1.ModelValidation + err := utils.KubectlGetJSON("modelvalidation", statusTestModelName, webhookTestNamespace, &mv) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(mv.Status.AuthMethod).To(Equal("public-key")) + g.Expect(mv.Status.InjectedPodCount).To(Equal(int32(0))) + g.Expect(mv.Status.UninjectedPodCount).To(Equal(int32(0))) + g.Expect(mv.Status.OrphanedPodCount).To(Equal(int32(0))) + } + Eventually(verifyInitialStatus).Should(Succeed()) + + By("deploying a pod with the model validation label and signed model volume") + err = utils.KubectlApply(podTemplate, defaultStatusPodData(statusTestPodName)) + Expect(err).NotTo(HaveOccurred(), "Failed to apply signed model pod") + + By("waiting for the init container to be injected") + verifyInitContainerInjection := func(g Gomega) { + var pod corev1.Pod + err := utils.KubectlGetJSON("pod", statusTestPodName, webhookTestNamespace, &pod) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(pod.Spec.InitContainers).To(HaveLen(1)) + g.Expect(pod.Spec.InitContainers[0].Name).To(Equal(constants.ModelValidationInitContainerName)) + } + Eventually(verifyInitContainerInjection).Should(Succeed()) + + By("verifying status shows injected pod count increment") + verifyInjectedStatus := func(g Gomega) { + var mv v1alpha1.ModelValidation + err := utils.KubectlGetJSON("modelvalidation", statusTestModelName, webhookTestNamespace, &mv) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(mv.Status.AuthMethod).To(Equal("public-key")) + g.Expect(mv.Status.InjectedPodCount).To(Equal(int32(1))) + g.Expect(mv.Status.UninjectedPodCount).To(Equal(int32(0))) + g.Expect(mv.Status.OrphanedPodCount).To(Equal(int32(0))) + g.Expect(mv.Status.InjectedPods).To(HaveLen(1)) + g.Expect(mv.Status.InjectedPods[0].Name).To(Equal(statusTestPodName)) + g.Expect(mv.Status.InjectedPods[0].UID).NotTo(BeEmpty()) + g.Expect(mv.Status.InjectedPods[0].InjectedAt).NotTo(BeZero()) + g.Expect(mv.Status.LastUpdated).NotTo(BeZero()) + } + Eventually(verifyInjectedStatus, 30*time.Second).Should(Succeed()) + + By("cleaning up test resources") + _ = utils.KubectlDelete(podTemplate, &utils.KubectlDeleteOptions{ + Timeout: "30s", + IgnoreNotFound: true, + TemplateData: defaultStatusPodData(statusTestPodName), + }) + _ = utils.KubectlDelete(modelValidationTemplate, &utils.KubectlDeleteOptions{Timeout: "30s", IgnoreNotFound: true}) + }) + + It("should handle multiple pods correctly", func() { + By("deploying a ModelValidation CR") + err := utils.KubectlApply(modelValidationTemplate, utils.CRTemplateData{ + ModelName: statusTestModelName, + Namespace: webhookTestNamespace, + }) + Expect(err).NotTo(HaveOccurred()) + + By("deploying multiple pods with the same validation label") + for i := 1; i <= 3; i++ { + podName := fmt.Sprintf("status-test-pod-%d", i) + err = utils.KubectlApply(podTemplate, defaultStatusPodData(podName)) + Expect(err).NotTo(HaveOccurred(), fmt.Sprintf("Failed to apply pod %d", i)) + } + + By("verifying status shows correct pod count") + verifyMultiplePodStatus := func(g Gomega) { + var mv v1alpha1.ModelValidation + err := utils.KubectlGetJSON("modelvalidation", statusTestModelName, webhookTestNamespace, &mv) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(mv.Status.InjectedPodCount).To(Equal(int32(3))) + g.Expect(mv.Status.InjectedPods).To(HaveLen(3)) + } + Eventually(verifyMultiplePodStatus, 60*time.Second).Should(Succeed()) + + By("cleaning up test resources") + for i := 1; i <= 3; i++ { + podToDelete := fmt.Sprintf( + "apiVersion: v1\nkind: Pod\nmetadata:\n name: %s\n namespace: %s", + fmt.Sprintf("status-test-pod-%d", i), webhookTestNamespace) + _ = utils.KubectlDelete([]byte(podToDelete), &utils.KubectlDeleteOptions{Timeout: "30s", IgnoreNotFound: true}) + } + _ = utils.KubectlDelete(modelValidationTemplate, &utils.KubectlDeleteOptions{Timeout: "30s", IgnoreNotFound: true}) + }) + + It("should track pod deletion and update counts", func() { + By("deploying a ModelValidation CR and pod") + err := utils.KubectlApply(modelValidationTemplate, utils.CRTemplateData{ + ModelName: statusTestModelName, + Namespace: webhookTestNamespace, + }) + Expect(err).NotTo(HaveOccurred()) + + err = utils.KubectlApply(podTemplate, defaultStatusPodData(statusTestPodName)) + Expect(err).NotTo(HaveOccurred()) + + By("waiting for pod injection and status update") + Eventually(func(g Gomega) { + var mv v1alpha1.ModelValidation + err := utils.KubectlGetJSON("modelvalidation", statusTestModelName, webhookTestNamespace, &mv) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(mv.Status.InjectedPodCount).To(Equal(int32(1))) + }, 30*time.Second).Should(Succeed()) + + By("deleting the pod") + err = utils.KubectlDelete(podTemplate, &utils.KubectlDeleteOptions{ + TemplateData: defaultStatusPodData(statusTestPodName), + }) + Expect(err).NotTo(HaveOccurred()) + + By("verifying status reflects pod deletion") + verifyPodDeletion := func(g Gomega) { + var mv v1alpha1.ModelValidation + err := utils.KubectlGetJSON("modelvalidation", statusTestModelName, webhookTestNamespace, &mv) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(mv.Status.InjectedPodCount).To(Equal(int32(0))) + g.Expect(mv.Status.InjectedPods).To(BeEmpty()) + } + Eventually(verifyPodDeletion, 30*time.Second).Should(Succeed()) + + By("cleaning up ModelValidation CR") + _ = utils.KubectlDelete(modelValidationTemplate, &utils.KubectlDeleteOptions{Timeout: "30s", IgnoreNotFound: true}) + }) + }) +}) diff --git a/test/e2e/testdata.go b/test/e2e/testdata.go new file mode 100644 index 00000000..58d3d46d --- /dev/null +++ b/test/e2e/testdata.go @@ -0,0 +1,35 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package e2e + +import _ "embed" + +// ModelValidation CR template used by multiple tests +// +//go:embed testdata/modelvalidation_template.yaml +var modelValidationTemplate []byte + +// Pod template +// +//go:embed testdata/pod_template.yaml +var podTemplate []byte + +//go:embed testdata/curl_metrics_pod_template.yaml +var curlPodTemplate []byte + +//go:embed testdata/clusterrolebinding_template.yaml +var clusterRoleBindingTemplate []byte diff --git a/test/e2e/testdata/clusterrolebinding_template.yaml b/test/e2e/testdata/clusterrolebinding_template.yaml new file mode 100644 index 00000000..1f030a47 --- /dev/null +++ b/test/e2e/testdata/clusterrolebinding_template.yaml @@ -0,0 +1,12 @@ +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: {{.Name}} +subjects: +- kind: ServiceAccount + name: {{.ServiceAccountName}} + namespace: {{.Namespace}} +roleRef: + kind: ClusterRole + name: {{.ClusterRoleName}} + apiGroup: rbac.authorization.k8s.io \ No newline at end of file diff --git a/test/e2e/testdata/curl_metrics_pod_template.yaml b/test/e2e/testdata/curl_metrics_pod_template.yaml new file mode 100644 index 00000000..a0673404 --- /dev/null +++ b/test/e2e/testdata/curl_metrics_pod_template.yaml @@ -0,0 +1,41 @@ +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{.ServiceAccount}} + namespace: {{.Namespace}} +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: {{.ServiceAccount}}-binding +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: model-validation-metrics-reader +subjects: +- kind: ServiceAccount + name: {{.ServiceAccount}} + namespace: {{.Namespace}} +--- +apiVersion: v1 +kind: Pod +metadata: + name: {{.PodName}} + namespace: {{.Namespace}} +spec: + containers: + - name: curl + image: curlimages/curl:latest + command: ["/bin/sh", "-c"] + args: ["sleep 3600"] + securityContext: + allowPrivilegeEscalation: false + capabilities: + drop: ["ALL"] + runAsNonRoot: true + runAsUser: 1000 + seccompProfile: + type: RuntimeDefault + serviceAccount: {{.ServiceAccount}} + restartPolicy: Never \ No newline at end of file diff --git a/test/e2e/testdata/model-data-daemonset.yaml b/test/e2e/testdata/model-data-daemonset.yaml new file mode 100644 index 00000000..7ed9293d --- /dev/null +++ b/test/e2e/testdata/model-data-daemonset.yaml @@ -0,0 +1,56 @@ +apiVersion: apps/v1 +kind: DaemonSet +metadata: + name: model-data-setup + namespace: e2e-webhook-test-ns + labels: + app: model-data-setup +spec: + selector: + matchLabels: + app: model-data-setup + template: + metadata: + labels: + app: model-data-setup + spec: + containers: + - name: model-setup + image: model-validation-test-model:latest + imagePullPolicy: Never + command: + - sh + - -c + - | + echo "Setting up model data on node $NODE_NAME..." + mkdir -p /host-data /host-keys + cp -r /data/* /host-data/ + cp -r /keys/* /host-keys/ + echo "Model data setup complete on node $NODE_NAME" + # Keep running so DaemonSet stays active + sleep infinity + env: + - name: NODE_NAME + valueFrom: + fieldRef: + fieldPath: spec.nodeName + volumeMounts: + - name: host-model-data + mountPath: /host-data + - name: host-keys-data + mountPath: /host-keys + securityContext: + runAsNonRoot: false + runAsUser: 0 + volumes: + - name: host-model-data + hostPath: + path: /tmp/e2e-model-data + type: DirectoryOrCreate + - name: host-keys-data + hostPath: + path: /tmp/e2e-keys-data + type: DirectoryOrCreate + tolerations: + - operator: Exists + effect: NoSchedule \ No newline at end of file diff --git a/test/e2e/testdata/modelvalidation_cr.yaml b/test/e2e/testdata/modelvalidation_cr.yaml deleted file mode 100644 index df6e9aff..00000000 --- a/test/e2e/testdata/modelvalidation_cr.yaml +++ /dev/null @@ -1,13 +0,0 @@ -apiVersion: ml.sigstore.dev/v1alpha1 -kind: ModelValidation -metadata: - name: e2e-test-model - namespace: e2e-webhook-test-ns -spec: - config: - sigstoreConfig: - certificateIdentity: "https://github.com/sigstore/model-validation-operator/.github/workflows/sign-model.yaml@refs/tags/v0.0.2" - certificateOidcIssuer: "https://token.actions.githubusercontent.com" - model: - path: /data - signaturePath: /data/model.sig diff --git a/test/e2e/testdata/modelvalidation_template.yaml b/test/e2e/testdata/modelvalidation_template.yaml new file mode 100644 index 00000000..20a6e4a2 --- /dev/null +++ b/test/e2e/testdata/modelvalidation_template.yaml @@ -0,0 +1,12 @@ +apiVersion: ml.sigstore.dev/v1alpha1 +kind: ModelValidation +metadata: + name: {{.ModelName}} + namespace: {{.Namespace}} +spec: + config: + publicKeyConfig: + keyPath: {{if .KeyPath}}{{.KeyPath}}{{else}}/keys/test_public_key.pub{{end}} + model: + path: /data + signaturePath: /data/model.sig diff --git a/test/e2e/testdata/pod_template.yaml b/test/e2e/testdata/pod_template.yaml new file mode 100644 index 00000000..34a888d3 --- /dev/null +++ b/test/e2e/testdata/pod_template.yaml @@ -0,0 +1,32 @@ +apiVersion: v1 +kind: Pod +metadata: + name: {{.PodName}} + namespace: {{.Namespace}} + labels: + validation.ml.sigstore.dev/ml: "{{.ModelName}}"{{if .TestBatch}} + test-batch: "{{.TestBatch}}"{{end}} +spec: + restartPolicy: Never + containers: + - name: test-container + image: busybox:latest + command: ["sh", "-c", "sleep 3600"]{{if .VolumeMounts}} + volumeMounts:{{range .VolumeMounts}} + - name: {{.Name}} + mountPath: {{.MountPath}} + readOnly: {{.ReadOnly}}{{end}}{{end}} + securityContext: + allowPrivilegeEscalation: false + capabilities: + drop: + - ALL + runAsNonRoot: true + runAsUser: 1000 + seccompProfile: + type: RuntimeDefault{{if .Volumes}} + volumes:{{range .Volumes}} + - name: {{.Name}}{{if .HostPath}} + hostPath: + path: {{.HostPath.Path}} + type: {{if .HostPath.Type}}{{.HostPath.Type}}{{else}}Directory{{end}}{{end}}{{end}}{{end}} \ No newline at end of file diff --git a/test/e2e/testdata/test_pod.yaml b/test/e2e/testdata/test_pod.yaml deleted file mode 100644 index 5b58d43e..00000000 --- a/test/e2e/testdata/test_pod.yaml +++ /dev/null @@ -1,22 +0,0 @@ -apiVersion: v1 -kind: Pod -metadata: - name: e2e-test-pod - namespace: e2e-webhook-test-ns - labels: - validation.ml.sigstore.dev/ml: "e2e-test-model" -spec: - restartPolicy: Never - containers: - - name: test-container - image: busybox - command: ["sh", "-c", "sleep 3600"] - securityContext: - allowPrivilegeEscalation: false - capabilities: - drop: - - ALL - runAsNonRoot: true - runAsUser: 1000 - seccompProfile: - type: RuntimeDefault \ No newline at end of file diff --git a/test/e2e/validation_test.go b/test/e2e/validation_test.go new file mode 100644 index 00000000..db20eae8 --- /dev/null +++ b/test/e2e/validation_test.go @@ -0,0 +1,163 @@ +/* +Copyright 2025. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package e2e + +import ( + "fmt" + "os/exec" + "time" + + _ "embed" + + . "github.com/onsi/ginkgo/v2" //nolint:revive + . "github.com/onsi/gomega" //nolint:revive + corev1 "k8s.io/api/core/v1" + + "github.com/sigstore/model-validation-operator/internal/constants" + utils "github.com/sigstore/model-validation-operator/test/utils" +) + +const validationTestNamespace = "e2e-webhook-test-ns" +const validationTestModelName = "validation-test-model" + +// defaultValidationPodData creates a standard validation test pod configuration +func defaultValidationPodData(podName, namespace, modelName string) utils.PodTemplateData { + return utils.DefaultPodData(podName, namespace, modelName, "validation") +} + +var _ = Describe("ModelValidation Success/Failure Scenarios", Ordered, func() { + Context("Validation Success and Failure", func() { + It("should fail validation with invalid signature (current test model)", func() { + podName := "failure-test-pod" + + By("deploying a ModelValidation CR for failure test") + err := utils.KubectlApply(modelValidationTemplate, utils.CRTemplateData{ + ModelName: validationTestModelName, + Namespace: validationTestNamespace, + KeyPath: "/keys/test_invalid_public_key.pub", + }) + Expect(err).NotTo(HaveOccurred()) + + By("deploying a pod that should fail validation due to invalid signature") + err = utils.KubectlApply(podTemplate, + defaultValidationPodData(podName, validationTestNamespace, validationTestModelName)) + Expect(err).NotTo(HaveOccurred()) + + By("verifying the init container was injected") + Eventually(func(g Gomega) { + var pod corev1.Pod + err := utils.KubectlGetJSON("pod", podName, validationTestNamespace, &pod) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(pod.Spec.InitContainers).ToNot(BeEmpty()) + g.Expect(utils.HasValidationContainer(&pod)).To(BeTrue()) + }, 30*time.Second, 1*time.Second).Should(Succeed()) + + By("verifying the init container fails due to invalid signature") + Eventually(func(g Gomega) { + var pod corev1.Pod + err := utils.KubectlGetJSON("pod", podName, validationTestNamespace, &pod) + g.Expect(err).NotTo(HaveOccurred()) + + g.Expect(pod.Status.Phase).To(Equal(corev1.PodFailed)) + + g.Expect(pod.Status.InitContainerStatuses).ToNot(BeEmpty()) + initStatus := pod.Status.InitContainerStatuses[0] + g.Expect(initStatus.State.Terminated).NotTo(BeNil()) + g.Expect(initStatus.State.Terminated.ExitCode).To(Equal(int32(1))) // Should fail with exit code 1 + }, 60*time.Second, 5*time.Second).Should(Succeed()) + + By("verifying the init container logs show validation failure") + Eventually(func(g Gomega) { + cmd := exec.Command("kubectl", "logs", podName, "-n", validationTestNamespace, "-c", + constants.ModelValidationInitContainerName) + output, err := utils.Run(cmd) + g.Expect(err).NotTo(HaveOccurred()) + + g.Expect(output).To(ContainSubstring("Verification failed with error")) + }, 30*time.Second, 5*time.Second).Should(Succeed()) + + By("validation failure test completed successfully - webhook injection and " + + "validation failure both working as expected") + + By("cleaning up failure test resources") + podToDelete := fmt.Sprintf( + "apiVersion: v1\nkind: Pod\nmetadata:\n name: %s\n namespace: %s", + podName, validationTestNamespace) + _ = utils.KubectlDelete([]byte(podToDelete), &utils.KubectlDeleteOptions{Timeout: "30s", IgnoreNotFound: true}) + }) + + It("should successfully validate with public key signature", func() { + modelName := "public-key-success-test" + podName := "success-test-pod" + + By("deploying a ModelValidation CR with public key configuration") + err := utils.KubectlApply(modelValidationTemplate, utils.CRTemplateData{ + ModelName: modelName, + Namespace: validationTestNamespace, + }) + Expect(err).NotTo(HaveOccurred()) + + By("deploying a pod that should pass validation with public key signature") + err = utils.KubectlApply(podTemplate, defaultValidationPodData(podName, validationTestNamespace, modelName)) + Expect(err).NotTo(HaveOccurred()) + + By("verifying the init container was injected") + Eventually(func(g Gomega) { + var pod corev1.Pod + err := utils.KubectlGetJSON("pod", podName, validationTestNamespace, &pod) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(pod.Spec.InitContainers).ToNot(BeEmpty()) + + g.Expect(utils.HasValidationContainer(&pod)).To(BeTrue()) + }, 30*time.Second, 2*time.Second).Should(Succeed()) + + By("verifying the init container succeeds and pod reaches Running state") + Eventually(func(g Gomega) { + var pod corev1.Pod + err := utils.KubectlGetJSON("pod", podName, validationTestNamespace, &pod) + g.Expect(err).NotTo(HaveOccurred()) + + g.Expect(pod.Status.Phase).To(Equal(corev1.PodRunning)) + + g.Expect(pod.Status.InitContainerStatuses).ToNot(BeEmpty()) + initStatus := pod.Status.InitContainerStatuses[0] + g.Expect(initStatus.State.Terminated).NotTo(BeNil()) + g.Expect(initStatus.State.Terminated.ExitCode).To(Equal(int32(0))) // Should succeed with exit code 0 + }, 60*time.Second, 5*time.Second).Should(Succeed()) + + By("verifying the init container logs show successful validation") + Eventually(func(g Gomega) { + cmd := exec.Command("kubectl", "logs", podName, "-n", validationTestNamespace, "-c", + constants.ModelValidationInitContainerName) + output, err := utils.Run(cmd) + g.Expect(err).NotTo(HaveOccurred()) + + g.Expect(output).To(ContainSubstring("Verification succeeded")) + }, 30*time.Second, 5*time.Second).Should(Succeed()) + + By("cleaning up success test resources") + podToDelete := fmt.Sprintf( + "apiVersion: v1\nkind: Pod\nmetadata:\n name: %s\n namespace: %s", + podName, validationTestNamespace) + _ = utils.KubectlDelete([]byte(podToDelete), &utils.KubectlDeleteOptions{Timeout: "30s", IgnoreNotFound: true}) + crToDelete := fmt.Sprintf( + "apiVersion: ml.sigstore.dev/v1alpha1\nkind: ModelValidation\nmetadata:\n name: %s\n namespace: %s", + modelName, validationTestNamespace) + _ = utils.KubectlDelete([]byte(crToDelete), &utils.KubectlDeleteOptions{Timeout: "30s", IgnoreNotFound: true}) + }) + }) +}) diff --git a/test/utils/utils.go b/test/utils/utils.go index 0adccea7..4335e4fd 100644 --- a/test/utils/utils.go +++ b/test/utils/utils.go @@ -18,27 +18,82 @@ limitations under the License. package utils_test //nolint:revive import ( - "bufio" "bytes" + "encoding/json" "fmt" "os" "os/exec" "strings" + "text/template" . "github.com/onsi/ginkgo/v2" //nolint:revive,staticcheck + dto "github.com/prometheus/client_model/go" + "github.com/prometheus/common/expfmt" + corev1 "k8s.io/api/core/v1" ) const ( - prometheusOperatorVersion = "v0.84.0" - prometheusOperatorURL = "https://github.com/prometheus-operator/prometheus-operator/" + - "releases/download/%s/bundle.yaml" + // ServiceAccountName is the name of the service account used by the model validation operator + ServiceAccountName = "model-validation-controller-manager" - certmanagerVersion = "v1.18.2" - certmanagerURLTmpl = "https://github.com/cert-manager/cert-manager/releases/download/%s/cert-manager.yaml" + // MetricsServiceAccountName is the name of the service account used for metrics access in tests + MetricsServiceAccountName = "e2e-metrics-reader" + + // MetricsServiceName is the name of the metrics service exposed by the model validation operator + MetricsServiceName = "model-validation-controller-manager-metrics-service" ) -func warnError(err error) { - _, _ = fmt.Fprintf(GinkgoWriter, "warning: %v\n", err) +// PodTemplateData represents template data for e2e pod tests +type PodTemplateData struct { + PodName string + Namespace string + ModelName string + TestBatch string + VolumeMounts []VolumeMount + Volumes []Volume +} + +// VolumeMount represents a volume mount configuration +type VolumeMount struct { + Name string + MountPath string + ReadOnly bool +} + +// Volume represents a volume configuration +type Volume struct { + Name string + HostPath *HostPathVolume +} + +// HostPathVolume represents a host path volume configuration +type HostPathVolume struct { + Path string + Type string +} + +// CRTemplateData represents template data for custom resource tests +type CRTemplateData struct { + ModelName string + Namespace string + KeyPath string +} + +// CurlPodTemplateData represents template data for curl pod tests +type CurlPodTemplateData struct { + PodName string + Namespace string + Token string + ServiceName string + ServiceAccount string +} + +// ClusterRoleBindingTemplateData represents template data for cluster role binding tests +type ClusterRoleBindingTemplateData struct { + Name string + ServiceAccountName string + Namespace string + ClusterRoleName string } // Run executes the provided command within this context @@ -55,198 +110,308 @@ func Run(cmd *exec.Cmd) (string, error) { _, _ = fmt.Fprintf(GinkgoWriter, "running: %s\n", command) output, err := cmd.CombinedOutput() if err != nil { - return string(output), fmt.Errorf("%s failed with error: (%v) %s", command, err, string(output)) + return string(output), fmt.Errorf("%s failed with error: (%v) %s", + command, err, string(output)) } return string(output), nil } -// InstallPrometheusOperator installs the prometheus Operator to be used to export the enabled metrics. -func InstallPrometheusOperator() error { - url := fmt.Sprintf(prometheusOperatorURL, prometheusOperatorVersion) - cmd := exec.Command("kubectl", "create", "-f", url) - _, err := Run(cmd) - return err +// GetNonEmptyLines converts given command output string into individual objects +// according to line breakers, and ignores the empty elements in it. +func GetNonEmptyLines(output string) []string { + var res []string + elements := strings.Split(output, "\n") + for _, element := range elements { + if element != "" { + res = append(res, element) + } + } + + return res } -// UninstallPrometheusOperator uninstalls the prometheus -func UninstallPrometheusOperator() { - url := fmt.Sprintf(prometheusOperatorURL, prometheusOperatorVersion) - cmd := exec.Command("kubectl", "delete", "-f", url) - if _, err := Run(cmd); err != nil { - warnError(err) +// GetProjectDir will return the directory where the project is +func GetProjectDir() (string, error) { + wd, err := os.Getwd() + if err != nil { + return wd, err } + wd = strings.ReplaceAll(wd, "/test/e2e", "") + return wd, nil } -// IsPrometheusCRDsInstalled checks if any Prometheus CRDs are installed -// by verifying the existence of key CRDs related to Prometheus. -func IsPrometheusCRDsInstalled() bool { - // List of common Prometheus CRDs - prometheusCRDs := []string{ - "prometheuses.monitoring.coreos.com", - "prometheusrules.monitoring.coreos.com", - "prometheusagents.monitoring.coreos.com", - } +// KubectlApply applies a YAML resource from embedded bytes data or template +// If templateData is nil, yamlData is used directly. If templateData is provided, +// yamlData is treated as a template. +func KubectlApply(yamlData []byte, templateData any) error { + var finalData []byte + var err error - cmd := exec.Command("kubectl", "get", "crds", "-o", "custom-columns=NAME:.metadata.name") - output, err := Run(cmd) - if err != nil { - return false + if templateData != nil { + finalData, err = executeTemplate(yamlData, templateData) + if err != nil { + return err + } + } else { + finalData = yamlData } - crdList := GetNonEmptyLines(output) - for _, crd := range prometheusCRDs { - for _, line := range crdList { - if strings.Contains(line, crd) { - return true + + cmd := exec.Command("kubectl", "apply", "-f", "-") + cmd.Stdin = bytes.NewReader(finalData) + _, err = Run(cmd) + return err +} + +// KubectlDeleteOptions contains options for kubectl delete operations +type KubectlDeleteOptions struct { + Timeout string // e.g. "30s", "5m" + IgnoreNotFound bool // Use --ignore-not-found flag + TemplateData any // Optional template data for Go template processing +} + +// KubectlDelete deletes a YAML resource from embedded bytes data or template +// with optional settings. If opts is nil, yamlData is used directly with default +// settings. If opts.TemplateData is provided, yamlData is treated as a template. +func KubectlDelete(yamlData []byte, opts *KubectlDeleteOptions) error { + args := []string{"delete", "-f", "-"} + var finalData []byte + var err error + + if opts != nil { + // Handle template processing if template data is provided + if opts.TemplateData != nil { + finalData, err = executeTemplate(yamlData, opts.TemplateData) + if err != nil { + return err } + } else { + finalData = yamlData + } + + if opts.Timeout != "" { + args = append(args, "--timeout="+opts.Timeout) + } + if opts.IgnoreNotFound { + args = append(args, "--ignore-not-found=true") } + } else { + finalData = yamlData } - return false + cmd := exec.Command("kubectl", args...) + cmd.Stdin = bytes.NewReader(finalData) + _, err = Run(cmd) + return err } -// UninstallCertManager uninstalls the cert manager -func UninstallCertManager() { - url := fmt.Sprintf(certmanagerURLTmpl, certmanagerVersion) - cmd := exec.Command("kubectl", "delete", "-f", url) - if _, err := Run(cmd); err != nil { - warnError(err) +// KubectlGet retrieves a Kubernetes resource and returns the output +// If name is empty, retrieves all resources of the specified type +func KubectlGet(resource, name, namespace string, outputFormat string) (string, error) { + args := []string{"get", resource} + if name != "" { + args = append(args, name) + } + if namespace != "" { + args = append(args, "-n", namespace) } + if outputFormat != "" { + args = append(args, "-o", outputFormat) + } + + cmd := exec.Command("kubectl", args...) + return Run(cmd) } -// InstallCertManager installs the cert manager bundle. -func InstallCertManager() error { - url := fmt.Sprintf(certmanagerURLTmpl, certmanagerVersion) - cmd := exec.Command("kubectl", "apply", "-f", url) - if _, err := Run(cmd); err != nil { +// KubectlGetJSON retrieves a Kubernetes resource as JSON and unmarshals it +func KubectlGetJSON(resource, name, namespace string, result any) error { + output, err := KubectlGet(resource, name, namespace, "json") + if err != nil { return err } - // Wait for cert-manager-webhook to be ready, which can take time if cert-manager - // was re-installed after uninstalling on a cluster. - cmd = exec.Command("kubectl", "wait", "deployment.apps/cert-manager-webhook", - "--for", "condition=Available", - "--namespace", "cert-manager", - "--timeout", "5m", - ) + return json.Unmarshal([]byte(output), result) +} +// KubectlWait waits for a condition on a Kubernetes resource +func KubectlWait(resource, name, namespace, condition string, timeout string) error { + args := []string{"wait", resource, name} + if namespace != "" { + args = append(args, "-n", namespace) + } + args = append(args, "--for", condition, "--timeout", timeout) + + cmd := exec.Command("kubectl", args...) _, err := Run(cmd) return err } -// IsCertManagerCRDsInstalled checks if any Cert Manager CRDs are installed -// by verifying the existence of key CRDs related to Cert Manager. -func IsCertManagerCRDsInstalled() bool { - // List of common Cert Manager CRDs - certManagerCRDs := []string{ - "certificates.cert-manager.io", - "issuers.cert-manager.io", - "clusterissuers.cert-manager.io", - "certificaterequests.cert-manager.io", - "orders.acme.cert-manager.io", - "challenges.acme.cert-manager.io", +// executeTemplate processes a Go template with the given data and returns the result +func executeTemplate(templateData []byte, data any) ([]byte, error) { + tmpl, err := template.New("kubectl").Parse(string(templateData)) + if err != nil { + return nil, fmt.Errorf("failed to parse template: %w", err) + } + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, data); err != nil { + return nil, fmt.Errorf("failed to execute template: %w", err) + } + + return buf.Bytes(), nil +} + +// DefaultPodData creates a pod template with common defaults +func DefaultPodData(podName, namespace, modelName, testBatch string) PodTemplateData { + modelDataVolume := Volume{ + Name: "model-data", + HostPath: &HostPathVolume{ + Path: "/tmp/e2e-model-data", + Type: "Directory", + }, + } + keysDataVolume := Volume{ + Name: "keys-data", + HostPath: &HostPathVolume{ + Path: "/tmp/e2e-keys-data", + Type: "Directory", + }, + } + modelDataMount := VolumeMount{ + Name: "model-data", + MountPath: "/data", + ReadOnly: true, + } + keysDataMount := VolumeMount{ + Name: "keys-data", + MountPath: "/keys", + ReadOnly: true, } - // Execute the kubectl command to get all CRDs - cmd := exec.Command("kubectl", "get", "crds") - output, err := Run(cmd) - if err != nil { - return false + data := PodTemplateData{ + PodName: podName, + Namespace: namespace, + ModelName: modelName, + TestBatch: testBatch, + VolumeMounts: []VolumeMount{modelDataMount, keysDataMount}, + Volumes: []Volume{modelDataVolume, keysDataVolume}, } - // Check if any of the Cert Manager CRDs are present - crdList := GetNonEmptyLines(output) - for _, crd := range certManagerCRDs { - for _, line := range crdList { - if strings.Contains(line, crd) { - return true - } - } + return data +} + +// GetMetricsOutput retrieves fresh metrics by executing curl in the persistent pod +func GetMetricsOutput(namespace, podName string) string { + token, err := CreateServiceAccountToken(MetricsServiceAccountName, namespace) + if err != nil { + return "" } - return false + curlCommand := fmt.Sprintf( + "curl -s -k -H 'Authorization: Bearer %s' https://%s.%s.svc.cluster.local:8443/metrics 2>/dev/null", + token, MetricsServiceName, namespace) + execCmd := exec.Command("kubectl", "exec", podName, "-n", namespace, + "--", "/bin/sh", "-c", curlCommand) + + output, err := Run(execCmd) + if err != nil { + return "" + } + return output } -// LoadImageToKindClusterWithName loads a local docker image to the kind cluster -func LoadImageToKindClusterWithName(name string) error { - cluster := "kind" - if v, ok := os.LookupEnv("KIND_CLUSTER"); ok { - cluster = v +// KubectlDescribe describes a Kubernetes resource +func KubectlDescribe(resource, name, namespace string) (string, error) { + args := []string{"describe", resource, name} + if namespace != "" { + args = append(args, "-n", namespace) } - kindOptions := []string{"load", "docker-image", name, "--name", cluster} - cmd := exec.Command("kind", kindOptions...) - _, err := Run(cmd) - return err + + cmd := exec.Command("kubectl", args...) + return Run(cmd) } -// GetNonEmptyLines converts given command output string into individual objects -// according to line breakers, and ignores the empty elements in it. -func GetNonEmptyLines(output string) []string { - var res []string - elements := strings.Split(output, "\n") - for _, element := range elements { - if element != "" { - res = append(res, element) +// KubectlResourceExists checks if a Kubernetes resource exists +func KubectlResourceExists(resource, name, namespace string) bool { + _, err := KubectlGet(resource, name, namespace, "") + return err == nil +} + +// HasValidationContainer checks if a pod has the model-validation init container +func HasValidationContainer(pod *corev1.Pod) bool { + for _, container := range pod.Spec.InitContainers { + if strings.Contains(container.Name, "model-validation") { + return true } } - - return res + return false } -// GetProjectDir will return the directory where the project is -func GetProjectDir() (string, error) { - wd, err := os.Getwd() +// CreateServiceAccountToken creates a token for the specified service account in the given namespace +func CreateServiceAccountToken(serviceAccountName, namespace string) (string, error) { + tokenCmd := exec.Command("kubectl", "create", "token", serviceAccountName, "-n", namespace) + token, err := Run(tokenCmd) if err != nil { - return wd, err + return "", fmt.Errorf( + "failed to create token for service account %s in namespace %s: %w", + serviceAccountName, namespace, err) } - wd = strings.ReplaceAll(wd, "/test/e2e", "") - return wd, nil + return strings.TrimSpace(token), nil } -// UncommentCode searches for target in the file and remove the comment prefix -// of the target content. The target content may span multiple lines. -func UncommentCode(filename, target, prefix string) error { - // false positive - // nolint:gosec - content, err := os.ReadFile(filename) +// ExtractMetricValue extracts a specific metric value from Prometheus output +func ExtractMetricValue(metricsOutput, metricName string, labels map[string]string) int { + parser := expfmt.TextParser{} + metricFamilies, err := parser.TextToMetricFamilies(strings.NewReader(metricsOutput)) if err != nil { - return err + return 0 } - strContent := string(content) - idx := strings.Index(strContent, target) - if idx < 0 { - return fmt.Errorf("unable to find the code %s to be uncomment", target) + metricFamily, exists := metricFamilies[metricName] + if !exists { + return 0 } - out := new(bytes.Buffer) - _, err = out.Write(content[:idx]) - if err != nil { - return err + for _, metric := range metricFamily.GetMetric() { + if labelsMatch(metric.GetLabel(), labels) { + // Extract the metric value based on type + switch metricFamily.GetType() { + case dto.MetricType_GAUGE: + if gauge := metric.GetGauge(); gauge != nil { + return int(gauge.GetValue()) + } + case dto.MetricType_COUNTER: + if counter := metric.GetCounter(); counter != nil { + return int(counter.GetValue()) + } + case dto.MetricType_HISTOGRAM: + if histogram := metric.GetHistogram(); histogram != nil { + return int(histogram.GetSampleCount()) + } + } + } } - scanner := bufio.NewScanner(bytes.NewBufferString(target)) - if !scanner.Scan() { - return nil + return 0 +} + +// labelsMatch checks if all expected labels are present and match in the actual Prometheus labels +func labelsMatch(actualLabels []*dto.LabelPair, expectedLabels map[string]string) bool { + actualLabelMap := make(map[string]string) + for _, labelPair := range actualLabels { + actualLabelMap[labelPair.GetName()] = labelPair.GetValue() } - for { - _, err := out.WriteString(strings.TrimPrefix(scanner.Text(), prefix)) - if err != nil { - return err - } - // Avoid writing a newline in case the previous line was the last in target. - if !scanner.Scan() { - break - } - if _, err := out.WriteString("\n"); err != nil { - return err + + for key, expectedValue := range expectedLabels { + actualValue, exists := actualLabelMap[key] + if !exists || actualValue != expectedValue { + return false } } + return true +} - _, err = out.Write(content[idx+len(target):]) - if err != nil { - return err - } - // false positive - // nolint:gosec - return os.WriteFile(filename, out.Bytes(), 0644) +// GetMetricValue gets the current value of a specific metric +func GetMetricValue(metricName string, labels map[string]string, namespace, podName string) int { + metrics := GetMetricsOutput(namespace, podName) + return ExtractMetricValue(metrics, metricName, labels) } diff --git a/testdata/docker/test-model.Dockerfile b/testdata/docker/test-model.Dockerfile new file mode 100644 index 00000000..c21472c1 --- /dev/null +++ b/testdata/docker/test-model.Dockerfile @@ -0,0 +1,15 @@ +FROM busybox + +# Copy the real TensorFlow SavedModel files from testdata +COPY tensorflow_saved_model/ /data/ + +# Copy public keys for verification tests - separate directory +RUN mkdir -p /keys +COPY docker/test_public_key.pub /keys/test_public_key.pub +COPY docker/test_invalid_public_key.pub /keys/test_invalid_public_key.pub + +# Make files readable and ensure no stray public key files in model directory +RUN chmod -R 644 /data /keys && rm -f /data/test_public_key.pub /data/*.pub + +# Default command +CMD ["sleep", "3600"] \ No newline at end of file