diff --git a/README.md b/README.md index b4ef04d..043a00f 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,16 @@ An example of basic configuration: "threshold": "LOW" } ], + "tools": [ + { + "name": "GOOGLE_SEARCH", + "enabled": true + }, + { + "name": "URL_CONTEXT", + "enabled": true + } + ], "history": { } } diff --git a/cmd/gemini/main.go b/cmd/gemini/main.go index aa42446..d3198a2 100644 --- a/cmd/gemini/main.go +++ b/cmd/gemini/main.go @@ -48,7 +48,7 @@ func run() int { } chatSession, err := gemini.NewChatSession(context.Background(), opts.GenerativeModel, - configuration.Data.GenaiSafetySettings()) + configuration.Data.GenaiContentConfig()) if err != nil { return err } diff --git a/gemini/chat_session.go b/gemini/chat_session.go index 7291eec..8e44acb 100644 --- a/gemini/chat_session.go +++ b/gemini/chat_session.go @@ -26,14 +26,14 @@ type ChatSession struct { } // NewChatSession returns a new [ChatSession]. -func NewChatSession(ctx context.Context, model string, safetySettings []*genai.SafetySetting) (*ChatSession, error) { +func NewChatSession(ctx context.Context, model string, + contentConfig *genai.GenerateContentConfig) (*ChatSession, error) { client, err := genai.NewClient(ctx, nil) if err != nil { return nil, fmt.Errorf("failed to create client: %w", err) } - config := &genai.GenerateContentConfig{SafetySettings: safetySettings} - chat, err := client.Chats.Create(ctx, model, config, nil) + chat, err := client.Chats.Create(ctx, model, contentConfig, nil) if err != nil { return nil, fmt.Errorf("failed to create chat: %w", err) } @@ -42,7 +42,7 @@ func NewChatSession(ctx context.Context, model string, safetySettings []*genai.S ctx: ctx, client: client, chat: chat, - config: config, + config: contentConfig, model: model, }, nil } diff --git a/internal/config/application_data.go b/internal/config/application_data.go index d19d704..2c61fa9 100644 --- a/internal/config/application_data.go +++ b/internal/config/application_data.go @@ -12,6 +12,11 @@ const ( thresholdOff = "OFF" ) +const ( + toolGoogleSearch = "GOOGLE_SEARCH" + toolURLContext = "URL_CONTEXT" +) + // Threshold is a custom type that wraps genai.HarmBlockThreshold // and uses the custom string for serialization. type Threshold string @@ -32,32 +37,45 @@ func (t Threshold) toGenai() genai.HarmBlockThreshold { } // SafetySetting is a custom type that wraps genai.SafetySetting -// and uses the custom HarmBlockThreshold for serialization. +// and uses the custom Threshold for serialization. type SafetySetting struct { Category genai.HarmCategory `json:"category"` Threshold Threshold `json:"threshold"` } +// Tool represents a model tool configuration. +type Tool struct { + Name string `json:"name"` + Enabled bool `json:"enabled"` +} + // ApplicationData encapsulates application state and configuration. // Note that the chat history is stored in plain text format. type ApplicationData struct { SystemPrompts map[string]gemini.SystemInstruction `json:"system_prompts"` - SafetySettings []*SafetySetting `json:"safety_settings"` + SafetySettings []SafetySetting `json:"safety_settings"` + Tools []Tool `json:"tools"` History map[string][]*gemini.SerializableContent `json:"history"` } // newDefaultApplicationData returns a new ApplicationData with default values. func newDefaultApplicationData() *ApplicationData { - defaultSafetySettings := []*SafetySetting{ + defaultSafetySettings := []SafetySetting{ {Category: genai.HarmCategoryHarassment, Threshold: thresholdLow}, {Category: genai.HarmCategoryHateSpeech, Threshold: thresholdLow}, {Category: genai.HarmCategorySexuallyExplicit, Threshold: thresholdLow}, {Category: genai.HarmCategoryDangerousContent, Threshold: thresholdLow}, } + defaultTools := []Tool{ + {Name: toolGoogleSearch, Enabled: true}, + {Name: toolURLContext, Enabled: true}, + } + return &ApplicationData{ SystemPrompts: make(map[string]gemini.SystemInstruction), SafetySettings: defaultSafetySettings, + Tools: defaultTools, History: make(map[string][]*gemini.SerializableContent), } } @@ -84,3 +102,36 @@ func (d *ApplicationData) GenaiSafetySettings() []*genai.SafetySetting { return genaiSafetySettings } + +// GenaiTools builds a genai Tool slice using enabled entries. +func (d *ApplicationData) GenaiTools() []*genai.Tool { + tools := make([]*genai.Tool, 0, len(d.Tools)) + for _, tool := range d.Tools { + if !tool.Enabled { + continue + } + + var genaiTool *genai.Tool + switch tool.Name { + case toolGoogleSearch: + genaiTool = &genai.Tool{GoogleSearch: &genai.GoogleSearch{}} + case toolURLContext: + genaiTool = &genai.Tool{URLContext: &genai.URLContext{}} + default: + continue // Skip unknown tools + } + + tools = append(tools, genaiTool) + } + + return tools +} + +// GenaiContentConfig builds a genai GenerateContentConfig with the current +// safety settings and enabled tools. +func (d *ApplicationData) GenaiContentConfig() *genai.GenerateContentConfig { + return &genai.GenerateContentConfig{ + SafetySettings: d.GenaiSafetySettings(), + Tools: d.GenaiTools(), + } +}