Skip to content
This repository was archived by the owner on Sep 30, 2024. It is now read-only.

Commit 582022d

Browse files
authored
Return model IDs from GraphQL, not model Names (#64307)
::sigh:: the problem here is super in the weeds, but ultimately this fixes a problem introduced when using AWS Bedrock and Sourcegraph instances using the older style "completions" config. ## The problem AWS Bedrock has some LLM model names that contain a colon, e.g. `anthropic.claude-3-opus-20240229-v1:0`. Cody clients connecting to Sourcegraph instances using the older style "completions" config will obtain the available LLM models by using GraphGL. So the Cody client would see that the chat model is `anthropic.claude-3-opus-20240229-v1:0`. However, under the hood, the Sourcegraph instance will convert the site config into the newer `modelconfig` format. And during that conversion, we use a _different value_ for the **model ID** than what is in the site config. (The **model name** is what is sent to the LLM API, and is unmodified. The model ID is a stable, unique identifier but is sanitized so that it adheres to naming rules.) Because of this, we have a problem. When the Cody client makes a request to the HTTP completions API with the model name of `anthropic.claude-3-opus-20240229-v1:0` or `anthropic/anthropic.claude-3-opus-20240229-v1:0` it fails. Because there is no model with ID `...v1:0`. (We only have the sanitized version, `...v1_0`.) ## The fix There were a few ways we could fix this, but this goes with just having the GraphQL component return the model ID instead of the model name. So that when the Cody client passes that model ID to the completions API, everything works as it should. And, practically speaking, for 99.9% of cases, the model name and model ID will be identical. We only strip out non-URL safe characters and colons, which usually aren't used in model names. ## Potential bugs With this fix however, there is a specific combination of { client, server, and model name } where things could in theory break. Specifically: Client | Server | Modelname | Works | --- | --- | --- | --- | unaware-of-modelconfig | not-using-modelconfig | standard | 🟢 [1] | aware-of-modelconfig | not-using-modelconfig | standard | 🟢 [1] | unaware-of-modelconfig | using-modelconfig | standard | 🟢 [1] | aware-of-modelconfig | using-modelconfig | standard | 🟢 [3] | unaware-of-modelconfig | not-using-modelconfig | non-standard | 🔴 [2] | aware-of-modelconfig | not-using-modelconfig | non-standard | 🔴 [2] | unaware-of-modelconfig | using-modelconfig | non-standard | 🔴 [2] | aware-of-modelconfig | using-modelconfig | non-standard | 🟢 [3] | 1. If the model name is something that doesn't require sanitization, there is no problem. The model ID will be the same as the model name, and things will work like they do today. 2. If the model name gets sanitized, then IFF the Cody client were to make a decision based on that exact model name, it wouldn't work. Because it would receive the sanitized name, and not the real one. As long as the Cody client is only passing that model name onto the Sourcegraph backend which will recognize the sanitized model name / ID, all is well. 3. If the client and server are new, and using model config, then this shouldn't be a problem because the client would use a different API to fetch the Sourcegraph instance's supported models. And within the client, natively refer to the model ID instead of the model name. Fixes [PRIME-464](https://linear.app/sourcegraph/issue/PRIME-464/aws-bedrock-x-completions-config-does-not-work-if-model-name-has-a). ## Test plan Added some unit tests. ## Changelog NA
1 parent c414477 commit 582022d

File tree

4 files changed

+83
-21
lines changed

4 files changed

+83
-21
lines changed

cmd/frontend/internal/modelconfig/resolver.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ type completionsConfigResolver struct {
6969
}
7070

7171
func (c *completionsConfigResolver) ChatModel() (string, error) {
72-
return c.config.ChatModel, nil
72+
return convertLegacyModelNameToModelID(c.config.ChatModel), nil
7373
}
7474

7575
func (c *completionsConfigResolver) ChatModelMaxTokens() (*int32, error) {
@@ -92,7 +92,7 @@ func (c *completionsConfigResolver) DisableClientConfigAPI() bool {
9292
}
9393

9494
func (c *completionsConfigResolver) FastChatModel() (string, error) {
95-
return c.config.FastChatModel, nil
95+
return convertLegacyModelNameToModelID(c.config.FastChatModel), nil
9696
}
9797

9898
func (c *completionsConfigResolver) FastChatModelMaxTokens() (*int32, error) {
@@ -108,7 +108,7 @@ func (c *completionsConfigResolver) Provider() string {
108108
}
109109

110110
func (c *completionsConfigResolver) CompletionModel() (string, error) {
111-
return c.config.CompletionModel, nil
111+
return convertLegacyModelNameToModelID(c.config.CompletionModel), nil
112112
}
113113

114114
func (c *completionsConfigResolver) CompletionModelMaxTokens() (*int32, error) {
@@ -254,12 +254,18 @@ func (r *modelconfigResolver) CompletionModelMaxTokens() (*int32, error) {
254254
// the provider as needed to match older behavior. (See unit tests and convertProviderID for
255255
// more information.)
256256
func (r *modelconfigResolver) toLegacyModelRef(model modelconfigSDK.Model) string {
257+
modelID := model.ModelRef.ModelID()
257258
providerID := model.ModelRef.ProviderID()
258259
legacyProviderName := r.convertProviderID(providerID)
259260

260-
// For compatibility, we are returning the model _name_ instead of the model _id_.
261-
// So the client will see "claude-3-xxxx" not the shortened model ID like "claude-3".
262-
return fmt.Sprintf("%s/%s", legacyProviderName, model.ModelName)
261+
// Potential issue: Older Cody clients calling the GraphQL may expect to see the model **name**
262+
// such as "claude-3-sonnet-20240229". But it is important that we only return the model **ID**
263+
// because that is what the HTTP completions API is expecting to see from the client.
264+
//
265+
// So when using older Cody clients, unaware of the newer modelconfig system, this could lead
266+
// to some errors. (But newer clients won't be using this GraphQL endpoint at all and instead
267+
// just use the newer modelconfig system, so hopefully this won't be a major concern.)
268+
return fmt.Sprintf("%s/%s", legacyProviderName, modelID)
263269
}
264270

265271
// convertProviderID returns the _API Provider_ for the referenced modelconfig provider.

cmd/frontend/internal/modelconfig/resolver_test.go

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,45 @@ func TestCompletionsResolver(t *testing.T) {
5959
model, err = testResolver.CompletionModel()
6060
assert.EqualValues(t, siteConfigData.CompletionModel, model)
6161
assert.NoError(t, err)
62+
63+
// In the GraphQL resolver we are returning the model name expressed in
64+
// the site config, but the HTTP completions API only accepts model IDs.
65+
// For the "completions" config, these are 99% identical, but in some cases
66+
// may differ.
67+
//
68+
// In the completions API (see get_model.go) we lookup a model by its mref
69+
// or model ID, and then use the unmodified model name when making the API
70+
// request.
71+
t.Run("Sanitization", func(t *testing.T) {
72+
// Copy and introduce more challenging model names.
73+
updatedSiteConfigData := *siteConfigData
74+
updatedSiteConfigData.ChatModel = "anthropic.claude-3-opus-20240229-v1:0/so:many:colons"
75+
updatedSiteConfigData.FastChatModel = "all/sorts@of;special_chars&but!no#sanitization"
76+
updatedSiteConfigData.CompletionModel = "other invalid tokens 😭😭😭"
77+
78+
updatedResolver := &completionsConfigResolver{
79+
config: &updatedSiteConfigData,
80+
}
81+
82+
var (
83+
model string
84+
err error
85+
)
86+
model, err = updatedResolver.ChatModel()
87+
assert.NotEqualValues(t, updatedSiteConfigData.ChatModel, model)
88+
assert.EqualValues(t, "anthropic.claude-3-opus-20240229-v1_0/so_many_colons", model)
89+
assert.NoError(t, err)
90+
91+
// Fast chat had wonky characters, but none required sanitizing.
92+
model, err = updatedResolver.FastChatModel()
93+
assert.EqualValues(t, updatedSiteConfigData.FastChatModel, model)
94+
assert.NoError(t, err)
95+
96+
model, err = updatedResolver.CompletionModel()
97+
assert.NotEqualValues(t, updatedSiteConfigData.CompletionModel, model)
98+
assert.EqualValues(t, "other_invalid_tokens_____________", model)
99+
assert.NoError(t, err)
100+
})
62101
})
63102
}
64103

@@ -75,8 +114,8 @@ func TestModelConfigResolver(t *testing.T) {
75114
},
76115
}
77116
awsBedrockModel := modelconfigSDK.Model{
78-
ModelRef: modelconfigSDK.ModelRef("test-provider_aws-bedrock::xxx::test-model_aws-bedrock"),
79-
ModelName: "aws-bedrock-model-name",
117+
ModelRef: modelconfigSDK.ModelRef("test-provider_aws-bedrock::xxx::aws-bedrock_model-id"),
118+
ModelName: "aws-bedrock_model-name",
80119
}
81120

82121
// Azure OpenAI provider and model.
@@ -89,8 +128,8 @@ func TestModelConfigResolver(t *testing.T) {
89128
},
90129
}
91130
azureOpenAIModel := modelconfigSDK.Model{
92-
ModelRef: modelconfigSDK.ModelRef("test-provider_azure-openai::xxx::test-model_azure-openai"),
93-
ModelName: "azure-openai-model-name",
131+
ModelRef: modelconfigSDK.ModelRef("test-provider_azure-openai::xxx::azure-openai_model-id"),
132+
ModelName: "azure-openai_model-name",
94133
}
95134

96135
// Cody Gateway provider and model.
@@ -103,8 +142,8 @@ func TestModelConfigResolver(t *testing.T) {
103142
},
104143
}
105144
codyGatewayModel := modelconfigSDK.Model{
106-
ModelRef: modelconfigSDK.ModelRef("test-provider_cody-gateway::xxx::test-model_cody-gateway"),
107-
ModelName: "cody-gateway-model-name",
145+
ModelRef: modelconfigSDK.ModelRef("test-provider_cody-gateway::xxx::cody-gateway_model-id"),
146+
ModelName: "cody-gateway_model-name",
108147
}
109148

