Skip to content

Commit bb8ad26

Browse files
hustxiayangyuzisun
andauthored
fix: finish reason should be tool calls when the model responded with a tool call (envoyproxy#1486)
**Description** Finish reason should be tool calls if the model returns a tool call response. In vertex api, there is no tool call finish reason, thus need a work around to make it compatible. --------- Signed-off-by: yxia216 <[email protected]> Co-authored-by: Dan Sun <[email protected]>
1 parent 73aa427 commit bb8ad26

File tree

3 files changed

+82
-44
lines changed

3 files changed

+82
-44
lines changed

internal/extproc/translator/gemini_helper.go

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -529,10 +529,12 @@ func geminiCandidatesToOpenAIChoices(candidates []*genai.Candidate, responseMode
529529

530530
// Create the choice.
531531
choice := openai.ChatCompletionResponseChoice{
532-
Index: int64(idx),
533-
FinishReason: geminiFinishReasonToOpenAI(candidate.FinishReason),
532+
Index: int64(idx),
534533
}
535534

535+
toolCalls := []openai.ChatCompletionMessageToolCallParam{}
536+
var err error
537+
536538
if candidate.Content != nil {
537539
message := openai.ChatCompletionResponseChoiceMessage{
538540
Role: openai.ChatMessageRoleAssistant,
@@ -542,7 +544,7 @@ func geminiCandidatesToOpenAIChoices(candidates []*genai.Candidate, responseMode
542544
message.Content = &content
543545

544546
// Extract tool calls if any.
545-
toolCalls, err := extractToolCallsFromGeminiParts(candidate.Content.Parts)
547+
toolCalls, err = extractToolCallsFromGeminiParts(toolCalls, candidate.Content.Parts)
546548
if err != nil {
547549
return nil, fmt.Errorf("error extracting tool calls: %w", err)
548550
}
@@ -569,16 +571,26 @@ func geminiCandidatesToOpenAIChoices(candidates []*genai.Candidate, responseMode
569571
choice.Logprobs = geminiLogprobsToOpenAILogprobs(*candidate.LogprobsResult)
570572
}
571573

574+
choice.FinishReason = geminiFinishReasonToOpenAI(candidate.FinishReason, toolCalls)
575+
572576
choices = append(choices, choice)
573577
}
574578

575579
return choices, nil
576580
}
577581

582+
// Define a type constraint that includes both stream and non-stream tool call slice types.
583+
type toolCallSlice interface {
584+
[]openai.ChatCompletionMessageToolCallParam | []openai.ChatCompletionChunkChoiceDeltaToolCall
585+
}
586+
578587
// geminiFinishReasonToOpenAI converts Gemini finish reason to OpenAI finish reason.
579-
func geminiFinishReasonToOpenAI(reason genai.FinishReason) openai.ChatCompletionChoicesFinishReason {
588+
func geminiFinishReasonToOpenAI[T toolCallSlice](reason genai.FinishReason, toolCalls T) openai.ChatCompletionChoicesFinishReason {
580589
switch reason {
581590
case genai.FinishReasonStop:
591+
if len(toolCalls) > 0 {
592+
return openai.ChatCompletionChoicesFinishReasonToolCalls
593+
}
582594
return openai.ChatCompletionChoicesFinishReasonStop
583595
case genai.FinishReasonMaxTokens:
584596
return openai.ChatCompletionChoicesFinishReasonLength
@@ -611,9 +623,7 @@ func extractTextFromGeminiParts(parts []*genai.Part, responseMode geminiResponse
611623
}
612624

613625
// extractToolCallsFromGeminiParts extracts tool calls from Gemini parts.
614-
func extractToolCallsFromGeminiParts(parts []*genai.Part) ([]openai.ChatCompletionMessageToolCallParam, error) {
615-
var toolCalls []openai.ChatCompletionMessageToolCallParam
616-
626+
func extractToolCallsFromGeminiParts(toolCalls []openai.ChatCompletionMessageToolCallParam, parts []*genai.Part) ([]openai.ChatCompletionMessageToolCallParam, error) {
617627
for _, part := range parts {
618628
if part == nil || part.FunctionCall == nil {
619629
continue
@@ -650,8 +660,7 @@ func extractToolCallsFromGeminiParts(parts []*genai.Part) ([]openai.ChatCompleti
650660
// extractToolCallsFromGeminiPartsStream extracts tool calls from Gemini parts for streaming responses.
651661
// Each tool call is assigned an incremental index starting from 0, matching OpenAI's streaming protocol.
652662
// Returns ChatCompletionChunkChoiceDeltaToolCall types suitable for streaming responses, or nil if no tool calls are found.
653-
func extractToolCallsFromGeminiPartsStream(parts []*genai.Part) ([]openai.ChatCompletionChunkChoiceDeltaToolCall, error) {
654-
var toolCalls []openai.ChatCompletionChunkChoiceDeltaToolCall
663+
func extractToolCallsFromGeminiPartsStream(toolCalls []openai.ChatCompletionChunkChoiceDeltaToolCall, parts []*genai.Part) ([]openai.ChatCompletionChunkChoiceDeltaToolCall, error) {
655664
toolCallIndex := int64(0)
656665

657666
for _, part := range parts {
@@ -772,10 +781,11 @@ func geminiCandidatesToOpenAIStreamingChoices(candidates []*genai.Candidate, res
772781

773782
// Create the streaming choice.
774783
choice := openai.ChatCompletionResponseChunkChoice{
775-
Index: 0,
776-
FinishReason: geminiFinishReasonToOpenAI(candidate.FinishReason),
784+
Index: 0,
777785
}
778786

787+
toolCalls := []openai.ChatCompletionChunkChoiceDeltaToolCall{}
788+
var err error
779789
if candidate.Content != nil {
780790
delta := &openai.ChatCompletionResponseChunkChoiceDelta{
781791
Role: openai.ChatMessageRoleAssistant,
@@ -788,15 +798,15 @@ func geminiCandidatesToOpenAIStreamingChoices(candidates []*genai.Candidate, res
788798
}
789799

790800
// Extract tool calls if any.
791-
toolCalls, err := extractToolCallsFromGeminiPartsStream(candidate.Content.Parts)
801+
toolCalls, err = extractToolCallsFromGeminiPartsStream(toolCalls, candidate.Content.Parts)
792802
if err != nil {
793803
return nil, fmt.Errorf("error extracting tool calls: %w", err)
794804
}
795805
delta.ToolCalls = toolCalls
796806

797807
choice.Delta = delta
798808
}
799-
809+
choice.FinishReason = geminiFinishReasonToOpenAI(candidate.FinishReason, toolCalls)
800810
choices = append(choices, choice)
801811
}
802812

internal/extproc/translator/gemini_helper_test.go

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,7 @@ func TestGeminiLogprobsToOpenAILogprobs(t *testing.T) {
12711271
}
12721272

12731273
func TestExtractToolCallsFromGeminiParts(t *testing.T) {
1274+
toolCalls := []openai.ChatCompletionMessageToolCallParam{}
12741275
tests := []struct {
12751276
name string
12761277
input []*genai.Part
@@ -1360,7 +1361,7 @@ func TestExtractToolCallsFromGeminiParts(t *testing.T) {
13601361

13611362
for _, tt := range tests {
13621363
t.Run(tt.name, func(t *testing.T) {
1363-
calls, err := extractToolCallsFromGeminiParts(tt.input)
1364+
calls, err := extractToolCallsFromGeminiParts(toolCalls, tt.input)
13641365

13651366
if tt.wantErr {
13661367
require.Error(t, err)
@@ -1381,56 +1382,80 @@ func TestExtractToolCallsFromGeminiParts(t *testing.T) {
13811382

13821383
func TestGeminiFinishReasonToOpenAI(t *testing.T) {
13831384
tests := []struct {
1384-
name string
1385-
input genai.FinishReason
1386-
expected openai.ChatCompletionChoicesFinishReason
1385+
name string
1386+
input genai.FinishReason
1387+
toolCalls []openai.ChatCompletionMessageToolCallParam
1388+
expected openai.ChatCompletionChoicesFinishReason
13871389
}{
13881390
{
1389-
name: "stop reason",
1390-
input: genai.FinishReasonStop,
1391-
expected: openai.ChatCompletionChoicesFinishReasonStop,
1391+
name: "stop reason",
1392+
input: genai.FinishReasonStop,
1393+
toolCalls: []openai.ChatCompletionMessageToolCallParam{},
1394+
expected: openai.ChatCompletionChoicesFinishReasonStop,
1395+
},
1396+
{
1397+
name: "tool calls reason",
1398+
input: genai.FinishReasonStop,
1399+
toolCalls: []openai.ChatCompletionMessageToolCallParam{
1400+
{
1401+
ID: ptr.To("tool_call_1"),
1402+
Function: openai.ChatCompletionMessageToolCallFunctionParam{
1403+
Name: "example_tool",
1404+
Arguments: "{\"param1\":\"value1\"}",
1405+
},
1406+
Type: openai.ChatCompletionMessageToolCallTypeFunction,
1407+
},
1408+
},
1409+
expected: openai.ChatCompletionChoicesFinishReasonToolCalls,
13921410
},
13931411
{
1394-
name: "max tokens reason",
1395-
input: genai.FinishReasonMaxTokens,
1396-
expected: openai.ChatCompletionChoicesFinishReasonLength,
1412+
name: "max tokens reason",
1413+
input: genai.FinishReasonMaxTokens,
1414+
toolCalls: []openai.ChatCompletionMessageToolCallParam{},
1415+
expected: openai.ChatCompletionChoicesFinishReasonLength,
13971416
},
13981417
{
1399-
name: "empty reason for streaming",
1400-
input: "",
1401-
expected: "",
1418+
name: "empty reason for streaming",
1419+
input: "",
1420+
toolCalls: []openai.ChatCompletionMessageToolCallParam{},
1421+
expected: "",
14021422
},
14031423
{
1404-
name: "safety reason",
1405-
input: genai.FinishReasonSafety,
1406-
expected: openai.ChatCompletionChoicesFinishReasonContentFilter,
1424+
name: "safety reason",
1425+
input: genai.FinishReasonSafety,
1426+
toolCalls: []openai.ChatCompletionMessageToolCallParam{},
1427+
expected: openai.ChatCompletionChoicesFinishReasonContentFilter,
14071428
},
14081429
{
1409-
name: "recitation reason",
1410-
input: genai.FinishReasonRecitation,
1411-
expected: openai.ChatCompletionChoicesFinishReasonContentFilter,
1430+
name: "recitation reason",
1431+
input: genai.FinishReasonRecitation,
1432+
toolCalls: []openai.ChatCompletionMessageToolCallParam{},
1433+
expected: openai.ChatCompletionChoicesFinishReasonContentFilter,
14121434
},
14131435
{
1414-
name: "other reason",
1415-
input: genai.FinishReasonOther,
1416-
expected: openai.ChatCompletionChoicesFinishReasonContentFilter,
1436+
name: "other reason",
1437+
input: genai.FinishReasonOther,
1438+
toolCalls: []openai.ChatCompletionMessageToolCallParam{},
1439+
expected: openai.ChatCompletionChoicesFinishReasonContentFilter,
14171440
},
14181441
{
1419-
name: "unknown reason",
1420-
input: genai.FinishReason("unknown_reason"),
1421-
expected: openai.ChatCompletionChoicesFinishReasonContentFilter,
1442+
name: "unknown reason",
1443+
input: genai.FinishReason("unknown_reason"),
1444+
toolCalls: []openai.ChatCompletionMessageToolCallParam{},
1445+
expected: openai.ChatCompletionChoicesFinishReasonContentFilter,
14221446
},
14231447
}
14241448

14251449
for _, tt := range tests {
14261450
t.Run(tt.name, func(t *testing.T) {
1427-
result := geminiFinishReasonToOpenAI(tt.input)
1451+
result := geminiFinishReasonToOpenAI(tt.input, tt.toolCalls)
14281452
require.Equal(t, tt.expected, result)
14291453
})
14301454
}
14311455
}
14321456

14331457
func TestExtractToolCallsFromGeminiPartsStream(t *testing.T) {
1458+
toolCalls := []openai.ChatCompletionChunkChoiceDeltaToolCall{}
14341459
tests := []struct {
14351460
name string
14361461
input []*genai.Part
@@ -1675,7 +1700,7 @@ func TestExtractToolCallsFromGeminiPartsStream(t *testing.T) {
16751700

16761701
for _, tt := range tests {
16771702
t.Run(tt.name, func(t *testing.T) {
1678-
calls, err := extractToolCallsFromGeminiPartsStream(tt.input)
1703+
calls, err := extractToolCallsFromGeminiPartsStream(toolCalls, tt.input)
16791704

16801705
if tt.wantErr {
16811706
require.Error(t, err)
@@ -1696,6 +1721,8 @@ func TestExtractToolCallsFromGeminiPartsStream(t *testing.T) {
16961721

16971722
// TestExtractToolCallsStreamVsNonStream tests the differences between streaming and non-streaming extraction
16981723
func TestExtractToolCallsStreamVsNonStream(t *testing.T) {
1724+
toolCalls := []openai.ChatCompletionMessageToolCallParam{}
1725+
toolCallsStream := []openai.ChatCompletionChunkChoiceDeltaToolCall{}
16991726
parts := []*genai.Part{
17001727
{
17011728
FunctionCall: &genai.FunctionCall{
@@ -1709,11 +1736,11 @@ func TestExtractToolCallsStreamVsNonStream(t *testing.T) {
17091736
}
17101737

17111738
// Get results from both functions
1712-
streamCalls, err := extractToolCallsFromGeminiPartsStream(parts)
1739+
streamCalls, err := extractToolCallsFromGeminiPartsStream(toolCallsStream, parts)
17131740
require.NoError(t, err)
17141741
require.Len(t, streamCalls, 1)
17151742

1716-
nonStreamCalls, err := extractToolCallsFromGeminiParts(parts)
1743+
nonStreamCalls, err := extractToolCallsFromGeminiParts(toolCalls, parts)
17171744
require.NoError(t, err)
17181745
require.Len(t, nonStreamCalls, 1)
17191746

@@ -1749,6 +1776,7 @@ func TestExtractToolCallsStreamVsNonStream(t *testing.T) {
17491776

17501777
// TestExtractToolCallsStreamIndexing specifically tests that multiple tool calls get correct indices
17511778
func TestExtractToolCallsStreamIndexing(t *testing.T) {
1779+
toolCalls := []openai.ChatCompletionChunkChoiceDeltaToolCall{}
17521780
parts := []*genai.Part{
17531781
{
17541782
FunctionCall: &genai.FunctionCall{
@@ -1771,7 +1799,7 @@ func TestExtractToolCallsStreamIndexing(t *testing.T) {
17711799
},
17721800
}
17731801

1774-
calls, err := extractToolCallsFromGeminiPartsStream(parts)
1802+
calls, err := extractToolCallsFromGeminiPartsStream(toolCalls, parts)
17751803
require.NoError(t, err)
17761804
require.Len(t, calls, 3)
17771805

tests/extproc/testupstream_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ func TestWithTestUpstream(t *testing.T) {
297297
responseStatus: strconv.Itoa(http.StatusOK),
298298
responseBody: `{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"get_delivery_date","args":{"order_id":"123"}}}]},"finishReason":"STOP","avgLogprobs":0.000001220789272338152}],"usageMetadata":{"promptTokenCount":50,"candidatesTokenCount":11,"totalTokenCount":61,"trafficType":"ON_DEMAND","promptTokensDetails":[{"modality":"TEXT","tokenCount":50}],"candidatesTokensDetails":[{"modality":"TEXT","tokenCount":11}]},"modelVersion":"gemini-2.0-flash-001","createTime":"2025-07-11T22:15:44.956335Z","responseId":"EI5xaK-vOtqJm22IPmuCR14AI"}`,
299299
expStatus: http.StatusOK,
300-
expResponseBody: `{"choices":[{"finish_reason":"stop","index":0,"message":{"role":"assistant","tool_calls":[{"id":"703482f8-2e5b-4dcc-a872-d74bd66c3866","function":{"arguments":"{\"order_id\":\"123\"}","name":"get_delivery_date"},"type":"function"}]}}],"model":"gemini-2.0-flash-001","object":"chat.completion","usage":{"completion_tokens":11,"completion_tokens_details":{},"prompt_tokens":50,"total_tokens":61,"prompt_tokens_details":{}}}`,
300+
expResponseBody: `{"choices":[{"finish_reason":"tool_calls","index":0,"message":{"role":"assistant","tool_calls":[{"id":"703482f8-2e5b-4dcc-a872-d74bd66c3866","function":{"arguments":"{\"order_id\":\"123\"}","name":"get_delivery_date"},"type":"function"}]}}],"model":"gemini-2.0-flash-001","object":"chat.completion","usage":{"completion_tokens":11,"completion_tokens_details":{},"prompt_tokens":50,"total_tokens":61,"prompt_tokens_details":{}}}`,
301301
},
302302
{
303303
name: "gcp-anthropicai - /v1/chat/completions",

0 commit comments

Comments
 (0)