diff --git a/README.md b/README.md index b7f5eca..b4ef04d 100644 --- a/README.md +++ b/README.md @@ -74,26 +74,36 @@ If it doesn't exist, the application will attempt to create it using default val An example of basic configuration: ```json { - "SystemPrompts": { + "system_prompts": { "Software Engineer": "You are an experienced software engineer.", "Technical Writer": "Act as a tech writer. I will provide you with the basic steps of an app functionality, and you will come up with an engaging article on how to do those steps." }, - "SafetySettings": [ + "safety_settings": [ { - "Category": 7, - "Threshold": 1 + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "LOW" }, { - "Category": 10, - "Threshold": 1 + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "LOW" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "LOW" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "LOW" } ], - "History": { + "history": { } } ``` -Upon user request, the `History` map will be populated with records. Note that the chat history is stored in plain -text format. See [history operations](#system-commands) for details. +1 Valid safety settings threshold values include LOW (block more), MEDIUM, HIGH (block less), and OFF. + +2 Upon user request, the `history` map will be populated with records. Note that the chat history is stored +in plain text format. See [history operations](#system-commands) for details. ### CLI help ```console diff --git a/cmd/gemini/main.go b/cmd/gemini/main.go index 622f461..aa42446 100644 --- a/cmd/gemini/main.go +++ b/cmd/gemini/main.go @@ -14,7 +14,6 @@ import ( const ( version = "0.4.0" - apiKeyEnv = "GEMINI_API_KEY" //nolint:gosec defaultConfigPath = "gemini_cli_config.json" ) @@ -48,15 +47,11 @@ func run() int { return err } - modelBuilder := gemini.NewGenerativeModelBuilder(). - WithName(opts.GenerativeModel). - WithSafetySettings(configuration.Data.SafetySettings) - apiKey := os.Getenv(apiKeyEnv) - chatSession, err := gemini.NewChatSession(context.Background(), modelBuilder, apiKey) + chatSession, err := gemini.NewChatSession(context.Background(), opts.GenerativeModel, + configuration.Data.GenaiSafetySettings()) if err != nil { return err } - defer func() { err = errors.Join(err, chatSession.Close()) }() chatHandler, err := chat.New(getCurrentUser(), chatSession, configuration, &opts) if err != nil { diff --git a/gemini/chat_session.go b/gemini/chat_session.go index 115a413..7291eec 100644 --- a/gemini/chat_session.go +++ b/gemini/chat_session.go @@ -4,10 +4,10 @@ import ( "context" "encoding/json" "fmt" + "iter" "sync" - "github.com/google/generative-ai-go/genai" - "google.golang.org/api/option" + "google.golang.org/genai" ) const DefaultModel = "gemini-2.5-flash" @@ -16,66 +16,59 @@ const DefaultModel = "gemini-2.5-flash" type ChatSession struct { ctx context.Context - client *genai.Client - model *genai.GenerativeModel - session *genai.ChatSession + client *genai.Client + chat *genai.Chat + config *genai.GenerateContentConfig + model string loadModels sync.Once models []string } // NewChatSession returns a new [ChatSession]. -func NewChatSession( - ctx context.Context, modelBuilder *GenerativeModelBuilder, apiKey string, -) (*ChatSession, error) { - client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey)) +func NewChatSession(ctx context.Context, model string, safetySettings []*genai.SafetySetting) (*ChatSession, error) { + client, err := genai.NewClient(ctx, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create client: %w", err) + } + + config := &genai.GenerateContentConfig{SafetySettings: safetySettings} + chat, err := client.Chats.Create(ctx, model, config, nil) + if err != nil { + return nil, fmt.Errorf("failed to create chat: %w", err) } - generativeModel := modelBuilder.build(client) return &ChatSession{ - ctx: ctx, - client: client, - model: generativeModel, - session: generativeModel.StartChat(), + ctx: ctx, + client: client, + chat: chat, + config: config, + model: model, }, nil } // SendMessage sends a request to the model as part of a chat session. func (c *ChatSession) SendMessage(input string) (*genai.GenerateContentResponse, error) { - return c.session.SendMessage(c.ctx, genai.Text(input)) + return c.chat.SendMessage(c.ctx, genai.Part{Text: input}) } // SendMessageStream is like SendMessage, but with a streaming request. -func (c *ChatSession) SendMessageStream(input string) *genai.GenerateContentResponseIterator { - return c.session.SendMessageStream(c.ctx, genai.Text(input)) -} - -// SetModel sets a new generative model configured with the builder and starts -// a new chat session. It preserves the history of the previous chat session. -func (c *ChatSession) SetModel(modelBuilder *GenerativeModelBuilder) { - history := c.session.History - c.model = modelBuilder.build(c.client) - c.session = c.model.StartChat() - c.session.History = history -} - -// CopyModelBuilder returns a copy builder for the chat generative model. -func (c *ChatSession) CopyModelBuilder() *GenerativeModelBuilder { - return newCopyGenerativeModelBuilder(c.model) +func (c *ChatSession) SendMessageStream(input string) iter.Seq2[*genai.GenerateContentResponse, error] { + return c.chat.SendMessageStream(c.ctx, genai.Part{Text: input}) } // ModelInfo returns information about the chat generative model in JSON format. func (c *ChatSession) ModelInfo() (string, error) { - modelInfo, err := c.model.Info(c.ctx) + modelInfo, err := c.client.Models.Get(c.ctx, c.model, nil) if err != nil { return "", err } + encoded, err := json.MarshalIndent(modelInfo, "", " ") if err != nil { return "", fmt.Errorf("error encoding model info: %w", err) } + return string(encoded), nil } @@ -83,34 +76,57 @@ func (c *ChatSession) ModelInfo() (string, error) { func (c *ChatSession) ListModels() []string { c.loadModels.Do(func() { c.models = []string{DefaultModel} - iter := c.client.ListModels(c.ctx) - for { - modelInfo, err := iter.Next() + for model, err := range c.client.Models.All(c.ctx) { if err != nil { - break + continue } - c.models = append(c.models, modelInfo.Name) + c.models = append(c.models, model.Name) } }) return c.models } +// SetModel sets the chat generative model. +func (c *ChatSession) SetModel(model string) error { + chat, err := c.client.Chats.Create(c.ctx, model, c.config, c.GetHistory()) + if err != nil { + return fmt.Errorf("failed to set model: %w", err) + } + + c.model = model + c.chat = chat + return nil +} + // GetHistory returns the chat session history. func (c *ChatSession) GetHistory() []*genai.Content { - return c.session.History + return c.chat.History(true) } // SetHistory sets the chat session history. -func (c *ChatSession) SetHistory(content []*genai.Content) { - c.session.History = content +func (c *ChatSession) SetHistory(history []*genai.Content) error { + chat, err := c.client.Chats.Create(c.ctx, c.model, c.config, history) + if err != nil { + return fmt.Errorf("failed to set history: %w", err) + } + + c.chat = chat + return nil } // ClearHistory clears the chat session history. -func (c *ChatSession) ClearHistory() { - c.session.History = make([]*genai.Content, 0) +func (c *ChatSession) ClearHistory() error { + return c.SetHistory(nil) } -// Close closes the chat session. -func (c *ChatSession) Close() error { - return c.client.Close() +// SetSystemInstruction sets the chat session system instruction. +func (c *ChatSession) SetSystemInstruction(systemInstruction *genai.Content) error { + c.config.SystemInstruction = systemInstruction + chat, err := c.client.Chats.Create(c.ctx, c.model, c.config, c.GetHistory()) + if err != nil { + return fmt.Errorf("failed to set system instruction: %w", err) + } + + c.chat = chat + return nil } diff --git a/gemini/generative_model_builder.go b/gemini/generative_model_builder.go deleted file mode 100644 index aee3b09..0000000 --- a/gemini/generative_model_builder.go +++ /dev/null @@ -1,134 +0,0 @@ -package gemini - -import ( - "github.com/google/generative-ai-go/genai" -) - -type boxed[T any] struct { - value T -} - -// GenerativeModelBuilder implements the builder pattern for [genai.GenerativeModel]. -type GenerativeModelBuilder struct { - copy *genai.GenerativeModel - - name *boxed[string] - generationConfig *boxed[genai.GenerationConfig] - safetySettings *boxed[[]*genai.SafetySetting] - tools *boxed[[]*genai.Tool] - toolConfig *boxed[*genai.ToolConfig] - systemInstruction *boxed[*genai.Content] - cachedContentName *boxed[string] -} - -// NewGenerativeModelBuilder returns a new [GenerativeModelBuilder] with empty default values. -func NewGenerativeModelBuilder() *GenerativeModelBuilder { - return &GenerativeModelBuilder{} -} - -// newCopyGenerativeModelBuilder creates a new [GenerativeModelBuilder], -// taking the default values from an existing [genai.GenerativeModel] object. -func newCopyGenerativeModelBuilder(model *genai.GenerativeModel) *GenerativeModelBuilder { - return &GenerativeModelBuilder{copy: model} -} - -// WithName sets the model name. -func (b *GenerativeModelBuilder) WithName( - modelName string, -) *GenerativeModelBuilder { - b.name = &boxed[string]{modelName} - return b -} - -// WithGenerationConfig sets the generation config. -func (b *GenerativeModelBuilder) WithGenerationConfig( - generationConfig genai.GenerationConfig, -) *GenerativeModelBuilder { - b.generationConfig = &boxed[genai.GenerationConfig]{generationConfig} - return b -} - -// WithSafetySettings sets the safety settings. -func (b *GenerativeModelBuilder) WithSafetySettings( - safetySettings []*genai.SafetySetting, -) *GenerativeModelBuilder { - b.safetySettings = &boxed[[]*genai.SafetySetting]{safetySettings} - return b -} - -// WithTools sets the tools. -func (b *GenerativeModelBuilder) WithTools( - tools []*genai.Tool, -) *GenerativeModelBuilder { - b.tools = &boxed[[]*genai.Tool]{tools} - return b -} - -// WithToolConfig sets the tool config. -func (b *GenerativeModelBuilder) WithToolConfig( - toolConfig *genai.ToolConfig, -) *GenerativeModelBuilder { - b.toolConfig = &boxed[*genai.ToolConfig]{toolConfig} - return b -} - -// WithSystemInstruction sets the system instruction. -func (b *GenerativeModelBuilder) WithSystemInstruction( - systemInstruction *genai.Content, -) *GenerativeModelBuilder { - b.systemInstruction = &boxed[*genai.Content]{systemInstruction} - return b -} - -// WithCachedContentName sets the name of the [genai.CachedContent] to use. -func (b *GenerativeModelBuilder) WithCachedContentName( - cachedContentName string, -) *GenerativeModelBuilder { - b.cachedContentName = &boxed[string]{cachedContentName} - return b -} - -// build builds and returns a new [genai.GenerativeModel] using the given [genai.Client]. -// It will panic if the copy and the model name are not set. -func (b *GenerativeModelBuilder) build(client *genai.Client) *genai.GenerativeModel { - if b.copy == nil && b.name == nil { - panic("model name is required") - } - - model := b.copy - if b.name != nil { - model = client.GenerativeModel(b.name.value) - if b.copy != nil { - model.GenerationConfig = b.copy.GenerationConfig - model.SafetySettings = b.copy.SafetySettings - model.Tools = b.copy.Tools - model.ToolConfig = b.copy.ToolConfig - model.SystemInstruction = b.copy.SystemInstruction - model.CachedContentName = b.copy.CachedContentName - } - } - b.configure(model) - return model -} - -// configure configures the given generative model using the builder values. -func (b *GenerativeModelBuilder) configure(model *genai.GenerativeModel) { - if b.generationConfig != nil { - model.GenerationConfig = b.generationConfig.value - } - if b.safetySettings != nil { - model.SafetySettings = b.safetySettings.value - } - if b.tools != nil { - model.Tools = b.tools.value - } - if b.toolConfig != nil { - model.ToolConfig = b.toolConfig.value - } - if b.systemInstruction != nil { - model.SystemInstruction = b.systemInstruction.value - } - if b.cachedContentName != nil { - model.CachedContentName = b.cachedContentName.value - } -} diff --git a/gemini/serializable_content.go b/gemini/serializable_content.go index 1b267c8..4a39904 100644 --- a/gemini/serializable_content.go +++ b/gemini/serializable_content.go @@ -1,9 +1,7 @@ package gemini import ( - "fmt" - - "github.com/google/generative-ai-go/genai" + "google.golang.org/genai" ) // SerializableContent is the data type containing multipart text message content. @@ -22,8 +20,9 @@ type SerializableContent struct { func NewSerializableContent(c *genai.Content) *SerializableContent { parts := make([]string, len(c.Parts)) for i, part := range c.Parts { - parts[i] = partToString(part) + parts[i] = part.Text } + return &SerializableContent{ Parts: parts, Role: c.Role, @@ -32,21 +31,13 @@ func NewSerializableContent(c *genai.Content) *SerializableContent { // ToContent converts the SerializableContent into a [genai.Content]. func (c *SerializableContent) ToContent() *genai.Content { - parts := make([]genai.Part, len(c.Parts)) + parts := make([]*genai.Part, len(c.Parts)) for i, part := range c.Parts { - parts[i] = genai.Text(part) + parts[i] = genai.NewPartFromText(part) } + return &genai.Content{ Parts: parts, Role: c.Role, } } - -func partToString(part genai.Part) string { - switch p := part.(type) { - case genai.Text: - return string(p) - default: - panic(fmt.Errorf("unsupported part type: %T", part)) - } -} diff --git a/gemini/system_instruction.go b/gemini/system_instruction.go index 347470b..4883cba 100644 --- a/gemini/system_instruction.go +++ b/gemini/system_instruction.go @@ -1,6 +1,6 @@ package gemini -import "github.com/google/generative-ai-go/genai" +import "google.golang.org/genai" // SystemInstruction represents a serializable system prompt, a more forceful // instruction to the language model. The model will prioritize adhering to @@ -9,5 +9,5 @@ type SystemInstruction string // ToContent converts the SystemInstruction to [genai.Content]. func (si SystemInstruction) ToContent() *genai.Content { - return genai.NewUserContent(genai.Text(si)) + return genai.NewContentFromText(string(si), genai.RoleUser) } diff --git a/go.mod b/go.mod index 6fa84fb..72fb335 100644 --- a/go.mod +++ b/go.mod @@ -5,20 +5,16 @@ go 1.24.0 require ( github.com/charmbracelet/glamour v0.10.0 github.com/chzyer/readline v1.5.1 - github.com/google/generative-ai-go v0.20.1 github.com/manifoldco/promptui v0.9.0 github.com/muesli/termenv v0.16.0 github.com/spf13/cobra v1.10.1 - google.golang.org/api v0.249.0 + google.golang.org/genai v1.35.0 ) require ( cloud.google.com/go v0.116.0 // indirect - cloud.google.com/go/ai v0.8.2 // indirect cloud.google.com/go/auth v0.16.5 // indirect - cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/compute/metadata v0.8.0 // indirect - cloud.google.com/go/longrunning v0.6.0 // indirect github.com/alecthomas/chroma/v2 v2.14.0 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect @@ -32,11 +28,12 @@ require ( github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/google/go-cmp v0.7.0 // indirect github.com/google/s2a-go v0.1.9 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect github.com/googleapis/gax-go/v2 v2.15.0 // indirect github.com/gorilla/css v1.0.1 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -49,20 +46,15 @@ require ( github.com/yuin/goldmark v1.7.8 // indirect github.com/yuin/goldmark-emoji v1.0.5 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect - go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect go.opentelemetry.io/otel v1.37.0 // indirect go.opentelemetry.io/otel/metric v1.37.0 // indirect go.opentelemetry.io/otel/trace v1.37.0 // indirect golang.org/x/crypto v0.41.0 // indirect golang.org/x/net v0.43.0 // indirect - golang.org/x/oauth2 v0.30.0 // indirect - golang.org/x/sync v0.16.0 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/term v0.34.0 // indirect golang.org/x/text v0.28.0 // indirect - golang.org/x/time v0.12.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250818200422-3122310a409c // indirect google.golang.org/grpc v1.75.0 // indirect google.golang.org/protobuf v1.36.8 // indirect diff --git a/go.sum b/go.sum index d22bf96..0306bc8 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,9 @@ cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE= cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U= -cloud.google.com/go/ai v0.8.2 h1:LEaQwqBv+k2ybrcdTtCTc9OPZXoEdcQaGrfvDYS6Bnk= -cloud.google.com/go/ai v0.8.2/go.mod h1:Wb3EUUGWwB6yHBaUf/+oxUq/6XbCaU1yh0GrwUS8lr4= cloud.google.com/go/auth v0.16.5 h1:mFWNQ2FEVWAliEQWpAdH80omXFokmrnbDhUS9cBywsI= cloud.google.com/go/auth v0.16.5/go.mod h1:utzRfHMP+Vv0mpOkTRQoWD2q3BatTOoWbA7gCc2dUhQ= -cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= -cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= cloud.google.com/go/compute/metadata v0.8.0 h1:HxMRIbao8w17ZX6wBnjhcDkW6lTFpgcaobyVfZWqRLA= cloud.google.com/go/compute/metadata v0.8.0/go.mod h1:sYOGTp851OV9bOFJ9CH7elVvyzopvWQFNNghtDQ/Biw= -cloud.google.com/go/longrunning v0.6.0 h1:mM1ZmaNsQsnb+5n1DNPeL0KwQd9jQRqSqSDEkBZr+aI= -cloud.google.com/go/longrunning v0.6.0/go.mod h1:uHzSZqW89h7/pasCWNYdUpwGz3PcVWhrWupreVPYLts= github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE= github.com/alecthomas/assert/v2 v2.7.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= github.com/alecthomas/chroma/v2 v2.14.0 h1:R3+wzpnUArGcQz7fCETQBzO5n9IMNi13iIs46aU4V9E= @@ -61,8 +55,6 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/google/generative-ai-go v0.20.1 h1:6dEIujpgN2V0PgLhr6c/M1ynRdc7ARtiIDPFzj45uNQ= -github.com/google/generative-ai-go v0.20.1/go.mod h1:TjOnZJmZKzarWbjUJgy+r3Ee7HGBRVLhOIgupnwR4Bg= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= @@ -75,6 +67,8 @@ github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81 github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc= github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= @@ -116,8 +110,6 @@ github.com/yuin/goldmark-emoji v1.0.5 h1:EMVWyCGPlXJfUXBXpuMu+ii3TIaxbVBnEX9uaDC github.com/yuin/goldmark-emoji v1.0.5/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 h1:q4XOmH/0opmeuJtPsbFNivyl7bCt7yRBbeEm2sC/XtQ= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0/go.mod h1:snMWehoOh2wsEwnvvwtDyFCxVeDAODenXHtn5vzrKjo= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q= go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= @@ -136,8 +128,6 @@ golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZ golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= -golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= -golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -149,16 +139,10 @@ golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4= golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= -golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= -golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= -google.golang.org/api v0.249.0 h1:0VrsWAKzIZi058aeq+I86uIXbNhm9GxSHpbmZ92a38w= -google.golang.org/api v0.249.0/go.mod h1:dGk9qyI0UYPwO/cjt2q06LG/EhUpwZGdAbYF14wHHrQ= -google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4= -google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s= -google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7 h1:FiusG7LWj+4byqhbvmB+Q93B/mOxJLN2DTozDuZm4EU= -google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:kXqgZtrWaf6qS3jZOCnCH7WYfrvFjkC51bM8fz3RsCA= +google.golang.org/genai v1.35.0 h1:Jo6g25CzVqFzGrX5mhWyBgQqXAUzxcx5jeK7U74zv9c= +google.golang.org/genai v1.35.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= google.golang.org/genproto/googleapis/rpc v0.0.0-20250818200422-3122310a409c h1:qXWI/sQtv5UKboZ/zUk7h+mrf/lXORyI+n9DKDAusdg= google.golang.org/genproto/googleapis/rpc v0.0.0-20250818200422-3122310a409c/go.mod h1:gw1tLEfykwDz2ET4a12jcXt4couGAm7IwsVaTy0Sflo= google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4= diff --git a/internal/config/application_data.go b/internal/config/application_data.go index 821b273..d19d704 100644 --- a/internal/config/application_data.go +++ b/internal/config/application_data.go @@ -1,26 +1,60 @@ package config import ( - "github.com/google/generative-ai-go/genai" "github.com/reugn/gemini-cli/gemini" + "google.golang.org/genai" ) +const ( + thresholdLow = "LOW" + thresholdMedium = "MEDIUM" + thresholdHigh = "HIGH" + thresholdOff = "OFF" +) + +// Threshold is a custom type that wraps genai.HarmBlockThreshold +// and uses the custom string for serialization. +type Threshold string + +func (t Threshold) toGenai() genai.HarmBlockThreshold { + switch t { + case thresholdLow: + return genai.HarmBlockThresholdBlockLowAndAbove + case thresholdMedium: + return genai.HarmBlockThresholdBlockMediumAndAbove + case thresholdHigh: + return genai.HarmBlockThresholdBlockOnlyHigh + case thresholdOff: + return genai.HarmBlockThresholdOff + default: + return genai.HarmBlockThresholdUnspecified + } +} + +// SafetySetting is a custom type that wraps genai.SafetySetting +// and uses the custom HarmBlockThreshold for serialization. +type SafetySetting struct { + Category genai.HarmCategory `json:"category"` + Threshold Threshold `json:"threshold"` +} + // ApplicationData encapsulates application state and configuration. // Note that the chat history is stored in plain text format. type ApplicationData struct { - SystemPrompts map[string]gemini.SystemInstruction - SafetySettings []*genai.SafetySetting - History map[string][]*gemini.SerializableContent + SystemPrompts map[string]gemini.SystemInstruction `json:"system_prompts"` + SafetySettings []*SafetySetting `json:"safety_settings"` + History map[string][]*gemini.SerializableContent `json:"history"` } // newDefaultApplicationData returns a new ApplicationData with default values. func newDefaultApplicationData() *ApplicationData { - defaultSafetySettings := []*genai.SafetySetting{ - {Category: genai.HarmCategoryHarassment, Threshold: genai.HarmBlockLowAndAbove}, - {Category: genai.HarmCategoryHateSpeech, Threshold: genai.HarmBlockLowAndAbove}, - {Category: genai.HarmCategorySexuallyExplicit, Threshold: genai.HarmBlockLowAndAbove}, - {Category: genai.HarmCategoryDangerousContent, Threshold: genai.HarmBlockLowAndAbove}, + defaultSafetySettings := []*SafetySetting{ + {Category: genai.HarmCategoryHarassment, Threshold: thresholdLow}, + {Category: genai.HarmCategoryHateSpeech, Threshold: thresholdLow}, + {Category: genai.HarmCategorySexuallyExplicit, Threshold: thresholdLow}, + {Category: genai.HarmCategoryDangerousContent, Threshold: thresholdLow}, } + return &ApplicationData{ SystemPrompts: make(map[string]gemini.SystemInstruction), SafetySettings: defaultSafetySettings, @@ -34,5 +68,19 @@ func (d *ApplicationData) AddHistoryRecord(label string, content []*genai.Conten for i, c := range content { serializableContent[i] = gemini.NewSerializableContent(c) } + d.History[label] = serializableContent } + +// GenaiSafetySettings converts the application data safety settings to genai safety settings. +func (d *ApplicationData) GenaiSafetySettings() []*genai.SafetySetting { + genaiSafetySettings := make([]*genai.SafetySetting, len(d.SafetySettings)) + for i, s := range d.SafetySettings { + genaiSafetySettings[i] = &genai.SafetySetting{ + Category: s.Category, + Threshold: s.Threshold.toGenai(), + } + } + + return genaiSafetySettings +} diff --git a/internal/handler/gemini_query.go b/internal/handler/gemini_query.go index 61ed204..3345c98 100644 --- a/internal/handler/gemini_query.go +++ b/internal/handler/gemini_query.go @@ -45,7 +45,7 @@ func (h *GeminiQuery) Handle(message string) (Response, bool) { var b strings.Builder for _, candidate := range response.Candidates { for _, part := range candidate.Content.Parts { - _, _ = fmt.Fprintf(&b, "%s", part) + _, _ = fmt.Fprint(&b, part.Text) } } diff --git a/internal/handler/history_command.go b/internal/handler/history_command.go index 0184554..f6dda06 100644 --- a/internal/handler/history_command.go +++ b/internal/handler/history_command.go @@ -5,10 +5,10 @@ import ( "slices" "time" - "github.com/google/generative-ai-go/genai" "github.com/manifoldco/promptui" "github.com/reugn/gemini-cli/gemini" "github.com/reugn/gemini-cli/internal/config" + "google.golang.org/genai" ) var historyOptions = []string{ @@ -63,7 +63,10 @@ func (h *HistoryCommand) Handle(_ string) (Response, bool) { // handleClear handles the chat history clear request. func (h *HistoryCommand) handleClear() Response { h.terminal.Write(h.terminalPrompt) - h.session.ClearHistory() + if err := h.session.ClearHistory(); err != nil { + return newErrorResponse(err) + } + return dataResponse("Cleared the chat history.") } @@ -77,6 +80,7 @@ func (h *HistoryCommand) handleStore() Response { timeLabel := time.Now().In(time.Local).Format(time.DateTime) recordLabel := fmt.Sprintf("%s - %s", timeLabel, historyLabel) + h.configuration.Data.AddHistoryRecord( recordLabel, h.session.GetHistory(), @@ -97,7 +101,10 @@ func (h *HistoryCommand) handleLoad() Response { return newErrorResponse(err) } - h.session.SetHistory(history) + if err := h.session.SetHistory(history); err != nil { + return newErrorResponse(err) + } + return dataResponse(fmt.Sprintf("%q has been loaded to the chat history.", label)) } diff --git a/internal/handler/model_command.go b/internal/handler/model_command.go index ee181ce..0a07534 100644 --- a/internal/handler/model_command.go +++ b/internal/handler/model_command.go @@ -49,6 +49,7 @@ func (h *ModelCommand) Handle(_ string) (Response, bool) { default: response = newErrorResponse(fmt.Errorf("unsupported option: %s", option)) } + return response, false } @@ -64,8 +65,10 @@ func (h *ModelCommand) handleSelectModel() Response { return dataResponse(unchangedMessage) } - modelBuilder := h.session.CopyModelBuilder().WithName(modelName) - h.session.SetModel(modelBuilder) + if err := h.session.SetModel(modelName); err != nil { + return newErrorResponse(err) + } + h.generativeModelName = modelName return dataResponse(fmt.Sprintf("Selected %q generative model.", modelName)) @@ -81,6 +84,7 @@ func (h *ModelCommand) handleModelInfo() Response { if err != nil { return newErrorResponse(err) } + return dataResponse(modelInfo) } diff --git a/internal/handler/prompt_command.go b/internal/handler/prompt_command.go index 70cb7c4..3607479 100644 --- a/internal/handler/prompt_command.go +++ b/internal/handler/prompt_command.go @@ -4,10 +4,10 @@ import ( "fmt" "slices" - "github.com/google/generative-ai-go/genai" "github.com/manifoldco/promptui" "github.com/reugn/gemini-cli/gemini" "github.com/reugn/gemini-cli/internal/config" + "google.golang.org/genai" ) // SystemPromptCommand processes the chat prompt system command. @@ -40,9 +40,9 @@ func (h *SystemPromptCommand) Handle(_ string) (Response, bool) { return newErrorResponse(err), false } - modelBuilder := h.session.CopyModelBuilder(). - WithSystemInstruction(systemPrompt) - h.session.SetModel(modelBuilder) + if err := h.session.SetSystemInstruction(systemPrompt); err != nil { + return newErrorResponse(err), false + } return dataResponse(fmt.Sprintf("Selected %q system instruction.", label)), false }