110149
modelconfigData := modelconfigSDK.ModelConfiguration{
@@ -146,24 +185,24 @@ func TestModelConfigResolver(t *testing.T) {
146185
})
147186

148187
t.Run("Models", func(t *testing.T) {
149-
// Note that for all these cases the returned string doesn't match
150-
// either the Provider ID nor the Model ID. Instead, it is the name
151-
// of the API Provider (e.g. "sourcegraph" if using Cody Gateway),
152-
// and we return the model name.
188+
// The models returned here are kinda confusing:
189+
// We replace the "provider" with whatever underlying API is used for serving responses.
190+
// However, we return the model IDs (rather than model Names) since that's what the
191+
// completions API expects.
153192
var (
154193
model string
155194
err error
156195
)
157196
model, err = testResolver.ChatModel()
158-
assert.Equal(t, "aws-bedrock/aws-bedrock-model-name", model)
197+
assert.Equal(t, "aws-bedrock/aws-bedrock_model-id", model)
159198
assert.NoError(t, err)
160199

161200
model, err = testResolver.CompletionModel()
162-
assert.Equal(t, "azure-openai/azure-openai-model-name", model)
201+
assert.Equal(t, "azure-openai/azure-openai_model-id", model)
163202
assert.NoError(t, err)
164203

165204
model, err = testResolver.FastChatModel()
166-
assert.Equal(t, "sourcegraph/cody-gateway-model-name", model)
205+
assert.Equal(t, "sourcegraph/cody-gateway_model-id", model)
167206
assert.NoError(t, err)
168207
})
169208
}

