diff --git a/api/lmes/v1alpha1/lmevaljob_types.go b/api/lmes/v1alpha1/lmevaljob_types.go index 05f649c4f..ac2754222 100644 --- a/api/lmes/v1alpha1/lmevaljob_types.go +++ b/api/lmes/v1alpha1/lmevaljob_types.go @@ -152,6 +152,41 @@ type CustomArtifacts struct { Tasks []CustomArtifact `json:"tasks,omitempty"` } +// MCP settings for the retrieval system to support end-to-end RAG. It mainly leverages +// the tools API to search the context information including contexts_id and contexts. +// Check the contexts and contexts_id information in this page: +// https://www.unitxt.ai/en/latest/docs/rag_support.html +type MCP struct { + // The endpoint of the MCP server. For example: + // http://localhost:3000/mcp + // +kubebuilder:validation:Pattern=`^https?://[a-zA-Z0-9._/-]+$` + URL string `json:"url"` + // The Tool name of the MCP tool. Or the API name of the MCP tool. + // +kubebuilder:validation:Pattern=`^[a-zA-Z0-9._-]+$` + Tool string `json:"tool"` + // The field name in the MCP payload that contains the JSON string of the context information. + // Since the MCP payload may contain multiple records, each record is a JSON object contains the + // specified field name as the key and its value is the JSON string of the context information. + // The underlying process parses each JSON string of the context information and aggregate them + // into an array. + // +kubebuilder:validation:Pattern=`^[a-zA-Z0-9._-]+$` + PayloadField string `json:"payloadField"` + // The jsonpath to the context field in the array of the context information objects + // +kubebuilder:validation:Pattern=`^[a-zA-Z0-9._-]+$` + ContextField string `json:"contextField"` + // The jsonpath to the id field in the array of the context information objects + // +kubebuilder:validation:Pattern=`^[a-zA-Z0-9._-]+$` + IdField string `json:"idField"` + // Verify server's certificate if the server is using HTTPS. + // +kubebuilder:default:=true + VerifyCertificate bool `json:"verifyCertificate"` +} + +type RAG struct { + // The MCP settings. Currently, this is the only option for the end-to-end RAG, + MCP MCP `json:"mcp"` +} + func (c *CustomArtifacts) GetTemplates() []CustomArtifact { if c == nil { return nil @@ -210,6 +245,9 @@ type TaskRecipe struct { // The pool size for the fewshot // +optional DemosPoolSize *int `json:"demosPoolSize,omitempty"` + // Specify the RAG information if needed + // +optional + RAG *RAG `json:"rag,omitempty"` } // GitSource specifies the git location of external tasks @@ -313,6 +351,30 @@ func (t *Task) String() string { return "" } +func (r *RAG) String() string { + if r == nil { + return "" + } + return r.MCP.String() +} + +// compose the MCP settings like the following format: +// +// session: url=http://localhost:3002/mcp +// request: tool=search,query_field=text,context_field=context,id_field=id +func (m *MCP) String() string { + if m == nil { + return "" + } + var b strings.Builder + b.WriteString(fmt.Sprintf(" session: url=%s\n", m.URL)) + b.WriteString(fmt.Sprintf(" request: tool=%s,query_field=%s,context_field=%s,id_field=%s,verify_cert=%t", + m.Tool, m.PayloadField, m.ContextField, m.IdField, m.VerifyCertificate)) + // End-2-end RAG requires the following settings: + b.WriteString("\nprocess_docs: !function ###UNITXT_PATH###/utils.process_docs\nprocess_results: !function ###UNITXT_PATH###/utils.postprocess_docs") + return b.String() +} + // Use the tp_idx and sp_idx to point to the corresponding custom template // and custom system_prompt func (t *TaskRecipe) String() string { @@ -346,6 +408,9 @@ func (t *TaskRecipe) String() string { if t.DemosPoolSize != nil { b.WriteString(fmt.Sprintf(",demos_pool_size=%d", *t.DemosPoolSize)) } + if t.RAG != nil { + b.WriteString(fmt.Sprintf("\nrag:\n%s", t.RAG.String())) + } return b.String() } diff --git a/api/lmes/v1alpha1/zz_generated.deepcopy.go b/api/lmes/v1alpha1/zz_generated.deepcopy.go index 22b60b5de..612ebca91 100644 --- a/api/lmes/v1alpha1/zz_generated.deepcopy.go +++ b/api/lmes/v1alpha1/zz_generated.deepcopy.go @@ -413,6 +413,21 @@ func (in *LMEvalPodSpec) DeepCopy() *LMEvalPodSpec { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MCP) DeepCopyInto(out *MCP) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MCP. +func (in *MCP) DeepCopy() *MCP { + if in == nil { + return nil + } + out := new(MCP) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Metric) DeepCopyInto(out *Metric) { *out = *in @@ -539,6 +554,22 @@ func (in *PersistentVolumeClaimManaged) DeepCopy() *PersistentVolumeClaimManaged return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *RAG) DeepCopyInto(out *RAG) { + *out = *in + out.MCP = in.MCP +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RAG. +func (in *RAG) DeepCopy() *RAG { + if in == nil { + return nil + } + out := new(RAG) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *SystemPrompt) DeepCopyInto(out *SystemPrompt) { *out = *in @@ -650,6 +681,11 @@ func (in *TaskRecipe) DeepCopyInto(out *TaskRecipe) { *out = new(int) **out = **in } + if in.RAG != nil { + in, out := &in.RAG, &out.RAG + *out = new(RAG) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TaskRecipe. diff --git a/config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml b/config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml index 029c8d4fc..e0be4d4af 100644 --- a/config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml +++ b/config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml @@ -4909,6 +4909,59 @@ spec: numDemos: description: Number of fewshot type: integer + rag: + description: Specify the RAG information if needed + properties: + mcp: + description: The MCP settings. Currently, this is the + only option for the end-to-end RAG, + properties: + contextField: + description: The jsonpath to the context field in + the array of the context information objects + pattern: ^[a-zA-Z0-9._-]+$ + type: string + idField: + description: The jsonpath to the id field in the + array of the context information objects + pattern: ^[a-zA-Z0-9._-]+$ + type: string + payloadField: + description: |- + The field name in the MCP payload that contains the JSON string of the context information. + Since the MCP payload may contain multiple records, each record is a JSON object contains the + specified field name as the key and its value is the JSON string of the context information. + The underlying process parses each JSON string of the context information and aggregate them + into an array. + pattern: ^[a-zA-Z0-9._-]+$ + type: string + tool: + description: The Tool name of the MCP tool. Or the + API name of the MCP tool. + pattern: ^[a-zA-Z0-9._-]+$ + type: string + url: + description: |- + The endpoint of the MCP server. For example: + http://localhost:3000/mcp + pattern: ^https?://[a-zA-Z0-9._/-]+$ + type: string + verifyCertificate: + default: true + description: Verify server's certificate if the + server is using HTTPS. + type: boolean + required: + - contextField + - idField + - payloadField + - tool + - url + - verifyCertificate + type: object + required: + - mcp + type: object systemPrompt: description: The Unitxt System Prompt properties: diff --git a/controllers/lmes/driver/driver.go b/controllers/lmes/driver/driver.go index 415e84453..0222f2721 100644 --- a/controllers/lmes/driver/driver.go +++ b/controllers/lmes/driver/driver.go @@ -52,6 +52,8 @@ const ( ShutdownURI = "/Shutdown" GetStatusURI = "/GetStatus" DefaultGitBranch = "main" + UnitxtPath = "/opt/app-root/src/lm_eval/tasks/unitxt" + UnitxtPattern = "###UNITXT_PATH###" ) type DriverOption struct { @@ -513,7 +515,7 @@ func (d *driverImpl) createTaskRecipes() error { []byte(fmt.Sprintf( "task: %s\ninclude: unitxt\nrecipe: %s", fmt.Sprintf("%s_%d", TaskRecipePrefix, i), - taskRecipe, + strings.Replace(taskRecipe, UnitxtPattern, UnitxtPath, -1), )), 0666, ) diff --git a/controllers/lmes/driver/driver_test.go b/controllers/lmes/driver/driver_test.go index e4cbfb62c..33954b38c 100644 --- a/controllers/lmes/driver/driver_test.go +++ b/controllers/lmes/driver/driver_test.go @@ -299,6 +299,42 @@ func Test_TaskRecipes(t *testing.T) { ) } +func Test_TaskRecipesWithRAG(t *testing.T) { + info := setupTest(t, true) + defer info.tearDown(t) + + driver, err := NewDriver(&DriverOption{ + Context: context.Background(), + OutputPath: info.outputPath, + CatalogPath: info.catalogPath, + Logger: driverLog, + TaskRecipesPath: info.taskPath, + TaskRecipes: []string{ + "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10\nrag:\n session: url=https://localhost:3002/mcp\n. request: tool=search,query_field=text,context_field=context,id_field=id,verify_cert=true\nprocess_docs: !function ###UNITXT_PATH###/utils.process_docs\nprocess_results: !function ###UNITXT_PATH###/utils.postprocess_docs", + }, + Args: []string{"sh", "-ec", "sleep 2; echo 'testing progress: 100%|' >&2; sleep 4"}, + CommPort: info.port, + }) + assert.Nil(t, err) + + msgs, _ := runDriverAndWait4Complete(t, driver, false) + + assert.Equal(t, []string{ + "initializing the evaluation job", + "testing progress: 100%", + "job completed", + }, msgs) + + assert.Nil(t, driver.Shutdown()) + + tr0, err := os.ReadFile(filepath.Join(info.taskPath, "tr_0.yaml")) + assert.Nil(t, err) + assert.Equal(t, + "task: tr_0\ninclude: unitxt\nrecipe: card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10\nrag:\n session: url=https://localhost:3002/mcp\n. request: tool=search,query_field=text,context_field=context,id_field=id,verify_cert=true\nprocess_docs: !function /opt/app-root/src/lm_eval/tasks/unitxt/utils.process_docs\nprocess_results: !function /opt/app-root/src/lm_eval/tasks/unitxt/utils.postprocess_docs", + string(tr0), + ) +} + func Test_CustomCards(t *testing.T) { info := setupTest(t, true) defer info.tearDown(t) diff --git a/controllers/lmes/lmevaljob_controller_test.go b/controllers/lmes/lmevaljob_controller_test.go index e4971781b..8a88305f4 100644 --- a/controllers/lmes/lmevaljob_controller_test.go +++ b/controllers/lmes/lmevaljob_controller_test.go @@ -1049,6 +1049,16 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) { Metrics: []lmesv1alpha1.Metric{{Name: "unitxt.metric3"}, {Name: "unitxt.metric4"}}, NumDemos: &numDemos, DemosPoolSize: &demosPoolSize, + RAG: &lmesv1alpha1.RAG{ + MCP: lmesv1alpha1.MCP{ + URL: "https://localhost:3002/mcp", + Tool: "search", + PayloadField: "text", + ContextField: "context", + IdField: "id", + VerifyCertificate: true, + }, + }, }, ) @@ -1061,7 +1071,7 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) { "/opt/app-root/src/bin/driver", "--output-path", "/opt/app-root/src/output", "--task-recipe", "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", - "--task-recipe", "card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10", + "--task-recipe", "card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10\nrag:\n session: url=https://localhost:3002/mcp\n request: tool=search,query_field=text,context_field=context,id_field=id,verify_cert=true\nprocess_docs: !function ###UNITXT_PATH###/utils.process_docs\nprocess_results: !function ###UNITXT_PATH###/utils.postprocess_docs", "--", }, generateCmd(svcOpts, job)) } diff --git a/controllers/lmes/validation.go b/controllers/lmes/validation.go index 64ad18237..84452a624 100644 --- a/controllers/lmes/validation.go +++ b/controllers/lmes/validation.go @@ -122,7 +122,7 @@ func ValidateUserInput(job *lmesv1alpha1.LMEvalJob) error { // Validate custom task git source if job.Spec.TaskList.CustomTasks != nil && len(job.Spec.TaskList.TaskNames) > 0 { gitSource := job.Spec.TaskList.CustomTasks.Source.GitSource - if err := ValidateGitURL(gitSource.URL); err != nil { + if err := ValidateURL(gitSource.URL); err != nil { return fmt.Errorf("invalid git URL: %w", err) } if gitSource.Path != "" { @@ -247,6 +247,13 @@ func ValidateTaskRecipes(args []lmesv1alpha1.TaskRecipe, argType string) error { } } + // validate RAG settings if presents + if arg.RAG != nil { + if err := ValidateMCP(&arg.RAG.MCP); err != nil { + return fmt.Errorf("%s[%d] RAG: %w", argType, i, err) + } + } + } return nil } @@ -458,20 +465,20 @@ func ValidateChatTemplateName(name string) error { return nil } -// ValidateGitURL validates git repository URLs -func ValidateGitURL(url string) error { +// ValidateURL validates git repository URLs +func ValidateURL(url string) error { if url == "" { - return fmt.Errorf("git URL cannot be empty") + return fmt.Errorf("URL cannot be empty") } // Check for shell metacharacters if ContainsShellMetacharacters(url) { - return fmt.Errorf("git URL contains shell metacharacters") + return fmt.Errorf("URL contains shell metacharacters") } // Must be HTTPS URL for security if !regexp.MustCompile(`^https://[a-zA-Z0-9._/-]+$`).MatchString(url) { - return fmt.Errorf("git URL must be a valid HTTPS URL (only alphanumeric, ., _, /, - allowed)") + return fmt.Errorf("URL must be a valid HTTPS URL (only alphanumeric, ., _, /, - allowed)") } return nil @@ -529,3 +536,34 @@ func ValidateGitCommit(commit string) error { return nil } + +func ValidateMCP(mcp *lmesv1alpha1.MCP) error { + if mcp == nil { + return nil + } + // only support MCP for now, so directly check the MCP settings + if err := ValidateURL(mcp.URL); err != nil { + return fmt.Errorf("invalid MCP URL: %w", err) + } + + // put restriction to only allow alphanumeric, hyphen, underscore, and period + rePat := regexp.MustCompile(`^[a-zA-Z0-9._-]+$`) + + if !rePat.MatchString(mcp.Tool) { + return fmt.Errorf("MCP tool contains invalid characters (only alphanumeric, ., _, - allowed)") + } + + if !rePat.MatchString(mcp.PayloadField) { + return fmt.Errorf("MCP payloadField contains invalid characters (only alphanumeric, ., _, - allowed)") + } + + if !rePat.MatchString(mcp.ContextField) { + return fmt.Errorf("MCP contextField contains invalid characters (only alphanumeric, ., _, - allowed)") + } + + if !rePat.MatchString(mcp.IdField) { + return fmt.Errorf("MCP idField contains invalid characters (only alphanumeric, ., _, - allowed)") + } + + return nil +}