Skip to content

Commit 713f8e4

Browse files
committed
add:gemini accumlator
1 parent 0315210 commit 713f8e4

File tree

6 files changed

+685
-15
lines changed

6 files changed

+685
-15
lines changed

model/gemini/accumulator.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
//
2+
// Tencent is pleased to support the open source community by making trpc-agent-go available.
3+
//
4+
// Copyright (C) 2025 Tencent. All rights reserved.
5+
//
6+
// trpc-agent-go is licensed under the Apache License Version 2.0.
7+
//
8+
//
9+
10+
// Package gemini provides Gemini-compatible model implementations.
11+
package gemini
12+
13+
import (
14+
"strings"
15+
"time"
16+
"trpc.group/trpc-go/trpc-agent-go/model"
17+
)
18+
19+
// Accumulator accumulates chunks from a stream
20+
type Accumulator struct {
21+
Model string
22+
FullText strings.Builder
23+
ReasoningContent strings.Builder
24+
FinishReason strings.Builder
25+
ToolCalls []model.ToolCall
26+
Usage model.Usage
27+
}
28+
29+
// Accumulate builds up the Message incrementally from a model.Response. The Message then can be used as
30+
// any other Message, except with the caveat that the Message.JSON field which normally can be used to inspect
31+
// the JSON sent over the network may not be populated fully.
32+
func (a *Accumulator) Accumulate(resp *model.Response) {
33+
a.Model = resp.Model
34+
if len(resp.Choices) > 0 {
35+
for _, choice := range resp.Choices {
36+
if choice.FinishReason != nil {
37+
a.FinishReason.WriteString(*choice.FinishReason)
38+
}
39+
if choice.Delta.Content != "" {
40+
a.FullText.WriteString(choice.Delta.Content)
41+
}
42+
if choice.Delta.ReasoningContent != "" {
43+
a.ReasoningContent.WriteString(choice.Delta.ReasoningContent)
44+
}
45+
if len(choice.Delta.ToolCalls) > 0 {
46+
a.ToolCalls = append(a.ToolCalls, choice.Delta.ToolCalls...)
47+
}
48+
}
49+
}
50+
if resp.Usage != nil {
51+
a.Usage.PromptTokens += resp.Usage.PromptTokens
52+
a.Usage.CompletionTokens += resp.Usage.CompletionTokens
53+
a.Usage.TotalTokens += resp.Usage.TotalTokens
54+
}
55+
}
56+
57+
// BuildResponse builds up the final a model.Response.
58+
func (a *Accumulator) BuildResponse() *model.Response {
59+
now := time.Now()
60+
return &model.Response{
61+
Model: a.Model,
62+
Created: now.Unix(),
63+
Timestamp: now,
64+
Done: true,
65+
Choices: []model.Choice{
66+
{
67+
Message: model.Message{
68+
Content: a.FullText.String(),
69+
ReasoningContent: a.ReasoningContent.String(),
70+
ToolCalls: a.ToolCalls,
71+
Role: model.RoleAssistant,
72+
},
73+
FinishReason: func() *string {
74+
fr := a.FinishReason.String()
75+
if fr == "" {
76+
return nil
77+
}
78+
return &fr
79+
}(),
80+
},
81+
},
82+
Usage: &a.Usage,
83+
}
84+
}

model/gemini/accumulator_test.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
package gemini
2+
3+
import (
4+
"github.com/stretchr/testify/assert"
5+
"strings"
6+
"testing"
7+
"trpc.group/trpc-go/trpc-agent-go/model"
8+
)
9+
10+
func TestAccumulator_BuildResponse(t *testing.T) {
11+
finishReason := "FinishReason"
12+
content := "Content"
13+
reasoningContent := "ReasoningContent"
14+
m := &model.Response{
15+
Usage: &model.Usage{
16+
PromptTokens: 1,
17+
CompletionTokens: 1,
18+
TotalTokens: 2,
19+
},
20+
Choices: []model.Choice{
21+
{
22+
Delta: model.Message{
23+
Content: content,
24+
ReasoningContent: reasoningContent,
25+
ToolCalls: []model.ToolCall{
26+
{
27+
ID: "id",
28+
},
29+
},
30+
},
31+
FinishReason: &finishReason,
32+
},
33+
},
34+
}
35+
type fields struct {
36+
Model string
37+
FullText strings.Builder
38+
ReasoningContent strings.Builder
39+
FinishReason strings.Builder
40+
ToolCalls []model.ToolCall
41+
Usage model.Usage
42+
}
43+
tests := []struct {
44+
name string
45+
fields fields
46+
want *model.Response
47+
}{
48+
{
49+
name: "Accumulate",
50+
fields: fields{
51+
Model: "gemini-pro",
52+
FullText: strings.Builder{},
53+
ReasoningContent: strings.Builder{},
54+
FinishReason: strings.Builder{},
55+
ToolCalls: []model.ToolCall{},
56+
Usage: model.Usage{},
57+
},
58+
want: &model.Response{
59+
Usage: &model.Usage{
60+
PromptTokens: 1,
61+
CompletionTokens: 1,
62+
TotalTokens: 2,
63+
},
64+
Choices: []model.Choice{
65+
{
66+
Delta: model.Message{
67+
Role: model.RoleAssistant,
68+
Content: content,
69+
ReasoningContent: reasoningContent,
70+
ToolCalls: []model.ToolCall{
71+
{
72+
ID: "id",
73+
},
74+
},
75+
},
76+
FinishReason: &finishReason,
77+
},
78+
},
79+
Done: true,
80+
},
81+
},
82+
}
83+
for _, tt := range tests {
84+
t.Run(tt.name, func(t *testing.T) {
85+
a := &Accumulator{
86+
Model: tt.fields.Model,
87+
FullText: tt.fields.FullText,
88+
ReasoningContent: tt.fields.ReasoningContent,
89+
FinishReason: tt.fields.FinishReason,
90+
ToolCalls: tt.fields.ToolCalls,
91+
Usage: tt.fields.Usage,
92+
}
93+
a.Accumulate(m)
94+
got := a.BuildResponse()
95+
assert.Equal(t, got.Choices[0].Message.Content, content)
96+
assert.Equal(t, got.Choices[0].Message.ReasoningContent, reasoningContent)
97+
assert.Equal(t, got.Choices[0].FinishReason, &finishReason)
98+
assert.Equal(t, got.Usage, m.Usage)
99+
})
100+
}
101+
}