cmd/frontend/internal/modelconfig/siteconfig_completions.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,21 @@ type legacyModelRef struct {
4949
serverSideConfig *types.ServerSideModelConfig
5050
}
5151

52+
// convertLegacyModelNameToModelID returns the ID that should be used for a model name
53+
// defined in the "completions" site config.
54+
//
55+
// When sending LLM models to the client, it expects to see the exact value specified in the site
56+
// configuration. So the client sees the model **name**. However, internally, this Sourcegraph
57+
// instance converts the site configuration into a modelconfigSDK.ModelConfigruation, which may
58+
// have a slightly different model **ID** from model name.
59+
//
60+
// When converting older-style completions config, we just keep these identical for 99.9% of
61+
// cases. (No need to differ.) But we need to have model IDs adhear to naming rules. So we
62+
// need to sanitize the results.
63+
func convertLegacyModelNameToModelID(model string) string {
64+
return modelconfig.SanitizeResourceName(model)
65+
}
66+
5267
// parseLegacyModelRef takes a reference to a model from the site configuration in the "legacy format",
5368
// and infers all the surrounding data. e.g. "claude-instant", "openai/gpt-4o".
5469
func parseLegacyModelRef(
@@ -83,7 +98,7 @@ func parseLegacyModelRef(
8398
// The model ID may contain colons or other invalid characters. So we strip those out here,
8499
// so that the Model's mref is valid.
85100
// But the model NAME remains unchanged. As that's what is sent to AWS.
86-
modelID = modelconfig.SanitizeResourceName(bedrockModelRef.Model)
101+
modelID = convertLegacyModelNameToModelID(bedrockModelRef.Model)
87102
modelName = bedrockModelRef.Model
88103

89104
if bedrockModelRef.ProvisionedCapacity != nil {
@@ -122,7 +137,7 @@ func parseLegacyModelRef(
122137
modelID = modelNameFromConfig[kind]
123138
}
124139
// Finally, sanitize the user-supplied model ID to ensure it is valid.
125-
modelID = modelconfig.SanitizeResourceName(modelID)
140+
modelID = convertLegacyModelNameToModelID(modelID)
126141

127142
default:
128143
// No other processing is needed.

cmd/frontend/internal/modelconfig/siteconfig_completions_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ func TestConvertCompletionsConfig(t *testing.T) {
169169
}
170170
{
171171
m := siteModelConfig.ModelOverrides[0]
172+
// Notice how the model ID has been sanitized. (No colon.) But the model name is the same
173+
// from the site config. (Since that's how the model is identified in its backing API.)
172174
assert.EqualValues(t, "anthropic::unknown::anthropic.claude-3-opus-20240229-v1_0", m.ModelRef)
173175
assert.EqualValues(t, "anthropic.claude-3-opus-20240229-v1_0", m.ModelRef.ModelID())
174176
// Unlike the Model's ID, the Name is unchanged, as this is what AWS expects in the API call.

0 commit comments

Comments
 (0)