Skip to content
This repository was archived by the owner on Sep 30, 2024. It is now read-only.

Commit b3fe6dc

Browse files
authored
fix(cody-gateway): getAPIURL before transformBody (#63406)
<!-- 💡 To write a useful PR description, make sure that your description covers: - WHAT this PR is changing: - How was it PREVIOUSLY. - How it will be from NOW on. - WHY this PR is needed. - CONTEXT, i.e. to which initiative, project or RFC it belongs. The structure of the description doesn't matter as much as covering these points, so use your best judgement based on your context. Learn how to write good pull request description: https://www.notion.so/sourcegraph/Write-a-good-pull-request-description-610a7fd3e613496eb76f450db5a49b6e?pvs=4 --> Fix an issue where the requestBody is used after transformBody has been executed. ## Test plan <!-- All pull requests REQUIRE a test plan: https://docs-legacy.sourcegraph.com/dev/background-information/testing_principles --> 1. Start Cody Gateway locally 2. Start SG local dev instance 3. Connect SG local dev instance to your local Cody Gateway instance 4. Set Gemini Flash as your chatModel 5. Connect Cody to your local dev instance 6. Ask Cody a question and verify you are getting a response ![image](https://github.com/sourcegraph/sourcegraph/assets/68532117/fbce22f9-8531-4f6e-8eb7-5c6b26e0a9fa) ## Changelog <!-- 1. Ensure your pull request title is formatted as: $type($domain): $what 8. Add bullet list items for each additional detail you want to cover (see example below) 9. You can edit this after the pull request was merged, as long as release shipping it hasn't been promoted to the public. 10. For more information, please see this how-to https://www.notion.so/sourcegraph/Writing-a-changelog-entry-dd997f411d524caabf0d8d38a24a878c? Audience: TS/CSE > Customers > Teammates (in that order). Cheat sheet: $type = chore|fix|feat $domain: source|search|ci|release|plg|cody|local|... --> <!-- Example: Title: fix(search): parse quotes with the appropriate context Changelog section: ## Changelog - When a quote is used with regexp pattern type, then ... - Refactored underlying code. -->
1 parent 78dcd57 commit b3fe6dc

File tree

10 files changed

+32
-24
lines changed

10 files changed

+32
-24
lines changed

cmd/cody-gateway/internal/httpapi/completions/anthropic.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func (a *AnthropicHandlerMethods) transformRequest(r *http.Request) {
166166
r.Header.Set("anthropic-version", "2023-01-01")
167167
}
168168

169-
func (a *AnthropicHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody anthropicRequest, r io.Reader) (promptUsage, completionUsage usageStats) {
169+
func (a *AnthropicHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody anthropicRequest, r io.Reader, isStreamRequest bool) (promptUsage, completionUsage usageStats) {
170170
var err error
171171

172172
// Setting a default -1 value so that in case of errors the tokenizer computed tokens don't impact the data
@@ -183,7 +183,7 @@ func (a *AnthropicHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBo
183183

184184
// Try to parse the request we saw, if it was non-streaming, we can simply parse
185185
// it as JSON.
186-
if !reqBody.Stream {
186+
if !isStreamRequest {
187187
var res anthropicResponse
188188
if err := json.NewDecoder(r).Decode(&res); err != nil {
189189
logger.Error("failed to parse Anthropic response as JSON", log.Error(err))

cmd/cody-gateway/internal/httpapi/completions/anthropicmessages.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ func (a *AnthropicMessagesHandlerMethods) transformRequest(r *http.Request) {
215215
r.Header.Set("anthropic-version", "2023-06-01")
216216
}
217217

218-
func (a *AnthropicMessagesHandlerMethods) parseResponseAndUsage(logger log.Logger, body anthropicMessagesRequest, r io.Reader) (promptUsage, completionUsage usageStats) {
218+
func (a *AnthropicMessagesHandlerMethods) parseResponseAndUsage(logger log.Logger, body anthropicMessagesRequest, r io.Reader, isStreamRequest bool) (promptUsage, completionUsage usageStats) {
219219
// First, extract prompt usage details from the request.
220220
for _, m := range body.Messages {
221221
promptUsage.characters += len(m.Content)
@@ -232,7 +232,7 @@ func (a *AnthropicMessagesHandlerMethods) parseResponseAndUsage(logger log.Logge
232232

233233
// Try to parse the request we saw, if it was non-streaming, we can simply parse
234234
// it as JSON.
235-
if !body.ShouldStream() {
235+
if !isStreamRequest {
236236
var res anthropicMessagesNonStreamingResponse
237237
if err := json.NewDecoder(r).Decode(&res); err != nil {
238238
logger.Error("failed to parse Anthropic response as JSON", log.Error(err))

cmd/cody-gateway/internal/httpapi/completions/fireworks.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,13 @@ func (f *FireworksHandlerMethods) transformRequest(r *http.Request) {
148148
r.Header.Set("Authorization", "Bearer "+f.config.AccessToken)
149149
}
150150

151-
func (f *FireworksHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody fireworksRequest, r io.Reader) (promptUsage, completionUsage usageStats) {
151+
func (f *FireworksHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody fireworksRequest, r io.Reader, isStreamRequest bool) (promptUsage, completionUsage usageStats) {
152152
// First, extract prompt usage details from the request.
153153
promptUsage.characters = len(reqBody.Prompt)
154154

155155
// Try to parse the request we saw, if it was non-streaming, we can simply parse
156156
// it as JSON.
157-
if !reqBody.Stream {
157+
if !isStreamRequest {
158158
var res fireworksResponse
159159
if err := json.NewDecoder(r).Decode(&res); err != nil {
160160
logger.Error("failed to parse fireworks response as JSON", log.Error(err))

cmd/cody-gateway/internal/httpapi/completions/fireworks_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func TestFireworksRequestGetTokenCount(t *testing.T) {
1818
req := fireworksRequest{Stream: true}
1919
r := strings.NewReader(fireworksStreamingResponse)
2020
handler := &FireworksHandlerMethods{}
21-
promptUsage, completionUsage := handler.parseResponseAndUsage(logger, req, r)
21+
promptUsage, completionUsage := handler.parseResponseAndUsage(logger, req, r, true)
2222

2323
assert.Equal(t, 79, promptUsage.tokens)
2424
assert.Equal(t, 30, completionUsage.tokens)
@@ -28,7 +28,7 @@ func TestFireworksRequestGetTokenCount(t *testing.T) {
2828
req := fireworksRequest{Stream: false}
2929
r := strings.NewReader(fireworksNonStreamingResponse)
3030
handler := &FireworksHandlerMethods{}
31-
promptUsage, completionUsage := handler.parseResponseAndUsage(logger, req, r)
31+
promptUsage, completionUsage := handler.parseResponseAndUsage(logger, req, r, false)
3232

3333
assert.Equal(t, 79, promptUsage.tokens)
3434
assert.Equal(t, 30, completionUsage.tokens)

cmd/cody-gateway/internal/httpapi/completions/google.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,9 @@ func (g *GoogleHandlerMethods) getAPIURL(feature codygateway.Feature, req google
6464
rpc := "generateContent"
6565
sseSuffix := ""
6666
// If we're streaming, we need to use the stream endpoint.
67-
if feature == codygateway.FeatureChatCompletions || req.ShouldStream() {
67+
if req.ShouldStream() {
6868
rpc = "streamGenerateContent"
6969
sseSuffix = "&alt=sse"
70-
req.Stream = true
7170
}
7271
return fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:%s?key=%s%s", req.Model, rpc, g.config.AccessToken, sseSuffix)
7372
}
@@ -108,12 +107,12 @@ func (o *GoogleHandlerMethods) transformRequest(r *http.Request) {
108107
r.Header.Set("Content-Type", "application/json")
109108
}
110109

111-
func (*GoogleHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody googleRequest, r io.Reader) (promptUsage, completionUsage usageStats) {
110+
func (*GoogleHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody googleRequest, r io.Reader, isStreamRequest bool) (promptUsage, completionUsage usageStats) {
112111
// First, extract prompt usage details from the request.
113112
promptUsage.characters = len(reqBody.BuildPrompt())
114113
// Try to parse the request we saw, if it was non-streaming, we can simply parse
115114
// it as JSON.
116-
if !reqBody.Stream && !reqBody.ShouldStream() {
115+
if !isStreamRequest {
117116
var res googleResponse
118117
if err := json.NewDecoder(r).Decode(&res); err != nil {
119118
logger.Error("failed to parse Google response as JSON", log.Error(err))
@@ -135,10 +134,10 @@ func (*GoogleHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody go
135134
if err != nil {
136135
logger.Error("failed to decode Google streaming response", log.Error(err))
137136
}
137+
promptUsage.tokens, completionUsage.tokens = promptTokens, completionTokens
138138
if completionUsage.tokens == -1 || promptUsage.tokens == -1 {
139139
logger.Warn("did not extract token counts from Google streaming response", log.Int("prompt-tokens", promptUsage.tokens), log.Int("completion-tokens", completionUsage.tokens))
140140
}
141-
promptUsage.tokens, completionUsage.tokens = promptTokens, completionTokens
142141
return promptUsage, completionUsage
143142
}
144143

cmd/cody-gateway/internal/httpapi/completions/google_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func TestGoogleRequestGetTokenCount(t *testing.T) {
1717
req := googleRequest{Stream: true}
1818
r := strings.NewReader(googleStreamingResponse)
1919
handler := &GoogleHandlerMethods{}
20-
promptUsage, completionUsage := handler.parseResponseAndUsage(logger, req, r)
20+
promptUsage, completionUsage := handler.parseResponseAndUsage(logger, req, r, true)
2121

2222
assert.Equal(t, 21, promptUsage.tokens)
2323
assert.Equal(t, 87, completionUsage.tokens)
@@ -27,7 +27,7 @@ func TestGoogleRequestGetTokenCount(t *testing.T) {
2727
req := googleRequest{Stream: false}
2828
r := strings.NewReader(googleNonStreamingResponse)
2929
handler := &GoogleHandlerMethods{}
30-
promptUsage, completionUsage := handler.parseResponseAndUsage(logger, req, r)
30+
promptUsage, completionUsage := handler.parseResponseAndUsage(logger, req, r, false)
3131

3232
assert.Equal(t, 59, promptUsage.tokens)
3333
assert.Equal(t, 54, completionUsage.tokens)

cmd/cody-gateway/internal/httpapi/completions/openai.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ func (o *OpenAIHandlerMethods) transformRequest(r *http.Request) {
153153
}
154154
}
155155

156-
func (*OpenAIHandlerMethods) parseResponseAndUsage(logger log.Logger, body openaiRequest, r io.Reader) (promptUsage, completionUsage usageStats) {
156+
func (*OpenAIHandlerMethods) parseResponseAndUsage(logger log.Logger, body openaiRequest, r io.Reader, isStreamRequest bool) (promptUsage, completionUsage usageStats) {
157157
// First, extract prompt usage details from the request.
158158
for _, m := range body.Messages {
159159
promptUsage.characters += len(m.Content)
@@ -164,7 +164,7 @@ func (*OpenAIHandlerMethods) parseResponseAndUsage(logger log.Logger, body opena
164164
completionUsage.tokenizerTokens = -1
165165
// Try to parse the request we saw, if it was non-streaming, we can simply parse
166166
// it as JSON.
167-
if !body.Stream {
167+
if !isStreamRequest {
168168
var res openaiResponse
169169
if err := json.NewDecoder(r).Decode(&res); err != nil {
170170
logger.Error("failed to parse OpenAI response as JSON", log.Error(err))

cmd/cody-gateway/internal/httpapi/completions/openai_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ func TestOpenAIRequestGetTokenCount(t *testing.T) {
1515
req := openaiRequest{Stream: true}
1616
r := strings.NewReader(openaiStreamingResponse)
1717
handler := &OpenAIHandlerMethods{}
18-
promptUsage, completionUsage := handler.parseResponseAndUsage(logger, req, r)
18+
promptUsage, completionUsage := handler.parseResponseAndUsage(logger, req, r, true)
1919

2020
assert.Equal(t, 427, promptUsage.tokens)
2121
assert.Equal(t, 12, completionUsage.tokens)
@@ -25,7 +25,7 @@ func TestOpenAIRequestGetTokenCount(t *testing.T) {
2525
req := openaiRequest{Stream: false}
2626
r := strings.NewReader(openaiNonStreamingResponse)
2727
handler := &OpenAIHandlerMethods{}
28-
promptUsage, completionUsage := handler.parseResponseAndUsage(logger, req, r)
28+
promptUsage, completionUsage := handler.parseResponseAndUsage(logger, req, r, false)
2929

3030
assert.Equal(t, 12, promptUsage.tokens)
3131
assert.Equal(t, 9, completionUsage.tokens)

cmd/cody-gateway/internal/httpapi/completions/upstream.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ type upstreamHandlerMethods[ReqT UpstreamRequest] interface {
104104
//
105105
// If data is unavailable, implementations should set relevant usage fields
106106
// to -1 as a sentinel value.
107-
parseResponseAndUsage(log.Logger, ReqT, io.Reader) (promptUsage, completionUsage usageStats)
107+
parseResponseAndUsage(log.Logger, ReqT, io.Reader, bool) (promptUsage, completionUsage usageStats)
108108
}
109109

110110
type UpstreamRequest interface {
@@ -273,6 +273,13 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
273273
}
274274
}
275275

276+
// Get the URL to call for the upstream provider before we transform the request.
277+
upstreamURL := methods.getAPIURL(feature, body)
278+
279+
// Store the shouldStream value in case it changes during the transformation.
280+
// Example: We remove it for Google requests.
281+
shouldStream := body.ShouldStream()
282+
276283
// identifier that can be provided to upstream for abuse detection
277284
// has the format '$ACTOR_ID:$SG_ACTOR_ID'. The latter is anonymized
278285
// (specific per-instance)
@@ -288,7 +295,6 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
288295
}
289296

290297
// Create a new request to send upstream, making sure we retain the same context.
291-
upstreamURL := methods.getAPIURL(feature, body)
292298
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(upstreamPayload))
293299
if err != nil {
294300
response.JSONError(logger, w, http.StatusInternalServerError, errors.Wrap(err, "failed to create request"))
@@ -359,7 +365,7 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
359365
o.Feature = feature
360366
o.UpstreamLatency = upstreamLatency
361367
o.Provider = upstreamName
362-
o.Stream = body.ShouldStream()
368+
o.Stream = shouldStream
363369

364370
err := eventLogger.LogEvent(
365371
ctx,
@@ -461,7 +467,7 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
461467
// if this is a streaming request, we want to flush ourselves instead of leaving that to the http.Server
462468
// (so events are sent to the client as soon as possible)
463469
var responseWriter io.Writer = w
464-
if config.AutoFlushStreamingResponses && body.ShouldStream() {
470+
if config.AutoFlushStreamingResponses && shouldStream {
465471
if fw, err := response.NewAutoFlushingWriter(w); err == nil {
466472
responseWriter = fw
467473
} else {
@@ -475,7 +481,7 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
475481

476482
if upstreamStatusCode >= 200 && upstreamStatusCode < 300 {
477483
// Pass reader to response transformer to capture token counts.
478-
promptUsage, completionUsage = methods.parseResponseAndUsage(logger, body, &responseBuf)
484+
promptUsage, completionUsage = methods.parseResponseAndUsage(logger, body, &responseBuf, shouldStream)
479485
} else if upstreamStatusCode >= 500 {
480486
logger.Error("error from upstream",
481487
log.Int("status_code", upstreamStatusCode))

cmd/frontend/internal/dotcom/productsubscription/codygateway_dotcom_user.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,9 @@ var allCodeCompletionModels = slices.Concat([]string{"anthropic/" + anthropic.Cl
334334
"anthropic/claude-instant-1.2-cyan",
335335
"anthropic/claude-instant-1.2",
336336
"google/" + google.Gemini15Flash,
337+
"google/" + google.Gemini15FlashLatest,
338+
"google/" + google.GeminiPro,
339+
"google/" + google.GeminiProLatest,
337340
"fireworks/starcoder",
338341
"fireworks/" + fireworks.Llama213bCode,
339342
"fireworks/" + fireworks.StarcoderTwo15b,

0 commit comments

Comments
 (0)