Skip to content

Commit bacdc4d

Browse files
authored
fix(RHOAIENG-27596): LMEval Job CR escape fix (#504)
* fix: Prevent lm-eval command escape * fix: Args test expectation * fix: Device patching * fix: Remove unitxt fields validation
1 parent f3a7022 commit bacdc4d

File tree

9 files changed

+2245
-135
lines changed

9 files changed

+2245
-135
lines changed

api/lmes/v1alpha1/lmevaljob_types.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,16 @@ const (
6161
)
6262

6363
type Arg struct {
64-
Name string `json:"name"`
64+
// +kubebuilder:validation:Pattern=`^[a-zA-Z0-9._-]+$`
65+
Name string `json:"name"`
66+
// +kubebuilder:validation:Pattern=`^[a-zA-Z0-9._/:\- ]*$`
6567
Value string `json:"value,omitempty"`
6668
}
6769

6870
type Card struct {
6971
// Unitxt card's ID
7072
// +optional
73+
// +kubebuilder:validation:Pattern=`^[a-zA-Z0-9._-]+$`
7174
Name string `json:"name,omitempty"`
7275
// A JSON string for a custom unitxt card which contains the custom dataset.
7376
// Use the documentation here: https://www.unitxt.ai/en/latest/docs/adding_dataset.html#adding-to-the-catalog
@@ -212,6 +215,7 @@ type TaskRecipe struct {
212215
// GitSource specifies the git location of external tasks
213216
type GitSource struct {
214217
// URL specifies the git repository URL
218+
// +kubebuilder:validation:Pattern=`^https://[a-zA-Z0-9._/-]+$`
215219
URL string `json:"url,omitempty"`
216220
// Branch specifies the git branch to use
217221
// +optional
@@ -221,6 +225,7 @@ type GitSource struct {
221225
Commit *string `json:"commit,omitempty"`
222226
// Path specifies the path to the task file
223227
// +optional
228+
// +kubebuilder:validation:Pattern=`^[a-zA-Z0-9._/-]*$`
224229
Path string `json:"path,omitempty"`
225230
}
226231

@@ -238,6 +243,7 @@ type CustomTasks struct {
238243

239244
type TaskList struct {
240245
// TaskNames from lm-eval's task list and/or from custom tasks if CustomTasks is defined
246+
// +kubebuilder:validation:items:Pattern=`^[a-zA-Z0-9._-]+$`
241247
TaskNames []string `json:"taskNames,omitempty"`
242248
// Task Recipes specifically for Unitxt
243249
TaskRecipes []TaskRecipe `json:"taskRecipes,omitempty"`
@@ -446,14 +452,15 @@ func (p *LMEvalPodSpec) GetSideCards() []corev1.Container {
446452
}
447453

448454
type OfflineS3Spec struct {
449-
AccessKeyIdRef corev1.SecretKeySelector `json:"accessKeyId"`
450-
SecretAccessKeyRef corev1.SecretKeySelector `json:"secretAccessKey"`
451-
Bucket corev1.SecretKeySelector `json:"bucket"`
452-
Path string `json:"path"`
453-
Region corev1.SecretKeySelector `json:"region"`
454-
Endpoint corev1.SecretKeySelector `json:"endpoint"`
455-
VerifySSL *bool `json:"verifySSL,omitempty"`
456-
CABundle *corev1.SecretKeySelector `json:"caBundle,omitempty"`
455+
AccessKeyIdRef corev1.SecretKeySelector `json:"accessKeyId"`
456+
SecretAccessKeyRef corev1.SecretKeySelector `json:"secretAccessKey"`
457+
Bucket corev1.SecretKeySelector `json:"bucket"`
458+
// +kubebuilder:validation:Pattern=`^[a-zA-Z0-9._/-]*$`
459+
Path string `json:"path"`
460+
Region corev1.SecretKeySelector `json:"region"`
461+
Endpoint corev1.SecretKeySelector `json:"endpoint"`
462+
VerifySSL *bool `json:"verifySSL,omitempty"`
463+
CABundle *corev1.SecretKeySelector `json:"caBundle,omitempty"`
457464
}
458465

459466
// OfflineStorageSpec defines the storage configuration for LMEvalJob's offline mode
@@ -493,6 +500,7 @@ type LMEvalJobSpec struct {
493500
// Important: Run "make" to regenerate code after modifying this file
494501

495502
// Model name
503+
// +kubebuilder:validation:Enum=hf;openai-completions;openai-chat-completions;local-completions;local-chat-completions;watsonx_llm;textsynth
496504
Model string `json:"model"`
497505
// Args for the model
498506
// +optional
@@ -506,6 +514,7 @@ type LMEvalJobSpec struct {
506514
// the number of documents to evaluate to the first X documents (if an integer)
507515
// per task or first X% of documents per task
508516
// +optional
517+
// +kubebuilder:validation:Pattern=`^(\d+\.?\d*|\d*\.\d+)$`
509518
Limit string `json:"limit,omitempty"`
510519
// Map to `--gen_kwargs` parameter for the underlying library.
511520
// +optional

config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,10 @@ spec:
7575
items:
7676
properties:
7777
name:
78+
pattern: ^[a-zA-Z0-9._-]+$
7879
type: string
7980
value:
81+
pattern: ^[a-zA-Z0-9._/:\- ]*$
8082
type: string
8183
required:
8284
- name
@@ -87,6 +89,7 @@ spec:
8789
Accepts an integer, or a float between 0.0 and 1.0 . If passed, will limit
8890
the number of documents to evaluate to the first X documents (if an integer)
8991
per task or first X% of documents per task
92+
pattern: ^(\d+\.?\d*|\d*\.\d+)$
9093
type: string
9194
logSamples:
9295
description: |-
@@ -95,14 +98,24 @@ spec:
9598
type: boolean
9699
model:
97100
description: Model name
101+
enum:
102+
- hf
103+
- openai-completions
104+
- openai-chat-completions
105+
- local-completions
106+
- local-chat-completions
107+
- watsonx_llm
108+
- textsynth
98109
type: string
99110
modelArgs:
100111
description: Args for the model
101112
items:
102113
properties:
103114
name:
115+
pattern: ^[a-zA-Z0-9._-]+$
104116
type: string
105117
value:
118+
pattern: ^[a-zA-Z0-9._/:\- ]*$
106119
type: string
107120
required:
108121
- name
@@ -204,6 +217,7 @@ spec:
204217
type: object
205218
x-kubernetes-map-type: atomic
206219
path:
220+
pattern: ^[a-zA-Z0-9._/-]*$
207221
type: string
208222
region:
209223
description: SecretKeySelector selects a key of a Secret.
@@ -4828,9 +4842,11 @@ spec:
48284842
type: string
48294843
path:
48304844
description: Path specifies the path to the task file
4845+
pattern: ^[a-zA-Z0-9._/-]*$
48314846
type: string
48324847
url:
48334848
description: URL specifies the git repository URL
4849+
pattern: ^https://[a-zA-Z0-9._/-]+$
48344850
type: string
48354851
type: object
48364852
type: object
@@ -4839,6 +4855,7 @@ spec:
48394855
description: TaskNames from lm-eval's task list and/or from custom
48404856
tasks if CustomTasks is defined
48414857
items:
4858+
pattern: ^[a-zA-Z0-9._-]+$
48424859
type: string
48434860
type: array
48444861
taskRecipes:
@@ -4860,6 +4877,7 @@ spec:
48604877
type: string
48614878
name:
48624879
description: Unitxt card's ID
4880+
pattern: ^[a-zA-Z0-9._-]+$
48634881
type: string
48644882
type: object
48654883
demosPoolSize:

controllers/lmes/driver/driver.go

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ func (d *driverImpl) detectDevice() error {
218218
return fmt.Errorf("failed to find the matched output")
219219
}
220220

221-
patchDevice(d.Option.Args, matches[1] == "True")
221+
d.Option.Args = patchDevice(d.Option.Args, matches[1] == "True")
222222

223223
return nil
224224
}
@@ -241,21 +241,21 @@ func (d *driverImpl) downloadS3Assets() error {
241241
return nil
242242
}
243243

244-
func patchDevice(args []string, hasCuda bool) {
245-
var device = "cpu"
244+
func patchDevice(args []string, hasCuda bool) []string {
245+
device := "cpu"
246246
if hasCuda {
247247
device = "cuda"
248248
}
249-
// patch the python command in the Option.Arg by adding the `--device cuda` option
250-
// find the string with the `python -m lm_eval` prefix. usually it should be the last one
251-
for idx, arg := range args {
252-
if strings.HasPrefix(arg, "python -m lm_eval") {
253-
if !strings.Contains(arg, "--device") {
254-
args[idx] = fmt.Sprintf("%s --device %s", arg, device)
255-
}
256-
break
249+
250+
// Check if --device already exists
251+
for _, arg := range args {
252+
if arg == "--device" {
253+
return args // already has device specified
257254
}
258255
}
256+
257+
// If we reach here, --device doesn't exist, so add it
258+
return append(args, "--device", device)
259259
}
260260

261261
// Create a domain socket and use HTTP protocal to handle communication
@@ -598,6 +598,7 @@ func (d *driverImpl) fetchGitCustomTasks() error {
598598
return err
599599
}
600600

601+
// #nosec G204 -- CustomTaskGitURL is validated by ValidateGitURL() in the controller
601602
cloneCommand := exec.Command("git", "clone", d.Option.CustomTaskGitURL, repositoryDestination)
602603
if output, err := cloneCommand.CombinedOutput(); err != nil {
603604
return fmt.Errorf("failed to clone git repository: %v, output: %s", err, string(output))
@@ -608,12 +609,14 @@ func (d *driverImpl) fetchGitCustomTasks() error {
608609

609610
// Checkout a specific branch, if specified
610611
if d.Option.CustomTaskGitBranch != "" {
612+
// #nosec G204 -- CustomTaskGitBranch is validated by ValidateGitBranch() in the controller
611613
checkoutCommand := exec.Command("git", clonedDirectory, workTree, "checkout", d.Option.CustomTaskGitBranch)
612614
if output, err := checkoutCommand.CombinedOutput(); err != nil {
613615
return fmt.Errorf("failed to checkout branch %s: %v, output: %s",
614616
d.Option.CustomTaskGitBranch, err, string(output))
615617
}
616618
} else {
619+
// #nosec G204 -- DefaultGitBranch is a constant value, not user input
617620
checkoutCmd := exec.Command("git", clonedDirectory, workTree, "checkout", DefaultGitBranch)
618621
if output, err := checkoutCmd.CombinedOutput(); err != nil {
619622
d.Option.Logger.Info("failed to checkout main branch, using default branch from clone",
@@ -623,6 +626,7 @@ func (d *driverImpl) fetchGitCustomTasks() error {
623626

624627
// Checkout a specific commit, if specified
625628
if d.Option.CustomTaskGitCommit != "" {
629+
// #nosec G204 -- CustomTaskGitCommit is validated by ValidateGitCommit() in the controller
626630
checkoutCommand := exec.Command("git", clonedDirectory, workTree, "checkout", d.Option.CustomTaskGitCommit)
627631
if output, err := checkoutCommand.CombinedOutput(); err != nil {
628632
return fmt.Errorf("failed to checkout commit %s: %v, output: %s",
@@ -644,6 +648,7 @@ func (d *driverImpl) fetchGitCustomTasks() error {
644648
return err
645649
}
646650

651+
// #nosec G204 -- taskPath is derived from validated CustomTaskGitPath, TaskRecipesPath is controlled by the application
647652
copyCmd := exec.Command("cp", "-r", taskPath+"/.", d.Option.TaskRecipesPath)
648653
output, err := copyCmd.CombinedOutput()
649654
if err != nil {

controllers/lmes/driver/driver_test.go

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -234,31 +234,26 @@ func Test_PatchDevice(t *testing.T) {
234234
OutputPath: ".",
235235
DetectDevice: true,
236236
Logger: driverLog,
237-
Args: []string{"sh", "-ec", "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2"},
237+
Args: []string{"python", "-m", "lm_eval", "--output_path", "/opt/app-root/src/output", "--model", "test", "--model_args", "arg1=value1", "--tasks", "task1,task2"},
238238
}
239239

240240
// append `--device cuda`
241-
patchDevice(driverOpt.Args, true)
242-
assert.Equal(t,
243-
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --device cuda",
244-
driverOpt.Args[2],
245-
)
241+
driverOpt.Args = patchDevice(driverOpt.Args, true)
242+
expected := []string{"python", "-m", "lm_eval", "--output_path", "/opt/app-root/src/output", "--model", "test", "--model_args", "arg1=value1", "--tasks", "task1,task2", "--device", "cuda"}
243+
assert.Equal(t, expected, driverOpt.Args)
246244

247245
// append `--device cpu`
248-
driverOpt.Args = []string{"sh", "-ec", "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2"}
249-
patchDevice(driverOpt.Args, false)
250-
assert.Equal(t,
251-
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --device cpu",
252-
driverOpt.Args[2],
253-
)
254-
255-
// no change because `--device cpu` exists
256-
driverOpt.Args = []string{"sh", "-ec", "python -m lm_eval --device cpu --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2"}
257-
patchDevice(driverOpt.Args, true)
258-
assert.Equal(t,
259-
"python -m lm_eval --device cpu --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2",
260-
driverOpt.Args[2],
261-
)
246+
driverOpt.Args = []string{"python", "-m", "lm_eval", "--output_path", "/opt/app-root/src/output", "--model", "test", "--model_args", "arg1=value1", "--tasks", "task1,task2"}
247+
driverOpt.Args = patchDevice(driverOpt.Args, false)
248+
expected = []string{"python", "-m", "lm_eval", "--output_path", "/opt/app-root/src/output", "--model", "test", "--model_args", "arg1=value1", "--tasks", "task1,task2", "--device", "cpu"}
249+
assert.Equal(t, expected, driverOpt.Args)
250+
251+
// no change because `--device` already exists
252+
driverOpt.Args = []string{"python", "-m", "lm_eval", "--device", "cpu", "--output_path", "/opt/app-root/src/output", "--model", "test", "--model_args", "arg1=value1", "--tasks", "task1,task2"}
253+
originalArgs := make([]string, len(driverOpt.Args))
254+
copy(originalArgs, driverOpt.Args)
255+
driverOpt.Args = patchDevice(driverOpt.Args, true)
256+
assert.Equal(t, originalArgs, driverOpt.Args)
262257
}
263258

264259
func Test_TaskRecipes(t *testing.T) {

0 commit comments

Comments
 (0)