From 701b1c7e2acf67fc96cd5a0061dd57ce36b414e5 Mon Sep 17 00:00:00 2001 From: Rui Vieira Date: Wed, 30 Jul 2025 11:40:42 +0100 Subject: [PATCH] feat(RHOAIENG-30963): Add support for DSC ConfigMaps --- controllers/job_mgr/job_mgr_controller.go | 3 +- controllers/lmes/constants.go | 4 + controllers/lmes/dsc_config_test.go | 332 ++++++++++++++++++ controllers/lmes/lmevaljob_controller.go | 114 +++++- controllers/lmes/lmevaljob_controller_test.go | 76 ++-- 5 files changed, 475 insertions(+), 54 deletions(-) create mode 100644 controllers/lmes/dsc_config_test.go diff --git a/controllers/job_mgr/job_mgr_controller.go b/controllers/job_mgr/job_mgr_controller.go index 675cfea74..1021e0b3e 100644 --- a/controllers/job_mgr/job_mgr_controller.go +++ b/controllers/job_mgr/job_mgr_controller.go @@ -114,7 +114,8 @@ func (job *LMEvalJob) Finished() (condition metav1.Condition, finished bool) { // PodSets will build workload podSets corresponding to the job. func (job *LMEvalJob) PodSets() []kueue.PodSet { log := log.FromContext(context.TODO()) - pod := lmes.CreatePod(lmes.Options, &job.LMEvalJob, log) + // job_mgr controller doesn't have access to reconciler, so we pass nil + pod := lmes.CreatePod(nil, lmes.Options, &job.LMEvalJob, log) podSet := kueue.PodSet{ Name: job.GetPodName(), Count: 1, diff --git a/controllers/lmes/constants.go b/controllers/lmes/constants.go index b6c283baf..c3fd055c2 100644 --- a/controllers/lmes/constants.go +++ b/controllers/lmes/constants.go @@ -45,4 +45,8 @@ const ( DefaultBatchSize = "1" DefaultDetectDevice = true ServiceName = "LMES" + // DataSienceCluster ConfigMap constants + DSCConfigMapName = "trustyai-dsc-config" + DSCAllowOnlineKey = "eval.lmeval.allowOnline" + DSCAllowCodeExecutionKey = "eval.lmeval.allowCodeExecution" ) diff --git a/controllers/lmes/dsc_config_test.go b/controllers/lmes/dsc_config_test.go new file mode 100644 index 000000000..f61b60f9a --- /dev/null +++ b/controllers/lmes/dsc_config_test.go @@ -0,0 +1,332 @@ +package lmes + +import ( + "context" + "testing" + + "github.com/go-logr/logr" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + + lmesv1alpha1 "github.com/trustyai-explainability/trustyai-service-operator/api/lmes/v1alpha1" +) + +func TestGetDSCLMEvalSettings(t *testing.T) { + tests := []struct { + name string + configMap *corev1.ConfigMap + expectAllowOnline *bool + expectAllowCodeExecution *bool + expectError bool + }{ + { + name: "DSC ConfigMap with both settings enabled", + configMap: &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: DSCConfigMapName, + Namespace: "test-namespace", + Annotations: map[string]string{ + "opendatahub.io/config-source": "datasciencecluster", + }, + }, + Data: map[string]string{ + DSCAllowOnlineKey: "true", + DSCAllowCodeExecutionKey: "true", + }, + }, + expectAllowOnline: boolPtr(true), + expectAllowCodeExecution: boolPtr(true), + expectError: false, + }, + { + name: "DSC ConfigMap with both settings disabled", + configMap: &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: DSCConfigMapName, + Namespace: "test-namespace", + Annotations: map[string]string{ + "opendatahub.io/config-source": "datasciencecluster", + }, + }, + Data: map[string]string{ + DSCAllowOnlineKey: "false", + DSCAllowCodeExecutionKey: "false", + }, + }, + expectAllowOnline: boolPtr(false), + expectAllowCodeExecution: boolPtr(false), + expectError: false, + }, + { + name: "DSC ConfigMap with only online setting", + configMap: &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: DSCConfigMapName, + Namespace: "test-namespace", + Annotations: map[string]string{ + "opendatahub.io/config-source": "datasciencecluster", + }, + }, + Data: map[string]string{ + DSCAllowOnlineKey: "true", + }, + }, + expectAllowOnline: boolPtr(true), + expectAllowCodeExecution: nil, + expectError: false, + }, + { + name: "DSC ConfigMap with only code execution setting", + configMap: &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: DSCConfigMapName, + Namespace: "test-namespace", + Annotations: map[string]string{ + "opendatahub.io/config-source": "datasciencecluster", + }, + }, + Data: map[string]string{ + DSCAllowCodeExecutionKey: "true", + }, + }, + expectAllowOnline: nil, + expectAllowCodeExecution: boolPtr(true), + expectError: false, + }, + { + name: "ConfigMap without DSC annotation", + configMap: &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: DSCConfigMapName, + Namespace: "test-namespace", + }, + Data: map[string]string{ + DSCAllowOnlineKey: "true", + DSCAllowCodeExecutionKey: "true", + }, + }, + expectAllowOnline: nil, + expectAllowCodeExecution: nil, + expectError: false, + }, + { + name: "ConfigMap with invalid boolean values", + configMap: &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: DSCConfigMapName, + Namespace: "test-namespace", + Annotations: map[string]string{ + "opendatahub.io/config-source": "datasciencecluster", + }, + }, + Data: map[string]string{ + DSCAllowOnlineKey: "invalid", + DSCAllowCodeExecutionKey: "also-invalid", + }, + }, + expectAllowOnline: nil, + expectAllowCodeExecution: nil, + expectError: false, // Invalid values are logged but don't cause errors + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scheme := runtime.NewScheme() + _ = corev1.AddToScheme(scheme) + + var client client.Client + if tt.configMap != nil { + client = fake.NewClientBuilder().WithScheme(scheme).WithObjects(tt.configMap).Build() + } else { + client = fake.NewClientBuilder().WithScheme(scheme).Build() + } + + reconciler := &LMEvalJobReconciler{ + Client: client, + Namespace: "test-namespace", + } + + allowOnline, allowCodeExecution, err := reconciler.getDSCLMEvalSettings(context.Background(), logr.Discard()) + + // Check error + if tt.expectError && err == nil { + t.Errorf("Expected error but got none") + } + if !tt.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + + // Check allowOnline + if tt.expectAllowOnline == nil && allowOnline != nil { + t.Errorf("Expected allowOnline to be nil, got %v", *allowOnline) + } + if tt.expectAllowOnline != nil && allowOnline == nil { + t.Errorf("Expected allowOnline to be %v, got nil", *tt.expectAllowOnline) + } + if tt.expectAllowOnline != nil && allowOnline != nil && *tt.expectAllowOnline != *allowOnline { + t.Errorf("Expected allowOnline to be %v, got %v", *tt.expectAllowOnline, *allowOnline) + } + + // Check allowCodeExecution + if tt.expectAllowCodeExecution == nil && allowCodeExecution != nil { + t.Errorf("Expected allowCodeExecution to be nil, got %v", *allowCodeExecution) + } + if tt.expectAllowCodeExecution != nil && allowCodeExecution == nil { + t.Errorf("Expected allowCodeExecution to be %v, got nil", *tt.expectAllowCodeExecution) + } + if tt.expectAllowCodeExecution != nil && allowCodeExecution != nil && *tt.expectAllowCodeExecution != *allowCodeExecution { + t.Errorf("Expected allowCodeExecution to be %v, got %v", *tt.expectAllowCodeExecution, *allowCodeExecution) + } + }) + } +} + +func TestCreatePodWithDSCSettings(t *testing.T) { + tests := []struct { + name string + dscConfigMap *corev1.ConfigMap + jobAllowOnline *bool + jobAllowCodeExecution *bool + operatorAllowOnline bool + operatorAllowCodeExec bool + expectOnlineMode bool + expectCodeExecution bool + }{ + { + name: "DSC allows both, job wants both", + dscConfigMap: &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: DSCConfigMapName, + Namespace: "test-namespace", + Annotations: map[string]string{ + "opendatahub.io/config-source": "datasciencecluster", + }, + }, + Data: map[string]string{ + DSCAllowOnlineKey: "true", + DSCAllowCodeExecutionKey: "true", + }, + }, + jobAllowOnline: boolPtr(true), + jobAllowCodeExecution: boolPtr(true), + operatorAllowOnline: false, + operatorAllowCodeExec: false, + expectOnlineMode: true, + expectCodeExecution: true, + }, + { + name: "DSC disallows both, job wants both", + dscConfigMap: &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: DSCConfigMapName, + Namespace: "test-namespace", + Annotations: map[string]string{ + "opendatahub.io/config-source": "datasciencecluster", + }, + }, + Data: map[string]string{ + DSCAllowOnlineKey: "false", + DSCAllowCodeExecutionKey: "false", + }, + }, + jobAllowOnline: boolPtr(true), + jobAllowCodeExecution: boolPtr(true), + operatorAllowOnline: true, + operatorAllowCodeExec: true, + expectOnlineMode: false, + expectCodeExecution: false, + }, + { + name: "No DSC ConfigMap, operator allows both, job wants both", + dscConfigMap: nil, + jobAllowOnline: boolPtr(true), + jobAllowCodeExecution: boolPtr(true), + operatorAllowOnline: true, + operatorAllowCodeExec: true, + expectOnlineMode: true, + expectCodeExecution: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scheme := runtime.NewScheme() + _ = corev1.AddToScheme(scheme) + + var client client.Client + if tt.dscConfigMap != nil { + client = fake.NewClientBuilder().WithScheme(scheme).WithObjects(tt.dscConfigMap).Build() + } else { + client = fake.NewClientBuilder().WithScheme(scheme).Build() + } + + reconciler := &LMEvalJobReconciler{ + Client: client, + Namespace: "test-namespace", + } + + // Create service options + svcOpts := &serviceOptions{ + AllowOnline: tt.operatorAllowOnline, + AllowCodeExecution: tt.operatorAllowCodeExec, + } + + // Create job + job := &lmesv1alpha1.LMEvalJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-job", + Namespace: "test-namespace", + }, + Spec: lmesv1alpha1.LMEvalJobSpec{ + Model: "test-model", + TaskList: lmesv1alpha1.TaskList{ + TaskNames: []string{"test-task"}, + }, + AllowOnline: tt.jobAllowOnline, + AllowCodeExecution: tt.jobAllowCodeExecution, + }, + } + + // Create pod + pod := CreatePod(reconciler, svcOpts, job, logr.Discard()) + + // Check environment variables + codeExecutionEnabled := false + + for _, env := range pod.Spec.Containers[0].Env { + switch env.Name { + case "TRUST_REMOTE_CODE", "HF_DATASETS_TRUST_REMOTE_CODE", "UNITXT_ALLOW_UNVERIFIED_CODE": + if env.Value == "1" || env.Value == "True" { + codeExecutionEnabled = true + } + } + } + + // Check command line flags + hasAllowOnlineFlag := false + for _, cmd := range pod.Spec.Containers[0].Command { + if cmd == "--allow-online" { + hasAllowOnlineFlag = true + break + } + } + + // Verify expectations + if tt.expectOnlineMode != hasAllowOnlineFlag { + t.Errorf("Expected online mode flag to be %v, got %v", tt.expectOnlineMode, hasAllowOnlineFlag) + } + + if tt.expectCodeExecution != codeExecutionEnabled { + t.Errorf("Expected code execution to be %v, got %v", tt.expectCodeExecution, codeExecutionEnabled) + } + }) + } +} + +func boolPtr(b bool) *bool { + return &b +} diff --git a/controllers/lmes/lmevaljob_controller.go b/controllers/lmes/lmevaljob_controller.go index e8d323fb0..12893a71d 100644 --- a/controllers/lmes/lmevaljob_controller.go +++ b/controllers/lmes/lmevaljob_controller.go @@ -31,6 +31,7 @@ import ( "github.com/trustyai-explainability/trustyai-service-operator/controllers/utils" corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" @@ -510,7 +511,7 @@ func (r *LMEvalJobReconciler) handleNewCR(ctx context.Context, log logr.Logger, // construct a new pod and create a pod for the job currentTime := v1.Now() - pod := CreatePod(Options, job, log) + pod := CreatePod(r, Options, job, log) if err := r.Create(ctx, pod, &client.CreateOptions{}); err != nil { // Failed to create the pod. Mark the status as complete with failed job.Status.State = lmesv1alpha1.CompleteJobState @@ -728,7 +729,7 @@ func (r *LMEvalJobReconciler) handleSuspend(ctx context.Context, log logr.Logger func (r *LMEvalJobReconciler) handleResume(ctx context.Context, log logr.Logger, job *lmesv1alpha1.LMEvalJob) (ctrl.Result, error) { log.Info("Resume job") - pod := CreatePod(Options, job, log) + pod := CreatePod(r, Options, job, log) if err := r.Create(ctx, pod); err != nil { log.Error(err, "failed to create pod to resume job") return r.pullingJobs.addOrUpdate(string(job.GetUID()), Options.PodCheckingInterval), nil @@ -853,7 +854,18 @@ func unmarshal(custom string, props []string) (map[string]interface{}, error) { return obj, nil } -func CreatePod(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Logger) *corev1.Pod { +func CreatePod(r *LMEvalJobReconciler, svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Logger) *corev1.Pod { + + // Get DataScienceCluster LMEval settings (if reconciler is available) + var dscAllowOnline, dscAllowCodeExecution *bool + if r != nil { + if dscAllowOnlineVal, dscAllowCodeExecutionVal, err := r.getDSCLMEvalSettings(context.Background(), log); err != nil { + log.Error(err, "failed to read DSC ConfigMap settings, using operator defaults") + } else { + dscAllowOnline = dscAllowOnlineVal + dscAllowCodeExecution = dscAllowCodeExecutionVal + } + } var envVars = removeProtectedEnvVars(job.Spec.Pod.GetContainer().GetEnv()) @@ -944,14 +956,23 @@ func CreatePod(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Lo }, } - if job.Spec.AllowCodeExecution != nil && *job.Spec.AllowCodeExecution { - // Disable remote code execution by default + // Check if code execution is allowed based on permissions + // Precedence: operator config -> DSC ConfigMap + codeExecutionAllowed := false + if dscAllowCodeExecution != nil { + codeExecutionAllowed = *dscAllowCodeExecution + log.Info("using DSC ConfigMap permission for LMEval code execution", "value", codeExecutionAllowed) + } else { + codeExecutionAllowed = svcOpts.AllowCodeExecution + log.Info("using operator config permission for LMEval code execution", "value", codeExecutionAllowed) + } - if !svcOpts.AllowCodeExecution { - log.Error(fmt.Errorf("code execution not allowed by the operator"), "change this setting and redeploy the operator") + if job.Spec.AllowCodeExecution != nil && *job.Spec.AllowCodeExecution { + if !codeExecutionAllowed { + log.Error(fmt.Errorf("code execution not allowed by permissions"), "DSC or operator config disallows code execution") envVars = append(envVars, disallowRemoteCodeEnvVars...) } else { - log.Info("enabling code execution") + log.Info("enabling code execution from job spec") envVars = append(envVars, allowRemoteCodeEnvVars...) } } else { @@ -981,12 +1002,23 @@ func CreatePod(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Lo }, } - // Enforce offline mode by default - if job.Spec.AllowOnline != nil && *job.Spec.AllowOnline { + // Check if online mode is allowed + // Precedence: operator config -> DSC ConfigMap + onlineModeAllowed := false + if dscAllowOnline != nil { + onlineModeAllowed = *dscAllowOnline + log.Info("using DSC ConfigMap permission for LMEval online mode", "value", onlineModeAllowed) + } else { + onlineModeAllowed = svcOpts.AllowOnline + log.Info("using operator config permission for LMEval online mode", "value", onlineModeAllowed) + } - if !svcOpts.AllowOnline { - log.Error(fmt.Errorf("online mode not allowed by the operator"), "change this setting and redeploy the operator") + if job.Spec.AllowOnline != nil && *job.Spec.AllowOnline { + if !onlineModeAllowed { + log.Error(fmt.Errorf("online mode not allowed by permissions"), "DSC or operator config disallows online mode") envVars = append(envVars, offlineHuggingFaceEnvVars...) + } else { + log.Info("enabling online mode from job spec") } } else { envVars = append(envVars, offlineHuggingFaceEnvVars...) @@ -1138,7 +1170,7 @@ func CreatePod(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, log logr.Lo Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, Env: envVars, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, dscAllowOnline), Args: generateArgs(svcOpts, job, log), SecurityContext: mainSecurityContext, VolumeMounts: volumeMounts, @@ -1377,7 +1409,7 @@ func concatTasks(tasks lmesv1alpha1.TaskList) []string { return append(tasks.TaskNames, recipesName...) } -func generateCmd(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob) []string { +func generateCmd(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob, dscAllowOnline *bool) []string { if job == nil { return nil } @@ -1398,7 +1430,16 @@ func generateCmd(svcOpts *serviceOptions, job *lmesv1alpha1.LMEvalJob) []string cmds = append(cmds, "--listen-port", fmt.Sprintf("%d", svcOpts.DriverPort)) } - if job.Spec.AllowOnline != nil && *job.Spec.AllowOnline && svcOpts.AllowOnline { + // Check if online mode + // Precedence: operator config -> DSC ConfigMap + onlineModeAllowed := false + if dscAllowOnline != nil { + onlineModeAllowed = *dscAllowOnline + } else { + onlineModeAllowed = svcOpts.AllowOnline + } + + if job.Spec.AllowOnline != nil && *job.Spec.AllowOnline && onlineModeAllowed { cmds = append(cmds, "--allow-online") } @@ -1531,3 +1572,46 @@ func removeProtectedEnvVars(envVars []corev1.EnvVar) []corev1.EnvVar { return allowedEnvVars } + +// getDSCLMEvalSettings reads the DataScienceCluster ConfigMap and returns the LMEval settings +// Returns nil values if the ConfigMap doesn't exist or doesn't contain the settings +func (r *LMEvalJobReconciler) getDSCLMEvalSettings(ctx context.Context, log logr.Logger) (*bool, *bool, error) { + configMapKey := types.NamespacedName{ + Namespace: r.Namespace, + Name: DSCConfigMapName, + } + + var cm corev1.ConfigMap + if err := r.Get(ctx, configMapKey, &cm); err != nil { + if errors.IsNotFound(err) { + // DataScienceCluster ConfigMap not found (use operator defaults) + log.Info("DataScienceCluster ConfigMap not found, using operator defaults", "configMap", configMapKey) + return nil, nil, nil + } + return nil, nil, fmt.Errorf("error reading DataScienceCluster configmap %s: %w", configMapKey, err) + } + + var allowOnline, allowCodeExecution *bool + + // allowOnline setting + if allowOnlineStr, ok := cm.Data[DSCAllowOnlineKey]; ok { + if allowOnlineVal, err := strconv.ParseBool(allowOnlineStr); err == nil { + allowOnline = &allowOnlineVal + log.Info("DataScienceCluster allowOnline setting", "value", allowOnlineVal) + } else { + log.Error(err, "invalid allowOnline value in DataScienceCluster ConfigMap", "value", allowOnlineStr) + } + } + + // allowCodeExecution setting + if allowCodeExecutionStr, ok := cm.Data[DSCAllowCodeExecutionKey]; ok { + if allowCodeExecutionVal, err := strconv.ParseBool(allowCodeExecutionStr); err == nil { + allowCodeExecution = &allowCodeExecutionVal + log.Info("DataScienceCluster allowCodeExecution setting", "value", allowCodeExecutionVal) + } else { + log.Error(err, "invalid allowCodeExecution value in DataScienceCluster ConfigMap", "value", allowCodeExecutionStr) + } + } + + return allowOnline, allowCodeExecution, nil +} diff --git a/controllers/lmes/lmevaljob_controller_test.go b/controllers/lmes/lmevaljob_controller_test.go index e4971781b..eaf1b2b18 100644 --- a/controllers/lmes/lmevaljob_controller_test.go +++ b/controllers/lmes/lmevaljob_controller_test.go @@ -114,7 +114,7 @@ func Test_SimplePod(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, nil), Args: generateArgs(svcOpts, job, log), SecurityContext: defaultSecurityContext, VolumeMounts: []corev1.VolumeMount{ @@ -184,7 +184,7 @@ func Test_SimplePod(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(nil, svcOpts, job, log) assert.Equal(t, expect, newPod) } @@ -330,7 +330,7 @@ func Test_WithCustomPod(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, nil), Args: generateArgs(svcOpts, job, log), SecurityContext: &corev1.SecurityContext{ RunAsUser: &runAsUser, @@ -445,7 +445,7 @@ func Test_WithCustomPod(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(nil, svcOpts, job, log) assert.Equal(t, expect, newPod) @@ -460,7 +460,7 @@ func Test_WithCustomPod(t *testing.T) { "custom/annotation1": "annotation1", } - newPod = CreatePod(svcOpts, job, log) + newPod = CreatePod(nil, svcOpts, job, log) assert.Equal(t, expect, newPod) } @@ -609,7 +609,7 @@ func Test_EnvSecretsPod(t *testing.T) { Value: "True", }, }, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, nil), Args: generateArgs(svcOpts, job, log), SecurityContext: defaultSecurityContext, VolumeMounts: []corev1.VolumeMount{ @@ -632,7 +632,7 @@ func Test_EnvSecretsPod(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(nil, svcOpts, job, log) // maybe only verify the envs: Containers[0].Env assert.Equal(t, expect, newPod) } @@ -734,7 +734,7 @@ func Test_FileSecretsPod(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, nil), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -823,7 +823,7 @@ func Test_FileSecretsPod(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(nil, svcOpts, job, log) // maybe only verify the envs: Containers[0].Env assert.Equal(t, expect, newPod) } @@ -1039,7 +1039,7 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) { "--output-path", "/opt/app-root/src/output", "--task-recipe", "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", "--", - }, generateCmd(svcOpts, job)) + }, generateCmd(svcOpts, job, nil)) job.Spec.TaskList.TaskRecipes = append(job.Spec.TaskList.TaskRecipes, lmesv1alpha1.TaskRecipe{ @@ -1063,7 +1063,7 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) { "--task-recipe", "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", "--task-recipe", "card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10", "--", - }, generateCmd(svcOpts, job)) + }, generateCmd(svcOpts, job, nil)) } func Test_GenerateArgCmdCustomCard(t *testing.T) { @@ -1121,7 +1121,7 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) { "--custom-artifact", `card|custom_0|{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "dutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, "--task-recipe", "card=cards.custom_0,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", "--", - }, generateCmd(svcOpts, job)) + }, generateCmd(svcOpts, job, nil)) // add second task using custom recipe + custom template job.Spec.TaskList.TaskRecipes = append(job.Spec.TaskList.TaskRecipes, @@ -1157,7 +1157,7 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) { "--task-recipe", "card=cards.custom_1,template=templates.tp_0,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10", "--custom-artifact", `template|tp_0|{ "__type__": "input_output_template", "instruction": "In the following task, you translate a {text_type}.", "input_format": "Translate this {text_type} from {source_language} to {target_language}: {text}.", "target_prefix": "Translation: ", "output_format": "{translation}", "postprocessors": [ "processors.lower_case" ] }`, "--", - }, generateCmd(svcOpts, job)) + }, generateCmd(svcOpts, job, nil)) // add third task using normal card + custom system_prompt job.Spec.TaskList.TaskRecipes = append(job.Spec.TaskList.TaskRecipes, @@ -1191,7 +1191,7 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) { "--custom-artifact", `template|tp_0|{ "__type__": "input_output_template", "instruction": "In the following task, you translate a {text_type}.", "input_format": "Translate this {text_type} from {source_language} to {target_language}: {text}.", "target_prefix": "Translation: ", "output_format": "{translation}", "postprocessors": [ "processors.lower_case" ] }`, "--custom-artifact", "system_prompt|sp_0|this is a custom system promp", "--", - }, generateCmd(svcOpts, job)) + }, generateCmd(svcOpts, job, nil)) // add forth task using custom card + custom template + custom system_prompt // and reuse the template and system prompt @@ -1226,7 +1226,7 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) { "--custom-artifact", `template|tp_0|{ "__type__": "input_output_template", "instruction": "In the following task, you translate a {text_type}.", "input_format": "Translate this {text_type} from {source_language} to {target_language}: {text}.", "target_prefix": "Translation: ", "output_format": "{translation}", "postprocessors": [ "processors.lower_case" ] }`, "--custom-artifact", "system_prompt|sp_0|this is a custom system promp", "--", - }, generateCmd(svcOpts, job)) + }, generateCmd(svcOpts, job, nil)) // add fifth task using regular card + custom template + custom system_prompt // both template and system prompt are new @@ -1272,7 +1272,7 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) { "--custom-artifact", "system_prompt|sp_0|this is a custom system promp", "--custom-artifact", "system_prompt|sp_1|this is a custom system promp2", "--", - }, generateCmd(svcOpts, job)) + }, generateCmd(svcOpts, job, nil)) } func Test_CustomCardValidation(t *testing.T) { @@ -1686,7 +1686,7 @@ func Test_ManagedPVC(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, nil), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -1769,7 +1769,7 @@ func Test_ManagedPVC(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(nil, svcOpts, job, log) assert.Equal(t, expect, newPod) } @@ -1850,7 +1850,7 @@ func Test_ExistingPVC(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, nil), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -1932,7 +1932,7 @@ func Test_ExistingPVC(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(nil, svcOpts, job, log) assert.Equal(t, expect, newPod) } @@ -2024,7 +2024,7 @@ func Test_PVCPreference(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, nil), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -2118,7 +2118,7 @@ func Test_PVCPreference(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(nil, svcOpts, job, log) assert.Equal(t, expect, newPod) } @@ -2209,7 +2209,7 @@ func Test_OfflineMode(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, nil), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -2303,7 +2303,7 @@ func Test_OfflineMode(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(nil, svcOpts, job, log) assert.Equal(t, expect, newPod) } @@ -2417,7 +2417,7 @@ func Test_ProtectedVars(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, nil), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -2516,7 +2516,7 @@ func Test_ProtectedVars(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(nil, svcOpts, job, log) assert.Equal(t, expect, newPod) } @@ -2614,7 +2614,7 @@ func Test_OnlineModeDisabled(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, nil), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -2708,7 +2708,7 @@ func Test_OnlineModeDisabled(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(nil, svcOpts, job, log) assert.Equal(t, expect, newPod) } @@ -2802,7 +2802,7 @@ func Test_OnlineMode(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, nil), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -2876,7 +2876,7 @@ func Test_OnlineMode(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(nil, svcOpts, job, log) assert.Equal(t, expect, newPod) } @@ -2973,7 +2973,7 @@ func Test_AllowCodeOnlineMode(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, nil), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -3047,7 +3047,7 @@ func Test_AllowCodeOnlineMode(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(nil, svcOpts, job, log) assert.Equal(t, expect, newPod) } @@ -3142,7 +3142,7 @@ func Test_AllowCodeOfflineMode(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, nil), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -3236,7 +3236,7 @@ func Test_AllowCodeOfflineMode(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(nil, svcOpts, job, log) assert.Equal(t, expect, newPod) } @@ -3331,7 +3331,7 @@ func Test_OfflineModeWithOutput(t *testing.T) { Name: "main", Image: svcOpts.PodImage, ImagePullPolicy: svcOpts.ImagePullPolicy, - Command: generateCmd(svcOpts, job), + Command: generateCmd(svcOpts, job, nil), Args: generateArgs(svcOpts, job, log), Ports: []corev1.ContainerPort{ { @@ -3437,7 +3437,7 @@ func Test_OfflineModeWithOutput(t *testing.T) { }, } - newPod := CreatePod(svcOpts, job, log) + newPod := CreatePod(nil, svcOpts, job, log) assert.Equal(t, expect, newPod) } @@ -3547,7 +3547,7 @@ func Test_CustomTasksGitSource(t *testing.T) { }, } - pod := CreatePod(svcOpts, job, log) + pod := CreatePod(nil, svcOpts, job, log) require.NotNil(t, pod) @@ -3638,7 +3638,7 @@ func Test_CustomTasksGitSourceOfflineMode(t *testing.T) { logger := logr.Discard() - pod := CreatePod(Options, job, logger) + pod := CreatePod(nil, Options, job, logger) if pod == nil { t.Fatal("pod should not be nil")