Skip to content
This repository was archived by the owner on Sep 30, 2024. It is now read-only.

Commit 1a6a7f7

Browse files
arafatkatzeabeatrixStephen Gutekanst
authored
Adding Anthropic messages API support to the Google provider through Google vertex (#63282)
[Linear Issue](https://linear.app/sourcegraph/project/claude-3-on-gcp-8c014e1a3506/overview) This PR adds support for anthropic models in the google provider through google vertex. NOTE: The current code only supported Google Gemini API and had boiler plate code for Google vertex(only for the gemini model) this PR adds Google Vertex for anthropic models properly so this way the google provider can be run in 3 different configurations 1. Google Gemini API(this works but only for chat and not for completions which is the intended behaviour for now) 2. Google Vertex API Anthropic Model(This works perfectly and is added in this PR and tested for both chat and completions and it works great) 3. Google Vertex API Gemini Model (this doesn't work yet and can eventually be added to this codebase but we gotta add a new decoder for the streaming responses of the gemini model through this API we can take care of this later) Sense of Urgency: This is a P0 because of enterprise requirements so I would appreciate a fast approval and merging. <!-- 💡 To write a useful PR description, make sure that your description covers: - WHAT this PR is changing: - How was it PREVIOUSLY. - How it will be from NOW on. - WHY this PR is needed. - CONTEXT, i.e. to which initiative, project or RFC it belongs. The structure of the description doesn't matter as much as covering these points, so use your best judgement based on your context. Learn how to write good pull request description: https://www.notion.so/sourcegraph/Write-a-good-pull-request-description-610a7fd3e613496eb76f450db5a49b6e?pvs=4 --> ## Test plan - Run this branch for Cody instance -> https://github.com/sourcegraph/cody/pull/4606 - Ask @arafatkatze to dm you the siteadmin config to make things work - Check the logs and play with completions and chat <!-- All pull requests REQUIRE a test plan: https://docs-legacy.sourcegraph.com/dev/background-information/testing_principles --> ## Changelog <!-- 1. Ensure your pull request title is formatted as: $type($domain): $what 3. Add bullet list items for each additional detail you want to cover (see example below) 4. You can edit this after the pull request was merged, as long as release shipping it hasn't been promoted to the public. 5. For more information, please see this how-to https://www.notion.so/sourcegraph/Writing-a-changelog-entry-dd997f411d524caabf0d8d38a24a878c? Audience: TS/CSE > Customers > Teammates (in that order). Cheat sheet: $type = chore|fix|feat $domain: source|search|ci|release|plg|cody|local|... --> <!-- Example: Title: fix(search): parse quotes with the appropriate context Changelog section: ## Changelog - When a quote is used with regexp pattern type, then ... - Refactored underlying code. --> --------- Signed-off-by: Stephen Gutekanst <[email protected]> Co-authored-by: Beatrix <[email protected]> Co-authored-by: Stephen Gutekanst <[email protected]>
1 parent 472a556 commit 1a6a7f7

File tree

13 files changed

+472
-31
lines changed

13 files changed

+472
-31
lines changed

WORKSPACE

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,14 @@ go_repository(
293293
version = "v1.14.5",
294294
)
295295

296+
go_repository(
297+
name = "com_google_cloud_go_auth",
298+
build_file_proto_mode = "disable_global",
299+
importpath = "cloud.google.com/go/auth",
300+
sum = "h1:0QNO7VThG54LUzKiQxv8C6x1YX7lUrzlAa1nVLF8CIw=",
301+
version = "v0.5.1",
302+
)
303+
296304
# Overrides the default provided protobuf dep from rules_go by a more
297305
# recent one.
298306
go_repository(

deps.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8363,8 +8363,8 @@ def go_dependencies():
83638363
name = "org_golang_google_genproto_googleapis_rpc",
83648364
build_file_proto_mode = "disable_global",
83658365
importpath = "google.golang.org/genproto/googleapis/rpc",
8366-
sum = "h1:Di6ANFilr+S60a4S61ZM00vLdw0IrQOSMS2/6mrnOU0=",
8367-
version = "v0.0.0-20240617180043-68d350f18fd4",
8366+
sum = "h1:Zy9XzmMEflZ/MAaA7vNcoebnRAld7FsPW1EeBB7V0m8=",
8367+
version = "v0.0.0-20240528184218-531527333157",
83688368
)
83698369
go_repository(
83708370
name = "org_golang_google_grpc",

go.mod

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ require (
250250
cdr.dev/slog v1.4.2-0.20221206192828-e4803b10ae17
251251
chainguard.dev/apko v0.14.0
252252
cloud.google.com/go/artifactregistry v1.14.8
253+
cloud.google.com/go/auth v0.5.1
253254
connectrpc.com/connect v1.16.1
254255
connectrpc.com/grpcreflect v1.2.0
255256
connectrpc.com/otelconnect v0.7.0
@@ -468,7 +469,7 @@ require (
468469
golang.org/x/tools/go/vcs v0.1.0-deprecated // indirect
469470
gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect
470471
gonum.org/v1/plot v0.14.0 // indirect
471-
google.golang.org/genproto/googleapis/rpc v0.0.0-20240617180043-68d350f18fd4 // indirect
472+
google.golang.org/genproto/googleapis/rpc v0.0.0-20240528184218-531527333157 // indirect
472473
gopkg.in/go-jose/go-jose.v2 v2.6.1 // indirect
473474
gopkg.in/ini.v1 v1.67.0 // indirect
474475
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect

go.sum

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ cloud.google.com/go v0.112.1 h1:uJSeirPke5UNZHIb4SxfZklVSiWWVqW4oXlETwZziwM=
2929
cloud.google.com/go v0.112.1/go.mod h1:+Vbu+Y1UU+I1rjmzeMOb/8RfkKJK2Gyxi1X6jJCZLo4=
3030
cloud.google.com/go/artifactregistry v1.14.8 h1:icIyRzJ1Ag6EOafuDuFFJ/AdStcOFRVfSGURn27/7Pk=
3131
cloud.google.com/go/artifactregistry v1.14.8/go.mod h1:1UlSXh6sTXYrIT4kMO21AE1IDlMFemlZuX6QS+JXW7I=
32+
cloud.google.com/go/auth v0.5.1 h1:0QNO7VThG54LUzKiQxv8C6x1YX7lUrzlAa1nVLF8CIw=
33+
cloud.google.com/go/auth v0.5.1/go.mod h1:vbZT8GjzDf3AVqCcQmqeeM32U9HBFc32vVVAbwDsa6s=
3234
cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o=
3335
cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE=
3436
cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc=
@@ -2464,8 +2466,8 @@ google.golang.org/genproto v0.0.0-20240213162025-012b6fc9bca9 h1:9+tzLLstTlPTRyJ
24642466
google.golang.org/genproto v0.0.0-20240213162025-012b6fc9bca9/go.mod h1:mqHbVIp48Muh7Ywss/AD6I5kNVKZMmAa/QEW58Gxp2s=
24652467
google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157 h1:7whR9kGa5LUwFtpLm2ArCEejtnxlGeLbAyjFY8sGNFw=
24662468
google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157/go.mod h1:99sLkeliLXfdj2J75X3Ho+rrVCaJze0uwN7zDDkjPVU=
2467-
google.golang.org/genproto/googleapis/rpc v0.0.0-20240617180043-68d350f18fd4 h1:Di6ANFilr+S60a4S61ZM00vLdw0IrQOSMS2/6mrnOU0=
2468-
google.golang.org/genproto/googleapis/rpc v0.0.0-20240617180043-68d350f18fd4/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY=
2469+
google.golang.org/genproto/googleapis/rpc v0.0.0-20240528184218-531527333157 h1:Zy9XzmMEflZ/MAaA7vNcoebnRAld7FsPW1EeBB7V0m8=
2470+
google.golang.org/genproto/googleapis/rpc v0.0.0-20240528184218-531527333157/go.mod h1:EfXuqaE1J41VCDicxHzUDm+8rk+7ZdXzHV0IhO/I6s0=
24692471
google.golang.org/grpc v1.2.1-0.20170921194603-d4b75ebd4f9f/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
24702472
google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
24712473
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=

internal/completions/client/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func getBasic(endpoint string, provider conftypes.CompletionsProviderName, acces
4242
case conftypes.CompletionsProviderNameAzureOpenAI:
4343
return azureopenai.NewClient(azureopenai.GetAPIClient, endpoint, accessToken, *tokenManager)
4444
case conftypes.CompletionsProviderNameGoogle:
45-
return google.NewClient(httpcli.UncachedExternalDoer, endpoint, accessToken, false), nil
45+
return google.NewClient(httpcli.UncachedExternalDoer, endpoint, accessToken, false)
4646
case conftypes.CompletionsProviderNameSourcegraph:
4747
return codygateway.NewClient(httpcli.CodyGatewayDoer, endpoint, accessToken, *tokenManager)
4848
case conftypes.CompletionsProviderNameFireworks:

internal/completions/client/codygateway/codygateway.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ func (c *codyGatewayClient) clientForParams(feature types.CompletionsFeature, re
9494
case string(conftypes.CompletionsProviderNameFireworks):
9595
return fireworks.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/fireworks"), "", ""), nil
9696
case string(conftypes.CompletionsProviderNameGoogle):
97-
return google.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/google"), "", "", true), nil
97+
return google.NewClient(gatewayDoer(c.upstream, feature, c.gatewayURL, c.accessToken, "/v1/completions/google"), "", "", true)
9898
case "":
9999
return nil, errors.Newf("no provider provided in model %s - a model in the format '$PROVIDER/$MODEL_NAME' is expected", model)
100100
default:

internal/completions/client/google/BUILD.bazel

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ load("//dev:go_defs.bzl", "go_test")
44
go_library(
55
name = "google",
66
srcs = [
7+
"anthropic_types.go",
78
"decoder.go",
9+
"gemini_types.go",
810
"google.go",
911
"models.go",
1012
"prompt.go",
11-
"types.go",
1213
],
1314
importpath = "github.com/sourcegraph/sourcegraph/internal/completions/client/google",
1415
visibility = ["//:__subpackages__"],
@@ -17,6 +18,8 @@ go_library(
1718
"//internal/httpcli",
1819
"//lib/errors",
1920
"@com_github_sourcegraph_log//:log",
21+
"@com_google_cloud_go_auth//credentials",
22+
"@com_google_cloud_go_auth//httptransport",
2023
],
2124
)
2225

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
package google
2+
3+
type anthropicUsage struct {
4+
PromptTokenCount int `json:"promptTokenCount"`
5+
// Use the same name we use elsewhere (completion instead of candidates)
6+
CompletionTokenCount int `json:"candidatesTokenCount"`
7+
TotalTokenCount int `json:"totalTokenCount"`
8+
}
9+
type anthropicRequest struct {
10+
AnthropicVersion string `json:"anthropic_version"`
11+
Messages []anthropicMessage `json:"messages"`
12+
MaxTokens int `json:"max_tokens"`
13+
Stream bool `json:"stream"`
14+
System string `json:"system"`
15+
}
16+
type anthropicMessage struct {
17+
Role string `json:"role"`
18+
Content []anthropicMessagePart `json:"content"`
19+
}
20+
21+
type anthropicMessagePart struct {
22+
Type string `json:"type"`
23+
Text string `json:"text"`
24+
}
25+
type anthropicContentMessage struct {
26+
Role string `json:"role"`
27+
Parts []anthropicContentMessagePart `json:"parts"`
28+
}
29+
30+
type anthropicContentMessagePart struct {
31+
Text string `json:"text"`
32+
}
33+
34+
type anthropicResponse struct {
35+
Candidates []anthropicCandidate `json:"candidates"`
36+
UsageMetadata anthropicUsage `json:"usageMetadata"`
37+
SafetySettings []anthropicSafetySettings `json:"safetySettings,omitempty"`
38+
SafetyRatings []anthropicSafetyRating `json:"safetyRatings,omitempty"`
39+
}
40+
41+
type anthropicCandidate struct {
42+
Content anthropicContentMessage `json:"content,omitempty"`
43+
StopReason string `json:"finishReason,omitempty"`
44+
}
45+
46+
type anthropicSafetyRating struct {
47+
Category string `json:"category"`
48+
Probability string `json:"probability"`
49+
ProbabilityScore float64 `json:"probabilityScore"`
50+
Severity string `json:"severity"`
51+
SeverityScore float64 `json:"severityScore"`
52+
}
53+
54+
// Safety setting, affecting the safety-blocking behavior.
55+
// Ref: https://ai.google.dev/gemini-api/docs/safety-settings
56+
type anthropicSafetySettings struct {
57+
Category string `json:"category"`
58+
Threshold string `json:"threshold"`
59+
}
60+
61+
type anthropicStreamingResponseMessage struct {
62+
Usage *anthropicMessagesResponseUsage `json:"usage"`
63+
}
64+
65+
type anthropicMessagesResponseUsage struct {
66+
InputTokens int `json:"input_tokens"`
67+
OutputTokens int `json:"output_tokens"`
68+
}
69+
70+
type anthropicStreamingResponseTextBucket struct {
71+
Text string `json:"text"` // for event `content_block_delta`
72+
StopReason string `json:"stop_reason"` // for event `message_delta`
73+
}
74+
75+
type anthropicStreamingResponseDelta struct {
76+
Type string `json:"type"`
77+
Text string `json:"text"`
78+
}
79+
80+
// AnthropicMessagesStreamingResponse captures all relevant-to-us fields from each relevant SSE event from https://docs.anthropic.com/claude/reference/messages_post.
81+
type anthropicStreamingResponse struct {
82+
Type string `json:"type"`
83+
Delta *anthropicStreamingResponseDelta `json:"delta"`
84+
ContentBlock *anthropicStreamingResponseTextBucket `json:"content_block"`
85+
Usage *anthropicMessagesResponseUsage `json:"usage"`
86+
Message *anthropicStreamingResponseMessage `json:"message"`
87+
}
88+
89+
type anthropicContent struct {
90+
Type string `json:"type"`
91+
Text string `json:"text"`
92+
}
93+
94+
type anthropicNonStreamingResponse struct {
95+
ID string `json:"id"`
96+
Type string `json:"type"`
97+
Role string `json:"role"`
98+
Model string `json:"model"`
99+
Content []anthropicContent `json:"content"`
100+
StopReason string `json:"stop_reason"`
101+
StopSequence *string `json:"stop_sequence"`
102+
Usage anthropicMessagesResponseUsage `json:"usage"`
103+
}

internal/completions/client/google/types.go renamed to internal/completions/client/google/gemini_types.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
package google
22

3-
import "github.com/sourcegraph/sourcegraph/internal/httpcli"
3+
import (
4+
"net/http"
5+
6+
"github.com/sourcegraph/sourcegraph/internal/httpcli"
7+
)
8+
9+
type APIFamily string
410

511
type googleCompletionStreamClient struct {
6-
cli httpcli.Doer
12+
httpCli httpcli.Doer
13+
gcpCli *http.Client
714
accessToken string
815
endpoint string
916
viaGateway bool
17+
apiFamily APIFamily
1018
}
1119

1220
// The request body for the completion stream endpoint.

0 commit comments

Comments
 (0)