Skip to content

Commit b4149b5

Browse files
authored
vertex prompt caching (RooCodeInc#2026)
* vertex prompt caching * duplicate declare * changeset
1 parent cee959e commit b4149b5

File tree

2 files changed

+118
-9
lines changed

2 files changed

+118
-9
lines changed

.changeset/six-years-sip.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"claude-dev": minor
3+
---
4+
5+
re-added prompt caching to vertex

src/api/providers/vertex.ts

Lines changed: 113 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,113 @@ export class VertexHandler implements ApiHandler {
2121

2222
@withRetry()
2323
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
24-
const stream = await this.client.messages.create({
25-
model: this.getModel().id,
26-
max_tokens: this.getModel().info.maxTokens || 8192,
27-
temperature: 0,
28-
system: systemPrompt,
29-
messages,
30-
stream: true,
31-
})
24+
const model = this.getModel()
25+
const modelId = model.id
26+
27+
let stream
28+
switch (modelId) {
29+
case "claude-3-7-sonnet@20250219":
30+
case "claude-3-5-sonnet-v2@20241022":
31+
case "claude-3-5-sonnet@20240620":
32+
case "claude-3-5-haiku@20241022":
33+
case "claude-3-opus@20240229":
34+
case "claude-3-haiku@20240307": {
35+
// Find indices of user messages for cache control
36+
const userMsgIndices = messages.reduce(
37+
(acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc),
38+
[] as number[],
39+
)
40+
const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1
41+
const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1
42+
43+
stream = await this.client.beta.messages.create(
44+
{
45+
model: modelId,
46+
max_tokens: model.info.maxTokens || 8192,
47+
temperature: 0,
48+
system: [
49+
{
50+
text: systemPrompt,
51+
type: "text",
52+
cache_control: { type: "ephemeral" },
53+
},
54+
],
55+
messages: messages.map((message, index) => {
56+
if (index === lastUserMsgIndex || index === secondLastMsgUserIndex) {
57+
return {
58+
...message,
59+
content:
60+
typeof message.content === "string"
61+
? [
62+
{
63+
type: "text",
64+
text: message.content,
65+
cache_control: {
66+
type: "ephemeral",
67+
},
68+
},
69+
]
70+
: message.content.map((content, contentIndex) =>
71+
contentIndex === message.content.length - 1
72+
? {
73+
...content,
74+
cache_control: {
75+
type: "ephemeral",
76+
},
77+
}
78+
: content,
79+
),
80+
}
81+
}
82+
return {
83+
...message,
84+
content:
85+
typeof message.content === "string"
86+
? [
87+
{
88+
type: "text",
89+
text: message.content,
90+
},
91+
]
92+
: message.content,
93+
}
94+
}),
95+
stream: true,
96+
},
97+
{
98+
headers: {},
99+
},
100+
)
101+
break
102+
}
103+
default: {
104+
stream = await this.client.beta.messages.create({
105+
model: modelId,
106+
max_tokens: model.info.maxTokens || 8192,
107+
temperature: 0,
108+
system: [
109+
{
110+
text: systemPrompt,
111+
type: "text",
112+
},
113+
],
114+
messages: messages.map((message) => ({
115+
...message,
116+
content:
117+
typeof message.content === "string"
118+
? [
119+
{
120+
type: "text",
121+
text: message.content,
122+
},
123+
]
124+
: message.content,
125+
})),
126+
stream: true,
127+
})
128+
break
129+
}
130+
}
32131
for await (const chunk of stream) {
33132
switch (chunk.type) {
34133
case "message_start":
@@ -37,6 +136,8 @@ export class VertexHandler implements ApiHandler {
37136
type: "usage",
38137
inputTokens: usage.input_tokens || 0,
39138
outputTokens: usage.output_tokens || 0,
139+
cacheWriteTokens: usage.cache_creation_input_tokens || undefined,
140+
cacheReadTokens: usage.cache_read_input_tokens || undefined,
40141
}
41142
break
42143
case "message_delta":
@@ -46,7 +147,8 @@ export class VertexHandler implements ApiHandler {
46147
outputTokens: chunk.usage.output_tokens || 0,
47148
}
48149
break
49-
150+
case "message_stop":
151+
break
50152
case "content_block_start":
51153
switch (chunk.content_block.type) {
52154
case "text":
@@ -73,6 +175,8 @@ export class VertexHandler implements ApiHandler {
73175
break
74176
}
75177
break
178+
case "content_block_stop":
179+
break
76180
}
77181
}
78182
}

0 commit comments

Comments
 (0)