Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ An example of basic configuration:
"threshold": "LOW"
}
],
"tools": [
{
"name": "GOOGLE_SEARCH",
"enabled": true
},
{
"name": "URL_CONTEXT",
"enabled": true
}
],
"history": {
}
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/gemini/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func run() int {
}

chatSession, err := gemini.NewChatSession(context.Background(), opts.GenerativeModel,
configuration.Data.GenaiSafetySettings())
configuration.Data.GenaiContentConfig())
if err != nil {
return err
}
Expand Down
8 changes: 4 additions & 4 deletions gemini/chat_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ type ChatSession struct {
}

// NewChatSession returns a new [ChatSession].
func NewChatSession(ctx context.Context, model string, safetySettings []*genai.SafetySetting) (*ChatSession, error) {
func NewChatSession(ctx context.Context, model string,
contentConfig *genai.GenerateContentConfig) (*ChatSession, error) {
client, err := genai.NewClient(ctx, nil)
if err != nil {
return nil, fmt.Errorf("failed to create client: %w", err)
}

config := &genai.GenerateContentConfig{SafetySettings: safetySettings}
chat, err := client.Chats.Create(ctx, model, config, nil)
chat, err := client.Chats.Create(ctx, model, contentConfig, nil)
if err != nil {
return nil, fmt.Errorf("failed to create chat: %w", err)
}
Expand All @@ -42,7 +42,7 @@ func NewChatSession(ctx context.Context, model string, safetySettings []*genai.S
ctx: ctx,
client: client,
chat: chat,
config: config,
config: contentConfig,
model: model,
}, nil
}
Expand Down
57 changes: 54 additions & 3 deletions internal/config/application_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ const (
thresholdOff = "OFF"
)

const (
toolGoogleSearch = "GOOGLE_SEARCH"
toolURLContext = "URL_CONTEXT"
)

// Threshold is a custom type that wraps genai.HarmBlockThreshold
// and uses the custom string for serialization.
type Threshold string
Expand All @@ -32,32 +37,45 @@ func (t Threshold) toGenai() genai.HarmBlockThreshold {
}

// SafetySetting is a custom type that wraps genai.SafetySetting
// and uses the custom HarmBlockThreshold for serialization.
// and uses the custom Threshold for serialization.
type SafetySetting struct {
Category genai.HarmCategory `json:"category"`
Threshold Threshold `json:"threshold"`
}

// Tool represents a model tool configuration.
type Tool struct {
Name string `json:"name"`
Enabled bool `json:"enabled"`
}

// ApplicationData encapsulates application state and configuration.
// Note that the chat history is stored in plain text format.
type ApplicationData struct {
SystemPrompts map[string]gemini.SystemInstruction `json:"system_prompts"`
SafetySettings []*SafetySetting `json:"safety_settings"`
SafetySettings []SafetySetting `json:"safety_settings"`
Tools []Tool `json:"tools"`
History map[string][]*gemini.SerializableContent `json:"history"`
}

// newDefaultApplicationData returns a new ApplicationData with default values.
func newDefaultApplicationData() *ApplicationData {
defaultSafetySettings := []*SafetySetting{
defaultSafetySettings := []SafetySetting{
{Category: genai.HarmCategoryHarassment, Threshold: thresholdLow},
{Category: genai.HarmCategoryHateSpeech, Threshold: thresholdLow},
{Category: genai.HarmCategorySexuallyExplicit, Threshold: thresholdLow},
{Category: genai.HarmCategoryDangerousContent, Threshold: thresholdLow},
}

defaultTools := []Tool{
{Name: toolGoogleSearch, Enabled: true},
{Name: toolURLContext, Enabled: true},
}

return &ApplicationData{
SystemPrompts: make(map[string]gemini.SystemInstruction),
SafetySettings: defaultSafetySettings,
Tools: defaultTools,
History: make(map[string][]*gemini.SerializableContent),
}
}
Expand All @@ -84,3 +102,36 @@ func (d *ApplicationData) GenaiSafetySettings() []*genai.SafetySetting {

return genaiSafetySettings
}

// GenaiTools builds a genai Tool slice using enabled entries.
func (d *ApplicationData) GenaiTools() []*genai.Tool {
tools := make([]*genai.Tool, 0, len(d.Tools))
for _, tool := range d.Tools {
if !tool.Enabled {
continue
}

var genaiTool *genai.Tool
switch tool.Name {
case toolGoogleSearch:
genaiTool = &genai.Tool{GoogleSearch: &genai.GoogleSearch{}}
case toolURLContext:
genaiTool = &genai.Tool{URLContext: &genai.URLContext{}}
default:
continue // Skip unknown tools
}

tools = append(tools, genaiTool)
}

return tools
}

// GenaiContentConfig builds a genai GenerateContentConfig with the current
// safety settings and enabled tools.
func (d *ApplicationData) GenaiContentConfig() *genai.GenerateContentConfig {
return &genai.GenerateContentConfig{
SafetySettings: d.GenaiSafetySettings(),
Tools: d.GenaiTools(),
}
}
Loading