Skip to content

Commit d2683d5

Browse files
authored
Merge pull request #408 from symflower/openai-reasoning_effort
Allow to set reasoning_effort for models (e.g. OpenAI's o3-mini)
2 parents 5478ba6 + 4ad31d4 commit d2683d5

File tree

17 files changed

+365
-190
lines changed

17 files changed

+365
-190
lines changed

cmd/eval-dev-quality/cmd/evaluate.go

Lines changed: 142 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ type Evaluate struct {
5353

5454
// Languages determines which language should be used for the evaluation, or empty if all languages should be used.
5555
Languages []string `long:"language" description:"Evaluate with this language. By default all languages are used."`
56-
// Models determines which models should be used for the evaluation, or empty if all models should be used.
57-
Models []string `long:"model" description:"Evaluate with this model. By default all models are used."`
56+
// ModelIDsWithProviderAndAttributes determines which models should be used for the evaluation, or empty if all models should be used.
57+
ModelIDsWithProviderAndAttributes []string `long:"model" description:"Evaluate with this model. By default all models are used."`
5858
// ProviderTokens holds all API tokens for the providers.
5959
ProviderTokens map[string]string `long:"tokens" description:"API tokens for model providers (of the form '$provider:$token'). When using the environment variable, separate multiple definitions with ','." env:"PROVIDER_TOKEN" env-delim:","`
6060
// ProviderUrls holds all custom inference endpoint urls for the providers.
@@ -123,7 +123,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
123123
command.logger.Panicf("the configuration file is not supported in containerized runtimes")
124124
}
125125

126-
if len(command.Models) > 0 || len(command.Repositories) > 0 {
126+
if len(command.ModelIDsWithProviderAndAttributes) > 0 || len(command.Repositories) > 0 {
127127
command.logger.Panicf("do not provide models and repositories when loading a configuration file")
128128
}
129129

@@ -139,7 +139,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
139139
command.logger.Panicf("ERROR: %s", err)
140140
}
141141

142-
command.Models = configuration.Models.Selected
142+
command.ModelIDsWithProviderAndAttributes = configuration.Models.Selected
143143
command.Repositories = configuration.Repositories.Selected
144144
}
145145

