Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions api/lmes/v1alpha1/lmevaljob_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,41 @@ type CustomArtifacts struct {
Tasks []CustomArtifact `json:"tasks,omitempty"`
}

// MCP settings for the retrieval system to support end-to-end RAG. It mainly leverages
// the tools API to search the context information including contexts_id and contexts.
// Check the contexts and contexts_id information in this page:
// https://www.unitxt.ai/en/latest/docs/rag_support.html
type MCP struct {
// The endpoint of the MCP server. For example:
// http://localhost:3000/mcp
// +kubebuilder:validation:Pattern=`^https?://[a-zA-Z0-9._/-]+$`
URL string `json:"url"`
// The Tool name of the MCP tool. Or the API name of the MCP tool.
// +kubebuilder:validation:Pattern=`^[a-zA-Z0-9._-]+$`
Tool string `json:"tool"`
// The field name in the MCP payload that contains the JSON string of the context information.
// Since the MCP payload may contain multiple records, each record is a JSON object contains the
// specified field name as the key and its value is the JSON string of the context information.
// The underlying process parses each JSON string of the context information and aggregate them
// into an array.
// +kubebuilder:validation:Pattern=`^[a-zA-Z0-9._-]+$`
PayloadField string `json:"payloadField"`
// The jsonpath to the context field in the array of the context information objects
// +kubebuilder:validation:Pattern=`^[a-zA-Z0-9._-]+$`
ContextField string `json:"contextField"`
// The jsonpath to the id field in the array of the context information objects
// +kubebuilder:validation:Pattern=`^[a-zA-Z0-9._-]+$`
IdField string `json:"idField"`
// Verify server's certificate if the server is using HTTPS.
// +kubebuilder:default:=true
VerifyCertificate bool `json:"verifyCertificate"`
}

type RAG struct {
// The MCP settings. Currently, this is the only option for the end-to-end RAG,
MCP MCP `json:"mcp"`
}

