@@ -341,133 +341,162 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
341341 evaluationContext .Languages [i ] = languagesSelected [languageID ]
342342 }
343343
344- // Register custom OpenAI API providers and models.
344+ // Gather models and initialize providers.
345+ var serviceShutdown []func () (err error )
345346 {
346- customProviders := map [string ]* openaiapi.Provider {}
347- for providerID , providerURL := range command .ProviderUrls {
348- if ! strings .HasPrefix (providerID , "custom-" ) {
349- continue
347+ // Gather providers.
348+ providers := map [string ]provider.Provider {}
349+ if len (command .ModelIDsWithProviderAndAttributes ) == 0 {
350+ for providerID , provider := range provider .Providers {
351+ providers [providerID ] = provider
352+ command .logger .Info ("selected provider" , "provider" , providerID )
350353 }
354+ } else {
355+ // Register custom providers.
356+ for providerID , providerURL := range command .ProviderUrls {
357+ if ! strings .HasPrefix (providerID , "custom-" ) {
358+ command .logger .Panicf ("ERROR: cannot set URL of %q because it is not a custom provider" , providerID )
359+ }
351360
352- p := openaiapi .NewProvider (providerID , providerURL )
353- provider .Register (p )
354- customProviders [providerID ] = p
355- }
356- for _ , model := range command .ModelIDsWithAttributes {
357- if ! strings .HasPrefix (model , "custom-" ) {
358- continue
361+ p := openaiapi .NewProvider (providerID , providerURL )
362+ provider .Register (p )
363+ providers [providerID ] = p
364+ command .logger .Info ("selected provider" , "provider" , providerID )
359365 }
360366
361- providerID , _ , ok := strings .Cut (model , provider .ProviderModelSeparator )
362- if ! ok {
363- command .logger .Panicf ("ERROR: cannot split %q into provider and model name by %q" , model , provider .ProviderModelSeparator )
364- }
365- modelProvider , ok := customProviders [providerID ]
366- if ! ok {
367- command .logger .Panicf ("ERROR: unknown custom provider %q for model %q" , providerID , model )
368- }
367+ // Add remaining providers from models.
368+ for _ , modelIDsWithProviderAndAttributes := range command .ModelIDsWithProviderAndAttributes {
369+ providerID , _ , ok := strings .Cut (modelIDsWithProviderAndAttributes , provider .ProviderModelSeparator )
370+ if ! ok {
371+ command .logger .Panicf ("ERROR: cannot split %q into provider and model name by %q" , modelIDsWithProviderAndAttributes , provider .ProviderModelSeparator )
372+ }
369373
370- modelProvider .AddModel (llm .NewModel (modelProvider , model ))
374+ p , ok := provider .Providers [providerID ]
375+ if ! ok {
376+ command .logger .Panicf ("ERROR: unknown provider %q for model %q" , providerID , modelIDsWithProviderAndAttributes )
377+ }
378+ if _ , ok := providers [providerID ]; ! ok {
379+ providers [providerID ] = p
380+ command .logger .Info ("selected provider" , "provider" , providerID )
381+ }
382+ }
371383 }
372- }
373384
374- // Gather models.
375- var serviceShutdown []func () (err error )
376- {
377- // Check which providers are needed for the evaluation.
378- providersSelected := map [string ]provider.Provider {}
379- if len (command .ModelIDsWithAttributes ) == 0 {
380- providersSelected = provider .Providers
381- } else {
382- for _ , model := range command .ModelIDsWithAttributes {
383- p := strings .SplitN (model , provider .ProviderModelSeparator , 2 )[0 ]
385+ // Initialize providers.
386+ {
387+ providerIDsSorted := maps .Keys (providers )
388+ sort .Strings (providerIDsSorted )
389+ for _ , providerID := range providerIDsSorted {
390+ p := providers [providerID ]
384391
385- if _ , ok := providersSelected [p ]; ok {
386- continue
392+ command .logger .Info ("initializing provider" , "provider" , providerID )
393+ if t , ok := p .(provider.InjectToken ); ok {
394+ if token , ok := command .ProviderTokens [p .ID ()]; ok {
395+ command .logger .Info ("set token of provider" , "provider" , providerID )
396+ t .SetToken (token )
397+ }
387398 }
399+ command .logger .Info ("checking availability for provider" , "provider" , providerID )
400+ if err := p .Available (command .logger ); err != nil {
401+ command .logger .Info ("skipping provider because it is not available" , "error" , err , "provider" , providerID )
402+ delete (providers , providerID )
388403
389- if provider , ok := provider .Providers [p ]; ! ok {
390- command .logger .Panicf ("Provider %q does not exist" , p )
391- } else {
392- providersSelected [provider .ID ()] = provider
404+ continue
405+ }
406+ if service , ok := p .(provider.Service ); ok {
407+ command .logger .Info ("starting services for provider" , "provider" , p .ID ())
408+ shutdown , err := service .Start (command .logger )
409+ if err != nil {
410+ command .logger .Panicf ("ERROR: could not start services for provider %q: %s" , p , err )
411+ }
412+ serviceShutdown = append (serviceShutdown , shutdown )
393413 }
394414 }
395415 }
396416
417+ // Gather models.
397418 models := map [string ]model.Model {}
398- modelsSelected := map [string ]model.Model {}
399- evaluationContext .ProviderForModel = map [model.Model ]provider.Provider {}
400- for _ , p := range providersSelected {
401- command .logger .Info ("querying provider models" , "provider" , p .ID ())
419+ {
420+ addAllModels := len (command .ModelIDsWithProviderAndAttributes ) == 0
421+ for _ , p := range providers {
422+ ms , err := p .Models ()
423+ if err != nil {
424+ command .logger .Panicf ("ERROR: could not query models for provider %q: %s" , p .ID (), err )
425+ }
426+ for _ , m := range ms {
427+ models [m .ID ()] = m
428+ evaluationConfiguration .Models .Available = append (evaluationConfiguration .Models .Available , m .ID ())
402429
403- if t , ok := p .(provider.InjectToken ); ok {
404- token , ok := command .ProviderTokens [p .ID ()]
405- if ok {
406- t .SetToken (token )
430+ if addAllModels {
431+ command .ModelIDsWithProviderAndAttributes = append (command .ModelIDsWithProviderAndAttributes , m .ID ())
432+ }
407433 }
408434 }
409- if err := p .Available (command .logger ); err != nil {
410- command .logger .Warn ("skipping unavailable provider" , "provider" , p .ID (), "error" , err )
435+ }
436+ modelIDs := maps .Keys (models )
437+ sort .Strings (modelIDs )
438+ sort .Strings (command .ModelIDsWithProviderAndAttributes )
411439
412- continue
413- }
440+ // Check and initialize models.
441+ evaluationContext .ProviderForModel = map [model.Model ]provider.Provider {}
442+ for _ , modelIDsWithProviderAndAttributes := range command .ModelIDsWithProviderAndAttributes {
443+ command .logger .Info ("selecting model" , "model" , modelIDsWithProviderAndAttributes )
414444
415- // Start services of providers.
416- if service , ok := p .(provider.Service ); ok {
417- command .logger .Info ("starting services for provider" , "provider" , p .ID ())
418- shutdown , err := service .Start (command .logger )
419- if err != nil {
420- command .logger .Panicf ("ERROR: could not start services for provider %q: %s" , p , err )
421- }
422- serviceShutdown = append (serviceShutdown , shutdown )
445+ providerID , modelIDsWithAttributes , ok := strings .Cut (modelIDsWithProviderAndAttributes , provider .ProviderModelSeparator )
446+ if ! ok {
447+ command .logger .Panicf ("ERROR: cannot split %q into provider and model name by %q" , modelIDsWithProviderAndAttributes , provider .ProviderModelSeparator )
423448 }
424449
425- // Check if a provider has the ability to pull models and do so if necessary.
450+ modelID , _ := model .ParseModelID (modelIDsWithAttributes )
451+
452+ p , ok := providers [providerID ]
453+ if ! ok {
454+ command .logger .Panicf ("ERROR: cannot find provider %q" , providerID )
455+ }
426456 if puller , ok := p .(provider.Puller ); ok {
427- command .logger .Info ("pulling available models for provider" , "provider" , p .ID ())
428- for _ , modelID := range command .ModelIDsWithAttributes {
429- if ! strings .HasPrefix (modelID , p .ID ()) { // TODO Move this into `NewModel` to validate that a model belongs to a provider.
430- panic (fmt .Errorf ("model %s does not belong to provider %s" , modelID , p .ID ()))
457+ command .logger .Info ("pulling model" , "model" , modelID )
458+ if err := puller .Pull (command .logger , modelID ); err != nil {
459+ command .logger .Panicf ("ERROR: could not pull model %q: %s" , modelID , err )
460+ }
461+
462+ // TODO If a model has not been pulled before, it was not available for at least the "Ollama" provider. Make this cleaner, we should not rebuild every time.
463+ if _ , ok := models [modelIDsWithProviderAndAttributes ]; ! ok {
464+ ms , err := p .Models ()
465+ if err != nil {
466+ command .logger .Panicf ("ERROR: could not query models for provider %q: %s" , p .ID (), err )
431467 }
468+ for _ , m := range ms {
469+ if _ , ok := models [m .ID ()]; ok {
470+ continue
471+ }
432472
433- if err := puller . Pull ( command . logger , modelID ); err != nil {
434- command . logger . Panicf ( "ERROR: could not pull model %q: %s" , modelID , err )
473+ models [ m . ID ()] = m
474+ evaluationConfiguration . Models . Available = append ( evaluationConfiguration . Models . Available , m . ID () )
435475 }
476+ modelIDs = maps .Keys (models )
477+ sort .Strings (modelIDs )
436478 }
437479 }
438480
439- ms , err := p .Models ()
440- if err != nil {
441- command .logger .Panicf ("ERROR: could not query models for provider %q: %s" , p .ID (), err )
442- }
481+ var m model.Model
482+ if strings .HasPrefix (providerID , "custom-" ) {
483+ pc , ok := p .(* openaiapi.Provider )
484+ if ! ok {
485+ command .logger .Panicf ("ERROR: %q is not a custom provider" , providerID )
486+ }
443487
444- for _ , m := range ms {
445- models [m .ID ()] = m
446- evaluationContext .ProviderForModel [m ] = p
447- evaluationConfiguration .Models .Available = append (evaluationConfiguration .Models .Available , m .ID ())
448- }
449- }
450- modelIDs := maps .Keys (models )
451- sort .Strings (modelIDs )
452- if len (command .ModelIDsWithAttributes ) == 0 {
453- command .ModelIDsWithAttributes = modelIDs
454- } else {
455- for _ , modelID := range command .ModelIDsWithAttributes {
456- if _ , ok := models [modelID ]; ! ok {
457- command .logger .Panicf ("ERROR: model %s does not exist. Valid models are: %s" , modelID , strings .Join (modelIDs , ", " ))
488+ m = llm .NewModel (pc , modelIDsWithProviderAndAttributes )
489+ pc .AddModel (m )
490+ } else {
491+ var ok bool
492+ m , ok = models [modelIDsWithProviderAndAttributes ]
493+ if ! ok {
494+ command .logger .Panicf ("ERROR: model %q does not exist for provider %q. Valid models are: %s" , modelIDsWithProviderAndAttributes , providerID , strings .Join (modelIDs , ", " ))
458495 }
459496 }
460- }
461- sort .Strings (command .ModelIDsWithAttributes )
462- for _ , modelID := range command .ModelIDsWithAttributes {
463- modelsSelected [modelID ] = models [modelID ]
464- }
465-
466- // Make the resolved selected models available in the command.
467- evaluationContext .Models = make ([]model.Model , len (command .ModelIDsWithAttributes ))
468- for i , modelID := range command .ModelIDsWithAttributes {
469- evaluationContext .Models [i ] = modelsSelected [modelID ]
470- evaluationConfiguration .Models .Selected = append (evaluationConfiguration .Models .Selected , modelID )
497+ evaluationContext .Models = append (evaluationContext .Models , m )
498+ evaluationContext .ProviderForModel [m ] = p
499+ evaluationConfiguration .Models .Selected = append (evaluationConfiguration .Models .Selected , modelIDsWithProviderAndAttributes )
471500 }
472501 }
473502
0 commit comments