diff --git a/Dockerfile.lmes-job b/Dockerfile.lmes-job index d7948677a..77452cf32 100644 --- a/Dockerfile.lmes-job +++ b/Dockerfile.lmes-job @@ -12,10 +12,10 @@ RUN mkdir -p /opt/app-root/src/my_catalogs/cards && chmod -R g+rwx /opt/app-root RUN mkdir -p /opt/app-root/src/.cache ENV PATH="/opt/app-root/bin:/opt/app-root/src/.local/bin/:/opt/app-root/src/bin:/usr/local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" -RUN curl -L https://github.com/opendatahub-io/lm-evaluation-harness/archive/refs/heads/incubation.zip -o repo.zip && \ +RUN curl -L https://github.com/tarilabs/lm-evaluation-harness/archive/refs/heads/tarilabs-20250917.zip -o repo.zip && \ unzip repo.zip && \ - cp -r lm-evaluation-harness-incubation/* . && \ - rm -rf lm-evaluation-harness-incubation repo.zip && \ + cp -r lm-evaluation-harness-tarilabs-20250917/* . && \ + rm -rf lm-evaluation-harness-tarilabs-20250917 repo.zip && \ pip install --no-cache-dir -r requirements.txt && \ pip install --no-cache-dir -e . diff --git a/api/lmes/v1alpha1/lmevaljob_types.go b/api/lmes/v1alpha1/lmevaljob_types.go index ce3332ac9..952feeed1 100644 --- a/api/lmes/v1alpha1/lmevaljob_types.go +++ b/api/lmes/v1alpha1/lmevaljob_types.go @@ -400,6 +400,9 @@ type Outputs struct { // Create an operator managed PVC // +optional PersistentVolumeClaimManaged *PersistentVolumeClaimManaged `json:"pvcManaged,omitempty"` + // Upload results to OCI registry + // +optional + OCISpec *OCISpec `json:"oci,omitempty"` } func (c *LMEvalContainer) GetSecurityContext() *corev1.SecurityContext { @@ -463,6 +466,40 @@ type OfflineS3Spec struct { CABundle *corev1.SecretKeySelector `json:"caBundle,omitempty"` } +type OCISpec struct { + // Registry URL (e.g., quay.io, registry.redhat.com) + Registry corev1.SecretKeySelector `json:"registry"` + // Repository path (e.g., myorg/evaluation-results) + Repository corev1.SecretKeySelector `json:"repository"` + // Optional tag for the artifact (defaults to job name if not specified) + // +optional + // +kubebuilder:validation:Pattern=`^[a-zA-Z0-9._-]*$` + Tag string `json:"tag,omitempty"` + // Path within the results to package as artifact + // +kubebuilder:validation:Pattern=`^[a-zA-Z0-9._/-]*$` + Path string `json:"path"` + // Subject for the OCI artifact + // +optional + // +kubebuilder:validation:Pattern=`^[a-zA-Z0-9._:/@-]*$` + // +kubebuilder:validation:MaxLength=255 + Subject string `json:"subject,omitempty"` + // Username for registry authentication + // +optional + UsernameRef *corev1.SecretKeySelector `json:"username,omitempty"` + // Password for registry authentication + // +optional + PasswordRef *corev1.SecretKeySelector `json:"password,omitempty"` + // Token for registry authentication (alternative to username/password) + // +optional + TokenRef *corev1.SecretKeySelector `json:"token,omitempty"` + // Whether to verify SSL certificates + // +optional + VerifySSL *bool `json:"verifySSL,omitempty"` + // CA bundle for custom certificates + // +optional + CABundle *corev1.SecretKeySelector `json:"caBundle,omitempty"` +} + // OfflineStorageSpec defines the storage configuration for LMEvalJob's offline mode type OfflineStorageSpec struct { PersistentVolumeClaimName *string `json:"pvcName,omitempty"` @@ -585,10 +622,26 @@ func (s *LMEvalJobSpec) HasOfflineS3() bool { return s.Offline != nil && s.Offline.StorageSpec.S3Spec != nil } +func (s *LMEvalJobSpec) HasOCIOutput() bool { + return s.Outputs != nil && s.Outputs.OCISpec != nil +} + func (s *OfflineS3Spec) HasCertificates() bool { return s.CABundle != nil } +func (s *OCISpec) HasCertificates() bool { + return s.CABundle != nil +} + +func (s *OCISpec) HasUsernamePassword() bool { + return s.UsernameRef != nil && s.PasswordRef != nil +} + +func (s *OCISpec) HasToken() bool { + return s.TokenRef != nil +} + // HasCustomOutput returns whether an LMEvalJobSpec defines custom outputs or not func (s *LMEvalJobSpec) HasCustomOutput() bool { return s.Outputs != nil @@ -604,6 +657,11 @@ func (o *Outputs) HasExistingPVC() bool { return o.PersistentVolumeClaimName != nil } +// HasOCI returns whether the outputs define OCI upload +func (o *Outputs) HasOCI() bool { + return o != nil && o.OCISpec != nil +} + // LMEvalJobStatus defines the observed state of LMEvalJob type LMEvalJobStatus struct { // Important: Run "make" to regenerate code after modifying this file diff --git a/api/lmes/v1alpha1/zz_generated.deepcopy.go b/api/lmes/v1alpha1/zz_generated.deepcopy.go index 77e4c07b7..52cae752a 100644 --- a/api/lmes/v1alpha1/zz_generated.deepcopy.go +++ b/api/lmes/v1alpha1/zz_generated.deepcopy.go @@ -433,6 +433,48 @@ func (in *Metric) DeepCopy() *Metric { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *OCISpec) DeepCopyInto(out *OCISpec) { + *out = *in + in.Registry.DeepCopyInto(&out.Registry) + in.Repository.DeepCopyInto(&out.Repository) + if in.UsernameRef != nil { + in, out := &in.UsernameRef, &out.UsernameRef + *out = new(v1.SecretKeySelector) + (*in).DeepCopyInto(*out) + } + if in.PasswordRef != nil { + in, out := &in.PasswordRef, &out.PasswordRef + *out = new(v1.SecretKeySelector) + (*in).DeepCopyInto(*out) + } + if in.TokenRef != nil { + in, out := &in.TokenRef, &out.TokenRef + *out = new(v1.SecretKeySelector) + (*in).DeepCopyInto(*out) + } + if in.VerifySSL != nil { + in, out := &in.VerifySSL, &out.VerifySSL + *out = new(bool) + **out = **in + } + if in.CABundle != nil { + in, out := &in.CABundle, &out.CABundle + *out = new(v1.SecretKeySelector) + (*in).DeepCopyInto(*out) + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new OCISpec. +func (in *OCISpec) DeepCopy() *OCISpec { + if in == nil { + return nil + } + out := new(OCISpec) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *OfflineS3Spec) DeepCopyInto(out *OfflineS3Spec) { *out = *in @@ -517,6 +559,11 @@ func (in *Outputs) DeepCopyInto(out *Outputs) { *out = new(PersistentVolumeClaimManaged) **out = **in } + if in.OCISpec != nil { + in, out := &in.OCISpec, &out.OCISpec + *out = new(OCISpec) + (*in).DeepCopyInto(*out) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Outputs. diff --git a/cmd/lmes_driver/main.go b/cmd/lmes_driver/main.go index 950c1d665..8281a8d33 100644 --- a/cmd/lmes_driver/main.go +++ b/cmd/lmes_driver/main.go @@ -60,6 +60,7 @@ var ( detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU") commPort = flag.Int("listen-port", driver.DefaultPort, "driver serves APIs on the port") downloadAssetsS3 = flag.Bool("download-assets-s3", false, "Download assets from S3") + uploadToOCI = flag.Bool("upload-to-oci", false, "Upload results to OCI registry") customTaskGitURL = flag.String("custom-task-git-url", "", "Git repository URL for custom tasks") customTaskGitBranch = flag.String("custom-task-git-branch", "", "Git repository branch for custom tasks") customTaskGitCommit = flag.String("custom-task-git-commit", "", "Git commit for custom tasks") @@ -129,6 +130,7 @@ func main() { Args: args, CommPort: *commPort, DownloadAssetsS3: *downloadAssetsS3, + UploadToOCI: *uploadToOCI, CustomTaskGitURL: *customTaskGitURL, CustomTaskGitBranch: *customTaskGitBranch, CustomTaskGitCommit: *customTaskGitCommit, diff --git a/config/base/params.env b/config/base/params.env index 9b8581d0b..ddf3c5322 100644 --- a/config/base/params.env +++ b/config/base/params.env @@ -3,7 +3,7 @@ trustyaiOperatorImage=quay.io/trustyai/trustyai-service-operator:latest oauthProxyImage=quay.io/openshift/origin-oauth-proxy:4.14.0 kServeServerless=enabled lmes-driver-image=quay.io/trustyai/ta-lmes-driver:latest -lmes-pod-image=quay.io/trustyai/ta-lmes-job:latest +lmes-pod-image=quay.io/mmortari/lm-evaluation-harness/job:latest lmes-pod-checking-interval=10s lmes-image-pull-policy=Always lmes-max-batch-size=24 diff --git a/config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml b/config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml index 033df5e83..3989bda56 100644 --- a/config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml +++ b/config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml @@ -276,6 +276,151 @@ spec: outputs: description: Outputs specifies storage for evaluation results properties: + oci: + description: Upload results to OCI registry + properties: + caBundle: + description: CA bundle for custom certificates + properties: + key: + description: The key of the secret to select from. Must + be a valid secret key. + type: string + name: + description: |- + Name of the referent. + More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names + type: string + optional: + description: Specify whether the Secret or its key must + be defined + type: boolean + required: + - key + type: object + x-kubernetes-map-type: atomic + password: + description: Password for registry authentication + properties: + key: + description: The key of the secret to select from. Must + be a valid secret key. + type: string + name: + description: |- + Name of the referent. + More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names + type: string + optional: + description: Specify whether the Secret or its key must + be defined + type: boolean + required: + - key + type: object + x-kubernetes-map-type: atomic + path: + description: Path within the results to package as artifact + pattern: ^[a-zA-Z0-9._/-]*$ + type: string + registry: + description: Registry URL (e.g., quay.io, registry.redhat.com) + properties: + key: + description: The key of the secret to select from. Must + be a valid secret key. + type: string + name: + description: |- + Name of the referent. + More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names + type: string + optional: + description: Specify whether the Secret or its key must + be defined + type: boolean + required: + - key + type: object + x-kubernetes-map-type: atomic + repository: + description: Repository path (e.g., myorg/evaluation-results) + properties: + key: + description: The key of the secret to select from. Must + be a valid secret key. + type: string + name: + description: |- + Name of the referent. + More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names + type: string + optional: + description: Specify whether the Secret or its key must + be defined + type: boolean + required: + - key + type: object + x-kubernetes-map-type: atomic + subject: + description: Subject for the OCI artifact + pattern: ^[a-zA-Z0-9._:/@-]*$ + type: string + tag: + description: Optional tag for the artifact (defaults to job + name if not specified) + pattern: ^[a-zA-Z0-9._-]*$ + type: string + token: + description: Token for registry authentication (alternative + to username/password) + properties: + key: + description: The key of the secret to select from. Must + be a valid secret key. + type: string + name: + description: |- + Name of the referent. + More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names + type: string + optional: + description: Specify whether the Secret or its key must + be defined + type: boolean + required: + - key + type: object + x-kubernetes-map-type: atomic + username: + description: Username for registry authentication + properties: + key: + description: The key of the secret to select from. Must + be a valid secret key. + type: string + name: + description: |- + Name of the referent. + More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names + type: string + optional: + description: Specify whether the Secret or its key must + be defined + type: boolean + required: + - key + type: object + x-kubernetes-map-type: atomic + verifySSL: + description: Whether to verify SSL certificates + type: boolean + required: + - path + - registry + - repository + type: object pvcManaged: description: Create an operator managed PVC properties: diff --git a/config/overlays/odh/params.env b/config/overlays/odh/params.env index 91c3ea58e..dabc63d88 100644 --- a/config/overlays/odh/params.env +++ b/config/overlays/odh/params.env @@ -3,7 +3,7 @@ trustyaiOperatorImage=quay.io/opendatahub/trustyai-service-operator:latest oauthProxyImage=quay.io/openshift/origin-oauth-proxy:4.14.0 kServeServerless=enabled lmes-driver-image=quay.io/opendatahub/ta-lmes-driver:latest -lmes-pod-image=quay.io/opendatahub/ta-lmes-job:latest +lmes-pod-image=quay.io/mmortari/lm-evaluation-harness/job:latest lmes-pod-checking-interval=10s lmes-image-pull-policy=Always lmes-max-batch-size=24 diff --git a/controllers/dsc/config.go b/controllers/dsc/config.go index c389bcfc2..7489c5623 100644 --- a/controllers/dsc/config.go +++ b/controllers/dsc/config.go @@ -50,7 +50,7 @@ func (r *DSCConfigReader) ReadDSCConfig(ctx context.Context, log *logr.Logger) ( if errors.IsNotFound(err) { log.V(1).Info("DSC ConfigMap not found, using default configuration", "configmap", configMapKey) - return &DSCConfig{}, nil + return nil, nil } return nil, fmt.Errorf("error reading DSC ConfigMap %s: %w", configMapKey, err) } diff --git a/controllers/dsc/config_test.go b/controllers/dsc/config_test.go index ac9cabc65..c12bfbf38 100644 --- a/controllers/dsc/config_test.go +++ b/controllers/dsc/config_test.go @@ -26,12 +26,12 @@ func TestDSCConfigReader_ReadDSCConfig(t *testing.T) { expectedCode bool }{ { - name: "ConfigMap not found - should use defaults", + name: "ConfigMap not found - should return nil", namespace: "test-namespace", configMapData: nil, expectError: false, - expectedOnline: false, // Default value - expectedCode: false, // Default value + expectedOnline: false, // Should not be applied + expectedCode: false, // Should not be applied }, { name: "Valid configuration with both settings enabled", @@ -121,9 +121,14 @@ func TestDSCConfigReader_ReadDSCConfig(t *testing.T) { assert.NoError(t, err) } - // Assert configuration values - assert.Equal(t, tt.expectedOnline, dscConfig.AllowOnline, "AllowOnline should match expected value") - assert.Equal(t, tt.expectedCode, dscConfig.AllowCodeExecution, "AllowCodeExecution should match expected value") + // Assert configuration values - only if dscConfig is not nil + if dscConfig != nil { + assert.Equal(t, tt.expectedOnline, dscConfig.AllowOnline, "AllowOnline should match expected value") + assert.Equal(t, tt.expectedCode, dscConfig.AllowCodeExecution, "AllowCodeExecution should match expected value") + } else { + // When ConfigMap is not found, dscConfig should be nil + assert.Nil(t, dscConfig, "DSC config should be nil when ConfigMap is not found") + } }) } } @@ -146,9 +151,8 @@ func TestDSCConfigReader_ReadDSCConfig_ConfigMapNotFound(t *testing.T) { // Should not return error when ConfigMap is not found assert.NoError(t, err) - // Values should be at defaults - assert.False(t, dscConfig.AllowOnline) - assert.False(t, dscConfig.AllowCodeExecution) + // DSC config should be nil when ConfigMap is not found + assert.Nil(t, dscConfig, "DSC config should be nil when ConfigMap is not found") } func TestDSCConfigReader_ReadDSCConfig_ClientError(t *testing.T) { diff --git a/controllers/job_mgr/job_mgr_controller.go b/controllers/job_mgr/job_mgr_controller.go index 675cfea74..c46f22823 100644 --- a/controllers/job_mgr/job_mgr_controller.go +++ b/controllers/job_mgr/job_mgr_controller.go @@ -114,7 +114,13 @@ func (job *LMEvalJob) Finished() (condition metav1.Condition, finished bool) { // PodSets will build workload podSets corresponding to the job. func (job *LMEvalJob) PodSets() []kueue.PodSet { log := log.FromContext(context.TODO()) - pod := lmes.CreatePod(lmes.Options, &job.LMEvalJob, log) + // Use global Options permissions for job manager. + // This will be updated before every job deployment. + permConfig := &lmes.PermissionConfig{ + AllowOnline: lmes.Options.AllowOnline, + AllowCodeExecution: lmes.Options.AllowCodeExecution, + } + pod := lmes.CreatePod(lmes.Options, &job.LMEvalJob, permConfig, log) podSet := kueue.PodSet{ Name: job.GetPodName(), Count: 1, diff --git a/controllers/lmes/config.go b/controllers/lmes/config.go index cbf3fbf36..d2c10a84c 100644 --- a/controllers/lmes/config.go +++ b/controllers/lmes/config.go @@ -17,6 +17,7 @@ limitations under the License. package lmes import ( + "context" "fmt" "reflect" "strconv" @@ -27,6 +28,8 @@ import ( "github.com/trustyai-explainability/trustyai-service-operator/controllers/dsc" "github.com/trustyai-explainability/trustyai-service-operator/controllers/lmes/driver" corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" ) // set by job_mgr controllerSetup func @@ -118,9 +121,68 @@ func constructOptionsFromConfigMap(log *logr.Logger, configmap *corev1.ConfigMap } // ApplyDSCConfig applies DSC configuration to the LMES Options +// Only applies configuration if DSC config is available and not nil func ApplyDSCConfig(dscConfig *dsc.DSCConfig) { if dscConfig != nil { Options.AllowOnline = dscConfig.AllowOnline Options.AllowCodeExecution = dscConfig.AllowCodeExecution } } + +// PermissionConfig holds the effective permissions for LMEval jobs +type PermissionConfig struct { + AllowOnline bool + AllowCodeExecution bool +} + +// ReadEffectivePermissions reads both the default configmap and DSC config, +// returning the effective permissions (DSC config overrides defaults) +func ReadEffectivePermissions(ctx context.Context, client client.Client, namespace, configMapName string, log *logr.Logger) (*PermissionConfig, error) { + // Start with default values + config := &PermissionConfig{ + AllowOnline: false, + AllowCodeExecution: false, + } + + // Read default configmap first + var cm corev1.ConfigMap + if err := client.Get(ctx, types.NamespacedName{Namespace: namespace, Name: configMapName}, &cm); err != nil { + log.Error(err, "failed to get default configmap", "namespace", namespace, "name", configMapName) + return nil, err + } + + // Apply default config values for permissions + if allowOnlineStr, found := cm.Data[AllowOnline]; found { + if allowOnline, err := strconv.ParseBool(allowOnlineStr); err == nil { + config.AllowOnline = allowOnline + } + } + if allowCodeExecutionStr, found := cm.Data[AllowCodeExecution]; found { + if allowCodeExecution, err := strconv.ParseBool(allowCodeExecutionStr); err == nil { + config.AllowCodeExecution = allowCodeExecution + } + } + + // Read DSC configuration if available and override defaults + dscReader := dsc.NewDSCConfigReader(client, namespace) + if dscConfig, err := dscReader.ReadDSCConfig(ctx, log); err != nil { + log.Error(err, "failed to read DSC configuration, using default config") + // Continue with default config - DSC config is optional + } else if dscConfig != nil { + config.AllowOnline = dscConfig.AllowOnline + config.AllowCodeExecution = dscConfig.AllowCodeExecution + log.V(1).Info("Applied DSC configuration overrides", + "allowOnline", dscConfig.AllowOnline, + "allowCodeExecution", dscConfig.AllowCodeExecution) + } + + return config, nil +} + +// NewDefaultPermissionConfig creates a permission config with default values for testing +func NewDefaultPermissionConfig() *PermissionConfig { + return &PermissionConfig{ + AllowOnline: false, + AllowCodeExecution: false, + } +} diff --git a/controllers/lmes/driver/driver.go b/controllers/lmes/driver/driver.go index a3f7743b1..ab29840a0 100644 --- a/controllers/lmes/driver/driver.go +++ b/controllers/lmes/driver/driver.go @@ -67,6 +67,7 @@ type DriverOption struct { Args []string CommPort int DownloadAssetsS3 bool + UploadToOCI bool CustomTaskGitURL string CustomTaskGitBranch string CustomTaskGitCommit string @@ -242,6 +243,64 @@ func (d *driverImpl) downloadS3Assets() error { return nil } +func (d *driverImpl) uploadToOCI() error { + if d == nil || !d.Option.UploadToOCI { + return nil + } + + fmt.Println("Uploading results to OCI registry") + + // Build command arguments: scripts/oci.py + registryFromEnv := os.Getenv("OCI_REGISTRY") + if registryFromEnv == "" { + return fmt.Errorf("OCI_REGISTRY environment variable not set") + } + + pathFromEnv := os.Getenv("OCI_PATH") + var resultsLocation string + if pathFromEnv == "" { + // If OCI_PATH is not set, use the output path directly + resultsLocation = d.Option.OutputPath + } else { + // If OCI_PATH is set, join it with the output path + resultsLocation = filepath.Join(d.Option.OutputPath, pathFromEnv) + } + + cmd := []string{"python", "/opt/app-root/src/scripts/oci.py", registryFromEnv, resultsLocation} + fmt.Printf("[DEBUG] OCI upload CLI: %v\n", cmd) + + // List all files and directories in resultsLocation + fmt.Printf("[DEBUG] Contents of results location (%s):\n", resultsLocation) + _ = filepath.Walk(resultsLocation, func(path string, info os.FileInfo, err error) error { + if err != nil { + fmt.Printf(" [error] %v\n", err) + return nil + } + rel, _ := filepath.Rel(resultsLocation, path) + if rel == "." { + fmt.Printf(" %s/\n", rel) + } else if info.IsDir() { + fmt.Printf(" %s/\n", rel) + } else { + fmt.Printf(" %s\n", rel) + } + return nil + }) + + output, err := exec.Command( + "python", + "/opt/app-root/src/scripts/oci.py", + registryFromEnv, + resultsLocation, + ).Output() + fmt.Println(string(output)) + if err != nil { + return fmt.Errorf("failed to upload results to OCI: %v", err) + } + + return nil +} + func patchDevice(args []string, hasCuda bool) []string { device := "cpu" if hasCuda { @@ -443,6 +502,11 @@ func (d *driverImpl) updateCompleteStatus(err error) { var results string results, err = d.getResults() d.status.Results = results + + // Upload results to OCI if configured + if err == nil { + err = d.uploadToOCI() + } } if err != nil { diff --git a/controllers/lmes/driver/driver_test.go b/controllers/lmes/driver/driver_test.go index 1d9e9efac..71a6fa495 100644 --- a/controllers/lmes/driver/driver_test.go +++ b/controllers/lmes/driver/driver_test.go @@ -443,3 +443,151 @@ func Test_ProgramError(t *testing.T) { assert.Nil(t, driver.Shutdown()) } + +func Test_OCIUploadSuccess(t *testing.T) { + info := setupTest(t, true) + defer info.tearDown(t) + + // Set up environment variables for OCI + os.Setenv("OCI_REGISTRY", "registry.example.com") + os.Setenv("OCI_PATH", "results") + defer func() { + os.Unsetenv("OCI_REGISTRY") + os.Unsetenv("OCI_PATH") + }() + + driver, err := NewDriver(&DriverOption{ + Context: context.Background(), + OutputPath: info.outputPath, + CatalogPath: info.catalogPath, + Logger: driverLog, + Args: []string{"sh", "-ec", "echo 'test completed'"}, + CommPort: info.port, + UploadToOCI: true, + }) + assert.Nil(t, err) + + // This will fail because the OCI script doesn't exist, but we can test the setup + msgs, _ := runDriverAndWait4Complete(t, driver, true) + + // Should fail during OCI upload since script doesn't exist + assert.Contains(t, msgs[len(msgs)-1], "failed to upload results to OCI") + + assert.Nil(t, driver.Shutdown()) +} + +func Test_OCIUploadDisabled(t *testing.T) { + info := setupTest(t, true) + defer info.tearDown(t) + + driver, err := NewDriver(&DriverOption{ + Context: context.Background(), + OutputPath: info.outputPath, + CatalogPath: info.catalogPath, + Logger: driverLog, + Args: []string{"sh", "-ec", "echo 'test completed'"}, + CommPort: info.port, + UploadToOCI: false, + }) + assert.Nil(t, err) + + msgs, _ := runDriverAndWait4Complete(t, driver, false) + + assert.Contains(t, msgs, "job completed", "Should complete successfully") + // The first message may vary depending on timing, so just check that we completed + + assert.Nil(t, driver.Shutdown()) +} + +func Test_OCIUploadMissingRegistry(t *testing.T) { + info := setupTest(t, true) + defer info.tearDown(t) + + // Don't set OCI_REGISTRY environment variable + os.Unsetenv("OCI_REGISTRY") + + driver, err := NewDriver(&DriverOption{ + Context: context.Background(), + OutputPath: info.outputPath, + CatalogPath: info.catalogPath, + Logger: driverLog, + Args: []string{"sh", "-ec", "echo 'test completed'"}, + CommPort: info.port, + UploadToOCI: true, + }) + assert.Nil(t, err) + + msgs, _ := runDriverAndWait4Complete(t, driver, true) + + // Should fail with missing registry error + assert.Contains(t, msgs[len(msgs)-1], "OCI_REGISTRY environment variable not set") + + assert.Nil(t, driver.Shutdown()) +} + +func Test_OCIUploadToOCIFunction(t *testing.T) { + info := setupTest(t, false) + defer info.tearDown(t) + + tests := []struct { + name string + uploadToOCI bool + registryEnv string + pathEnv string + expectError bool + expectedErrMsg string + }{ + { + name: "disabled OCI upload", + uploadToOCI: false, + expectError: false, + }, + { + name: "missing registry env", + uploadToOCI: true, + registryEnv: "", + expectError: true, + expectedErrMsg: "OCI_REGISTRY environment variable not set", + }, + { + name: "script execution fails", + uploadToOCI: true, + registryEnv: "registry.example.com", + pathEnv: "results", + expectError: true, + expectedErrMsg: "failed to upload results to OCI", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up environment + if tt.registryEnv != "" { + os.Setenv("OCI_REGISTRY", tt.registryEnv) + defer os.Unsetenv("OCI_REGISTRY") + } + if tt.pathEnv != "" { + os.Setenv("OCI_PATH", tt.pathEnv) + defer os.Unsetenv("OCI_PATH") + } + + driver := &driverImpl{ + Option: &DriverOption{ + OutputPath: info.outputPath, + UploadToOCI: tt.uploadToOCI, + }, + } + + err := driver.uploadToOCI() + + if tt.expectError { + assert.NotNil(t, err) + if tt.expectedErrMsg != "" { + assert.Contains(t, err.Error(), tt.expectedErrMsg) + } + } else { + assert.Nil(t, err) + } + }) + } +} diff --git a/controllers/lmes/lmevaljob_controller.go b/controllers/lmes/lmevaljob_controller.go index 57df5b404..a1e8a75af 100644 --- a/controllers/lmes/lmevaljob_controller.go +++ b/controllers/lmes/lmevaljob_controller.go @@ -276,6 +276,7 @@ func (r *LMEvalJobReconciler) SetupWithManager(mgr ctrl.Manager) error { if err := constructOptionsFromConfigMap(&log, &cm); err != nil { return err } + log.Info("Constructed options from configmap", "options", Options) // Read DSC configuration if available dscReader := dsc.NewDSCConfigReader(r.Client, r.Namespace) @@ -286,6 +287,7 @@ func (r *LMEvalJobReconciler) SetupWithManager(mgr ctrl.Manager) error { ApplyDSCConfig(dscConfig) } + log.Info("Applied DSC configuration", "options", Options) return nil })); err != nil { return err @@ -518,9 +520,23 @@ func (r *LMEvalJobReconciler) handleNewCR(ctx context.Context, log logr.Logger, return ctrl.Result{}, err } + // Read current permissions for this job deployment + permConfig, err := ReadEffectivePermissions(ctx, r.Client, r.Namespace, r.ConfigMap, &log) + if err != nil { + // Failed to read permissions. Mark the status as complete with failed + job.Status.State = lmesv1alpha1.CompleteJobState + job.Status.Reason = lmesv1alpha1.FailedReason + job.Status.Message = fmt.Sprintf("Failed to read configuration: %s", err.Error()) + if err := r.Status().Update(ctx, job); err != nil { + log.Error(err, "unable to update LMEvalJob status for config read failure") + } + log.Error(err, "Failed to read permissions for LMEvalJob", "name", job.Name) + return ctrl.Result{}, err + } + // construct a new pod and create a pod for the job currentTime := v1.Now() - pod := CreatePod(Options, job, log) + pod := CreatePod(Options, job, permConfig, log) if err := r.Create(ctx, pod, &client.CreateOptions{}); err != nil { // Failed to create the pod. Mark the status as complete with failed job.Status.State = lmesv1alpha1.CompleteJobState @@ -738,13 +754,21 @@ func (r *LMEvalJobReconciler) handleSuspend(ctx context.Context, log logr.Logger func (r *LMEvalJobReconciler) handleResume(ctx context.Context, log logr.Logger, job *lmesv1alpha1.LMEvalJob) (ctrl.Result, error) { log.Info("Resume job") - pod := CreatePod(Options, job, log) - if err := r.Create(ctx, pod); err != nil { - log.Error(err, "failed to create pod to resume job") + + // Read effective permissions for this job deployment + permConfig, err := ReadEffectivePermissions(ctx, r.Client, r.Namespace, r.ConfigMap, &log) + if err != nil { + log.Error(err, "Failed to read effective permissions for LMEvalJob resume", "name", job.Name) + return r.pullingJobs.addOrUpdate(string(job.GetUID()), Options.PodCheckingInterval), err + } + + pod := CreatePod(Options, job, permConfig, log) + if createErr := r.Create(ctx, pod); createErr != nil { + log.Error(createErr, "failed to create pod to resume job") return r.pullingJobs.addOrUpdate(string(job.GetUID()), Options.PodCheckingInterval), nil } job.Status.State = lmesv1alpha1.ScheduledJobState - err := r.Status().Update(ctx, job) + err = r.Status().Update(ctx, job) if err != nil { log.Error(err, "failed to update job status to scheduled") } @@ -863,7 +887,7 @@ func unmarshal(custom string, props []string) (map[string]interface{}, error) { return obj, nil } -func CreatePod(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Logger) *corev1.Pod { +func CreatePod(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, permConfig *PermissionConfig, log logr.Logger) *corev1.Pod { var envVars = removeProtectedEnvVars(job.Spec.Pod.GetContainer().GetEnv()) @@ -887,13 +911,12 @@ func CreatePod(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Lo }, } - if job.Spec.HasCustomOutput() { + if job.Spec.Outputs != nil && (job.Spec.Outputs.HasManagedPVC() || job.Spec.Outputs.HasExistingPVC()) { outputPVCMount := corev1.VolumeMount{ Name: "outputs", MountPath: OutputPath, } volumeMounts = append(volumeMounts, outputPVCMount) - } var volumes = []corev1.Volume{ @@ -904,7 +927,7 @@ func CreatePod(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Lo }, } - if job.Spec.HasCustomOutput() { + if job.Spec.Outputs != nil && (job.Spec.Outputs.HasManagedPVC() || job.Spec.Outputs.HasExistingPVC()) { var claimName string if job.Spec.Outputs.HasManagedPVC() { @@ -965,7 +988,7 @@ func CreatePod(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Lo if job.Spec.AllowCodeExecution != nil && *job.Spec.AllowCodeExecution { // Disable remote code execution by default - if !svcOpts.AllowCodeExecution { + if !permConfig.AllowCodeExecution { log.Error(fmt.Errorf("code execution not allowed by the operator"), "change this setting and redeploy the operator") envVars = append(envVars, disallowRemoteCodeEnvVars...) } else { @@ -1002,7 +1025,7 @@ func CreatePod(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Lo // Enforce offline mode by default if job.Spec.AllowOnline != nil && *job.Spec.AllowOnline { - if !svcOpts.AllowOnline { + if !permConfig.AllowOnline { log.Error(fmt.Errorf("online mode not allowed by the operator"), "change this setting and redeploy the operator") envVars = append(envVars, offlineHuggingFaceEnvVars...) } @@ -1142,6 +1165,120 @@ func CreatePod(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Lo } + // Always add OCI env vars if configured, regardless of offline/online mode + if job.Spec.HasOCIOutput() && job.Spec.Outputs != nil && job.Spec.Outputs.OCISpec != nil { + ociEnvVars := []corev1.EnvVar{ + { + Name: "OCI_REGISTRY", + ValueFrom: &corev1.EnvVarSource{ + SecretKeyRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: job.Spec.Outputs.OCISpec.Registry.Name, + }, + Key: job.Spec.Outputs.OCISpec.Registry.Key, + }, + }, + }, + { + Name: "OCI_REPOSITORY", + ValueFrom: &corev1.EnvVarSource{ + SecretKeyRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: job.Spec.Outputs.OCISpec.Repository.Name, + }, + Key: job.Spec.Outputs.OCISpec.Repository.Key, + }, + }, + }, + { + Name: "OCI_PATH", + Value: job.Spec.Outputs.OCISpec.Path, + }, + } + + // Add tag if specified, otherwise driver will use job name as default + if job.Spec.Outputs.OCISpec.Tag != "" { + ociEnvVars = append(ociEnvVars, corev1.EnvVar{ + Name: "OCI_TAG", + Value: job.Spec.Outputs.OCISpec.Tag, + }) + } + + // Add subject if specified + if job.Spec.Outputs.OCISpec.Subject != "" { + ociEnvVars = append(ociEnvVars, corev1.EnvVar{ + Name: "OCI_SUBJECT", + Value: job.Spec.Outputs.OCISpec.Subject, + }) + } + + // Handle authentication - either username/password or token + if job.Spec.Outputs.OCISpec.HasUsernamePassword() { + ociAuthEnvVars := []corev1.EnvVar{ + { + Name: "OCI_USERNAME", + ValueFrom: &corev1.EnvVarSource{ + SecretKeyRef: job.Spec.Outputs.OCISpec.UsernameRef, + }, + }, + { + Name: "OCI_PASSWORD", + ValueFrom: &corev1.EnvVarSource{ + SecretKeyRef: job.Spec.Outputs.OCISpec.PasswordRef, + }, + }, + } + ociEnvVars = append(ociEnvVars, ociAuthEnvVars...) + } else if job.Spec.Outputs.OCISpec.HasToken() { + ociTokenEnvVar := corev1.EnvVar{ + Name: "OCI_TOKEN", + ValueFrom: &corev1.EnvVarSource{ + SecretKeyRef: job.Spec.Outputs.OCISpec.TokenRef, + }, + } + ociEnvVars = append(ociEnvVars, ociTokenEnvVar) + } + + // Handle SSL verification + ociVerifySSL := "true" + if job.Spec.Outputs.OCISpec.VerifySSL != nil { + ociVerifySSL = strconv.FormatBool(*job.Spec.Outputs.OCISpec.VerifySSL) + } + ociEnvVars = append(ociEnvVars, corev1.EnvVar{ + Name: "OCI_VERIFY_SSL", + Value: ociVerifySSL, + }) + + envVars = append(envVars, ociEnvVars...) + + // If certificates are specified, create volume to hold them + if job.Spec.Outputs.OCISpec.HasCertificates() { + ociCertificatesMount := corev1.VolumeMount{ + Name: "certificates-oci", + MountPath: "/etc/certificates/oci", + } + volumeMounts = append(volumeMounts, ociCertificatesMount) + + ociCertificatesVolume := corev1.Volume{ + Name: "certificates-oci", + VolumeSource: corev1.VolumeSource{ + ConfigMap: &corev1.ConfigMapVolumeSource{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: job.Spec.Outputs.OCISpec.CABundle.Name, + }, + }, + }, + } + volumes = append(volumes, ociCertificatesVolume) + + ociCertificateEnvVar := corev1.EnvVar{ + Name: "OCI_CA_BUNDLE", + Value: fmt.Sprintf("/etc/certificates/oci/%s", job.Spec.Outputs.OCISpec.CABundle.Key), + } + envVars = append(envVars, ociCertificateEnvVar) + } + } + volumes = append(volumes, job.Spec.Pod.GetVolumes()...) volumeMounts = append(volumeMounts, job.Spec.Pod.GetContainer().GetVolumMounts()...) labels := getPodLabels(job.Labels, log) @@ -1156,7 +1293,7 @@ func CreatePod(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Lo Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, Env: envVars, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, permConfig), Args: generateArgs(svcOpts, job, log), SecurityContext: mainSecurityContext, VolumeMounts: volumeMounts, @@ -1401,7 +1538,7 @@ func concatTasks(tasks lmesv1alpha1.TaskList) []string { return append(tasks.TaskNames, recipesName...) } -func generateCmd(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob) []string { +func generateCmd(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, permConfig *PermissionConfig) []string { if job == nil { return nil } @@ -1414,6 +1551,10 @@ func generateCmd(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob) []string cmds = append(cmds, "--download-assets-s3") } + if job.Spec.HasOCIOutput() { + cmds = append(cmds, "--upload-to-oci") + } + if svcOpts.DetectDevice { cmds = append(cmds, "--detect-device") } @@ -1422,7 +1563,7 @@ func generateCmd(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob) []string cmds = append(cmds, "--listen-port", fmt.Sprintf("%d", svcOpts.DriverPort)) } - if job.Spec.AllowOnline != nil && *job.Spec.AllowOnline && svcOpts.AllowOnline { + if job.Spec.AllowOnline != nil && *job.Spec.AllowOnline && permConfig.AllowOnline { cmds = append(cmds, "--allow-online") } diff --git a/controllers/lmes/lmevaljob_controller_test.go b/controllers/lmes/lmevaljob_controller_test.go index 724a2f340..6fcc07068 100644 --- a/controllers/lmes/lmevaljob_controller_test.go +++ b/controllers/lmes/lmevaljob_controller_test.go @@ -114,7 +114,7 @@ func Test_SimplePod(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, NewDefaultPermissionConfig()), Args: generateArgs(svcOpts, job, log), SecurityContext: defaultSecurityContext, VolumeMounts: []corev1.VolumeMount{ @@ -188,7 +188,7 @@ func Test_SimplePod(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(svcOpts, job, NewDefaultPermissionConfig(), log) assert.Equal(t, expect, newPod) } @@ -334,7 +334,7 @@ func Test_WithCustomPod(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, NewDefaultPermissionConfig()), Args: generateArgs(svcOpts, job, log), SecurityContext: &corev1.SecurityContext{ RunAsUser: &runAsUser, @@ -453,7 +453,7 @@ func Test_WithCustomPod(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(svcOpts, job, NewDefaultPermissionConfig(), log) assert.Equal(t, expect, newPod) @@ -468,7 +468,7 @@ func Test_WithCustomPod(t *testing.T) { "custom/annotation1": "annotation1", } - newPod = CreatePod(svcOpts, job, log) + newPod = CreatePod(svcOpts, job, NewDefaultPermissionConfig(), log) assert.Equal(t, expect, newPod) } @@ -621,7 +621,7 @@ func Test_EnvSecretsPod(t *testing.T) { Value: "True", }, }, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, NewDefaultPermissionConfig()), Args: generateArgs(svcOpts, job, log), SecurityContext: defaultSecurityContext, VolumeMounts: []corev1.VolumeMount{ @@ -644,7 +644,7 @@ func Test_EnvSecretsPod(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(svcOpts, job, NewDefaultPermissionConfig(), log) // maybe only verify the envs: Containers[0].Env assert.Equal(t, expect, newPod) } @@ -746,7 +746,7 @@ func Test_FileSecretsPod(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, NewDefaultPermissionConfig()), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -839,7 +839,7 @@ func Test_FileSecretsPod(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(svcOpts, job, NewDefaultPermissionConfig(), log) // maybe only verify the envs: Containers[0].Env assert.Equal(t, expect, newPod) } @@ -1055,7 +1055,7 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) { "--output-path", "/opt/app-root/src/output", "--task-recipe", "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", "--", - }, generateCmd(svcOpts, job)) + }, generateCmd(svcOpts, job, NewDefaultPermissionConfig())) job.Spec.TaskList.TaskRecipes = append(job.Spec.TaskList.TaskRecipes, lmesv1alpha1.TaskRecipe{ @@ -1079,7 +1079,7 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) { "--task-recipe", "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", "--task-recipe", "card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10", "--", - }, generateCmd(svcOpts, job)) + }, generateCmd(svcOpts, job, NewDefaultPermissionConfig())) } func Test_GenerateArgCmdCustomCard(t *testing.T) { @@ -1137,7 +1137,7 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) { "--custom-artifact", `card|custom_0|{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "dutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, "--task-recipe", "card=cards.custom_0,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", "--", - }, generateCmd(svcOpts, job)) + }, generateCmd(svcOpts, job, NewDefaultPermissionConfig())) // add second task using custom recipe + custom template job.Spec.TaskList.TaskRecipes = append(job.Spec.TaskList.TaskRecipes, @@ -1173,7 +1173,7 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) { "--task-recipe", "card=cards.custom_1,template=templates.tp_0,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10", "--custom-artifact", `template|tp_0|{ "__type__": "input_output_template", "instruction": "In the following task, you translate a {text_type}.", "input_format": "Translate this {text_type} from {source_language} to {target_language}: {text}.", "target_prefix": "Translation: ", "output_format": "{translation}", "postprocessors": [ "processors.lower_case" ] }`, "--", - }, generateCmd(svcOpts, job)) + }, generateCmd(svcOpts, job, NewDefaultPermissionConfig())) // add third task using normal card + custom system_prompt job.Spec.TaskList.TaskRecipes = append(job.Spec.TaskList.TaskRecipes, @@ -1207,7 +1207,7 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) { "--custom-artifact", `template|tp_0|{ "__type__": "input_output_template", "instruction": "In the following task, you translate a {text_type}.", "input_format": "Translate this {text_type} from {source_language} to {target_language}: {text}.", "target_prefix": "Translation: ", "output_format": "{translation}", "postprocessors": [ "processors.lower_case" ] }`, "--custom-artifact", "system_prompt|sp_0|this is a custom system promp", "--", - }, generateCmd(svcOpts, job)) + }, generateCmd(svcOpts, job, NewDefaultPermissionConfig())) // add forth task using custom card + custom template + custom system_prompt // and reuse the template and system prompt @@ -1242,7 +1242,7 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) { "--custom-artifact", `template|tp_0|{ "__type__": "input_output_template", "instruction": "In the following task, you translate a {text_type}.", "input_format": "Translate this {text_type} from {source_language} to {target_language}: {text}.", "target_prefix": "Translation: ", "output_format": "{translation}", "postprocessors": [ "processors.lower_case" ] }`, "--custom-artifact", "system_prompt|sp_0|this is a custom system promp", "--", - }, generateCmd(svcOpts, job)) + }, generateCmd(svcOpts, job, NewDefaultPermissionConfig())) // add fifth task using regular card + custom template + custom system_prompt // both template and system prompt are new @@ -1288,7 +1288,7 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) { "--custom-artifact", "system_prompt|sp_0|this is a custom system promp", "--custom-artifact", "system_prompt|sp_1|this is a custom system promp2", "--", - }, generateCmd(svcOpts, job)) + }, generateCmd(svcOpts, job, NewDefaultPermissionConfig())) } func Test_CustomCardValidation(t *testing.T) { @@ -1702,7 +1702,7 @@ func Test_ManagedPVC(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, NewDefaultPermissionConfig()), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -1789,7 +1789,7 @@ func Test_ManagedPVC(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(svcOpts, job, NewDefaultPermissionConfig(), log) assert.Equal(t, expect, newPod) } @@ -1870,7 +1870,7 @@ func Test_ExistingPVC(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, NewDefaultPermissionConfig()), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -1956,7 +1956,7 @@ func Test_ExistingPVC(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(svcOpts, job, NewDefaultPermissionConfig(), log) assert.Equal(t, expect, newPod) } @@ -2048,7 +2048,7 @@ func Test_PVCPreference(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, NewDefaultPermissionConfig()), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -2146,7 +2146,7 @@ func Test_PVCPreference(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(svcOpts, job, NewDefaultPermissionConfig(), log) assert.Equal(t, expect, newPod) } @@ -2237,7 +2237,7 @@ func Test_OfflineMode(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, NewDefaultPermissionConfig()), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -2335,7 +2335,7 @@ func Test_OfflineMode(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(svcOpts, job, NewDefaultPermissionConfig(), log) assert.Equal(t, expect, newPod) } @@ -2453,7 +2453,7 @@ func Test_ProtectedVars(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, NewDefaultPermissionConfig()), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -2556,7 +2556,7 @@ func Test_ProtectedVars(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(svcOpts, job, NewDefaultPermissionConfig(), log) assert.Equal(t, expect, newPod) } @@ -2654,7 +2654,7 @@ func Test_OnlineModeDisabled(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, NewDefaultPermissionConfig()), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -2752,7 +2752,7 @@ func Test_OnlineModeDisabled(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(svcOpts, job, NewDefaultPermissionConfig(), log) assert.Equal(t, expect, newPod) } @@ -2846,7 +2846,7 @@ func Test_OnlineMode(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, NewDefaultPermissionConfig()), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -2924,7 +2924,7 @@ func Test_OnlineMode(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(svcOpts, job, NewDefaultPermissionConfig(), log) assert.Equal(t, expect, newPod) } @@ -3021,7 +3021,7 @@ func Test_AllowCodeOnlineMode(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, NewDefaultPermissionConfig()), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -3099,7 +3099,7 @@ func Test_AllowCodeOnlineMode(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(svcOpts, job, NewDefaultPermissionConfig(), log) assert.Equal(t, expect, newPod) } @@ -3194,7 +3194,7 @@ func Test_AllowCodeOfflineMode(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, NewDefaultPermissionConfig()), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -3292,7 +3292,7 @@ func Test_AllowCodeOfflineMode(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(svcOpts, job, NewDefaultPermissionConfig(), log) assert.Equal(t, expect, newPod) } @@ -3387,7 +3387,7 @@ func Test_OfflineModeWithOutput(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, NewDefaultPermissionConfig()), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -3497,7 +3497,7 @@ func Test_OfflineModeWithOutput(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(svcOpts, job, NewDefaultPermissionConfig(), log) assert.Equal(t, expect, newPod) } @@ -3607,7 +3607,7 @@ func Test_CustomTasksGitSource(t *testing.T) { }, } - pod := CreatePod(svcOpts, job, log) + pod := CreatePod(svcOpts, job, NewDefaultPermissionConfig(), log) require.NotNil(t, pod) @@ -3698,7 +3698,7 @@ func Test_CustomTasksGitSourceOfflineMode(t *testing.T) { logger := logr.Discard() - pod := CreatePod(Options, job, logger) + pod := CreatePod(Options, job, NewDefaultPermissionConfig(), logger) if pod == nil { t.Fatal("pod should not be nil") @@ -3846,3 +3846,553 @@ func Test_AllowCodeExecution(t *testing.T) { args := generateArgs(svcOpts, job, log) assert.Contains(t, args, "--confirm_run_unsafe_code") } + +// Test OCI controller functionality +func Test_OCICommandGeneration(t *testing.T) { + svcOpts := &serviceOptions{ + PodImage: "podimage:latest", + DriverImage: "driver:latest", + ImagePullPolicy: corev1.PullAlways, + MaxBatchSize: 20, + DefaultBatchSize: "4", + } + + t.Run("WithOCIOutput", func(t *testing.T) { + job := &lmesv1alpha1.LMEvalJob{ + ObjectMeta: v1.ObjectMeta{ + Name: "test", + Namespace: "default", + }, + Spec: lmesv1alpha1.LMEvalJobSpec{ + Model: "test", + TaskList: lmesv1alpha1.TaskList{ + TaskNames: []string{"task1"}, + }, + Outputs: &lmesv1alpha1.Outputs{ + OCISpec: &lmesv1alpha1.OCISpec{ + Registry: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "registry", + }, + Repository: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "repository", + }, + Path: "results", + }, + }, + }, + } + + cmds := generateCmd(svcOpts, job) + assert.Contains(t, cmds, "--upload-to-oci", "Should include OCI upload flag") + }) + + t.Run("WithoutOCIOutput", func(t *testing.T) { + job := &lmesv1alpha1.LMEvalJob{ + ObjectMeta: v1.ObjectMeta{ + Name: "test", + Namespace: "default", + }, + Spec: lmesv1alpha1.LMEvalJobSpec{ + Model: "test", + TaskList: lmesv1alpha1.TaskList{ + TaskNames: []string{"task1"}, + }, + }, + } + + cmds := generateCmd(svcOpts, job) + assert.NotContains(t, cmds, "--upload-to-oci", "Should not include OCI upload flag") + }) +} + +func Test_OCIPodConfiguration(t *testing.T) { + logger := log.FromContext(context.Background()) + svcOpts := &serviceOptions{ + PodImage: "podimage:latest", + DriverImage: "driver:latest", + ImagePullPolicy: corev1.PullAlways, + MaxBatchSize: 20, + DefaultBatchSize: "4", + } + + t.Run("BasicOCIConfiguration", func(t *testing.T) { + job := &lmesv1alpha1.LMEvalJob{ + ObjectMeta: v1.ObjectMeta{ + Name: "test-oci", + Namespace: "default", + }, + Spec: lmesv1alpha1.LMEvalJobSpec{ + Model: "test", + TaskList: lmesv1alpha1.TaskList{ + TaskNames: []string{"task1"}, + }, + Outputs: &lmesv1alpha1.Outputs{ + OCISpec: &lmesv1alpha1.OCISpec{ + Registry: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "registry", + }, + Repository: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "repository", + }, + Path: "results", + UsernameRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "username", + }, + PasswordRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "password", + }, + }, + }, + }, + } + + // Debug: Check if HasOCIOutput returns true + hasOCI := job.Spec.HasOCIOutput() + assert.True(t, hasOCI, "Job should have OCI output configured") + + pod := CreatePod(svcOpts, job, logger) + assert.NotNil(t, pod, "Should generate pod successfully") + + // Check environment variables + envVars := pod.Spec.Containers[0].Env + envMap := make(map[string]corev1.EnvVar) + for _, env := range envVars { + envMap[env.Name] = env + } + + // Verify OCI environment variables are set + assert.Contains(t, envMap, "OCI_REGISTRY", "Should have OCI_REGISTRY env var") + assert.Contains(t, envMap, "OCI_REPOSITORY", "Should have OCI_REPOSITORY env var") + assert.Contains(t, envMap, "OCI_PATH", "Should have OCI_PATH env var") + assert.Contains(t, envMap, "OCI_VERIFY_SSL", "Should have OCI_VERIFY_SSL env var") + assert.Contains(t, envMap, "OCI_USERNAME", "Should have OCI_USERNAME env var when auth is configured") + assert.Contains(t, envMap, "OCI_PASSWORD", "Should have OCI_PASSWORD env var when auth is configured") + assert.NotContains(t, envMap, "OCI_TOKEN", "Should not have OCI_TOKEN env var without token auth") + + // Verify environment variable sources + assert.Equal(t, "oci-secret", envMap["OCI_REGISTRY"].ValueFrom.SecretKeyRef.Name) + assert.Equal(t, "registry", envMap["OCI_REGISTRY"].ValueFrom.SecretKeyRef.Key) + assert.Equal(t, "oci-secret", envMap["OCI_REPOSITORY"].ValueFrom.SecretKeyRef.Name) + assert.Equal(t, "repository", envMap["OCI_REPOSITORY"].ValueFrom.SecretKeyRef.Key) + assert.Equal(t, "results", envMap["OCI_PATH"].Value) + assert.Equal(t, "true", envMap["OCI_VERIFY_SSL"].Value) + assert.Equal(t, "oci-secret", envMap["OCI_USERNAME"].ValueFrom.SecretKeyRef.Name) + assert.Equal(t, "username", envMap["OCI_USERNAME"].ValueFrom.SecretKeyRef.Key) + assert.Equal(t, "oci-secret", envMap["OCI_PASSWORD"].ValueFrom.SecretKeyRef.Name) + assert.Equal(t, "password", envMap["OCI_PASSWORD"].ValueFrom.SecretKeyRef.Key) + }) + + t.Run("OCIWithToken", func(t *testing.T) { + job := &lmesv1alpha1.LMEvalJob{ + ObjectMeta: v1.ObjectMeta{ + Name: "test-oci-token", + Namespace: "default", + }, + Spec: lmesv1alpha1.LMEvalJobSpec{ + Model: "test", + TaskList: lmesv1alpha1.TaskList{ + TaskNames: []string{"task1"}, + }, + Outputs: &lmesv1alpha1.Outputs{ + OCISpec: &lmesv1alpha1.OCISpec{ + Registry: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "registry", + }, + Repository: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "repository", + }, + Path: "results", + TokenRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-token"}, + Key: "token", + }, + }, + }, + }, + } + + pod := CreatePod(svcOpts, job, logger) + assert.NotNil(t, pod, "Should generate pod successfully") + + // Check environment variables + envVars := pod.Spec.Containers[0].Env + envMap := make(map[string]corev1.EnvVar) + for _, env := range envVars { + envMap[env.Name] = env + } + + // Verify token authentication is used + assert.Contains(t, envMap, "OCI_TOKEN", "Should have OCI_TOKEN env var") + assert.NotContains(t, envMap, "OCI_USERNAME", "Should not have OCI_USERNAME env var") + assert.NotContains(t, envMap, "OCI_PASSWORD", "Should not have OCI_PASSWORD env var") + + // Verify token source + assert.Equal(t, "oci-token", envMap["OCI_TOKEN"].ValueFrom.SecretKeyRef.Name) + assert.Equal(t, "token", envMap["OCI_TOKEN"].ValueFrom.SecretKeyRef.Key) + }) + + t.Run("OCIWithCustomTag", func(t *testing.T) { + job := &lmesv1alpha1.LMEvalJob{ + ObjectMeta: v1.ObjectMeta{ + Name: "test-oci-tag", + Namespace: "default", + }, + Spec: lmesv1alpha1.LMEvalJobSpec{ + Model: "test", + TaskList: lmesv1alpha1.TaskList{ + TaskNames: []string{"task1"}, + }, + Outputs: &lmesv1alpha1.Outputs{ + OCISpec: &lmesv1alpha1.OCISpec{ + Registry: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "registry", + }, + Repository: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "repository", + }, + Tag: "custom-tag-v1.0", + Path: "results", + TokenRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-token"}, + Key: "token", + }, + }, + }, + }, + } + + pod := CreatePod(svcOpts, job, logger) + assert.NotNil(t, pod, "Should generate pod successfully") + + // Check environment variables + envVars := pod.Spec.Containers[0].Env + envMap := make(map[string]corev1.EnvVar) + for _, env := range envVars { + envMap[env.Name] = env + } + + // Verify custom tag is set + assert.Contains(t, envMap, "OCI_TAG", "Should have OCI_TAG env var") + assert.Equal(t, "custom-tag-v1.0", envMap["OCI_TAG"].Value) + }) + + t.Run("OCIWithSubject", func(t *testing.T) { + job := &lmesv1alpha1.LMEvalJob{ + ObjectMeta: v1.ObjectMeta{ + Name: "test-oci-subject", + Namespace: "default", + }, + Spec: lmesv1alpha1.LMEvalJobSpec{ + Model: "test", + TaskList: lmesv1alpha1.TaskList{ + TaskNames: []string{"task1"}, + }, + Outputs: &lmesv1alpha1.Outputs{ + OCISpec: &lmesv1alpha1.OCISpec{ + Registry: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "registry", + }, + Repository: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "repository", + }, + Subject: "llama-2-7b-chat", + Path: "results", + TokenRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-token"}, + Key: "token", + }, + }, + }, + }, + } + + pod := CreatePod(svcOpts, job, logger) + assert.NotNil(t, pod, "Should generate pod successfully") + + // Check environment variables + envVars := pod.Spec.Containers[0].Env + envMap := make(map[string]corev1.EnvVar) + for _, env := range envVars { + envMap[env.Name] = env + } + + // Verify subject is set + assert.Contains(t, envMap, "OCI_SUBJECT", "Should have OCI_SUBJECT env var") + assert.Equal(t, "llama-2-7b-chat", envMap["OCI_SUBJECT"].Value) + }) + + t.Run("OCIWithSubjectOmitted", func(t *testing.T) { + job := &lmesv1alpha1.LMEvalJob{ + ObjectMeta: v1.ObjectMeta{ + Name: "test-oci-subject-omitted", + Namespace: "default", + }, + Spec: lmesv1alpha1.LMEvalJobSpec{ + Model: "test", + TaskList: lmesv1alpha1.TaskList{ + TaskNames: []string{"task1"}, + }, + Outputs: &lmesv1alpha1.Outputs{ + OCISpec: &lmesv1alpha1.OCISpec{ + Registry: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "registry", + }, + Repository: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "repository", + }, + Path: "results", + TokenRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-token"}, + Key: "token", + }, + }, + }, + }, + } + + pod := CreatePod(svcOpts, job, logger) + assert.NotNil(t, pod, "Should generate pod successfully") + + // Check environment variables + envVars := pod.Spec.Containers[0].Env + envMap := make(map[string]corev1.EnvVar) + for _, env := range envVars { + envMap[env.Name] = env + } + + // Verify OCI_SUBJECT is not set when omitted + assert.NotContains(t, envMap, "OCI_SUBJECT", "Should not have OCI_SUBJECT env var when omitted") + }) + + t.Run("OCIWithSubjectEmpty", func(t *testing.T) { + job := &lmesv1alpha1.LMEvalJob{ + ObjectMeta: v1.ObjectMeta{ + Name: "test-oci-subject-empty", + Namespace: "default", + }, + Spec: lmesv1alpha1.LMEvalJobSpec{ + Model: "test", + TaskList: lmesv1alpha1.TaskList{ + TaskNames: []string{"task1"}, + }, + Outputs: &lmesv1alpha1.Outputs{ + OCISpec: &lmesv1alpha1.OCISpec{ + Registry: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "registry", + }, + Repository: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "repository", + }, + Subject: "", + Path: "results", + TokenRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-token"}, + Key: "token", + }, + }, + }, + }, + } + + pod := CreatePod(svcOpts, job, logger) + assert.NotNil(t, pod, "Should generate pod successfully") + + // Check environment variables + envVars := pod.Spec.Containers[0].Env + envMap := make(map[string]corev1.EnvVar) + for _, env := range envVars { + envMap[env.Name] = env + } + + // Verify OCI_SUBJECT is not set when empty + assert.NotContains(t, envMap, "OCI_SUBJECT", "Should not have OCI_SUBJECT env var when empty") + }) + + t.Run("OCIWithSubjectSpecialChars", func(t *testing.T) { + job := &lmesv1alpha1.LMEvalJob{ + ObjectMeta: v1.ObjectMeta{ + Name: "test-oci-subject-special", + Namespace: "default", + }, + Spec: lmesv1alpha1.LMEvalJobSpec{ + Model: "test", + TaskList: lmesv1alpha1.TaskList{ + TaskNames: []string{"task1"}, + }, + Outputs: &lmesv1alpha1.Outputs{ + OCISpec: &lmesv1alpha1.OCISpec{ + Registry: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "registry", + }, + Repository: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "repository", + }, + Subject: "valid-subject-123", + Path: "results", + TokenRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-token"}, + Key: "token", + }, + }, + }, + }, + } + + pod := CreatePod(svcOpts, job, logger) + assert.NotNil(t, pod, "Should generate pod successfully") + + // Check environment variables + envVars := pod.Spec.Containers[0].Env + envMap := make(map[string]corev1.EnvVar) + for _, env := range envVars { + envMap[env.Name] = env + } + + // Verify subject with valid special characters is set + assert.Contains(t, envMap, "OCI_SUBJECT", "Should have OCI_SUBJECT env var") + assert.Equal(t, "valid-subject-123", envMap["OCI_SUBJECT"].Value) + }) + + t.Run("OCIWithSubjectDigestFormat", func(t *testing.T) { + job := &lmesv1alpha1.LMEvalJob{ + ObjectMeta: v1.ObjectMeta{ + Name: "test-oci-subject-digest", + Namespace: "default", + }, + Spec: lmesv1alpha1.LMEvalJobSpec{ + Model: "test", + TaskList: lmesv1alpha1.TaskList{ + TaskNames: []string{"task1"}, + }, + Outputs: &lmesv1alpha1.Outputs{ + OCISpec: &lmesv1alpha1.OCISpec{ + Registry: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "registry", + }, + Repository: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "repository", + }, + Subject: "sha256:a1b2c3d4e5f6789012345678901234567890abcdef1234567890abcdef12345678", + Path: "results", + TokenRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-token"}, + Key: "token", + }, + }, + }, + }, + } + + pod := CreatePod(svcOpts, job, logger) + assert.NotNil(t, pod, "Should generate pod successfully") + + // Check environment variables + envVars := pod.Spec.Containers[0].Env + envMap := make(map[string]corev1.EnvVar) + for _, env := range envVars { + envMap[env.Name] = env + } + + // Verify subject with OCI digest format is set + assert.Contains(t, envMap, "OCI_SUBJECT", "Should have OCI_SUBJECT env var") + assert.Equal(t, "sha256:a1b2c3d4e5f6789012345678901234567890abcdef1234567890abcdef12345678", envMap["OCI_SUBJECT"].Value) + }) + + t.Run("OCIWithCertificates", func(t *testing.T) { + job := &lmesv1alpha1.LMEvalJob{ + ObjectMeta: v1.ObjectMeta{ + Name: "test-oci-certs", + Namespace: "default", + }, + Spec: lmesv1alpha1.LMEvalJobSpec{ + Model: "test", + TaskList: lmesv1alpha1.TaskList{ + TaskNames: []string{"task1"}, + }, + Outputs: &lmesv1alpha1.Outputs{ + OCISpec: &lmesv1alpha1.OCISpec{ + Registry: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "registry", + }, + Repository: corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-secret"}, + Key: "repository", + }, + Path: "results", + TokenRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-token"}, + Key: "token", + }, + CABundle: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "ca-bundle"}, + Key: "ca.crt", + }, + VerifySSL: ptr.To(false), + }, + }, + }, + } + + pod := CreatePod(svcOpts, job, logger) + assert.NotNil(t, pod, "Should generate pod successfully") + + // Check environment variables + envVars := pod.Spec.Containers[0].Env + envMap := make(map[string]corev1.EnvVar) + for _, env := range envVars { + envMap[env.Name] = env + } + + // Verify SSL configuration + assert.Contains(t, envMap, "OCI_VERIFY_SSL", "Should have OCI_VERIFY_SSL env var") + assert.Equal(t, "false", envMap["OCI_VERIFY_SSL"].Value) + assert.Contains(t, envMap, "OCI_CA_BUNDLE", "Should have OCI_CA_BUNDLE env var") + assert.Equal(t, "/etc/certificates/oci/ca.crt", envMap["OCI_CA_BUNDLE"].Value) + + // Check volume mounts + volumeMounts := pod.Spec.Containers[0].VolumeMounts + var ociCertMount *corev1.VolumeMount + for _, mount := range volumeMounts { + if mount.Name == "certificates-oci" { + ociCertMount = &mount + break + } + } + assert.NotNil(t, ociCertMount, "Should have OCI certificate volume mount") + assert.Equal(t, "/etc/certificates/oci", ociCertMount.MountPath) + + // Check volumes + volumes := pod.Spec.Volumes + var ociCertVolume *corev1.Volume + for _, volume := range volumes { + if volume.Name == "certificates-oci" { + ociCertVolume = &volume + break + } + } + assert.NotNil(t, ociCertVolume, "Should have OCI certificate volume") + assert.NotNil(t, ociCertVolume.ConfigMap, "Should be a ConfigMap volume") + assert.Equal(t, "ca-bundle", ociCertVolume.ConfigMap.Name) + }) +} diff --git a/controllers/lmes/lmevaljob_controller_validation_test.go b/controllers/lmes/lmevaljob_controller_validation_test.go index e237c9c13..15bb6f460 100644 --- a/controllers/lmes/lmevaljob_controller_validation_test.go +++ b/controllers/lmes/lmevaljob_controller_validation_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" lmesv1alpha1 "github.com/trustyai-explainability/trustyai-service-operator/api/lmes/v1alpha1" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/log" ) @@ -1488,3 +1489,329 @@ func Test_ComplexValidationScenario(t *testing.T) { assert.Error(t, err, "Should reject JSON with command execution patterns") }) } + +// Test OCI functionality +func Test_OCIHelperMethods(t *testing.T) { + t.Run("HasOCIOutput", func(t *testing.T) { + // Test with no outputs + job := &lmesv1alpha1.LMEvalJob{ + Spec: lmesv1alpha1.LMEvalJobSpec{}, + } + assert.False(t, job.Spec.HasOCIOutput(), "Should return false when outputs is nil") + + // Test with outputs but no OCI spec + job.Spec.Outputs = &lmesv1alpha1.Outputs{} + assert.False(t, job.Spec.HasOCIOutput(), "Should return false when OCI spec is nil") + + // Test with OCI spec + job.Spec.Outputs.OCISpec = &lmesv1alpha1.OCISpec{} + assert.True(t, job.Spec.HasOCIOutput(), "Should return true when OCI spec is present") + }) + + t.Run("OutputsHasOCI", func(t *testing.T) { + // Test with nil outputs + var outputs *lmesv1alpha1.Outputs + assert.False(t, outputs.HasOCI(), "Should return false when outputs is nil") + + // Test with outputs but no OCI spec + outputs = &lmesv1alpha1.Outputs{} + assert.False(t, outputs.HasOCI(), "Should return false when OCI spec is nil") + + // Test with OCI spec + outputs.OCISpec = &lmesv1alpha1.OCISpec{} + assert.True(t, outputs.HasOCI(), "Should return true when OCI spec is present") + }) +} + +func Test_OCISpecHelperMethods(t *testing.T) { + t.Run("HasCertificates", func(t *testing.T) { + ociSpec := &lmesv1alpha1.OCISpec{} + assert.False(t, ociSpec.HasCertificates(), "Should return false when CABundle is nil") + + ociSpec.CABundle = &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "ca-bundle"}, + Key: "ca.crt", + } + assert.True(t, ociSpec.HasCertificates(), "Should return true when CABundle is present") + }) + + t.Run("HasUsernamePassword", func(t *testing.T) { + ociSpec := &lmesv1alpha1.OCISpec{} + assert.False(t, ociSpec.HasUsernamePassword(), "Should return false when both are nil") + + // Test with only username + ociSpec.UsernameRef = &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "creds"}, + Key: "username", + } + assert.False(t, ociSpec.HasUsernamePassword(), "Should return false when password is nil") + + // Test with only password + ociSpec.UsernameRef = nil + ociSpec.PasswordRef = &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "creds"}, + Key: "password", + } + assert.False(t, ociSpec.HasUsernamePassword(), "Should return false when username is nil") + + // Test with both + ociSpec.UsernameRef = &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "creds"}, + Key: "username", + } + assert.True(t, ociSpec.HasUsernamePassword(), "Should return true when both are present") + }) + + t.Run("HasToken", func(t *testing.T) { + ociSpec := &lmesv1alpha1.OCISpec{} + assert.False(t, ociSpec.HasToken(), "Should return false when token is nil") + + ociSpec.TokenRef = &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "token-secret"}, + Key: "token", + } + assert.True(t, ociSpec.HasToken(), "Should return true when token is present") + }) +} + +func Test_ValidateOCIPath(t *testing.T) { + testCases := []struct { + name string + path string + expectError bool + errorMsg string + }{ + { + name: "EmptyPath", + path: "", + expectError: false, + }, + { + name: "ValidSimplePath", + path: "results", + expectError: false, + }, + { + name: "ValidNestedPath", + path: "evaluation/results", + expectError: false, + }, + { + name: "ValidPathWithDots", + path: "results.json", + expectError: false, + }, + { + name: "ValidPathWithUnderscores", + path: "eval_results_2024", + expectError: false, + }, + { + name: "ValidPathWithHyphens", + path: "eval-results-v1", + expectError: false, + }, + { + name: "InvalidPathTraversal", + path: "../results", + expectError: true, + errorMsg: "invalid pattern: ../", + }, + { + name: "InvalidPathTraversalDeep", + path: "results/../../etc", + expectError: true, + errorMsg: "invalid pattern: ../", + }, + { + name: "InvalidCurrentDir", + path: "./results", + expectError: true, + errorMsg: "invalid pattern: ./", + }, + { + name: "InvalidShellMetacharacters", + path: "results; rm -rf /", + expectError: true, + errorMsg: "invalid characters", + }, + { + name: "InvalidSpecialCharacters", + path: "results@#$%", + expectError: true, + errorMsg: "invalid characters", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateOCIPath(tc.path) + if tc.expectError { + assert.Error(t, err, "Expected error for path: %s", tc.path) + assert.Contains(t, err.Error(), tc.errorMsg, "Error message should contain: %s", tc.errorMsg) + } else { + assert.NoError(t, err, "Expected no error for path: %s", tc.path) + } + }) + } +} + +func Test_ValidateOCIAuth(t *testing.T) { + t.Run("NilSpec", func(t *testing.T) { + err := ValidateOCIAuth(nil) + assert.Error(t, err, "Should reject nil spec") + assert.Contains(t, err.Error(), "cannot be nil") + }) + + t.Run("NoAuthentication", func(t *testing.T) { + ociSpec := &lmesv1alpha1.OCISpec{} + err := ValidateOCIAuth(ociSpec) + assert.Error(t, err, "Should reject spec with no authentication") + assert.Contains(t, err.Error(), "requires either username/password or token") + }) + + t.Run("ValidUsernamePassword", func(t *testing.T) { + ociSpec := &lmesv1alpha1.OCISpec{ + UsernameRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "creds"}, + Key: "username", + }, + PasswordRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "creds"}, + Key: "password", + }, + } + err := ValidateOCIAuth(ociSpec) + assert.NoError(t, err, "Should accept valid username/password authentication") + }) + + t.Run("ValidToken", func(t *testing.T) { + ociSpec := &lmesv1alpha1.OCISpec{ + TokenRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "token-secret"}, + Key: "token", + }, + } + err := ValidateOCIAuth(ociSpec) + assert.NoError(t, err, "Should accept valid token authentication") + }) + + t.Run("BothUsernamePasswordAndToken", func(t *testing.T) { + ociSpec := &lmesv1alpha1.OCISpec{ + UsernameRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "creds"}, + Key: "username", + }, + PasswordRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "creds"}, + Key: "password", + }, + TokenRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "token-secret"}, + Key: "token", + }, + } + err := ValidateOCIAuth(ociSpec) + assert.Error(t, err, "Should reject spec with both username/password and token") + assert.Contains(t, err.Error(), "cannot have both") + }) + + t.Run("OnlyUsername", func(t *testing.T) { + ociSpec := &lmesv1alpha1.OCISpec{ + UsernameRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "creds"}, + Key: "username", + }, + } + err := ValidateOCIAuth(ociSpec) + assert.Error(t, err, "Should reject spec with only username") + assert.Contains(t, err.Error(), "OCI authentication requires either username/password or token") + }) + + t.Run("OnlyPassword", func(t *testing.T) { + ociSpec := &lmesv1alpha1.OCISpec{ + PasswordRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "creds"}, + Key: "password", + }, + } + err := ValidateOCIAuth(ociSpec) + assert.Error(t, err, "Should reject spec with only password") + assert.Contains(t, err.Error(), "OCI authentication requires either username/password or token") + }) +} + +func Test_ValidateUserInputWithOCI(t *testing.T) { + t.Run("ValidOCIConfiguration", func(t *testing.T) { + job := &lmesv1alpha1.LMEvalJob{ + Spec: lmesv1alpha1.LMEvalJobSpec{ + Model: "hf", + TaskList: lmesv1alpha1.TaskList{ + TaskNames: []string{"hellaswag"}, + }, + Outputs: &lmesv1alpha1.Outputs{ + OCISpec: &lmesv1alpha1.OCISpec{ + Path: "results", + UsernameRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-creds"}, + Key: "username", + }, + PasswordRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-creds"}, + Key: "password", + }, + }, + }, + }, + } + err := ValidateUserInput(job) + assert.NoError(t, err, "Should accept valid OCI configuration") + }) + + t.Run("InvalidOCIPath", func(t *testing.T) { + job := &lmesv1alpha1.LMEvalJob{ + Spec: lmesv1alpha1.LMEvalJobSpec{ + Model: "hf", + TaskList: lmesv1alpha1.TaskList{ + TaskNames: []string{"hellaswag"}, + }, + Outputs: &lmesv1alpha1.Outputs{ + OCISpec: &lmesv1alpha1.OCISpec{ + Path: "../malicious", + UsernameRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-creds"}, + Key: "username", + }, + PasswordRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "oci-creds"}, + Key: "password", + }, + }, + }, + }, + } + err := ValidateUserInput(job) + assert.Error(t, err, "Should reject invalid OCI path") + assert.Contains(t, err.Error(), "invalid OCI path") + }) + + t.Run("InvalidOCIAuth", func(t *testing.T) { + job := &lmesv1alpha1.LMEvalJob{ + Spec: lmesv1alpha1.LMEvalJobSpec{ + Model: "hf", + TaskList: lmesv1alpha1.TaskList{ + TaskNames: []string{"hellaswag"}, + }, + Outputs: &lmesv1alpha1.Outputs{ + OCISpec: &lmesv1alpha1.OCISpec{ + Path: "results", + // No authentication specified + }, + }, + }, + } + err := ValidateUserInput(job) + assert.Error(t, err, "Should reject invalid OCI authentication") + assert.Contains(t, err.Error(), "invalid OCI authentication") + }) +} diff --git a/controllers/lmes/validation.go b/controllers/lmes/validation.go index df15ff25c..d7c5fcdd1 100644 --- a/controllers/lmes/validation.go +++ b/controllers/lmes/validation.go @@ -112,6 +112,16 @@ func ValidateUserInput(job *lmesv1alpha1.LMEvalJob) error { } } + // Validate OCI path and authentication + if job.Spec.HasOCIOutput() { + if err := ValidateOCIPath(job.Spec.Outputs.OCISpec.Path); err != nil { + return fmt.Errorf("invalid OCI path: %w", err) + } + if err := ValidateOCIAuth(job.Spec.Outputs.OCISpec); err != nil { + return fmt.Errorf("invalid OCI authentication: %w", err) + } + } + // Validate batch size if job.Spec.BatchSize != nil { if err := ValidateBatchSizeInput(*job.Spec.BatchSize); err != nil { @@ -529,3 +539,57 @@ func ValidateGitCommit(commit string) error { return nil } + +// ValidateOCIPath validates OCI artifact paths +func ValidateOCIPath(path string) error { + if path == "" { + return nil // Empty path is valid for root + } + + // Check for shell metacharacters + if ContainsShellMetacharacters(path) { + return fmt.Errorf("OCI path contains invalid characters") + } + + // OCI paths should not contain invalid patterns + dangerousPatterns := []string{"../", "..\\", "./", ".\\"} + for _, pattern := range dangerousPatterns { + if strings.Contains(path, pattern) { + return fmt.Errorf("OCI path contains invalid pattern: %s", pattern) + } + } + + // OCI paths allowed characters (similar to filesystem paths) + if !regexp.MustCompile(`^[a-zA-Z0-9._/-]*$`).MatchString(path) { + return fmt.Errorf("OCI path contains invalid characters (only alphanumeric, ., _, /, - allowed)") + } + + return nil +} + +// ValidateOCIAuth validates OCI authentication configuration +func ValidateOCIAuth(ociSpec *lmesv1alpha1.OCISpec) error { + if ociSpec == nil { + return fmt.Errorf("OCI spec cannot be nil") + } + + // Must have either username/password or token, but not both + hasUsernamePassword := ociSpec.HasUsernamePassword() + hasToken := ociSpec.HasToken() + + if !hasUsernamePassword && !hasToken { + return fmt.Errorf("OCI authentication requires either username/password or token") + } + + if hasUsernamePassword && hasToken { + return fmt.Errorf("OCI authentication cannot have both username/password and token") + } + + // If using username/password, both must be present + if (ociSpec.UsernameRef != nil && ociSpec.PasswordRef == nil) || + (ociSpec.UsernameRef == nil && ociSpec.PasswordRef != nil) { + return fmt.Errorf("OCI authentication with username/password requires both username and password") + } + + return nil +}