Skip to content

Commit 377f295

Browse files
authored
Merge pull request #421 from symflower/allow-model-parameters
fix, Allow selecting models with attributes for openRouter as well
2 parents 99f5feb + 3291eac commit 377f295

File tree

6 files changed

+112
-7
lines changed

6 files changed

+112
-7
lines changed

cmd/eval-dev-quality/cmd/evaluate.go

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,8 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
447447
command.logger.Panicf("ERROR: cannot split %q into provider and model name by %q", modelIDsWithProviderAndAttributes, provider.ProviderModelSeparator)
448448
}
449449

450-
modelID, _ := model.ParseModelID(modelIDsWithAttributes)
450+
modelID, attributes := model.ParseModelID(modelIDsWithAttributes)
451+
modelIDWithProvider := providerID + provider.ProviderModelSeparator + modelID
451452

452453
p, ok := providers[providerID]
453454
if !ok {
@@ -460,18 +461,18 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
460461
}
461462

462463
// TODO If a model has not been pulled before, it was not available for at least the "Ollama" provider. Make this cleaner, we should not rebuild every time.
463-
if _, ok := models[modelIDsWithProviderAndAttributes]; !ok {
464+
if _, ok := models[modelIDWithProvider]; !ok {
464465
ms, err := p.Models()
465466
if err != nil {
466467
command.logger.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err)
467468
}
468469
for _, m := range ms {
469-
if _, ok := models[m.ID()]; ok {
470+
if _, ok := models[m.ModelID()]; ok {
470471
continue
471472
}
472473

473-
models[m.ID()] = m
474-
evaluationConfiguration.Models.Available = append(evaluationConfiguration.Models.Available, m.ID())
474+
models[m.ModelID()] = m
475+
evaluationConfiguration.Models.Available = append(evaluationConfiguration.Models.Available, m.ModelID())
475476
}
476477
modelIDs = maps.Keys(models)
477478
sort.Strings(modelIDs)
@@ -489,10 +490,18 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
489490
pc.AddModel(m)
490491
} else {
491492
var ok bool
492-
m, ok = models[modelIDsWithProviderAndAttributes]
493+
m, ok = models[modelIDWithProvider]
493494
if !ok {
494495
command.logger.Panicf("ERROR: model %q does not exist for provider %q. Valid models are: %s", modelIDsWithProviderAndAttributes, providerID, strings.Join(modelIDs, ", "))
495496
}
497+
498+
// If a model with attributes is requested, we add the base model plus attributes as new model to our list.
499+
if len(attributes) > 0 {
500+
modelWithAttributes := m.Clone()
501+
modelWithAttributes.SetAttributes(attributes)
502+
models[modelWithAttributes.ID()] = modelWithAttributes
503+
m = modelWithAttributes
504+
}
496505
}
497506
evaluationContext.Models = append(evaluationContext.Models, m)
498507
evaluationContext.ProviderForModel[m] = p

cmd/eval-dev-quality/cmd/evaluate_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,6 +1478,44 @@ func TestEvaluateInitialize(t *testing.T) {
14781478
}, config.Repositories.Selected)
14791479
},
14801480
})
1481+
validate(t, &testCase{
1482+
Name: "Model with attributes",
1483+
1484+
Command: makeValidCommand(func(command *Evaluate) {
1485+
command.ModelIDsWithProviderAndAttributes = []string{
1486+
"openrouter/openai/o3-mini@reasoning_effort=low",
1487+
"openrouter/openai/o3-mini@reasoning_effort=high",
1488+
}
1489+
command.ProviderTokens = map[string]string{
1490+
"openrouter": "fake-token",
1491+
}
1492+
}),
1493+
1494+
ValidateContext: func(t *testing.T, context *evaluate.Context) {
1495+
assert.Len(t, context.Models, 2)
1496+
1497+
assert.Equal(t, "openrouter/openai/o3-mini@reasoning_effort=high", context.Models[0].ID())
1498+
assert.Equal(t, "openrouter/openai/o3-mini", context.Models[0].ModelID())
1499+
expectedAttributes := map[string]string{
1500+
"reasoning_effort": "high",
1501+
}
1502+
assert.Equal(t, expectedAttributes, context.Models[0].Attributes())
1503+
1504+
assert.Equal(t, "openrouter/openai/o3-mini@reasoning_effort=low", context.Models[1].ID())
1505+
assert.Equal(t, "openrouter/openai/o3-mini", context.Models[1].ModelID())
1506+
expectedAttributes = map[string]string{
1507+
"reasoning_effort": "low",
1508+
}
1509+
assert.Equal(t, expectedAttributes, context.Models[1].Attributes())
1510+
},
1511+
ValidateConfiguration: func(t *testing.T, config *EvaluationConfiguration) {
1512+
expectedSelected := []string{
1513+
"openrouter/openai/o3-mini@reasoning_effort=high",
1514+
"openrouter/openai/o3-mini@reasoning_effort=low",
1515+
}
1516+
assert.Equal(t, expectedSelected, config.Models.Selected)
1517+
},
1518+
})
14811519
validate(t, &testCase{
14821520
Name: "Local runtime does not allow parallel parameter",
14831521

model/llm/llm.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,12 @@ var _ model.Model = (*Model)(nil)
7070

7171
// ID returns full identifier, including the provider and attributes.
7272
func (m *Model) ID() (id string) {
73-
return m.id
73+
attributeString := ""
74+
for key, value := range m.attributes {
75+
attributeString += "@" + key + "=" + value
76+
}
77+
78+
return m.id + attributeString
7479
}
7580

7681
// ModelID returns the unique identifier of this model with its provider.
@@ -93,11 +98,23 @@ func (m *Model) Attributes() (attributes map[string]string) {
9398
return m.attributes
9499
}
95100

101+
// SetAttributes sets the given attributes.
102+
func (m *Model) SetAttributes(attributes map[string]string) {
103+
m.attributes = attributes
104+
}
105+
96106
// MetaInformation returns the meta information of a model.
97107
func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) {
98108
return m.metaInformation
99109
}
100110

111+
// Clone returns a copy of the model.
112+
func (m *Model) Clone() (clone model.Model) {
113+
model := *m
114+
115+
return &model
116+
}
117+
101118
// llmSourceFilePromptContext is the base template context for an LLM generation prompt.
102119
type llmSourceFilePromptContext struct {
103120
// Language holds the programming language name.

model/model.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,14 @@ type Model interface {
1818

1919
// Attributes returns query attributes.
2020
Attributes() (attributes map[string]string)
21+
// SetAttributes sets the given attributes.
22+
SetAttributes(attributes map[string]string)
2123

2224
// MetaInformation returns the meta information of a model.
2325
MetaInformation() *MetaInformation
26+
27+
// Clone returns a copy of the model.
28+
Clone() (clone Model)
2429
}
2530

2631
// ParseModelID takes a packaged model ID with optional attributes and converts it into its model ID and optional attributes.

model/symflower/symflower.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,22 @@ func (m *Model) Attributes() (attributes map[string]string) {
8282
return nil
8383
}
8484

85+
// SetAttributes sets the given attributes.
86+
func (m *Model) SetAttributes(attributes map[string]string) {
87+
}
88+
8589
// MetaInformation returns the meta information of a model.
8690
func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) {
8791
return nil
8892
}
8993

94+
// Clone returns a copy of the model.
95+
func (m *Model) Clone() (clone model.Model) {
96+
model := *m
97+
98+
return &model
99+
}
100+
90101
var _ model.CapabilityWriteTests = (*Model)(nil)
91102

92103
// WriteTests generates test files for the given implementation file in a repository.

model/testing/Model_mock_gen.go

Lines changed: 25 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)