@@ -258,43 +258,13 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
258258
// In a containerized runtime we check the availability of the testdata, repositories and models/providers inside the container.
259259
if command.Runtime != "local" {
260260
// Copy the models over.
261-
for _, modelID := range command.Models {
261+
for _, modelID := range command.ModelIDsWithProviderAndAttributes {
262262
evaluationContext.Models = append(evaluationContext.Models, llm.NewModel(nil, modelID))
263263
}
264264

265265
return evaluationContext, evaluationConfiguration, func() {}
266266
}
267267

268-
// Register custom OpenAI API providers and models.
269-
{
270-
customProviders := map[string]*openaiapi.Provider{}
271-
for providerID, providerURL := range command.ProviderUrls {
272-
if !strings.HasPrefix(providerID, "custom-") {
273-
continue
274-
}
275-
276-
p := openaiapi.NewProvider(providerID, providerURL)
277-
provider.Register(p)
278-
customProviders[providerID] = p
279-
}
280-
for _, model := range command.Models {
281-
if !strings.HasPrefix(model, "custom-") {
282-
continue
283-
}
284-
285-
providerID, _, ok := strings.Cut(model, provider.ProviderModelSeparator)
286-
if !ok {
287-
command.logger.Panicf("ERROR: cannot split %q into provider and model name by %q", model, provider.ProviderModelSeparator)
288-
}
289-
modelProvider, ok := customProviders[providerID]
290-
if !ok {
291-
command.logger.Panicf("ERROR: unknown custom provider %q for model %q", providerID, model)
292-
}
293-
294-
modelProvider.AddModel(llm.NewModel(modelProvider, model))
295-
}
296-
}
297-
298268
// Ensure the "testdata" path exists and make it absolute.
299269
{
300270
if err := osutil.DirExists(command.TestdataPath); err != nil {
@@ -371,101 +341,162 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
371341
evaluationContext.Languages[i] = languagesSelected[languageID]
372342
}
373343

374-
// Gather models.
375-
serviceShutdown := []func() (err error){}
344+
// Gather models and initialize providers.
345+
var serviceShutdown []func() (err error)
376346
{
377-
// Check which providers are needed for the evaluation.
378-
providersSelected := map[string]provider.Provider{}
379-
if len(command.Models) == 0 {
380-
providersSelected = provider.Providers
347+
// Gather providers.
348+
providers := map[string]provider.Provider{}
349+
if len(command.ModelIDsWithProviderAndAttributes) == 0 {
350+
for providerID, provider := range provider.Providers {
351+
providers[providerID] = provider
352+
command.logger.Info("selected provider", "provider", providerID)
353+
}
381354
} else {
382-
for _, model := range command.Models {
383-
p := strings.SplitN(model, provider.ProviderModelSeparator, 2)[0]
355+
// Register custom providers.
356+
for providerID, providerURL := range command.ProviderUrls {
357+
if !strings.HasPrefix(providerID, "custom-") {
358+
command.logger.Panicf("ERROR: cannot set URL of %q because it is not a custom provider", providerID)
359+
}
384360

385-
if _, ok := providersSelected[p]; ok {
386-
continue
361+
p := openaiapi.NewProvider(providerID, providerURL)
362+
provider.Register(p)
363+
providers[providerID] = p
364+
command.logger.Info("selected provider", "provider", providerID)
365+
}
366+
367+
// Add remaining providers from models.
368+
for _, modelIDsWithProviderAndAttributes := range command.ModelIDsWithProviderAndAttributes {
369+
providerID, _, ok := strings.Cut(modelIDsWithProviderAndAttributes, provider.ProviderModelSeparator)
370+
if !ok {
371+
command.logger.Panicf("ERROR: cannot split %q into provider and model name by %q", modelIDsWithProviderAndAttributes, provider.ProviderModelSeparator)
387372
}
388373

389-
if provider, ok := provider.Providers[p]; !ok {
390-
command.logger.Panicf("Provider %q does not exist", p)
391-
} else {
392-
providersSelected[provider.ID()] = provider
374+
p, ok := provider.Providers[providerID]
375+
if !ok {
376+
command.logger.Panicf("ERROR: unknown provider %q for model %q", providerID, modelIDsWithProviderAndAttributes)
377+
}
378+
if _, ok := providers[providerID]; !ok {
379+
providers[providerID] = p
380+
command.logger.Info("selected provider", "provider", providerID)
393381
}
394382
}
395383
}
396384

397-
models := map[string]model.Model{}
398-
modelsSelected := map[string]model.Model{}
399-
evaluationContext.ProviderForModel = map[model.Model]provider.Provider{}
400-
for _, p := range providersSelected {
401-
command.logger.Info("querying provider models", "provider", p.ID())
385+
// Initialize providers.
386+
{
387+
providerIDsSorted := maps.Keys(providers)
388+
sort.Strings(providerIDsSorted)
389+
for _, providerID := range providerIDsSorted {
390+
p := providers[providerID]
402391

403-
if t, ok := p.(provider.InjectToken); ok {
404-
token, ok := command.ProviderTokens[p.ID()]
405-
if ok {
406-
t.SetToken(token)
392+
command.logger.Info("initializing provider", "provider", providerID)
393+
if t, ok := p.(provider.InjectToken); ok {
394+
if token, ok := command.ProviderTokens[p.ID()]; ok {
395+
command.logger.Info("set token of provider", "provider", providerID)
396+
t.SetToken(token)
397+
}
407398
}
408-
}
409-
if err := p.Available(command.logger); err != nil {
410-
command.logger.Warn("skipping unavailable provider", "provider", p.ID(), "error", err)
399+
command.logger.Info("checking availability for provider", "provider", providerID)
400+
if err := p.Available(command.logger); err != nil {
401+
command.logger.Info("skipping provider because it is not available", "error", err, "provider", providerID)
402+
delete(providers, providerID)
411403

412-
continue
404+
continue
405+
}
406+
if service, ok := p.(provider.Service); ok {
407+
command.logger.Info("starting services for provider", "provider", p.ID())
408+
shutdown, err := service.Start(command.logger)
409+
if err != nil {
410+
command.logger.Panicf("ERROR: could not start services for provider %q: %s", p, err)
411+
}
412+
serviceShutdown = append(serviceShutdown, shutdown)
413+
}
413414
}
415+
}
414416

415-
// Start services of providers.
416-
if service, ok := p.(provider.Service); ok {
417-
command.logger.Info("starting services for provider", "provider", p.ID())
418-
shutdown, err := service.Start(command.logger)
417+
// Gather models.
418+
models := map[string]model.Model{}
419+
{
420+
addAllModels := len(command.ModelIDsWithProviderAndAttributes) == 0
421+
for _, p := range providers {
422+
ms, err := p.Models()
419423
if err != nil {
420-
command.logger.Panicf("ERROR: could not start services for provider %q: %s", p, err)
424+
command.logger.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err)
421425
}
422-
serviceShutdown = append(serviceShutdown, shutdown)
423-
}
426+
for _, m := range ms {
427+
models[m.ID()] = m
428+
evaluationConfiguration.Models.Available = append(evaluationConfiguration.Models.Available, m.ID())
424429

425-
// Check if a provider has the ability to pull models and do so if necessary.
426-
if puller, ok := p.(provider.Puller); ok {
427-
command.logger.Info("pulling available models for provider", "provider", p.ID())
428-
for _, modelID := range command.Models {
429-
if strings.HasPrefix(modelID, p.ID()) {
430-
if err := puller.Pull(command.logger, modelID); err != nil {
431-
command.logger.Panicf("ERROR: could not pull model %q: %s", modelID, err)
432-
}
430+
if addAllModels {
431+
command.ModelIDsWithProviderAndAttributes = append(command.ModelIDsWithProviderAndAttributes, m.ID())
433432
}
434433
}
435434
}
435+
}
436+
modelIDs := maps.Keys(models)
437+
sort.Strings(modelIDs)
438+
sort.Strings(command.ModelIDsWithProviderAndAttributes)
436439

437-
ms, err := p.Models()
438-
if err != nil {
439-
command.logger.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err)
440+
// Check and initialize models.
441+
evaluationContext.ProviderForModel = map[model.Model]provider.Provider{}
442+
for _, modelIDsWithProviderAndAttributes := range command.ModelIDsWithProviderAndAttributes {
443+
command.logger.Info("selecting model", "model", modelIDsWithProviderAndAttributes)
444+
445+
providerID, modelIDsWithAttributes, ok := strings.Cut(modelIDsWithProviderAndAttributes, provider.ProviderModelSeparator)
446+
if !ok {
447+
command.logger.Panicf("ERROR: cannot split %q into provider and model name by %q", modelIDsWithProviderAndAttributes, provider.ProviderModelSeparator)
440448
}
441449

442-
for _, m := range ms {
443-
models[m.ID()] = m
444-
evaluationContext.ProviderForModel[m] = p
445-
evaluationConfiguration.Models.Available = append(evaluationConfiguration.Models.Available, m.ID())
450+
modelID, _ := model.ParseModelID(modelIDsWithAttributes)
451+
452+
p, ok := providers[providerID]
453+
if !ok {
454+
command.logger.Panicf("ERROR: cannot find provider %q", providerID)
446455
}
447-
}
448-
modelIDs := maps.Keys(models)
449-
sort.Strings(modelIDs)
450-
if len(command.Models) == 0 {
451-
command.Models = modelIDs
452-
} else {
453-
for _, modelID := range command.Models {
454-
if _, ok := models[modelID]; !ok {
455-
command.logger.Panicf("ERROR: model %s does not exist. Valid models are: %s", modelID, strings.Join(modelIDs, ", "))
456+
if puller, ok := p.(provider.Puller); ok {
457+
command.logger.Info("pulling model", "model", modelID)
458+
if err := puller.Pull(command.logger, modelID); err != nil {
459+
command.logger.Panicf("ERROR: could not pull model %q: %s", modelID, err)
460+
}
461+
462+
// TODO If a model has not been pulled before, it was not available for at least the "Ollama" provider. Make this cleaner, we should not rebuild every time.
463+
if _, ok := models[modelIDsWithProviderAndAttributes]; !ok {
464+
ms, err := p.Models()
465+
if err != nil {
466+
command.logger.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err)
467+
}
468+
for _, m := range ms {
469+
if _, ok := models[m.ID()]; ok {
470+
continue
471+
}
472+
473+
models[m.ID()] = m
474+
evaluationConfiguration.Models.Available = append(evaluationConfiguration.Models.Available, m.ID())
475+
}
476+
modelIDs = maps.Keys(models)
477+
sort.Strings(modelIDs)
456478
}
457479
}
458-
}
459-
sort.Strings(command.Models)
460-
for _, modelID := range command.Models {
461-
modelsSelected[modelID] = models[modelID]
462-
}
463480

464-
// Make the resolved selected models available in the command.
465-
evaluationContext.Models = make([]model.Model, len(command.Models))
466-
for i, modelID := range command.Models {
467-
evaluationContext.Models[i] = modelsSelected[modelID]
468-
evaluationConfiguration.Models.Selected = append(evaluationConfiguration.Models.Selected, modelID)
481+
var m model.Model
482+
if strings.HasPrefix(providerID, "custom-") {
483+
pc, ok := p.(*openaiapi.Provider)
484+
if !ok {
485+
command.logger.Panicf("ERROR: %q is not a custom provider", providerID)
486+
}
487+
488+
m = llm.NewModel(pc, modelIDsWithProviderAndAttributes)
489+
pc.AddModel(m)
490+
} else {
491+
var ok bool
492+
m, ok = models[modelIDsWithProviderAndAttributes]
493+
if !ok {
494+
command.logger.Panicf("ERROR: model %q does not exist for provider %q. Valid models are: %s", modelIDsWithProviderAndAttributes, providerID, strings.Join(modelIDs, ", "))
495+
}
496+
}
497+
evaluationContext.Models = append(evaluationContext.Models, m)
498+
evaluationContext.ProviderForModel[m] = p
499+
evaluationConfiguration.Models.Selected = append(evaluationConfiguration.Models.Selected, modelIDsWithProviderAndAttributes)
469500
}
470501
}
471502

@@ -613,7 +644,7 @@ func (command *Evaluate) evaluateDocker(ctx *evaluate.Context) (err error) {
613644
"-e", "SYMFLOWER_INTERNAL_LICENSE_FILE",
614645
"-e", "SYMFLOWER_LICENSE_KEY",
615646
"-v", volumeName + ":/app/evaluation",
616-
"--rm", // automatically remove container after it finished
647+
"--rm", // Automatically remove container after it finished.
617648
command.RuntimeImage,
618649
}
619650

@@ -706,7 +737,7 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) {
706737
// Define a regex to replace all non alphanumeric characters and "-".
707738
kubeNameRegex := regexp.MustCompile(`[^a-zA-Z0-9-]+`)
708739

709-
jobTmpl, err := template.ParseFiles(filepath.Join("conf", "kube", "job.yml"))
740+
kubernetesJobTemplate, err := template.ParseFiles(filepath.Join("conf", "kube", "job.yml"))
710741
if err != nil {
711742
return pkgerrors.Wrap(err, "could not create kubernetes job template")
712743
}
@@ -735,7 +766,7 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) {
735766
"kubectl",
736767
"apply",
737768
"-f",
738-
"-", // apply STDIN
769+
"-", // Apply STDIN.
739770
}
740771

741772
// Commands for the evaluation to run inside the container.
@@ -763,14 +794,14 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) {
763794
}
764795

765796
parallel.Execute(func() {
766-
var tmplData bytes.Buffer
767-
if err := jobTmpl.Execute(&tmplData, data); err != nil {
797+
var kubernetesJobData bytes.Buffer
798+
if err := kubernetesJobTemplate.Execute(&kubernetesJobData, data); err != nil {
768799
command.logger.Panicf("ERROR: %s", err)
769800
}
770801

771802
commandOutput, err := util.CommandWithResult(context.Background(), command.logger, &util.Command{
772803
Command: kubeCommand,
773-
Stdin: tmplData.String(),
804+
Stdin: kubernetesJobData.String(),
774805
})
775806
if err != nil {
776807
command.logger.Error("kubernetes evaluation failed", "error", pkgerrors.WithMessage(pkgerrors.WithStack(err), commandOutput))
@@ -830,7 +861,7 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) {
830861

831862
var storageTemplateData bytes.Buffer
832863
if err := storageTemplate.Execute(&storageTemplateData, data); err != nil {
833-
return pkgerrors.Wrap(err, "could not execute storate template")
864+
return pkgerrors.Wrap(err, "could not execute storage template")
834865
}
835866

836867
// Create the storage access pod.
@@ -839,7 +870,7 @@ func (command *Evaluate) evaluateKubernetes(ctx *evaluate.Context) (err error) {
839870
"kubectl",
840871
"apply",
841872
"-f",
842-
"-", // apply STDIN
873+
"-", // Apply STDIN.
843874
},
844875
Stdin: storageTemplateData.String(),
845876
})

0 commit comments

Comments
 (0)