Skip to content

Commit 9fb2a66

Browse files
committed
refactor, Reflect that model IDs now hold attributes as well
1 parent 3a66bf5 commit 9fb2a66

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

cmd/eval-dev-quality/cmd/evaluate.go

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ type Evaluate struct {
5353

5454
// Languages determines which language should be used for the evaluation, or empty if all languages should be used.
5555
Languages []string `long:"language" description:"Evaluate with this language. By default all languages are used."`
56-
// Models determines which models should be used for the evaluation, or empty if all models should be used.
57-
Models []string `long:"model" description:"Evaluate with this model. By default all models are used."`
56+
// ModelIDsWithProviderAndAttributes determines which models should be used for the evaluation, or empty if all models should be used.
57+
ModelIDsWithProviderAndAttributes []string `long:"model" description:"Evaluate with this model. By default all models are used."`
5858
// ProviderTokens holds all API tokens for the providers.
5959
ProviderTokens map[string]string `long:"tokens" description:"API tokens for model providers (of the form '$provider:$token'). When using the environment variable, separate multiple definitions with ','." env:"PROVIDER_TOKEN" env-delim:","`
6060
// ProviderUrls holds all custom inference endpoint urls for the providers.
@@ -123,7 +123,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
123123
command.logger.Panicf("the configuration file is not supported in containerized runtimes")
124124
}
125125

126-
if len(command.Models) > 0 || len(command.Repositories) > 0 {
126+
if len(command.ModelIDsWithProviderAndAttributes) > 0 || len(command.Repositories) > 0 {
127127
command.logger.Panicf("do not provide models and repositories when loading a configuration file")
128128
}
129129

@@ -139,7 +139,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
139139
command.logger.Panicf("ERROR: %s", err)
140140
}
141141

142-
command.Models = configuration.Models.Selected
142+
command.ModelIDsWithProviderAndAttributes = configuration.Models.Selected
143143
command.Repositories = configuration.Repositories.Selected
144144
}
145145

@@ -258,7 +258,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
258258
// In a containerized runtime we check the availability of the testdata, repositories and models/providers inside the container.
259259
if command.Runtime != "local" {
260260
// Copy the models over.
261-
for _, modelID := range command.Models {
261+
for _, modelID := range command.ModelIDsWithProviderAndAttributes {
262262
evaluationContext.Models = append(evaluationContext.Models, llm.NewModel(nil, modelID))
263263
}
264264

@@ -376,10 +376,10 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
376376
{
377377
// Check which providers are needed for the evaluation.
378378
providersSelected := map[string]provider.Provider{}
379-
if len(command.Models) == 0 {
379+
if len(command.ModelIDsWithAttributes) == 0 {
380380
providersSelected = provider.Providers
381381
} else {
382-
for _, model := range command.Models {
382+
for _, model := range command.ModelIDsWithAttributes {
383383
p := strings.SplitN(model, provider.ProviderModelSeparator, 2)[0]
384384

385385
if _, ok := providersSelected[p]; ok {
@@ -425,7 +425,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
425425
// Check if a provider has the ability to pull models and do so if necessary.
426426
if puller, ok := p.(provider.Puller); ok {
427427
command.logger.Info("pulling available models for provider", "provider", p.ID())
428-
for _, modelID := range command.Models {
428+
for _, modelID := range command.ModelIDsWithAttributes {
429429
if !strings.HasPrefix(modelID, p.ID()) { // TODO Move this into `NewModel` to validate that a model belongs to a provider.
430430
panic(fmt.Errorf("model %s does not belong to provider %s", modelID, p.ID()))
431431
}
@@ -449,23 +449,23 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
449449
}
450450
modelIDs := maps.Keys(models)
451451
sort.Strings(modelIDs)
452-
if len(command.Models) == 0 {
453-
command.Models = modelIDs
452+
if len(command.ModelIDsWithAttributes) == 0 {
453+
command.ModelIDsWithAttributes = modelIDs
454454
} else {
455-
for _, modelID := range command.Models {
455+
for _, modelID := range command.ModelIDsWithAttributes {
456456
if _, ok := models[modelID]; !ok {
457457
command.logger.Panicf("ERROR: model %s does not exist. Valid models are: %s", modelID, strings.Join(modelIDs, ", "))
458458
}
459459
}
460460
}
461-
sort.Strings(command.Models)
462-
for _, modelID := range command.Models {
461+
sort.Strings(command.ModelIDsWithAttributes)
462+
for _, modelID := range command.ModelIDsWithAttributes {
463463
modelsSelected[modelID] = models[modelID]
464464
}
465465

466466
// Make the resolved selected models available in the command.
467-
evaluationContext.Models = make([]model.Model, len(command.Models))
468-
for i, modelID := range command.Models {
467+
evaluationContext.Models = make([]model.Model, len(command.ModelIDsWithAttributes))
468+
for i, modelID := range command.ModelIDsWithAttributes {
469469
evaluationContext.Models[i] = modelsSelected[modelID]
470470
evaluationConfiguration.Models.Selected = append(evaluationConfiguration.Models.Selected, modelID)
471471
}

cmd/eval-dev-quality/cmd/evaluate_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,12 +1324,12 @@ func TestEvaluateInitialize(t *testing.T) {
13241324
Name: "Selecting no model defaults to all",
13251325

13261326
Command: makeValidCommand(func(command *Evaluate) {
1327-
command.Models = []string{}
1327+
command.ModelIDsWithProviderAndAttributes = []string{}
13281328
}),
13291329

13301330
// Could also select arbitrary Ollama or new Openrouter models so sanity check that at least symflower is there.
13311331
ValidateCommand: func(t *testing.T, command *Evaluate) {
1332-
assert.Contains(t, command.Models, "symflower/symbolic-execution")
1332+
assert.Contains(t, command.ModelIDsWithProviderAndAttributes, "symflower/symbolic-execution")
13331333
},
13341334
ValidateContext: func(t *testing.T, context *evaluate.Context) {
13351335
modelIDs := make([]string, len(context.Models))
@@ -1453,7 +1453,7 @@ func TestEvaluateInitialize(t *testing.T) {
14531453
ValidateCommand: func(t *testing.T, command *Evaluate) {
14541454
assert.Equal(t, []string{
14551455
"symflower/symbolic-execution",
1456-
}, command.Models)
1456+
}, command.ModelIDsWithProviderAndAttributes)
14571457
assert.Equal(t, []string{
14581458
filepath.Join("golang", "plain"),
14591459
filepath.Join("java", "plain"),

0 commit comments

Comments
 (0)