Skip to content
This repository was archived by the owner on Sep 30, 2024. It is now read-only.

Commit c06fd8d

Browse files
author
Rik
authored
Backport: Client Compatible Bedrock ARN handling (#62720) (#62793)
Client Compatible Bedrock ARN handling (#62720) * Improve Bedrock ARN handling * Fix up PR comments (cherry picked from commit 6a7666c)
1 parent 660dee6 commit c06fd8d

File tree

10 files changed

+290
-21
lines changed

10 files changed

+290
-21
lines changed

internal/completions/client/awsbedrock/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ go_library(
1212
deps = [
1313
"//internal/completions/tokenusage",
1414
"//internal/completions/types",
15+
"//internal/conf/conftypes",
1516
"//internal/httpcli",
1617
"//lib/errors",
1718
"@com_github_aws_aws_sdk_go_v2//aws",

internal/completions/client/awsbedrock/bedrock.go

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222

2323
"github.com/sourcegraph/sourcegraph/internal/completions/tokenusage"
2424
"github.com/sourcegraph/sourcegraph/internal/completions/types"
25+
"github.com/sourcegraph/sourcegraph/internal/conf/conftypes"
2526
"github.com/sourcegraph/sourcegraph/internal/httpcli"
2627
"github.com/sourcegraph/sourcegraph/lib/errors"
2728
)
@@ -68,7 +69,8 @@ func (c *awsBedrockAnthropicCompletionStreamClient) Complete(
6869
completion += content.Text
6970
}
7071

71-
err = c.tokenManager.UpdateTokenCountsFromModelUsage(response.Usage.InputTokens, response.Usage.OutputTokens, "anthropic/"+requestParams.Model, string(feature), tokenusage.AwsBedrock)
72+
parsedModelId := conftypes.NewBedrockModelRefFromModelID(requestParams.Model)
73+
err = c.tokenManager.UpdateTokenCountsFromModelUsage(response.Usage.InputTokens, response.Usage.OutputTokens, "anthropic/"+parsedModelId.Model, string(feature), tokenusage.AwsBedrock)
7274
if err != nil {
7375
return nil, err
7476
}
@@ -153,7 +155,8 @@ func (a *awsBedrockAnthropicCompletionStreamClient) Stream(
153155
case "message_delta":
154156
if event.Delta != nil {
155157
stopReason = event.Delta.StopReason
156-
err = a.tokenManager.UpdateTokenCountsFromModelUsage(inputPromptTokens, event.Usage.OutputTokens, "anthropic/"+requestParams.Model, string(feature), tokenusage.AwsBedrock)
158+
parsedModelId := conftypes.NewBedrockModelRefFromModelID(requestParams.Model)
159+
err = a.tokenManager.UpdateTokenCountsFromModelUsage(inputPromptTokens, event.Usage.OutputTokens, "anthropic/"+parsedModelId.Model, string(feature), tokenusage.AwsBedrock)
157160
if err != nil {
158161
logger.Warn("Failed to count tokens with the token manager %w ", log.Error(err))
159162
}
@@ -232,19 +235,8 @@ func (c *awsBedrockAnthropicCompletionStreamClient) makeRequest(ctx context.Cont
232235
if err != nil {
233236
return nil, errors.Wrap(err, "marshalling request body")
234237
}
235-
apiURL, err := url.Parse(c.endpoint)
236-
if err != nil || apiURL.Scheme == "" {
237-
apiURL = &url.URL{
238-
Scheme: "https",
239-
Host: fmt.Sprintf("bedrock-runtime.%s.amazonaws.com", defaultConfig.Region),
240-
}
241-
}
242238

243-
if stream {
244-
apiURL.Path = fmt.Sprintf("/model/%s/invoke-with-response-stream", requestParams.Model)
245-
} else {
246-
apiURL.Path = fmt.Sprintf("/model/%s/invoke", requestParams.Model)
247-
}
239+
apiURL := buildApiUrl(c.endpoint, requestParams.Model, stream, defaultConfig.Region)
248240

249241
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL.String(), bytes.NewReader(reqBody))
250242
if err != nil {
@@ -282,6 +274,41 @@ func (c *awsBedrockAnthropicCompletionStreamClient) makeRequest(ctx context.Cont
282274
return resp, nil
283275
}
284276

277+
// Builds a bedrock api URL from the configured endpoint url.
278+
// If the endpoint isn't valid, falls back to the default endpoint for the specified fallbackRegion
279+
func buildApiUrl(endpoint string, model string, stream bool, fallbackRegion string) *url.URL {
280+
apiURL, err := url.Parse(endpoint)
281+
if err != nil || apiURL.Scheme == "" {
282+
apiURL = &url.URL{
283+
Scheme: "https",
284+
Host: fmt.Sprintf("bedrock-runtime.%s.amazonaws.com", fallbackRegion),
285+
}
286+
}
287+
288+
bedrockModelRef := conftypes.NewBedrockModelRefFromModelID(model)
289+
290+
if bedrockModelRef.ProvisionedCapacity != nil {
291+
// We need to Query escape the provisioned capacity ARN, since otherwise
292+
// the AWS API Gateway interprets the path as a path and doesn't route
293+
// to the Bedrock service. This would results in abstract Coral errors
294+
if stream {
295+
apiURL.RawPath = fmt.Sprintf("/model/%s/invoke-with-response-stream", url.QueryEscape(*bedrockModelRef.ProvisionedCapacity))
296+
apiURL.Path = fmt.Sprintf("/model/%s/invoke-with-response-stream", *bedrockModelRef.ProvisionedCapacity)
297+
} else {
298+
apiURL.RawPath = fmt.Sprintf("/model/%s/invoke", url.QueryEscape(*bedrockModelRef.ProvisionedCapacity))
299+
apiURL.Path = fmt.Sprintf("/model/%s/invoke", *bedrockModelRef.ProvisionedCapacity)
300+
}
301+
} else {
302+
if stream {
303+
apiURL.Path = fmt.Sprintf("/model/%s/invoke-with-response-stream", bedrockModelRef.Model)
304+
} else {
305+
apiURL.Path = fmt.Sprintf("/model/%s/invoke", bedrockModelRef.Model)
306+
}
307+
}
308+
309+
return apiURL
310+
}
311+
285312
func awsConfigOptsForKeyConfig(endpoint string, accessToken string) []func(*config.LoadOptions) error {
286313
configOpts := []func(*config.LoadOptions) error{}
287314
if endpoint != "" {

internal/completions/client/awsbedrock/bedrock_test.go

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,56 @@ package awsbedrock
22

33
import (
44
"context"
5+
"fmt"
56
"testing"
67

78
"github.com/aws/aws-sdk-go-v2/config"
89
"github.com/stretchr/testify/require"
910
)
1011

11-
func TestAwsConfigOptsForKeyConfig(t *testing.T) {
12+
func Test_BedrockProvisionedThroughputModel(t *testing.T) {
13+
tests := []struct {
14+
want string
15+
endpoint string
16+
model string
17+
fallbackRegion string
18+
stream bool
19+
}{
20+
{
21+
want: "https://bedrock-runtime.us-west-2.amazonaws.com/model/amazon.titan-text-express-v1/invoke",
22+
endpoint: "",
23+
model: "amazon.titan-text-express-v1",
24+
fallbackRegion: "us-west-2",
25+
stream: false,
26+
},
27+
{
28+
want: "https://bedrock-runtime.us-west-2.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0:200k/invoke",
29+
endpoint: "",
30+
model: "anthropic.claude-3-sonnet-20240229-v1:0:200k",
31+
fallbackRegion: "us-west-2",
32+
stream: false,
33+
},
34+
{
35+
want: "https://vpce-12345678910.bedrock-runtime.us-west-2.vpce.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A012345678901%3Aprovisioned-model%2Fabcdefghijkl/invoke-with-response-stream",
36+
endpoint: "https://vpce-12345678910.bedrock-runtime.us-west-2.vpce.amazonaws.com",
37+
model: "anthropic.claude-instant-v1/arn:aws:bedrock:us-west-2:012345678901:provisioned-model/abcdefghijkl",
38+
fallbackRegion: "us-east-1",
39+
stream: true,
40+
},
41+
}
42+
43+
for _, tt := range tests {
44+
t.Run(fmt.Sprintf("%q", tt.want), func(t *testing.T) {
45+
got := buildApiUrl(tt.endpoint, tt.model, tt.stream, tt.fallbackRegion)
46+
if got.String() != tt.want {
47+
t.Logf("got %q but wanted %q", got, tt.want)
48+
t.Fail()
49+
}
50+
})
51+
}
52+
}
53+
54+
func Test_AwsConfigOptsForKeyConfig(t *testing.T) {
1255

1356
t.Run("With endpoint as URL", func(t *testing.T) {
1457
endpoint := "https://example.com"

internal/completions/client/azureopenai/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ go_library(
2222
go_test(
2323
name = "azureopenai_test",
2424
srcs = ["openai_test.go"],
25+
data = glob(["testdata/**"]),
2526
embed = [":azureopenai"],
2627
deps = [
2728
"//internal/completions/tokenusage",

internal/conf/computed.go

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,11 @@ func GetCompletionsConfig(siteConfig schema.SiteConfiguration) (c *conftypes.Com
706706
completionsConfig.ChatModel = completionsConfig.Model
707707
}
708708

709+
// This records if the modelIDs have been canonicalized by the provider
710+
// specific configuration. By default a ToLower will be applied the modelIDs
711+
// if no other canonicalization has already been applied. In particular this
712+
// is because BedrockModelRefs need different canonicalization
713+
canonicalized := false
709714
if completionsConfig.Provider == string(conftypes.CompletionsProviderNameSourcegraph) {
710715
// If no endpoint is configured, use a default value.
711716
if completionsConfig.Endpoint == "" {
@@ -850,12 +855,27 @@ func GetCompletionsConfig(siteConfig schema.SiteConfiguration) (c *conftypes.Com
850855
if completionsConfig.CompletionModel == "" {
851856
completionsConfig.CompletionModel = "anthropic.claude-instant-v1"
852857
}
858+
859+
// We apply BedrockModelRef specific canonicalization
860+
// Make sure models are always treated case-insensitive.
861+
chatModelRef := conftypes.NewBedrockModelRefFromModelID(completionsConfig.ChatModel)
862+
completionsConfig.ChatModel = chatModelRef.CanonicalizedModelID()
863+
864+
fastChatModelRef := conftypes.NewBedrockModelRefFromModelID(completionsConfig.FastChatModel)
865+
completionsConfig.FastChatModel = fastChatModelRef.CanonicalizedModelID()
866+
867+
completionsModelRef := conftypes.NewBedrockModelRefFromModelID(completionsConfig.CompletionModel)
868+
completionsConfig.CompletionModel = completionsModelRef.CanonicalizedModelID()
869+
canonicalized = true
853870
}
854871

855-
// Make sure models are always treated case-insensitive.
856-
completionsConfig.ChatModel = strings.ToLower(completionsConfig.ChatModel)
857-
completionsConfig.FastChatModel = strings.ToLower(completionsConfig.FastChatModel)
858-
completionsConfig.CompletionModel = strings.ToLower(completionsConfig.CompletionModel)
872+
// only apply canonicalization if not already applied. Not all model IDs can simply be lowercased
873+
if !canonicalized {
874+
// Make sure models are always treated case-insensitive.
875+
completionsConfig.ChatModel = strings.ToLower(completionsConfig.ChatModel)
876+
completionsConfig.FastChatModel = strings.ToLower(completionsConfig.FastChatModel)
877+
completionsConfig.CompletionModel = strings.ToLower(completionsConfig.CompletionModel)
878+
}
859879

860880
// If after trying to set default we still have not all models configured, completions are
861881
// not available.
@@ -1185,8 +1205,9 @@ func defaultMaxPromptTokens(provider conftypes.CompletionsProviderName, model st
11851205
// this is a sane default for GPT in general.
11861206
return 7_000
11871207
case conftypes.CompletionsProviderNameAWSBedrock:
1188-
if strings.HasPrefix(model, "anthropic.") {
1189-
return anthropicDefaultMaxPromptTokens(strings.TrimPrefix(model, "anthropic."))
1208+
parsed := conftypes.NewBedrockModelRefFromModelID(model)
1209+
if strings.HasPrefix(parsed.Model, "anthropic.") {
1210+
return anthropicDefaultMaxPromptTokens(strings.TrimPrefix(parsed.Model, "anthropic."))
11901211
}
11911212
// Fallback for weird values.
11921213
return 9_000
@@ -1197,6 +1218,8 @@ func defaultMaxPromptTokens(provider conftypes.CompletionsProviderName, model st
11971218
}
11981219

11991220
func anthropicDefaultMaxPromptTokens(model string) int {
1221+
// TODO: this doesn't nearly cover all the ways that token size can be specified.
1222+
// See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
12001223
if strings.HasSuffix(model, "-100k") {
12011224
return 100_000
12021225

internal/conf/computed_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,30 @@ func TestGetCompletionsConfig(t *testing.T) {
546546
Endpoint: "us-west-2",
547547
},
548548
},
549+
{
550+
name: "AWS Bedrock completions with Provisioned Throughput for some of the models",
551+
siteConfig: schema.SiteConfiguration{
552+
CodyEnabled: pointers.Ptr(true),
553+
LicenseKey: licenseKey,
554+
Completions: &schema.Completions{
555+
Provider: "aws-bedrock",
556+
Endpoint: "us-west-2",
557+
ChatModel: "anthropic.claude-3-haiku-20240307-v1:0-100k/arn:aws:bedrock:us-west-2:012345678901:provisioned-model/abcdefghijkl",
558+
FastChatModel: "anthropic.claude-v2",
559+
},
560+
},
561+
wantConfig: &conftypes.CompletionsConfig{
562+
ChatModel: "anthropic.claude-3-haiku-20240307-v1:0-100k/arn:aws:bedrock:us-west-2:012345678901:provisioned-model/abcdefghijkl",
563+
ChatModelMaxTokens: 100_000,
564+
FastChatModel: "anthropic.claude-v2",
565+
FastChatModelMaxTokens: 12000,
566+
CompletionModel: "anthropic.claude-instant-v1",
567+
CompletionModelMaxTokens: 9000,
568+
AccessToken: "",
569+
Provider: "aws-bedrock",
570+
Endpoint: "us-west-2",
571+
},
572+
},
549573
{
550574
name: "zero-config cody gateway completions without license key",
551575
siteConfig: schema.SiteConfiguration{

internal/conf/conftypes/conftypes.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package conftypes
22

33
import (
44
"reflect"
5+
"strings"
56
"time"
67

78
"google.golang.org/protobuf/types/known/durationpb"
@@ -101,3 +102,50 @@ func (r *RawUnified) FromProto(in *proto.RawUnified) {
101102
func (r RawUnified) Equal(other RawUnified) bool {
102103
return r.Site == other.Site && reflect.DeepEqual(r.ServiceConnections, other.ServiceConnections)
103104
}
105+
106+
// Bedrock Model IDs can be in one of two forms:
107+
// - A static model ID, e.g. "anthropic.claude-v2".
108+
// - A model ID and ARN for provisioned capacity, e.g.
109+
// "anthropic.claude-v2/arn:aws:bedrock:us-west-2:012345678901:provisioned-model/xxxxxxxx"
110+
//
111+
// See the AWS docs for more information:
112+
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
113+
// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_CreateProvisionedModelThroughput.html
114+
type BedrockModelRef struct {
115+
// Model is the underlying LLM model Bedrock is serving, e.g. "anthropic.claude-3-haiku-20240307-v1:0
116+
Model string
117+
// If the configuration is using provisioned capacity, this will
118+
// contain the ARN of the model to use for making API calls.
119+
// e.g. "anthropic.claude-v2/arn:aws:bedrock:us-west-2:012345678901:provisioned-model/xxxxxxxx"
120+
ProvisionedCapacity *string
121+
}
122+
123+
func NewBedrockModelRefFromModelID(modelID string) BedrockModelRef {
124+
parts := strings.SplitN(modelID, "/", 2)
125+
126+
if parts == nil { // this shouldn't really happen
127+
return BedrockModelRef{Model: modelID}
128+
}
129+
130+
parsed := BedrockModelRef{
131+
Model: parts[0],
132+
}
133+
134+
if len(parts) == 2 {
135+
parsed.ProvisionedCapacity = &parts[1]
136+
}
137+
return parsed
138+
}
139+
140+
// Ensures that all case insensitive parts of the model ID are lowercased so
141+
// that they can be compared.
142+
func (bmr BedrockModelRef) CanonicalizedModelID() string {
143+
// Bedrock models are case sensitive if they contain a ARN
144+
// make sure to only lowercase the non ARN part
145+
model := strings.ToLower(bmr.Model)
146+
147+
if bmr.ProvisionedCapacity != nil {
148+
return strings.Join([]string{model, *bmr.ProvisionedCapacity}, "/")
149+
}
150+
return model
151+
}

internal/conf/validation/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ go_test(
4747
name = "validation_test",
4848
srcs = [
4949
"auth_test.go",
50+
"cody_test.go",
5051
"prometheus_test.go",
5152
"txemail_test.go",
5253
],

internal/conf/validation/cody.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package validation
22

33
import (
44
"fmt"
5+
"strings"
56
"time"
67

78
"github.com/sourcegraph/sourcegraph/internal/conf"
@@ -19,6 +20,8 @@ func init() {
1920
conf.ContributeValidator(embeddingsConfigValidator)
2021
}
2122

23+
const bedrockArnMessageTemplate = "completions.%s is invalid. Provisioned Capacity IDs must be formatted like \"model_id/provisioned_capacity_arn\".\nFor example \"anthropic.claude-instant-v1/%s\""
24+
2225
func completionsConfigValidator(q conftypes.SiteConfigQuerier) conf.Problems {
2326
problems := []string{}
2427
completionsConf := q.SiteConfig().Completions
@@ -30,6 +33,35 @@ func completionsConfigValidator(q conftypes.SiteConfigQuerier) conf.Problems {
3033
problems = append(problems, "'completions.enabled' has been superceded by 'cody.enabled', please migrate to the new configuration.")
3134
}
3235

36+
// Check for bedrock Provisioned Capacity ARNs which should instead be
37+
// formatted like:
38+
// "anthropic.claude-v2/arn:aws:bedrock:us-west-2:012345678901:provisioned-model/xxxxxxxx"
39+
if completionsConf.Provider == string(conftypes.CompletionsProviderNameAWSBedrock) {
40+
type modelID struct {
41+
value string
42+
field string
43+
}
44+
allModelIds := []modelID{
45+
{value: completionsConf.ChatModel, field: "chatModel"},
46+
{value: completionsConf.FastChatModel, field: "fastChatModel"},
47+
{value: completionsConf.CompletionModel, field: "completionModel"},
48+
}
49+
var modelIdsToCheck []modelID
50+
for _, modelId := range allModelIds {
51+
if modelId.value != "" {
52+
modelIdsToCheck = append(modelIdsToCheck, modelId)
53+
}
54+
}
55+
56+
for _, modelId := range modelIdsToCheck {
57+
// When using provisioned capacity we expect an admin would just put the ARN
58+
// here directly, but we need both the model AND the ARN. Hence the check.
59+
if strings.HasPrefix(modelId.value, "arn:aws:") {
60+
problems = append(problems, fmt.Sprintf(bedrockArnMessageTemplate, modelId.field, modelId.value))
61+
}
62+
}
63+
}
64+
3365
if len(problems) > 0 {
3466
return conf.NewSiteProblems(problems...)
3567
}

0 commit comments

Comments
 (0)