Skip to content

Commit a3c7f01

Browse files
committed
Differentiate between ID (with provider and attributes) and just the model ID (that we need to query LLM models)
Part of #407
1 parent 315cb3b commit a3c7f01

File tree

7 files changed

+47
-13
lines changed

7 files changed

+47
-13
lines changed

model/llm/llm.go

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ import (
2424

2525
// Model represents a LLM model accessed via a provider.
2626
type Model struct {
27+
// id holds the full identifier, including the provider and attributes.
28+
id string
2729
// provider is the client to query the LLM model.
2830
provider provider.Query
29-
// modelID holds the identifier for the LLM modelID.
31+
// modelID holds the identifier for the LLM model.
3032
modelID string
3133

3234
// attributes holds query attributes.
@@ -41,6 +43,7 @@ type Model struct {
4143
// NewModel returns an LLM model corresponding to the given identifier which is queried via the given provider.
4244
func NewModel(provider provider.Query, modelIDWithAttributes string) (llmModel *Model) {
4345
llmModel = &Model{
46+
id: modelIDWithAttributes,
4447
provider: provider,
4548

4649
queryAttempts: 1,
@@ -53,6 +56,7 @@ func NewModel(provider provider.Query, modelIDWithAttributes string) (llmModel *
5356
// NewModelWithMetaInformation returns a LLM model with meta information corresponding to the given identifier which is queried via the given provider.
5457
func NewModelWithMetaInformation(provider provider.Query, modelIdentifier string, metaInformation *model.MetaInformation) *Model {
5558
return &Model{
59+
id: modelIdentifier,
5660
provider: provider,
5761
modelID: modelIdentifier,
5862

@@ -62,6 +66,18 @@ func NewModelWithMetaInformation(provider provider.Query, modelIdentifier string
6266
}
6367
}
6468

69+
var _ model.Model = (*Model)(nil)
70+
71+
// ID returns full identifier, including the provider and attributes.
72+
func (m *Model) ID() (id string) {
73+
return m.id
74+
}
75+
76+
// ModelID returns the unique identifier of this model.
77+
func (m *Model) ModelID() (modelID string) {
78+
return m.modelID
79+
}
80+
6581
// Attributes returns query attributes.
6682
func (m *Model) Attributes() (attributes map[string]string) {
6783
return m.attributes
@@ -241,13 +257,6 @@ func (ctx *llmMigrateSourceFilePromptContext) Format() (message string, err erro
241257
return b.String(), nil
242258
}
243259

244-
var _ model.Model = (*Model)(nil)
245-
246-
// ID returns the unique ID of this model.
247-
func (m *Model) ID() (id string) {
248-
return m.modelID
249-
}
250-
251260
var _ model.CapabilityWriteTests = (*Model)(nil)
252261

253262
// WriteTests generates test files for the given implementation file in a repository.

model/model.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ import (
99

1010
// Model defines a model that can be queried for generations.
1111
type Model interface {
12-
// ID returns the unique ID of this model.
12+
// ID returns full identifier, including the provider and attributes.
1313
ID() (id string)
14+
// ModelID returns the unique identifier of this model.
15+
ModelID() (modelID string)
1416

1517
// Attributes returns query attributes.
1618
Attributes() (attributes map[string]string)

model/symflower/symflower.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,16 @@ func NewModelSmartTemplateWithTimeout(timeout time.Duration) (model *Model) {
6262

6363
var _ model.Model = (*Model)(nil)
6464

65-
// ID returns the unique ID of this model.
65+
// ID returns full identifier, including the provider and attributes.
6666
func (m *Model) ID() (id string) {
6767
return "symflower" + provider.ProviderModelSeparator + m.id
6868
}
6969

70+
// ModelID returns the unique identifier of this model.
71+
func (m *Model) ModelID() (modelID string) {
72+
return "symflower" + provider.ProviderModelSeparator + m.id
73+
}
74+
7075
// Attributes returns query attributes.
7176
func (m *Model) Attributes() (attributes map[string]string) {
7277
return nil

model/testing/Model_mock_gen.go

Lines changed: 18 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

provider/ollama/ollama.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ var _ provider.Query = (*Provider)(nil)
8383
// Query queries the provider with the given model name.
8484
func (p *Provider) Query(ctx context.Context, model model.Model, promptText string) (response string, err error) {
8585
client := p.client()
86-
modelIdentifier := strings.TrimPrefix(model.ID(), p.ID()+provider.ProviderModelSeparator)
86+
modelIdentifier := strings.TrimPrefix(model.ModelID(), p.ID()+provider.ProviderModelSeparator)
8787

8888
return openaiapi.QueryOpenAIAPIModel(ctx, client, modelIdentifier, model.Attributes(), promptText)
8989
}

provider/openai-api/openai.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ var _ provider.Query = (*Provider)(nil)
6262
// Query queries the provider with the given model name.
6363
func (p *Provider) Query(ctx context.Context, model model.Model, promptText string) (response string, err error) {
6464
client := p.client()
65-
modelIdentifier := strings.TrimPrefix(model.ID(), p.ID()+provider.ProviderModelSeparator)
65+
modelIdentifier := strings.TrimPrefix(model.ModelID(), p.ID()+provider.ProviderModelSeparator)
6666

6767
return QueryOpenAIAPIModel(ctx, client, modelIdentifier, model.Attributes(), promptText)
6868
}

provider/openrouter/openrouter.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ var _ provider.Query = (*Provider)(nil)
140140
// Query queries the provider with the given model name.
141141
func (p *Provider) Query(ctx context.Context, model model.Model, promptText string) (response string, err error) {
142142
client := p.client()
143-
modelIdentifier := strings.TrimPrefix(model.ID(), p.ID()+provider.ProviderModelSeparator)
143+
modelIdentifier := strings.TrimPrefix(model.ModelID(), p.ID()+provider.ProviderModelSeparator)
144144

145145
return openaiapi.QueryOpenAIAPIModel(ctx, client, modelIdentifier, model.Attributes(), promptText)
146146
}

0 commit comments

Comments
 (0)