Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,36 @@ If it doesn't exist, the application will attempt to create it using default val
An example of basic configuration:
```json
{
"SystemPrompts": {
"system_prompts": {
"Software Engineer": "You are an experienced software engineer.",
"Technical Writer": "Act as a tech writer. I will provide you with the basic steps of an app functionality, and you will come up with an engaging article on how to do those steps."
},
"SafetySettings": [
"safety_settings": [
{
"Category": 7,
"Threshold": 1
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "LOW"
},
{
"Category": 10,
"Threshold": 1
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "LOW"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "LOW"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "LOW"
}
],
"History": {
"history": {
}
}
```
Upon user request, the `History` map will be populated with records. Note that the chat history is stored in plain
text format. See [history operations](#system-commands) for details.
<sup>1</sup> Valid safety settings threshold values include LOW (block more), MEDIUM, HIGH (block less), and OFF.

<sup>2</sup> Upon user request, the `history` map will be populated with records. Note that the chat history is stored
in plain text format. See [history operations](#system-commands) for details.

### CLI help
```console
Expand Down
9 changes: 2 additions & 7 deletions cmd/gemini/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (

const (
version = "0.4.0"
apiKeyEnv = "GEMINI_API_KEY" //nolint:gosec
defaultConfigPath = "gemini_cli_config.json"
)

Expand Down Expand Up @@ -48,15 +47,11 @@ func run() int {
return err
}

modelBuilder := gemini.NewGenerativeModelBuilder().
WithName(opts.GenerativeModel).
WithSafetySettings(configuration.Data.SafetySettings)
apiKey := os.Getenv(apiKeyEnv)
chatSession, err := gemini.NewChatSession(context.Background(), modelBuilder, apiKey)
chatSession, err := gemini.NewChatSession(context.Background(), opts.GenerativeModel,
configuration.Data.GenaiSafetySettings())
if err != nil {
return err
}
defer func() { err = errors.Join(err, chatSession.Close()) }()

chatHandler, err := chat.New(getCurrentUser(), chatSession, configuration, &opts)
if err != nil {
Expand Down
108 changes: 62 additions & 46 deletions gemini/chat_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import (
"context"
"encoding/json"
"fmt"
"iter"
"sync"

"github.com/google/generative-ai-go/genai"
"google.golang.org/api/option"
"google.golang.org/genai"
)

const DefaultModel = "gemini-2.5-flash"
Expand All @@ -16,101 +16,117 @@ const DefaultModel = "gemini-2.5-flash"
type ChatSession struct {
ctx context.Context

client *genai.Client
model *genai.GenerativeModel
session *genai.ChatSession
client *genai.Client
chat *genai.Chat
config *genai.GenerateContentConfig
model string

loadModels sync.Once
models []string
}

// NewChatSession returns a new [ChatSession].
func NewChatSession(
ctx context.Context, modelBuilder *GenerativeModelBuilder, apiKey string,
) (*ChatSession, error) {
client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey))
func NewChatSession(ctx context.Context, model string, safetySettings []*genai.SafetySetting) (*ChatSession, error) {
client, err := genai.NewClient(ctx, nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create client: %w", err)
}

config := &genai.GenerateContentConfig{SafetySettings: safetySettings}
chat, err := client.Chats.Create(ctx, model, config, nil)
if err != nil {
return nil, fmt.Errorf("failed to create chat: %w", err)
}

generativeModel := modelBuilder.build(client)
return &ChatSession{
ctx: ctx,
client: client,
model: generativeModel,
session: generativeModel.StartChat(),
ctx: ctx,
client: client,
chat: chat,
config: config,
model: model,
}, nil
}

// SendMessage sends a request to the model as part of a chat session.
func (c *ChatSession) SendMessage(input string) (*genai.GenerateContentResponse, error) {
return c.session.SendMessage(c.ctx, genai.Text(input))
return c.chat.SendMessage(c.ctx, genai.Part{Text: input})
}

// SendMessageStream is like SendMessage, but with a streaming request.
func (c *ChatSession) SendMessageStream(input string) *genai.GenerateContentResponseIterator {
return c.session.SendMessageStream(c.ctx, genai.Text(input))
}

// SetModel sets a new generative model configured with the builder and starts
// a new chat session. It preserves the history of the previous chat session.
func (c *ChatSession) SetModel(modelBuilder *GenerativeModelBuilder) {
history := c.session.History
c.model = modelBuilder.build(c.client)
c.session = c.model.StartChat()
c.session.History = history
}

// CopyModelBuilder returns a copy builder for the chat generative model.
func (c *ChatSession) CopyModelBuilder() *GenerativeModelBuilder {
return newCopyGenerativeModelBuilder(c.model)
func (c *ChatSession) SendMessageStream(input string) iter.Seq2[*genai.GenerateContentResponse, error] {
return c.chat.SendMessageStream(c.ctx, genai.Part{Text: input})
}

// ModelInfo returns information about the chat generative model in JSON format.
func (c *ChatSession) ModelInfo() (string, error) {
modelInfo, err := c.model.Info(c.ctx)
modelInfo, err := c.client.Models.Get(c.ctx, c.model, nil)
if err != nil {
return "", err
}

encoded, err := json.MarshalIndent(modelInfo, "", " ")
if err != nil {
return "", fmt.Errorf("error encoding model info: %w", err)
}

return string(encoded), nil
}

// ListModels returns a list of the supported generative model names.
func (c *ChatSession) ListModels() []string {
c.loadModels.Do(func() {
c.models = []string{DefaultModel}
iter := c.client.ListModels(c.ctx)
for {
modelInfo, err := iter.Next()
for model, err := range c.client.Models.All(c.ctx) {
if err != nil {
break
continue
}
c.models = append(c.models, modelInfo.Name)
c.models = append(c.models, model.Name)
}
})
return c.models
}

// SetModel sets the chat generative model.
func (c *ChatSession) SetModel(model string) error {
chat, err := c.client.Chats.Create(c.ctx, model, c.config, c.GetHistory())
if err != nil {
return fmt.Errorf("failed to set model: %w", err)
}

c.model = model
c.chat = chat
return nil
}

// GetHistory returns the chat session history.
func (c *ChatSession) GetHistory() []*genai.Content {
return c.session.History
return c.chat.History(true)
}

// SetHistory sets the chat session history.
func (c *ChatSession) SetHistory(content []*genai.Content) {
c.session.History = content
func (c *ChatSession) SetHistory(history []*genai.Content) error {
chat, err := c.client.Chats.Create(c.ctx, c.model, c.config, history)
if err != nil {
return fmt.Errorf("failed to set history: %w", err)
}

c.chat = chat
return nil
}

// ClearHistory clears the chat session history.
func (c *ChatSession) ClearHistory() {
c.session.History = make([]*genai.Content, 0)
func (c *ChatSession) ClearHistory() error {
return c.SetHistory(nil)
}

// Close closes the chat session.
func (c *ChatSession) Close() error {
return c.client.Close()
// SetSystemInstruction sets the chat session system instruction.
func (c *ChatSession) SetSystemInstruction(systemInstruction *genai.Content) error {
c.config.SystemInstruction = systemInstruction
chat, err := c.client.Chats.Create(c.ctx, c.model, c.config, c.GetHistory())
if err != nil {
return fmt.Errorf("failed to set system instruction: %w", err)
}

c.chat = chat
return nil
}
134 changes: 0 additions & 134 deletions gemini/generative_model_builder.go

This file was deleted.

Loading
Loading