model/gemini/gemini.go

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ type ChatStreamCompleteCallbackFunc func(
5252
ctx context.Context,
5353
chatRequest []*genai.Content,
5454
generateConfig *genai.GenerateContentConfig,
55-
chatResponse *genai.GenerateContentResponse,
55+
chatResponse *model.Response,
5656
)
5757

58-
// Model implements the model.Model interface for OpenAI API.
58+
// Model implements the model.Model interface for Gemini API.
5959
type Model struct {
6060
client *genai.Client
6161
name string
@@ -79,7 +79,7 @@ type Model struct {
7979
maxInputTokensRatio float64
8080
}
8181

82-
// New creates a new OpenAI-like model.
82+
// New creates a new Gemini-like model.
8383
func New(ctx context.Context, name string, opts ...Option) (*Model, error) {
8484
o := &options{
8585
channelBufferSize: defaultChannelBufferSize,
@@ -230,23 +230,33 @@ func (m *Model) handleStreamingResponse(
230230
) {
231231
chatCompletion := m.client.Models.GenerateContentStream(
232232
ctx, m.name, chatRequest, generateConfig)
233+
acc := &Accumulator{}
233234
for chunk := range chatCompletion {
234235
response := m.buildResponse(chunk)
235-
if response.IsPartial {
236-
if m.chatChunkCallback != nil {
237-
m.chatChunkCallback(ctx, chatRequest, generateConfig, chunk)
238-
}
239-
} else {
240-
if m.chatStreamCompleteCallback != nil {
241-
m.chatStreamCompleteCallback(ctx, chatRequest, generateConfig, chunk)
242-
}
236+
acc.Accumulate(response)
237+
response.Object = model.ObjectTypeChatCompletionChunk
238+
response.IsPartial = true
239+
if m.chatChunkCallback != nil {
240+
m.chatChunkCallback(ctx, chatRequest, generateConfig, chunk)
241+
}
242+
if m.chatChunkCallback != nil {
243+
m.chatChunkCallback(ctx, chatRequest, generateConfig, chunk)
243244
}
244245
select {
245246
case responseChan <- response:
246247
case <-ctx.Done():
247248
return
248249
}
249250
}
251+
finalResponse := acc.BuildResponse()
252+
if m.chatStreamCompleteCallback != nil {
253+
m.chatStreamCompleteCallback(ctx, chatRequest, generateConfig, finalResponse)
254+
}
255+
select {
256+
case responseChan <- finalResponse:
257+
case <-ctx.Done():
258+
return
259+
}
250260
}
251261

252262
// convertContentBlock builds a single assistant message from Gemini Candidate.
@@ -292,6 +302,8 @@ func (m *Model) convertContentBlock(candidates []*genai.Candidate) (model.Messag
292302
}, finishReason
293303
}
294304

305+
// buildResponse builds a partial streaming response for a chunk.
306+
// Returns nil if the chunk should be skipped.
295307
func (m *Model) buildResponse(chatCompletion *genai.GenerateContentResponse) *model.Response {
296308
if chatCompletion == nil {
297309
return &model.Response{}
@@ -300,7 +312,7 @@ func (m *Model) buildResponse(chatCompletion *genai.GenerateContentResponse) *mo
300312
ID: chatCompletion.ResponseID,
301313
Created: chatCompletion.CreateTime.Unix(),
302314
Model: chatCompletion.ModelVersion,
303-
Timestamp: time.Now(),
315+
Timestamp: chatCompletion.CreateTime,
304316
}
305317
message, finishReason := m.convertContentBlock(chatCompletion.Candidates)
306318
response.Choices = []model.Choice{
@@ -313,9 +325,6 @@ func (m *Model) buildResponse(chatCompletion *genai.GenerateContentResponse) *mo
313325
// Set finish reason.
314326
if finishReason != "" {
315327
response.Choices[0].FinishReason = &finishReason
316-
response.Done = true
317-
} else {
318-
response.IsPartial = true
319328
}
320329
// Convert usage information.
321330
response.Usage = m.completionUsageToModelUsage(chatCompletion.UsageMetadata)

0 commit comments

Comments
 (0)