Skip to content

Commit 047965e

Browse files
committed
refactor: migrate to the new go-genai library
1 parent d2c8413 commit 047965e

File tree

12 files changed

+157
-254
lines changed

12 files changed

+157
-254
lines changed

cmd/gemini/main.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414

1515
const (
1616
version = "0.4.0"
17-
apiKeyEnv = "GEMINI_API_KEY" //nolint:gosec
1817
defaultConfigPath = "gemini_cli_config.json"
1918
)
2019

@@ -48,15 +47,11 @@ func run() int {
4847
return err
4948
}
5049

51-
modelBuilder := gemini.NewGenerativeModelBuilder().
52-
WithName(opts.GenerativeModel).
53-
WithSafetySettings(configuration.Data.SafetySettings)
54-
apiKey := os.Getenv(apiKeyEnv)
55-
chatSession, err := gemini.NewChatSession(context.Background(), modelBuilder, apiKey)
50+
chatSession, err := gemini.NewChatSession(context.Background(), opts.GenerativeModel,
51+
configuration.Data.GenaiSafetySettings())
5652
if err != nil {
5753
return err
5854
}
59-
defer func() { err = errors.Join(err, chatSession.Close()) }()
6055

6156
chatHandler, err := chat.New(getCurrentUser(), chatSession, configuration, &opts)
6257
if err != nil {

gemini/chat_session.go

Lines changed: 62 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7+
"iter"
78
"sync"
89

9-
"github.com/google/generative-ai-go/genai"
10-
"google.golang.org/api/option"
10+
"google.golang.org/genai"
1111
)
1212

1313
const DefaultModel = "gemini-2.5-flash"
@@ -16,101 +16,117 @@ const DefaultModel = "gemini-2.5-flash"
1616
type ChatSession struct {
1717
ctx context.Context
1818

19-
client *genai.Client
20-
model *genai.GenerativeModel
21-
session *genai.ChatSession
19+
client *genai.Client
20+
chat *genai.Chat
21+
config *genai.GenerateContentConfig
22+
model string
2223

2324
loadModels sync.Once
2425
models []string
2526
}
2627

2728
// NewChatSession returns a new [ChatSession].
28-
func NewChatSession(
29-
ctx context.Context, modelBuilder *GenerativeModelBuilder, apiKey string,
30-
) (*ChatSession, error) {
31-
client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey))
29+
func NewChatSession(ctx context.Context, model string, safetySettings []*genai.SafetySetting) (*ChatSession, error) {
30+
client, err := genai.NewClient(ctx, nil)
3231
if err != nil {
33-
return nil, err
32+
return nil, fmt.Errorf("failed to create client: %w", err)
33+
}
34+
35+
config := &genai.GenerateContentConfig{SafetySettings: safetySettings}
36+
chat, err := client.Chats.Create(ctx, model, config, nil)
37+
if err != nil {
38+
return nil, fmt.Errorf("failed to create chat: %w", err)
3439
}
3540

36-
generativeModel := modelBuilder.build(client)
3741
return &ChatSession{
38-
ctx: ctx,
39-
client: client,
40-
model: generativeModel,
41-
session: generativeModel.StartChat(),
42+
ctx: ctx,
43+
client: client,
44+
chat: chat,
45+
config: config,
46+
model: model,
4247
}, nil
4348
}
4449

4550
// SendMessage sends a request to the model as part of a chat session.
4651
func (c *ChatSession) SendMessage(input string) (*genai.GenerateContentResponse, error) {
47-
return c.session.SendMessage(c.ctx, genai.Text(input))
52+
return c.chat.SendMessage(c.ctx, genai.Part{Text: input})
4853
}
4954