func (c *CustomArtifacts) GetTemplates() []CustomArtifact {
if c == nil {
return nil
Expand Down Expand Up @@ -210,6 +245,9 @@ type TaskRecipe struct {
// The pool size for the fewshot
// +optional
DemosPoolSize *int `json:"demosPoolSize,omitempty"`
// Specify the RAG information if needed
// +optional
RAG *RAG `json:"rag,omitempty"`
}

// GitSource specifies the git location of external tasks
Expand Down Expand Up @@ -313,6 +351,30 @@ func (t *Task) String() string {
return ""
}

func (r *RAG) String() string {
if r == nil {
return ""
}
return r.MCP.String()
}

// compose the MCP settings like the following format:
//
// session: url=http://localhost:3002/mcp
// request: tool=search,query_field=text,context_field=context,id_field=id
func (m *MCP) String() string {
if m == nil {
return ""
}
var b strings.Builder
b.WriteString(fmt.Sprintf(" session: url=%s\n", m.URL))
b.WriteString(fmt.Sprintf(" request: tool=%s,query_field=%s,context_field=%s,id_field=%s,verify_cert=%t",
m.Tool, m.PayloadField, m.ContextField, m.IdField, m.VerifyCertificate))
// End-2-end RAG requires the following settings:
b.WriteString("\nprocess_docs: !function ###UNITXT_PATH###/utils.process_docs\nprocess_results: !function ###UNITXT_PATH###/utils.postprocess_docs")
return b.String()
}

// Use the tp_idx and sp_idx to point to the corresponding custom template
// and custom system_prompt
func (t *TaskRecipe) String() string {
Expand Down Expand Up @@ -346,6 +408,9 @@ func (t *TaskRecipe) String() string {
if t.DemosPoolSize != nil {
b.WriteString(fmt.Sprintf(",demos_pool_size=%d", *t.DemosPoolSize))
}
if t.RAG != nil {
b.WriteString(fmt.Sprintf("\nrag:\n%s", t.RAG.String()))
}
return b.String()
}

Expand Down
36 changes: 36 additions & 0 deletions api/lmes/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 53 additions & 0 deletions config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4909,6 +4909,59 @@ spec:
numDemos:
description: Number of fewshot
type: integer
rag:
description: Specify the RAG information if needed
properties:
mcp:
description: The MCP settings. Currently, this is the
only option for the end-to-end RAG,
properties:
contextField:
description: The jsonpath to the context field in
the array of the context information objects
pattern: ^[a-zA-Z0-9._-]+$
type: string
idField:
description: The jsonpath to the id field in the
array of the context information objects
pattern: ^[a-zA-Z0-9._-]+$
type: string
payloadField:
description: |-
The field name in the MCP payload that contains the JSON string of the context information.
Since the MCP payload may contain multiple records, each record is a JSON object contains the
specified field name as the key and its value is the JSON string of the context information.
The underlying process parses each JSON string of the context information and aggregate them
into an array.
pattern: ^[a-zA-Z0-9._-]+$
type: string
tool:
description: The Tool name of the MCP tool. Or the
API name of the MCP tool.
pattern: ^[a-zA-Z0-9._-]+$
type: string
url:
description: |-
The endpoint of the MCP server. For example:
http://localhost:3000/mcp
pattern: ^https?://[a-zA-Z0-9._/-]+$
type: string
verifyCertificate:
default: true
description: Verify server's certificate if the
server is using HTTPS.
type: boolean
required:
- contextField
- idField
- payloadField
- tool
- url
- verifyCertificate
type: object
required:
- mcp
type: object
systemPrompt:
description: The Unitxt System Prompt
properties:
Expand Down
4 changes: 3 additions & 1 deletion controllers/lmes/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ const (
ShutdownURI = "/Shutdown"
GetStatusURI = "/GetStatus"
DefaultGitBranch = "main"
UnitxtPath = "/opt/app-root/src/lm_eval/tasks/unitxt"
UnitxtPattern = "###UNITXT_PATH###"
)

type DriverOption struct {
Expand Down Expand Up @@ -513,7 +515,7 @@ func (d *driverImpl) createTaskRecipes() error {
[]byte(fmt.Sprintf(
"task: %s\ninclude: unitxt\nrecipe: %s",
fmt.Sprintf("%s_%d", TaskRecipePrefix, i),
taskRecipe,
strings.Replace(taskRecipe, UnitxtPattern, UnitxtPath, -1),
)),
0666,
)
Expand Down
36 changes: 36 additions & 0 deletions controllers/lmes/driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,42 @@ func Test_TaskRecipes(t *testing.T) {
)
}

func Test_TaskRecipesWithRAG(t *testing.T) {
info := setupTest(t, true)
defer info.tearDown(t)

driver, err := NewDriver(&DriverOption{
Context: context.Background(),
OutputPath: info.outputPath,
CatalogPath: info.catalogPath,
Logger: driverLog,
TaskRecipesPath: info.taskPath,
TaskRecipes: []string{
"card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10\nrag:\n session: url=https://localhost:3002/mcp\n. request: tool=search,query_field=text,context_field=context,id_field=id,verify_cert=true\nprocess_docs: !function ###UNITXT_PATH###/utils.process_docs\nprocess_results: !function ###UNITXT_PATH###/utils.postprocess_docs",
},
Args: []string{"sh", "-ec", "sleep 2; echo 'testing progress: 100%|' >&2; sleep 4"},
CommPort: info.port,
})
assert.Nil(t, err)

msgs, _ := runDriverAndWait4Complete(t, driver, false)

assert.Equal(t, []string{
"initializing the evaluation job",
"testing progress: 100%",
"job completed",
}, msgs)

assert.Nil(t, driver.Shutdown())

tr0, err := os.ReadFile(filepath.Join(info.taskPath, "tr_0.yaml"))
assert.Nil(t, err)
assert.Equal(t,
"task: tr_0\ninclude: unitxt\nrecipe: card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10\nrag:\n session: url=https://localhost:3002/mcp\n. request: tool=search,query_field=text,context_field=context,id_field=id,verify_cert=true\nprocess_docs: !function /opt/app-root/src/lm_eval/tasks/unitxt/utils.process_docs\nprocess_results: !function /opt/app-root/src/lm_eval/tasks/unitxt/utils.postprocess_docs",
string(tr0),
)
}

func Test_CustomCards(t *testing.T) {
info := setupTest(t, true)
defer info.tearDown(t)
Expand Down
12 changes: 11 additions & 1 deletion controllers/lmes/lmevaljob_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,16 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) {
Metrics: []lmesv1alpha1.Metric{{Name: "unitxt.metric3"}, {Name: "unitxt.metric4"}},
NumDemos: &numDemos,
DemosPoolSize: &demosPoolSize,
RAG: &lmesv1alpha1.RAG{
MCP: lmesv1alpha1.MCP{
URL: "https://localhost:3002/mcp",
Tool: "search",
PayloadField: "text",
ContextField: "context",
IdField: "id",
VerifyCertificate: true,
},
},
},
)

Expand All @@ -1061,7 +1071,7 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) {
"/opt/app-root/src/bin/driver",
"--output-path", "/opt/app-root/src/output",
"--task-recipe", "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10",
"--task-recipe", "card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10",
"--task-recipe", "card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10\nrag:\n session: url=https://localhost:3002/mcp\n request: tool=search,query_field=text,context_field=context,id_field=id,verify_cert=true\nprocess_docs: !function ###UNITXT_PATH###/utils.process_docs\nprocess_results: !function ###UNITXT_PATH###/utils.postprocess_docs",
"--",
}, generateCmd(svcOpts, job))
}
Expand Down
50 changes: 44 additions & 6 deletions controllers/lmes/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func ValidateUserInput(job *lmesv1alpha1.LMEvalJob) error {
// Validate custom task git source
if job.Spec.TaskList.CustomTasks != nil && len(job.Spec.TaskList.TaskNames) > 0 {
gitSource := job.Spec.TaskList.CustomTasks.Source.GitSource
if err := ValidateGitURL(gitSource.URL); err != nil {
if err := ValidateURL(gitSource.URL); err != nil {
return fmt.Errorf("invalid git URL: %w", err)
}
if gitSource.Path != "" {
Expand Down Expand Up @@ -247,6 +247,13 @@ func ValidateTaskRecipes(args []lmesv1alpha1.TaskRecipe, argType string) error {
}
}

// validate RAG settings if presents
if arg.RAG != nil {
if err := ValidateMCP(&arg.RAG.MCP); err != nil {
return fmt.Errorf("%s[%d] RAG: %w", argType, i, err)
}
}

}
return nil
}
Expand Down Expand Up @@ -458,20 +465,20 @@ func ValidateChatTemplateName(name string) error {
return nil
}

// ValidateGitURL validates git repository URLs
func ValidateGitURL(url string) error {
// ValidateURL validates git repository URLs
func ValidateURL(url string) error {
if url == "" {
return fmt.Errorf("git URL cannot be empty")
return fmt.Errorf("URL cannot be empty")
}

// Check for shell metacharacters
if ContainsShellMetacharacters(url) {
return fmt.Errorf("git URL contains shell metacharacters")
return fmt.Errorf("URL contains shell metacharacters")
}

// Must be HTTPS URL for security
if !regexp.MustCompile(`^https://[a-zA-Z0-9._/-]+$`).MatchString(url) {
return fmt.Errorf("git URL must be a valid HTTPS URL (only alphanumeric, ., _, /, - allowed)")
return fmt.Errorf("URL must be a valid HTTPS URL (only alphanumeric, ., _, /, - allowed)")
}

return nil
Expand Down Expand Up @@ -529,3 +536,34 @@ func ValidateGitCommit(commit string) error {

return nil
}

func ValidateMCP(mcp *lmesv1alpha1.MCP) error {
if mcp == nil {
return nil
}
// only support MCP for now, so directly check the MCP settings
if err := ValidateURL(mcp.URL); err != nil {
return fmt.Errorf("invalid MCP URL: %w", err)
}

// put restriction to only allow alphanumeric, hyphen, underscore, and period
rePat := regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)

if !rePat.MatchString(mcp.Tool) {
return fmt.Errorf("MCP tool contains invalid characters (only alphanumeric, ., _, - allowed)")
}

if !rePat.MatchString(mcp.PayloadField) {
return fmt.Errorf("MCP payloadField contains invalid characters (only alphanumeric, ., _, - allowed)")
}

if !rePat.MatchString(mcp.ContextField) {
return fmt.Errorf("MCP contextField contains invalid characters (only alphanumeric, ., _, - allowed)")
}

if !rePat.MatchString(mcp.IdField) {
return fmt.Errorf("MCP idField contains invalid characters (only alphanumeric, ., _, - allowed)")
}

return nil
}
Loading