-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathgenerative_model_builder.go
More file actions
134 lines (118 loc) · 3.9 KB
/
generative_model_builder.go
File metadata and controls
134 lines (118 loc) · 3.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
package gemini
import (
"github.com/google/generative-ai-go/genai"
)
type boxed[T any] struct {
value T
}
// GenerativeModelBuilder implements the builder pattern for [genai.GenerativeModel].
type GenerativeModelBuilder struct {
copy *genai.GenerativeModel
name *boxed[string]
generationConfig *boxed[genai.GenerationConfig]
safetySettings *boxed[[]*genai.SafetySetting]
tools *boxed[[]*genai.Tool]
toolConfig *boxed[*genai.ToolConfig]
systemInstruction *boxed[*genai.Content]
cachedContentName *boxed[string]
}
// NewGenerativeModelBuilder returns a new [GenerativeModelBuilder] with empty default values.
func NewGenerativeModelBuilder() *GenerativeModelBuilder {
return &GenerativeModelBuilder{}
}
// newCopyGenerativeModelBuilder creates a new [GenerativeModelBuilder],
// taking the default values from an existing [genai.GenerativeModel] object.
func newCopyGenerativeModelBuilder(model *genai.GenerativeModel) *GenerativeModelBuilder {
return &GenerativeModelBuilder{copy: model}
}
// WithName sets the model name.
func (b *GenerativeModelBuilder) WithName(
modelName string,
) *GenerativeModelBuilder {
b.name = &boxed[string]{modelName}
return b
}
// WithGenerationConfig sets the generation config.
func (b *GenerativeModelBuilder) WithGenerationConfig(
generationConfig genai.GenerationConfig,
) *GenerativeModelBuilder {
b.generationConfig = &boxed[genai.GenerationConfig]{generationConfig}
return b
}
// WithSafetySettings sets the safety settings.
func (b *GenerativeModelBuilder) WithSafetySettings(
safetySettings []*genai.SafetySetting,
) *GenerativeModelBuilder {
b.safetySettings = &boxed[[]*genai.SafetySetting]{safetySettings}
return b
}
// WithTools sets the tools.
func (b *GenerativeModelBuilder) WithTools(
tools []*genai.Tool,
) *GenerativeModelBuilder {
b.tools = &boxed[[]*genai.Tool]{tools}
return b
}
// WithToolConfig sets the tool config.
func (b *GenerativeModelBuilder) WithToolConfig(
toolConfig *genai.ToolConfig,
) *GenerativeModelBuilder {
b.toolConfig = &boxed[*genai.ToolConfig]{toolConfig}
return b
}
// WithSystemInstruction sets the system instruction.
func (b *GenerativeModelBuilder) WithSystemInstruction(
systemInstruction *genai.Content,
) *GenerativeModelBuilder {
b.systemInstruction = &boxed[*genai.Content]{systemInstruction}
return b
}
// WithCachedContentName sets the name of the [genai.CachedContent] to use.
func (b *GenerativeModelBuilder) WithCachedContentName(
cachedContentName string,
) *GenerativeModelBuilder {
b.cachedContentName = &boxed[string]{cachedContentName}
return b
}
// build builds and returns a new [genai.GenerativeModel] using the given [genai.Client].
// It will panic if the copy and the model name are not set.
func (b *GenerativeModelBuilder) build(client *genai.Client) *genai.GenerativeModel {
if b.copy == nil && b.name == nil {
panic("model name is required")
}
model := b.copy
if b.name != nil {
model = client.GenerativeModel(b.name.value)
if b.copy != nil {
model.GenerationConfig = b.copy.GenerationConfig
model.SafetySettings = b.copy.SafetySettings
model.Tools = b.copy.Tools
model.ToolConfig = b.copy.ToolConfig
model.SystemInstruction = b.copy.SystemInstruction
model.CachedContentName = b.copy.CachedContentName
}
}
b.configure(model)
return model
}
// configure configures the given generative model using the builder values.
func (b *GenerativeModelBuilder) configure(model *genai.GenerativeModel) {
if b.generationConfig != nil {
model.GenerationConfig = b.generationConfig.value
}
if b.safetySettings != nil {
model.SafetySettings = b.safetySettings.value
}
if b.tools != nil {
model.Tools = b.tools.value
}
if b.toolConfig != nil {
model.ToolConfig = b.toolConfig.value
}
if b.systemInstruction != nil {
model.SystemInstruction = b.systemInstruction.value
}
if b.cachedContentName != nil {
model.CachedContentName = b.cachedContentName.value
}
}