@@ -53,8 +53,8 @@ type Evaluate struct {
5353
5454 // Languages determines which language should be used for the evaluation, or empty if all languages should be used.
5555 Languages []string `long:"language" description:"Evaluate with this language. By default all languages are used."`
56- // Models determines which models should be used for the evaluation, or empty if all models should be used.
57- Models []string `long:"model" description:"Evaluate with this model. By default all models are used."`
56+ // ModelIDsWithProviderAndAttributes determines which models should be used for the evaluation, or empty if all models should be used.
57+ ModelIDsWithProviderAndAttributes []string `long:"model" description:"Evaluate with this model. By default all models are used."`
5858 // ProviderTokens holds all API tokens for the providers.
5959 ProviderTokens map [string ]string `long:"tokens" description:"API tokens for model providers (of the form '$provider:$token'). When using the environment variable, separate multiple definitions with ','." env:"PROVIDER_TOKEN" env-delim:","`
6060 // ProviderUrls holds all custom inference endpoint urls for the providers.
@@ -123,7 +123,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
123123 command .logger .Panicf ("the configuration file is not supported in containerized runtimes" )
124124 }
125125
126- if len (command .Models ) > 0 || len (command .Repositories ) > 0 {
126+ if len (command .ModelIDsWithProviderAndAttributes ) > 0 || len (command .Repositories ) > 0 {
127127 command .logger .Panicf ("do not provide models and repositories when loading a configuration file" )
128128 }
129129
@@ -139,7 +139,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
139139 command .logger .Panicf ("ERROR: %s" , err )
140140 }
141141
142- command .Models = configuration .Models .Selected
142+ command .ModelIDsWithProviderAndAttributes = configuration .Models .Selected
143143 command .Repositories = configuration .Repositories .Selected
144144 }
145145
@@ -258,43 +258,13 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
258258 // In a containerized runtime we check the availability of the testdata, repositories and models/providers inside the container.
259259 if command .Runtime != "local" {
260260 // Copy the models over.
261- for _ , modelID := range command .Models {
261+ for _ , modelID := range command .ModelIDsWithProviderAndAttributes {
262262 evaluationContext .Models = append (evaluationContext .Models , llm .NewModel (nil , modelID ))
263263 }
264264
265265 return evaluationContext , evaluationConfiguration , func () {}
266266 }
267267
268- // Register custom OpenAI API providers and models.
269- {
270- customProviders := map [string ]* openaiapi.Provider {}
271- for providerID , providerURL := range command .ProviderUrls {
272- if ! strings .HasPrefix (providerID , "custom-" ) {
273- continue
274- }
275-
276- p := openaiapi .NewProvider (providerID , providerURL )
277- provider .Register (p )
278- customProviders [providerID ] = p
279- }
280- for _ , model := range command .Models {
281- if ! strings .HasPrefix (model , "custom-" ) {
282- continue
283- }
284-
285- providerID , _ , ok := strings .Cut (model , provider .ProviderModelSeparator )
286- if ! ok {
287- command .logger .Panicf ("ERROR: cannot split %q into provider and model name by %q" , model , provider .ProviderModelSeparator )
288- }
289- modelProvider , ok := customProviders [providerID ]
290- if ! ok {
291- command .logger .Panicf ("ERROR: unknown custom provider %q for model %q" , providerID , model )
292- }
293-
294- modelProvider .AddModel (llm .NewModel (modelProvider , model ))
295- }
296- }
297-
298268 // Ensure the "testdata" path exists and make it absolute.
299269 {
300270 if err := osutil .DirExists (command .TestdataPath ); err != nil {
@@ -371,101 +341,162 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
371341 evaluationContext .Languages [i ] = languagesSelected [languageID ]
372342 }
373343
374- // Gather models.
375- serviceShutdown := []func () (err error ){}
344+ // Gather models and initialize providers .
345+ var serviceShutdown []func () (err error )
376346 {
377- // Check which providers are needed for the evaluation.
378- providersSelected := map [string ]provider.Provider {}
379- if len (command .Models ) == 0 {
380- providersSelected = provider .Providers
347+ // Gather providers.
348+ providers := map [string ]provider.Provider {}
349+ if len (command .ModelIDsWithProviderAndAttributes ) == 0 {
350+ for providerID , provider := range provider .Providers {
351+ providers [providerID ] = provider
352+ command .logger .Info ("selected provider" , "provider" , providerID )
353+ }
381354 } else {
382- for _ , model := range command .Models {
383- p := strings .SplitN (model , provider .ProviderModelSeparator , 2 )[0 ]
355+ // Register custom providers.
356+ for providerID , providerURL := range command .ProviderUrls {
357+ if ! strings .HasPrefix (providerID , "custom-" ) {
358+ command .logger .Panicf ("ERROR: cannot set URL of %q because it is not a custom provider" , providerID )
359+ }
384360
385- if _ , ok := providersSelected [p ]; ok {
386- continue
361+ p := openaiapi .NewProvider (providerID , providerURL )
362+ provider .Register (p )
363+ providers [providerID ] = p
364+ command .logger .Info ("selected provider" , "provider" , providerID )
365+ }
366+
367+ // Add remaining providers from models.
368+ for _ , modelIDsWithProviderAndAttributes := range command .ModelIDsWithProviderAndAttributes {
369+ providerID , _ , ok := strings .Cut (modelIDsWithProviderAndAttributes , provider .ProviderModelSeparator )
370+ if ! ok {
371+ command .logger .Panicf ("ERROR: cannot split %q into provider and model name by %q" , modelIDsWithProviderAndAttributes , provider .ProviderModelSeparator )
387372 }
388373
389- if provider , ok := provider .Providers [p ]; ! ok {
390- command .logger .Panicf ("Provider %q does not exist" , p )
391- } else {
392- providersSelected [provider .ID ()] = provider
374+ p , ok := provider .Providers [providerID ]
375+ if ! ok {
376+ command .logger .Panicf ("ERROR: unknown provider %q for model %q" , providerID , modelIDsWithProviderAndAttributes )
377+ }
378+ if _ , ok := providers [providerID ]; ! ok {
379+ providers [providerID ] = p
380+ command .logger .Info ("selected provider" , "provider" , providerID )
393381 }
394382 }
395383 }
396384
397- models := map [string ]model.Model {}
398- modelsSelected := map [string ]model.Model {}
399- evaluationContext .ProviderForModel = map [model.Model ]provider.Provider {}
400- for _ , p := range providersSelected {
401- command .logger .Info ("querying provider models" , "provider" , p .ID ())
385+ // Initialize providers.
386+ {
387+ providerIDsSorted := maps .Keys (providers )
388+ sort .Strings (providerIDsSorted )
389+ for _ , providerID := range providerIDsSorted {
390+ p := providers [providerID ]
402391
403- if t , ok := p .(provider.InjectToken ); ok {
404- token , ok := command .ProviderTokens [p .ID ()]
405- if ok {
406- t .SetToken (token )
392+ command .logger .Info ("initializing provider" , "provider" , providerID )
393+ if t , ok := p .(provider.InjectToken ); ok {
394+ if token , ok := command .ProviderTokens [p .ID ()]; ok {
395+ command .logger .Info ("set token of provider" , "provider" , providerID )
396+ t .SetToken (token )
397+ }
407398 }
408- }
409- if err := p .Available (command .logger ); err != nil {
410- command .logger .Warn ("skipping unavailable provider" , "provider" , p .ID (), "error" , err )
399+ command .logger .Info ("checking availability for provider" , "provider" , providerID )
400+ if err := p .Available (command .logger ); err != nil {
401+ command .logger .Info ("skipping provider because it is not available" , "error" , err , "provider" , providerID )
402+ delete (providers , providerID )
411403
412- continue
404+ continue
405+ }
406+ if service , ok := p .(provider.Service ); ok {
407+ command .logger .Info ("starting services for provider" , "provider" , p .ID ())
408+ shutdown , err := service .Start (command .logger )
409+ if err != nil {
410+ command .logger .Panicf ("ERROR: could not start services for provider %q: %s" , p , err )
411+ }
412+ serviceShutdown = append (serviceShutdown , shutdown )
413+ }
413414 }
415+ }
414416
415- // Start services of providers.
416- if service , ok := p .(provider.Service ); ok {
417- command .logger .Info ("starting services for provider" , "provider" , p .ID ())
418- shutdown , err := service .Start (command .logger )
417+ // Gather models.
418+ models := map [string ]model.Model {}
419+ {
420+ addAllModels := len (command .ModelIDsWithProviderAndAttributes ) == 0
421+ for _ , p := range providers {
422+ ms , err := p .Models ()
419423 if err != nil {
420- command .logger .Panicf ("ERROR: could not start services for provider %q: %s" , p , err )
424+ command .logger .Panicf ("ERROR: could not query models for provider %q: %s" , p . ID () , err )
421425 }
422- serviceShutdown = append (serviceShutdown , shutdown )
423- }
426+ for _ , m := range ms {
427+ models [m .ID ()] = m
428+ evaluationConfiguration .Models .Available = append (evaluationConfiguration .Models .Available , m .ID ())
424429
425- // Check if a provider has the ability to pull models and do so if necessary.
426- if puller , ok := p .(provider.Puller ); ok {
427- command .logger .Info ("pulling available models for provider" , "provider" , p .ID ())
428- for _ , modelID := range command .Models {
429- if strings .HasPrefix (modelID , p .ID ()) {
430- if err := puller .Pull (command .logger , modelID ); err != nil {
431- command .logger .Panicf ("ERROR: could not pull model %q: %s" , modelID , err )
432- }
430+ if addAllModels {
431+ command .ModelIDsWithProviderAndAttributes = append (command .ModelIDsWithProviderAndAttributes , m .ID ())
433432 }
434433 }
435434 }
435+ }
436+ modelIDs := maps .Keys (models )
437+ sort .Strings (modelIDs )
438+ sort .Strings (command .ModelIDsWithProviderAndAttributes )
436439
437- ms , err := p .Models ()
438- if err != nil {
439- command .logger .Panicf ("ERROR: could not query models for provider %q: %s" , p .ID (), err )
440+ // Check and initialize models.
441+ evaluationContext .ProviderForModel = map [model.Model ]provider.Provider {}
442+ for _ , modelIDsWithProviderAndAttributes := range command .ModelIDsWithProviderAndAttributes {
443+ command .logger .Info ("selecting model" , "model" , modelIDsWithProviderAndAttributes )
444+
445+ providerID , modelIDsWithAttributes , ok := strings .Cut (modelIDsWithProviderAndAttributes , provider .ProviderModelSeparator )
446+ if ! ok {
447+ command .logger .Panicf ("ERROR: cannot split %q into provider and model name by %q" , modelIDsWithProviderAndAttributes , provider .ProviderModelSeparator )
440448 }
441449
442- for _ , m := range ms {
443- models [m .ID ()] = m
444- evaluationContext .ProviderForModel [m ] = p
445- evaluationConfiguration .Models .Available = append (evaluationConfiguration .Models .Available , m .ID ())
450+ modelID , _ := model .ParseModelID (modelIDsWithAttributes )
451+
452+ p , ok := providers [providerID ]
453+ if ! ok {
454+ command .logger .Panicf ("ERROR: cannot find provider %q" , providerID )
446455 }
447- }
448- modelIDs := maps .Keys (models )
449- sort .Strings (modelIDs )
450- if len (command .Models ) == 0 {
451- command .Models = modelIDs
452- } else {
453- for _ , modelID := range command .Models {
454- if _ , ok := models [modelID ]; ! ok {
455- command .logger .Panicf ("ERROR: model %s does not exist. Valid models are: %s" , modelID , strings .Join (modelIDs , ", " ))
456+ if puller , ok := p .(provider.Puller ); ok {
457+ command .logger .Info ("pulling model" , "model" , modelID )
458+ if err := puller .Pull (command .logger , modelID ); err != nil {
459+ command .logger .Panicf ("ERROR: could not pull model %q: %s" , modelID , err )
460+ }
461+
462+ // TODO If a model has not been pulled before, it was not available for at least the "Ollama" provider. Make this cleaner, we should not rebuild every time.
463+ if _ , ok := models [modelIDsWithProviderAndAttributes ]; ! ok {
464+ ms , err := p .Models ()
465+ if err != nil {
466+ command .logger .Panicf ("ERROR: could not query models for provider %q: %s" , p .ID (), err )
467+ }
468+ for _ , m := range ms {
469+ if _ , ok := models [m .ID ()]; ok {
470+ continue
471+ }
472+
473+ models [m .ID ()] = m
474+ evaluationConfiguration .Models .Available = append (evaluationConfiguration .Models .Available , m .ID ())
475+ }
476+ modelIDs = maps .Keys (models )
477+ sort .Strings (modelIDs )
456478 }
457479 }
458- }
459- sort .Strings (command .Models )
460- for _ , modelID := range command .Models {
461- modelsSelected [modelID ] = models [modelID ]
462- }
463480
464- // Make the resolved selected models available in the command.
465- evaluationContext .Models = make ([]model.Model , len (command .Models ))
466- for i , modelID := range command .Models {
467- evaluationContext .Models [i ] = modelsSelected [modelID ]
468- evaluationConfiguration .Models .Selected = append (evaluationConfiguration .Models .Selected , modelID )
481+ var m model.Model
482+ if strings .HasPrefix (providerID , "custom-" ) {
483+ pc , ok := p .(* openaiapi.Provider )
484+ if ! ok {
485+ command .logger .Panicf ("ERROR: %q is not a custom provider" , providerID )
486+ }
487+
488+ m = llm .NewModel (pc , modelIDsWithProviderAndAttributes )
489+ pc .AddModel (m )
490+ } else {
491+ var ok bool
492+ m , ok = models [modelIDsWithProviderAndAttributes ]
493+ if ! ok {
494+ command .logger .Panicf ("ERROR: model %q does not exist for provider %q. Valid models are: %s" , modelIDsWithProviderAndAttributes , providerID , strings .Join (modelIDs , ", " ))
495+ }
496+ }
497+ evaluationContext .Models = append (evaluationContext .Models , m )
498+ evaluationContext .ProviderForModel [m ] = p
499+ evaluationConfiguration .Models .Selected = append (evaluationConfiguration .Models .Selected , modelIDsWithProviderAndAttributes )
469500 }
470501 }
471502
@@ -613,7 +644,7 @@ func (command *Evaluate) evaluateDocker(ctx *evaluate.Context) (err error) {
613644 "-e" , "SYMFLOWER_INTERNAL_LICENSE_FILE" ,
614645 "-e" , "SYMFLOWER_LICENSE_KEY" ,
615646 "-v" , volumeName + ":/app/evaluation" ,
616- "--rm" , // automatically remove container after it finished
647+ "--rm" , // Automatically remove container after it finished.
617648 command .RuntimeImage ,
618649 }
619650
@@ -706,7 +737,7 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) {
706737 // Define a regex to replace all non alphanumeric characters and "-".
707738 kubeNameRegex := regexp .MustCompile (`[^a-zA-Z0-9-]+` )
708739
709- jobTmpl , err := template .ParseFiles (filepath .Join ("conf" , "kube" , "job.yml" ))
740+ kubernetesJobTemplate , err := template .ParseFiles (filepath .Join ("conf" , "kube" , "job.yml" ))
710741 if err != nil {
711742 return pkgerrors .Wrap (err , "could not create kubernetes job template" )
712743 }
@@ -735,7 +766,7 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) {
735766 "kubectl" ,
736767 "apply" ,
737768 "-f" ,
738- "-" , // apply STDIN
769+ "-" , // Apply STDIN.
739770 }
740771
741772 // Commands for the evaluation to run inside the container.
@@ -763,14 +794,14 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) {
763794 }
764795
765796 parallel .Execute (func () {
766- var tmplData bytes.Buffer
767- if err := jobTmpl .Execute (& tmplData , data ); err != nil {
797+ var kubernetesJobData bytes.Buffer
798+ if err := kubernetesJobTemplate .Execute (& kubernetesJobData , data ); err != nil {
768799 command .logger .Panicf ("ERROR: %s" , err )
769800 }
770801
771802 commandOutput , err := util .CommandWithResult (context .Background (), command .logger , & util.Command {
772803 Command : kubeCommand ,
773- Stdin : tmplData .String (),
804+ Stdin : kubernetesJobData .String (),
774805 })
775806 if err != nil {
776807 command .logger .Error ("kubernetes evaluation failed" , "error" , pkgerrors .WithMessage (pkgerrors .WithStack (err ), commandOutput ))
@@ -830,7 +861,7 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) {
830861
831862 var storageTemplateData bytes.Buffer
832863 if err := storageTemplate .Execute (& storageTemplateData , data ); err != nil {
833- return pkgerrors .Wrap (err , "could not execute storate template" )
864+ return pkgerrors .Wrap (err , "could not execute storage template" )
834865 }
835866
836867 // Create the storage access pod.
@@ -839,7 +870,7 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) {
839870 "kubectl" ,
840871 "apply" ,
841872 "-f" ,
842- "-" , // apply STDIN
873+ "-" , // Apply STDIN.
843874 },
844875 Stdin : storageTemplateData .String (),
845876 })
0 commit comments