5055
// SendMessageStream is like SendMessage, but with a streaming request.
51-
func (c *ChatSession) SendMessageStream(input string) *genai.GenerateContentResponseIterator {
52-
return c.session.SendMessageStream(c.ctx, genai.Text(input))
53-
}
54-
55-
// SetModel sets a new generative model configured with the builder and starts
56-
// a new chat session. It preserves the history of the previous chat session.
57-
func (c *ChatSession) SetModel(modelBuilder *GenerativeModelBuilder) {
58-
history := c.session.History
59-
c.model = modelBuilder.build(c.client)
60-
c.session = c.model.StartChat()
61-
c.session.History = history
62-
}
63-
64-
// CopyModelBuilder returns a copy builder for the chat generative model.
65-
func (c *ChatSession) CopyModelBuilder() *GenerativeModelBuilder {
66-
return newCopyGenerativeModelBuilder(c.model)
56+
func (c *ChatSession) SendMessageStream(input string) iter.Seq2[*genai.GenerateContentResponse, error] {
57+
return c.chat.SendMessageStream(c.ctx, genai.Part{Text: input})
6758
}
6859

6960
// ModelInfo returns information about the chat generative model in JSON format.
7061
func (c *ChatSession) ModelInfo() (string, error) {
71-
modelInfo, err := c.model.Info(c.ctx)
62+
modelInfo, err := c.client.Models.Get(c.ctx, c.model, nil)
7263
if err != nil {
7364
return "", err
7465
}
66+
7567
encoded, err := json.MarshalIndent(modelInfo, "", " ")
7668
if err != nil {
7769
return "", fmt.Errorf("error encoding model info: %w", err)
7870
}
71+
7972
return string(encoded), nil
8073
}
8174

8275
// ListModels returns a list of the supported generative model names.
8376
func (c *ChatSession) ListModels() []string {
8477
c.loadModels.Do(func() {
8578
c.models = []string{DefaultModel}
86-
iter := c.client.ListModels(c.ctx)
87-
for {
88-
modelInfo, err := iter.Next()
79+
for model, err := range c.client.Models.All(c.ctx) {
8980
if err != nil {
90-
break
81+
continue
9182
}
92-
c.models = append(c.models, modelInfo.Name)
83+
c.models = append(c.models, model.Name)
9384
}
9485
})
9586
return c.models
9687
}
9788

89+
// SetModel sets the chat generative model.
90+
func (c *ChatSession) SetModel(model string) error {
91+
chat, err := c.client.Chats.Create(c.ctx, model, c.config, c.GetHistory())
92+
if err != nil {
93+
return fmt.Errorf("failed to set model: %w", err)
94+
}
95+
96+
c.model = model
97+
c.chat = chat
98+
return nil
99+
}
100+
98101
// GetHistory returns the chat session history.
99102
func (c *ChatSession) GetHistory() []*genai.Content {
100-
return c.session.History
103+
return c.chat.History(true)
101104
}
102105

103106
// SetHistory sets the chat session history.
104-
func (c *ChatSession) SetHistory(content []*genai.Content) {
105-
c.session.History = content
107+
func (c *ChatSession) SetHistory(history []*genai.Content) error {
108+
chat, err := c.client.Chats.Create(c.ctx, c.model, c.config, history)
109+
if err != nil {
110+
return fmt.Errorf("failed to set history: %w", err)
111+
}
112+
113+
c.chat = chat
114+
return nil
106115
}
107116

108117
// ClearHistory clears the chat session history.
109-
func (c *ChatSession) ClearHistory() {
110-
c.session.History = make([]*genai.Content, 0)
118+
func (c *ChatSession) ClearHistory() error {
119+
return c.SetHistory(nil)
111120
}
112121

113-
// Close closes the chat session.
114-
func (c *ChatSession) Close() error {
115-
return c.client.Close()
122+
// SetSystemInstruction sets the chat session system instruction.
123+
func (c *ChatSession) SetSystemInstruction(systemInstruction *genai.Content) error {
124+
c.config.SystemInstruction = systemInstruction
125+
chat, err := c.client.Chats.Create(c.ctx, c.model, c.config, c.GetHistory())
126+
if err != nil {
127+
return fmt.Errorf("failed to set system instruction: %w", err)
128+
}
129+
130+
c.chat = chat
131+
return nil
116132
}

gemini/generative_model_builder.go

Lines changed: 0 additions & 134 deletions
This file was deleted.

gemini/serializable_content.go

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
package gemini
22

33
import (
4-
"fmt"
5-
6-
"github.com/google/generative-ai-go/genai"
4+
"google.golang.org/genai"
75
)
86

97
// SerializableContent is the data type containing multipart text message content.
@@ -22,8 +20,9 @@ type SerializableContent struct {
2220
func NewSerializableContent(c *genai.Content) *SerializableContent {
2321
parts := make([]string, len(c.Parts))
2422
for i, part := range c.Parts {
25-
parts[i] = partToString(part)
23+
parts[i] = part.Text
2624
}
25+
2726
return &SerializableContent{
2827
Parts: parts,
2928
Role: c.Role,
@@ -32,21 +31,13 @@ func NewSerializableContent(c *genai.Content) *SerializableContent {
3231

3332
// ToContent converts the SerializableContent into a [genai.Content].
3433
func (c *SerializableContent) ToContent() *genai.Content {
35-
parts := make([]genai.Part, len(c.Parts))
34+
parts := make([]*genai.Part, len(c.Parts))
3635
for i, part := range c.Parts {
37-
parts[i] = genai.Text(part)
36+
parts[i] = genai.NewPartFromText(part)
3837
}
38+
3939
return &genai.Content{
4040
Parts: parts,
4141
Role: c.Role,
4242
}
4343
}
44-
45-
func partToString(part genai.Part) string {
46-
switch p := part.(type) {
47-
case genai.Text:
48-
return string(p)
49-
default:
50-
panic(fmt.Errorf("unsupported part type: %T", part))
51-
}
52-
}

gemini/system_instruction.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package gemini
22

3-
import "github.com/google/generative-ai-go/genai"
3+
import "google.golang.org/genai"
44

55
// SystemInstruction represents a serializable system prompt, a more forceful
66
// instruction to the language model. The model will prioritize adhering to
@@ -9,5 +9,5 @@ type SystemInstruction string
99

1010
// ToContent converts the SystemInstruction to [genai.Content].
1111
func (si SystemInstruction) ToContent() *genai.Content {
12-
return genai.NewUserContent(genai.Text(si))
12+
return genai.NewContentFromText(string(si), genai.RoleUser)
1313
}

0 commit comments

Comments
 (0)