Skip to content

Commit 315cb3b

Browse files
committed
Redo how we load providers and models from the CLI to allow for attributes
Part of #407
1 parent 2344479 commit 315cb3b

File tree

4 files changed

+153
-115
lines changed

4 files changed

+153
-115
lines changed

cmd/eval-dev-quality/cmd/evaluate.go

Lines changed: 123 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -341,133 +341,162 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
341341
evaluationContext.Languages[i] = languagesSelected[languageID]
342342
}
343343

344-
// Register custom OpenAI API providers and models.
344+
// Gather models and initialize providers.
345+
var serviceShutdown []func() (err error)
345346
{
346-
customProviders := map[string]*openaiapi.Provider{}
347-
for providerID, providerURL := range command.ProviderUrls {
348-
if !strings.HasPrefix(providerID, "custom-") {
349-
continue
347+
// Gather providers.
348+
providers := map[string]provider.Provider{}
349+
if len(command.ModelIDsWithProviderAndAttributes) == 0 {
350+
for providerID, provider := range provider.Providers {
351+
providers[providerID] = provider
352+
command.logger.Info("selected provider", "provider", providerID)
350353
}
354+
} else {
355+
// Register custom providers.
356+
for providerID, providerURL := range command.ProviderUrls {
357+
if !strings.HasPrefix(providerID, "custom-") {
358+
command.logger.Panicf("ERROR: cannot set URL of %q because it is not a custom provider", providerID)
359+
}
351360

352-
p := openaiapi.NewProvider(providerID, providerURL)
353-
provider.Register(p)
354-
customProviders[providerID] = p
355-
}
356-
for _, model := range command.ModelIDsWithAttributes {
357-
if !strings.HasPrefix(model, "custom-") {
358-
continue
361+
p := openaiapi.NewProvider(providerID, providerURL)
362+
provider.Register(p)
363+
providers[providerID] = p
364+
command.logger.Info("selected provider", "provider", providerID)
359365
}
360366

361-
providerID, _, ok := strings.Cut(model, provider.ProviderModelSeparator)
362-
if !ok {
363-
command.logger.Panicf("ERROR: cannot split %q into provider and model name by %q", model, provider.ProviderModelSeparator)
364-
}
365-
modelProvider, ok := customProviders[providerID]
366-
if !ok {
367-
command.logger.Panicf("ERROR: unknown custom provider %q for model %q", providerID, model)
368-
}
367+
// Add remaining providers from models.
368+
for _, modelIDsWithProviderAndAttributes := range command.ModelIDsWithProviderAndAttributes {
369+
providerID, _, ok := strings.Cut(modelIDsWithProviderAndAttributes, provider.ProviderModelSeparator)
370+
if !ok {
371+
command.logger.Panicf("ERROR: cannot split %q into provider and model name by %q", modelIDsWithProviderAndAttributes, provider.ProviderModelSeparator)
372+
}
369373

370-
modelProvider.AddModel(llm.NewModel(modelProvider, model))
374+
p, ok := provider.Providers[providerID]
375+
if !ok {
376+
command.logger.Panicf("ERROR: unknown provider %q for model %q", providerID, modelIDsWithProviderAndAttributes)
377+
}
378+
if _, ok := providers[providerID]; !ok {
379+
providers[providerID] = p
380+
command.logger.Info("selected provider", "provider", providerID)
381+
}
382+
}
371383
}
372-
}
373384

374-
// Gather models.
375-
var serviceShutdown []func() (err error)
376-
{
377-
// Check which providers are needed for the evaluation.
378-
providersSelected := map[string]provider.Provider{}
379-
if len(command.ModelIDsWithAttributes) == 0 {
380-
providersSelected = provider.Providers
381-
} else {
382-
for _, model := range command.ModelIDsWithAttributes {
383-
p := strings.SplitN(model, provider.ProviderModelSeparator, 2)[0]
385+
// Initialize providers.
386+
{
387+
providerIDsSorted := maps.Keys(providers)
388+
sort.Strings(providerIDsSorted)
389+
for _, providerID := range providerIDsSorted {
390+
p := providers[providerID]
384391

385-
if _, ok := providersSelected[p]; ok {
386-
continue
392+
command.logger.Info("initializing provider", "provider", providerID)
393+
if t, ok := p.(provider.InjectToken); ok {
394+
if token, ok := command.ProviderTokens[p.ID()]; ok {
395+
command.logger.Info("set token of provider", "provider", providerID)
396+
t.SetToken(token)
397+
}
387398
}
399+
command.logger.Info("checking availability for provider", "provider", providerID)
400+
if err := p.Available(command.logger); err != nil {
401+
command.logger.Info("skipping provider because it is not available", "error", err, "provider", providerID)
402+
delete(providers, providerID)
388403

389-
if provider, ok := provider.Providers[p]; !ok {
390-
command.logger.Panicf("Provider %q does not exist", p)
391-
} else {
392-
providersSelected[provider.ID()] = provider
404+
continue
405+
}
406+
if service, ok := p.(provider.Service); ok {
407+
command.logger.Info("starting services for provider", "provider", p.ID())
408+
shutdown, err := service.Start(command.logger)
409+
if err != nil {
410+
command.logger.Panicf("ERROR: could not start services for provider %q: %s", p, err)
411+
}
412+
serviceShutdown = append(serviceShutdown, shutdown)
393413
}
394414
}
395415
}
396416

417+
// Gather models.
397418
models := map[string]model.Model{}
398-
modelsSelected := map[string]model.Model{}
399-
evaluationContext.ProviderForModel = map[model.Model]provider.Provider{}
400-
for _, p := range providersSelected {
401-
command.logger.Info("querying provider models", "provider", p.ID())
419+
{
420+
addAllModels := len(command.ModelIDsWithProviderAndAttributes) == 0
421+
for _, p := range providers {
422+
ms, err := p.Models()
423+
if err != nil {
424+
command.logger.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err)
425+
}
426+
for _, m := range ms {
427+
models[m.ID()] = m
428+
evaluationConfiguration.Models.Available = append(evaluationConfiguration.Models.Available, m.ID())
402429

403-
if t, ok := p.(provider.InjectToken); ok {
404-
token, ok := command.ProviderTokens[p.ID()]
405-
if ok {
406-
t.SetToken(token)
430+
if addAllModels {
431+
command.ModelIDsWithProviderAndAttributes = append(command.ModelIDsWithProviderAndAttributes, m.ID())
432+
}
407433
}
408434
}
409-
if err := p.Available(command.logger); err != nil {
410-
command.logger.Warn("skipping unavailable provider", "provider", p.ID(), "error", err)
435+
}
436+
modelIDs := maps.Keys(models)
437+
sort.Strings(modelIDs)
438+
sort.Strings(command.ModelIDsWithProviderAndAttributes)
411439

412-
continue
413-
}
440+
// Check and initialize models.
441+
evaluationContext.ProviderForModel = map[model.Model]provider.Provider{}
442+
for _, modelIDsWithProviderAndAttributes := range command.ModelIDsWithProviderAndAttributes {
443+
command.logger.Info("selecting model", "model", modelIDsWithProviderAndAttributes)
414444

415-
// Start services of providers.
416-
if service, ok := p.(provider.Service); ok {
417-
command.logger.Info("starting services for provider", "provider", p.ID())
418-
shutdown, err := service.Start(command.logger)
419-
if err != nil {
420-
command.logger.Panicf("ERROR: could not start services for provider %q: %s", p, err)
421-
}
422-
serviceShutdown = append(serviceShutdown, shutdown)
445+
providerID, modelIDsWithAttributes, ok := strings.Cut(modelIDsWithProviderAndAttributes, provider.ProviderModelSeparator)
446+
if !ok {
447+
command.logger.Panicf("ERROR: cannot split %q into provider and model name by %q", modelIDsWithProviderAndAttributes, provider.ProviderModelSeparator)
423448
}
424449

425-
// Check if a provider has the ability to pull models and do so if necessary.
450+
modelID, _ := model.ParseModelID(modelIDsWithAttributes)
451+
452+
p, ok := providers[providerID]
453+
if !ok {
454+
command.logger.Panicf("ERROR: cannot find provider %q", providerID)
455+
}
426456
if puller, ok := p.(provider.Puller); ok {
427-
command.logger.Info("pulling available models for provider", "provider", p.ID())
428-
for _, modelID := range command.ModelIDsWithAttributes {
429-
if !strings.HasPrefix(modelID, p.ID()) { // TODO Move this into `NewModel` to validate that a model belongs to a provider.
430-
panic(fmt.Errorf("model %s does not belong to provider %s", modelID, p.ID()))
457+
command.logger.Info("pulling model", "model", modelID)
458+
if err := puller.Pull(command.logger, modelID); err != nil {
459+
command.logger.Panicf("ERROR: could not pull model %q: %s", modelID, err)
460+
}
461+
462+
// TODO If a model has not been pulled before, it was not available for at least the "Ollama" provider. Make this cleaner, we should not rebuild every time.
463+
if _, ok := models[modelIDsWithProviderAndAttributes]; !ok {
464+
ms, err := p.Models()
465+
if err != nil {
466+
command.logger.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err)
431467
}
468+
for _, m := range ms {
469+
if _, ok := models[m.ID()]; ok {
470+
continue
471+
}
432472

433-
if err := puller.Pull(command.logger, modelID); err != nil {
434-
command.logger.Panicf("ERROR: could not pull model %q: %s", modelID, err)
473+
models[m.ID()] = m
474+
evaluationConfiguration.Models.Available = append(evaluationConfiguration.Models.Available, m.ID())
435475
}
476+
modelIDs = maps.Keys(models)
477+
sort.Strings(modelIDs)
436478
}
437479
}
438480

439-
ms, err := p.Models()
440-
if err != nil {
441-
command.logger.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err)
442-
}
481+
var m model.Model
482+
if strings.HasPrefix(providerID, "custom-") {
483+
pc, ok := p.(*openaiapi.Provider)
484+
if !ok {
485+
command.logger.Panicf("ERROR: %q is not a custom provider", providerID)
486+
}
443487

444-
for _, m := range ms {
445-
models[m.ID()] = m
446-
evaluationContext.ProviderForModel[m] = p
447-
evaluationConfiguration.Models.Available = append(evaluationConfiguration.Models.Available, m.ID())
448-
}
449-
}
450-
modelIDs := maps.Keys(models)
451-
sort.Strings(modelIDs)
452-
if len(command.ModelIDsWithAttributes) == 0 {
453-
command.ModelIDsWithAttributes = modelIDs
454-
} else {
455-
for _, modelID := range command.ModelIDsWithAttributes {
456-
if _, ok := models[modelID]; !ok {
457-
command.logger.Panicf("ERROR: model %s does not exist. Valid models are: %s", modelID, strings.Join(modelIDs, ", "))
488+
m = llm.NewModel(pc, modelIDsWithProviderAndAttributes)
489+
pc.AddModel(m)
490+
} else {
491+
var ok bool
492+
m, ok = models[modelIDsWithProviderAndAttributes]
493+
if !ok {
494+
command.logger.Panicf("ERROR: model %q does not exist for provider %q. Valid models are: %s", modelIDsWithProviderAndAttributes, providerID, strings.Join(modelIDs, ", "))
458495
}
459496
}
460-
}
461-
sort.Strings(command.ModelIDsWithAttributes)
462-
for _, modelID := range command.ModelIDsWithAttributes {
463-
modelsSelected[modelID] = models[modelID]
464-
}
465-
466-
// Make the resolved selected models available in the command.
467-
evaluationContext.Models = make([]model.Model, len(command.ModelIDsWithAttributes))
468-
for i, modelID := range command.ModelIDsWithAttributes {
469-
evaluationContext.Models[i] = modelsSelected[modelID]
470-
evaluationConfiguration.Models.Selected = append(evaluationConfiguration.Models.Selected, modelID)
497+
evaluationContext.Models = append(evaluationContext.Models, m)
498+
evaluationContext.ProviderForModel[m] = p
499+
evaluationConfiguration.Models.Selected = append(evaluationConfiguration.Models.Selected, modelIDsWithProviderAndAttributes)
471500
}
472501
}
473502

cmd/eval-dev-quality/cmd/evaluate_test.go

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -462,10 +462,10 @@ func TestEvaluateExecute(t *testing.T) {
462462

463463
ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){
464464
filepath.Join("result-directory", "evaluation.log"): func(t *testing.T, filePath string, data string) {
465-
assert.Contains(t, data, "\"msg\":\"skipping unavailable provider\",\"provider\":\"openrouter\"")
465+
assert.Contains(t, data, `"msg":"skipping provider because it is not available","error":"missing access token","provider":"openrouter"`)
466466
},
467467
},
468-
ExpectedPanicContains: "ERROR: model openrouter/auto does not exist",
468+
ExpectedPanicContains: `ERROR: cannot find provider "openrouter"`,
469469
})
470470
})
471471
t.Run("Ollama", func(t *testing.T) {
@@ -1271,13 +1271,17 @@ func TestEvaluateInitialize(t *testing.T) {
12711271
// makeValidCommand is a helper to abstract all the default values that have to be set to make a command valid.
12721272
makeValidCommand := func(modify func(command *Evaluate)) *Evaluate {
12731273
c := &Evaluate{
1274+
ModelIDsWithProviderAndAttributes: []string{"symflower/smart-template"},
1275+
QueryAttempts: 1,
1276+
1277+
ResultPath: filepath.Join("$TEMP_PATH", "result-directory"),
1278+
TestdataPath: filepath.Join("..", "..", "..", "testdata"),
1279+
12741280
ExecutionTimeout: 1,
1275-
Parallel: 1,
1276-
QueryAttempts: 1,
1277-
ResultPath: filepath.Join("$TEMP_PATH", "result-directory"),
12781281
Runs: 1,
1279-
Runtime: "local",
1280-
TestdataPath: filepath.Join("..", "..", "..", "testdata"),
1282+
1283+
Runtime: "local",
1284+
Parallel: 1,
12811285
}
12821286

12831287
if modify != nil {
@@ -1325,10 +1329,14 @@ func TestEvaluateInitialize(t *testing.T) {
13251329

13261330
Command: makeValidCommand(func(command *Evaluate) {
13271331
command.ModelIDsWithProviderAndAttributes = []string{}
1332+
command.ProviderTokens = map[string]string{
1333+
"openrouter": "fake-token",
1334+
}
13281335
}),
13291336

13301337
// Could also select arbitrary Ollama or new Openrouter models so sanity check that at least symflower is there.
13311338
ValidateCommand: func(t *testing.T, command *Evaluate) {
1339+
assert.Contains(t, command.ModelIDsWithProviderAndAttributes, "symflower/smart-template")
13321340
assert.Contains(t, command.ModelIDsWithProviderAndAttributes, "symflower/symbolic-execution")
13331341
},
13341342
ValidateContext: func(t *testing.T, context *evaluate.Context) {
@@ -1448,6 +1456,7 @@ func TestEvaluateInitialize(t *testing.T) {
14481456

14491457
Command: makeValidCommand(func(command *Evaluate) {
14501458
command.Configuration = "config.json"
1459+
command.ModelIDsWithProviderAndAttributes = nil
14511460
}),
14521461

14531462
ValidateCommand: func(t *testing.T, command *Evaluate) {
@@ -1513,7 +1522,7 @@ func TestEvaluateInitialize(t *testing.T) {
15131522
}
15141523

15151524
validate(t, &testCase{
1516-
Name: "Parallel parameter hast to be greater then zero",
1525+
Name: "Parallel parameter has to be greater then zero",
15171526

15181527
Command: makeValidCommand(func(command *Evaluate) {
15191528
command.Runtime = "docker"

evaluate/evaluate_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ func TestEvaluate(t *testing.T) {
244244

245245
Before: func(t *testing.T, logger *log.Logger, resultPath string) {
246246
// Set up mocks, when test is running.
247-
mockedQuery.On("Query", mock.Anything, mockedModelID, mock.Anything).Return("", ErrEmptyResponseFromModel)
247+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("", ErrEmptyResponseFromModel)
248248
},
249249
After: func(t *testing.T, logger *log.Logger, resultPath string) {
250250
mockedQuery.AssertNumberOfCalls(t, "Query", 2)
@@ -325,10 +325,10 @@ func TestEvaluate(t *testing.T) {
325325

326326
Before: func(t *testing.T, logger *log.Logger, resultPath string) {
327327
// Set up mocks, when test is running.
328-
mockedQuery.On("Query", mock.Anything, mockedModelID, mock.Anything).Return("", ErrEmptyResponseFromModel).Once()
329-
mockedQuery.On("Query", mock.Anything, mockedModelID, mock.Anything).Return("model-response", nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
330-
mockedQuery.On("Query", mock.Anything, mockedModelID, mock.Anything).Return("", ErrEmptyResponseFromModel).Once()
331-
mockedQuery.On("Query", mock.Anything, mockedModelID, mock.Anything).Return("model-response", nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
328+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("", ErrEmptyResponseFromModel).Once()
329+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("model-response", nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
330+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("", ErrEmptyResponseFromModel).Once()
331+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("model-response", nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
332332
},
333333
After: func(t *testing.T, logger *log.Logger, resultPath string) {
334334
mockedQuery.AssertNumberOfCalls(t, "Query", 4)
@@ -424,7 +424,7 @@ func TestEvaluate(t *testing.T) {
424424

425425
Before: func(t *testing.T, logger *log.Logger, resultPath string) {
426426
// Set up mocks, when test is running.
427-
mockedQuery.On("Query", mock.Anything, mockedModelID, mock.Anything).Return("model-response", nil).After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
427+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("model-response", nil).After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
428428
},
429429
After: func(t *testing.T, logger *log.Logger, resultPath string) {
430430
mockedQuery.AssertNumberOfCalls(t, "Query", 2)

0 commit comments

Comments
 (0)