Skip to content

Commit 18483fc

Browse files
committed
google calls can now be cached (by default)
1 parent f4ce65a commit 18483fc

File tree

9 files changed

+565
-7
lines changed

9 files changed

+565
-7
lines changed

go.mod

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ require (
1010
github.com/gage-technologies/mistral-go v1.1.0
1111
github.com/gen2brain/beeep v0.0.0-20240516210008-9c006672e7f4
1212
github.com/goccy/go-yaml v1.17.1
13+
github.com/google/generative-ai-go v0.15.1
1314
github.com/kirsle/configdir v0.0.0-20170128060238-e45d2f54772f
1415
github.com/kteru/reversereader v0.0.0-20190328040929-bd5e29d6c056
1516
github.com/mark3labs/mcp-go v0.27.1
@@ -44,7 +45,6 @@ require (
4445
github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect
4546
github.com/go-toast/toast v0.0.0-20190211030409-01e6764cf0a4 // indirect
4647
github.com/godbus/dbus/v5 v5.1.0 // indirect
47-
github.com/google/generative-ai-go v0.15.1 // indirect
4848
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect
4949
github.com/google/s2a-go v0.1.9 // indirect
5050
github.com/google/uuid v1.6.0 // indirect
@@ -102,4 +102,4 @@ require (
102102
gopkg.in/yaml.v3 v3.0.1 // indirect
103103
)
104104

105-
replace github.com/tmc/langchaingo => github.com/rainu/langchaingo v0.0.0-20250530154254-565c5b692d5b
105+
replace github.com/tmc/langchaingo => github.com/rainu/langchaingo v0.0.0-20250531142006-3a73e38741a3

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ github.com/rainu/go-command-chain v0.4.0 h1:qgrNbNsqkTfJHdwGzVuGPPK+p+XSnGAhAT/8
134134
github.com/rainu/go-command-chain v0.4.0/go.mod h1:RvLsDKnTGD9XoUY7nmBz73ayffI0bFCDH/EVJPRgfks=
135135
github.com/rainu/go-yacl v0.2.1 h1:BZdwonr/JA8RiE/8xptE9RKT0+COwyfTmyB7kiANGPw=
136136
github.com/rainu/go-yacl v0.2.1/go.mod h1:cZwUkCDYE1w6xlTUi6vCqdV1O3iLvM/govQdUn6I9NU=
137-
github.com/rainu/langchaingo v0.0.0-20250530154254-565c5b692d5b h1:d6BoQOmei4gvKDB+a/V1+hhFbOm326C0EG/5xT2/BJo=
138-
github.com/rainu/langchaingo v0.0.0-20250530154254-565c5b692d5b/go.mod h1:5TXP7bKcWjN05g9e2+MpqUXprf5jwI62Q2xGtAKLno8=
137+
github.com/rainu/langchaingo v0.0.0-20250531142006-3a73e38741a3 h1:XgLNKmrXX3F+Je6DR8IwiKVV834ATnUP0A0dmIweIHA=
138+
github.com/rainu/langchaingo v0.0.0-20250531142006-3a73e38741a3/go.mod h1:5TXP7bKcWjN05g9e2+MpqUXprf5jwI62Q2xGtAKLno8=
139139
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
140140
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
141141
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=

internal/config/model/llm/google.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@ import (
77
"github.com/rainu/ask-mai/internal/llms/google"
88
"github.com/rainu/go-yacl"
99
"github.com/tmc/langchaingo/llms/googleai"
10+
"time"
1011
)
1112

1213
type GoogleAIConfig struct {
1314
APIKey common.Secret `yaml:"api-key,omitempty" usage:"API Key"`
1415
Model string `yaml:"model,omitempty" usage:"Model"`
1516
HarmThreshold *int32 `yaml:"harm-threshold,omitempty"`
17+
18+
ToolCacheTTL *time.Duration `yaml:"tool-cache-ttl,omitempty" usage:"TTL for tool cache. 0 means no caching. Minimum is 1 minute."`
1619
}
1720

1821
func (c *GoogleAIConfig) SetDefaults() {
@@ -22,6 +25,9 @@ func (c *GoogleAIConfig) SetDefaults() {
2225
if c.HarmThreshold == nil {
2326
c.HarmThreshold = yacl.P(int32(googleai.HarmBlockUnspecified))
2427
}
28+
if c.ToolCacheTTL == nil {
29+
c.ToolCacheTTL = yacl.P(5 * time.Minute)
30+
}
2531
}
2632

2733
func (c *GoogleAIConfig) GetUsage(field string) string {
@@ -66,10 +72,13 @@ func (c *GoogleAIConfig) Validate() error {
6672
return fmt.Errorf("Invalid harm threshold value: %d", c.HarmThreshold)
6773
}
6874
}
75+
if c.ToolCacheTTL != nil && *c.ToolCacheTTL > 0 && *c.ToolCacheTTL < time.Minute {
76+
return fmt.Errorf("Tool cache TTL must be at least 1 minute, got: %s", c.ToolCacheTTL)
77+
}
6978

7079
return nil
7180
}
7281

7382
func (c *GoogleAIConfig) BuildLLM() (llmCommon.Model, error) {
74-
return google.New(c.AsOptions())
83+
return google.New(c.AsOptions(), *c.ToolCacheTTL)
7584
}

internal/llms/google/cache.go

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package google
2+
3+
import (
4+
"context"
5+
"github.com/google/generative-ai-go/genai"
6+
"github.com/tmc/langchaingo/llms/googleai"
7+
"log/slog"
8+
"slices"
9+
"strings"
10+
"sync"
11+
"time"
12+
)
13+
14+
func (g *Google) preSendingHook(ctx context.Context, model *genai.GenerativeModel, meta googleai.PreSendingHookMetadata) {
15+
key := getCacheKey(model)
16+
if key == "" {
17+
return
18+
}
19+
20+
if g.cacheNames.Get(key) == "" {
21+
err := g.createNewToolsCache(ctx, key, meta.Options.Model, model.Tools)
22+
if err != nil {
23+
slog.Warn("Error creating cache", "error", err)
24+
return
25+
}
26+
}
27+
28+
model.CachedContentName = g.cacheNames.Get(key)
29+
model.Tools = nil
30+
31+
return
32+
}
33+
34+
func getCacheKey(model *genai.GenerativeModel) string {
35+
var funcNames []string
36+
37+
for _, tool := range model.Tools {
38+
for _, fd := range tool.FunctionDeclarations {
39+
funcNames = append(funcNames, fd.Name)
40+
}
41+
}
42+
slices.Sort(funcNames)
43+
44+
return strings.Join(funcNames, "")
45+
}
46+
47+
func (g *Google) createNewToolsCache(ctx context.Context, cKey, model string, tools []*genai.Tool) error {
48+
cache, err := genaiCreateCachedContent(g, ctx, &genai.CachedContent{
49+
Expiration: genai.ExpireTimeOrTTL{
50+
TTL: g.cacheTTL,
51+
},
52+
Model: model,
53+
Tools: tools,
54+
})
55+
if err != nil {
56+
return err
57+
}
58+
59+
g.cacheNames.Write(cKey, cache.Name)
60+
go g.startCacheRefresher(cKey, cache)
61+
62+
return nil
63+
}
64+
65+
func (g *Google) startCacheRefresher(cKey string, cache *genai.CachedContent) {
66+
go func() {
67+
select {
68+
case <-g.clientCtx.Done():
69+
slog.Debug("cache refresher stopped", "cacheName", cache.Name)
70+
return
71+
case <-time.After(g.cacheRefresh):
72+
}
73+
74+
newCache, err := genaiUpdateCachedContent(g, g.clientCtx, cache, &genai.CachedContentToUpdate{
75+
Expiration: &genai.ExpireTimeOrTTL{
76+
TTL: g.cacheTTL,
77+
},
78+
})
79+
if err != nil {
80+
slog.Warn("Error refreshing cache", "cacheName", cache.Name, "error", err)
81+
return
82+
}
83+
g.cacheNames.Write(cKey, newCache.Name)
84+
85+
go g.startCacheRefresher(cKey, newCache)
86+
}()
87+
}
88+
89+
func (g *Google) removeAllCaches(ctx context.Context) {
90+
wg := &sync.WaitGroup{}
91+
g.cacheNames.For(func(_, name string) {
92+
wg.Add(1)
93+
go func() {
94+
defer wg.Done()
95+
g.removeCache(ctx, name)
96+
}()
97+
})
98+
99+
wg.Wait()
100+
}
101+
102+
func (g *Google) removeCache(ctx context.Context, name string) {
103+
if err := genaiDeleteCachedContent(g, ctx, name); err != nil {
104+
slog.Error("Error deleting cache", "cacheName", name, "error", err)
105+
} else {
106+
slog.Debug("Cache deleted", "cacheName", name)
107+
}
108+
}
109+
110+
var genaiCreateCachedContent = func(g *Google, ctx context.Context, cc *genai.CachedContent) (*genai.CachedContent, error) {
111+
return g.client.GetGenaiClient().CreateCachedContent(ctx, cc)
112+
}
113+
114+
var genaiUpdateCachedContent = func(g *Google, ctx context.Context, cc *genai.CachedContent, update *genai.CachedContentToUpdate) (*genai.CachedContent, error) {
115+
return g.client.GetGenaiClient().UpdateCachedContent(ctx, cc, update)
116+
}
117+
118+
var genaiDeleteCachedContent = func(g *Google, ctx context.Context, name string) error {
119+
return g.client.GetGenaiClient().DeleteCachedContent(ctx, name)
120+
}

0 commit comments

Comments
 (0)