|
1 | 1 | import { Anthropic } from "@anthropic-ai/sdk" |
2 | 2 | import { Stream as AnthropicStream } from "@anthropic-ai/sdk/streaming" |
3 | | -import { anthropicDefaultModelId, AnthropicModelId, anthropicModels, ModelInfo } from "../../shared/api" |
| 3 | +import { withRetry } from "../retry" |
| 4 | +import { anthropicDefaultModelId, AnthropicModelId, anthropicModels, ApiHandlerOptions, ModelInfo } from "../../shared/api" |
| 5 | +import { ApiHandler } from "../index" |
4 | 6 | import { ApiStream } from "../transform/stream" |
5 | | -import { ClaudeStreamingHandler } from "./claude-streaming" |
6 | 7 |
|
7 | | -/** |
8 | | - * Handles interactions with the Anthropic service. |
9 | | - */ |
10 | | -export class AnthropicHandler extends ClaudeStreamingHandler<Anthropic> { |
11 | | - getClient() { |
12 | | - return new Anthropic({ |
| 8 | +export class AnthropicHandler implements ApiHandler { |
| 9 | + private options: ApiHandlerOptions |
| 10 | + private client: Anthropic |
| 11 | + |
| 12 | + constructor(options: ApiHandlerOptions) { |
| 13 | + this.options = options |
| 14 | + this.client = new Anthropic({ |
13 | 15 | apiKey: this.options.apiKey, |
14 | | - baseURL: this.options.anthropicBaseUrl || null, // default baseURL: https://api.anthropic.com |
| 16 | + baseURL: this.options.anthropicBaseUrl || undefined, |
15 | 17 | }) |
16 | 18 | } |
17 | 19 |
|
18 | | - override async *createStreamingMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { |
| 20 | + @withRetry() |
| 21 | + async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { |
19 | 22 | const model = this.getModel() |
| 23 | + let stream: AnthropicStream<Anthropic.Beta.PromptCaching.Messages.RawPromptCachingBetaMessageStreamEvent> |
20 | 24 | const modelId = model.id |
21 | | - |
22 | | - let stream: AnthropicStream<Anthropic.Messages.RawMessageStreamEvent> |
23 | | - |
24 | | - if (Object.keys(anthropicModels).includes(modelId)) { |
25 | | - stream = await this.createModelStream( |
26 | | - systemPrompt, |
27 | | - messages, |
28 | | - modelId, |
29 | | - model.info.maxTokens ?? AnthropicHandler.DEFAULT_TOKEN_SIZE, |
30 | | - ) |
31 | | - } else { |
32 | | - throw new Error(`Invalid model ID: ${modelId}`) |
| 25 | + switch (modelId) { |
| 26 | + // 'latest' alias does not support cache_control |
| 27 | + case "claude-3-7-sonnet-20250219": |
| 28 | + case "claude-3-5-sonnet-20241022": |
| 29 | + case "claude-3-5-haiku-20241022": |
| 30 | + case "claude-3-opus-20240229": |
| 31 | + case "claude-3-haiku-20240307": { |
| 32 | + /* |
| 33 | + The latest message will be the new user message, one before will be the assistant message from a previous request, and the user message before that will be a previously cached user message. So we need to mark the latest user message as ephemeral to cache it for the next request, and mark the second to last user message as ephemeral to let the server know the last message to retrieve from the cache for the current request.. |
| 34 | + */ |
| 35 | + const userMsgIndices = messages.reduce( |
| 36 | + (acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc), |
| 37 | + [] as number[], |
| 38 | + ) |
| 39 | + const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 |
| 40 | + const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 |
| 41 | + stream = await this.client.beta.promptCaching.messages.create( |
| 42 | + { |
| 43 | + model: modelId, |
| 44 | + max_tokens: model.info.maxTokens || 8192, |
| 45 | + temperature: 0, |
| 46 | + system: [ |
| 47 | + { |
| 48 | + text: systemPrompt, |
| 49 | + type: "text", |
| 50 | + cache_control: { type: "ephemeral" }, |
| 51 | + }, |
| 52 | + ], // setting cache breakpoint for system prompt so new tasks can reuse it |
| 53 | + messages: messages.map((message, index) => { |
| 54 | + if (index === lastUserMsgIndex || index === secondLastMsgUserIndex) { |
| 55 | + return { |
| 56 | + ...message, |
| 57 | + content: |
| 58 | + typeof message.content === "string" |
| 59 | + ? [ |
| 60 | + { |
| 61 | + type: "text", |
| 62 | + text: message.content, |
| 63 | + cache_control: { |
| 64 | + type: "ephemeral", |
| 65 | + }, |
| 66 | + }, |
| 67 | + ] |
| 68 | + : message.content.map((content, contentIndex) => |
| 69 | + contentIndex === message.content.length - 1 |
| 70 | + ? { |
| 71 | + ...content, |
| 72 | + cache_control: { |
| 73 | + type: "ephemeral", |
| 74 | + }, |
| 75 | + } |
| 76 | + : content, |
| 77 | + ), |
| 78 | + } |
| 79 | + } |
| 80 | + return message |
| 81 | + }), |
| 82 | + // tools, // cache breakpoints go from tools > system > messages, and since tools dont change, we can just set the breakpoint at the end of system (this avoids having to set a breakpoint at the end of tools which by itself does not meet min requirements for haiku caching) |
| 83 | + // tool_choice: { type: "auto" }, |
| 84 | + // tools: tools, |
| 85 | + stream: true, |
| 86 | + }, |
| 87 | + (() => { |
| 88 | + // prompt caching: https://x.com/alexalbert__/status/1823751995901272068 |
| 89 | + // https://github.com/anthropics/anthropic-sdk-typescript?tab=readme-ov-file#default-headers |
| 90 | + // https://github.com/anthropics/anthropic-sdk-typescript/commit/c920b77fc67bd839bfeb6716ceab9d7c9bbe7393 |
| 91 | + switch (modelId) { |
| 92 | + case "claude-3-7-sonnet-20250219": |
| 93 | + case "claude-3-5-sonnet-20241022": |
| 94 | + case "claude-3-5-haiku-20241022": |
| 95 | + case "claude-3-opus-20240229": |
| 96 | + case "claude-3-haiku-20240307": |
| 97 | + return { |
| 98 | + headers: { |
| 99 | + "anthropic-beta": "prompt-caching-2024-07-31", |
| 100 | + }, |
| 101 | + } |
| 102 | + default: |
| 103 | + return undefined |
| 104 | + } |
| 105 | + })(), |
| 106 | + ) |
| 107 | + break |
| 108 | + } |
| 109 | + default: { |
| 110 | + stream = (await this.client.messages.create({ |
| 111 | + model: modelId, |
| 112 | + max_tokens: model.info.maxTokens || 8192, |
| 113 | + temperature: 0, |
| 114 | + system: [{ text: systemPrompt, type: "text" }], |
| 115 | + messages, |
| 116 | + // tools, |
| 117 | + // tool_choice: { type: "auto" }, |
| 118 | + stream: true, |
| 119 | + })) as any |
| 120 | + break |
| 121 | + } |
33 | 122 | } |
34 | 123 |
|
35 | | - yield* this.processStream(stream) |
36 | | - } |
| 124 | + for await (const chunk of stream) { |
| 125 | + switch (chunk.type) { |
| 126 | + case "message_start": |
| 127 | + // tells us cache reads/writes/input/output |
| 128 | + const usage = chunk.message.usage |
| 129 | + yield { |
| 130 | + type: "usage", |
| 131 | + inputTokens: usage.input_tokens || 0, |
| 132 | + outputTokens: usage.output_tokens || 0, |
| 133 | + cacheWriteTokens: usage.cache_creation_input_tokens || undefined, |
| 134 | + cacheReadTokens: usage.cache_read_input_tokens || undefined, |
| 135 | + } |
| 136 | + break |
| 137 | + case "message_delta": |
| 138 | + // tells us stop_reason, stop_sequence, and output tokens along the way and at the end of the message |
37 | 139 |
|
38 | | - override async createModelStream( |
39 | | - systemPrompt: string, |
40 | | - messages: Anthropic.Messages.MessageParam[], |
41 | | - modelId: string, |
42 | | - maxTokens: number, |
43 | | - ): Promise<AnthropicStream<Anthropic.Messages.RawMessageStreamEvent>> { |
44 | | - /* |
45 | | - The latest message will be the new user message, one before will be the assistant message from a previous request, and the user message before that will be a previously cached user message. So we need to mark the latest user message as ephemeral to cache it for the next request, and mark the second to last user message as ephemeral to let the server know the last message to retrieve from the cache for the current request.. |
46 | | - */ |
47 | | - const userMsgIndices = messages.reduce((acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc), [] as number[]) |
48 | | - const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 |
49 | | - const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 |
50 | | - return await this.client.messages.create({ |
51 | | - model: modelId, |
52 | | - max_tokens: maxTokens || AnthropicHandler.DEFAULT_TOKEN_SIZE, |
53 | | - temperature: AnthropicHandler.DEFAULT_TEMPERATURE, |
54 | | - system: [ |
55 | | - { |
56 | | - text: systemPrompt, |
57 | | - type: "text", |
58 | | - cache_control: { type: "ephemeral" }, |
59 | | - }, |
60 | | - ], // setting cache breakpoint for system prompt so new tasks can reuse it |
61 | | - messages: messages.map((message, index) => |
62 | | - this.transformMessage(message, index, lastUserMsgIndex, secondLastMsgUserIndex), |
63 | | - ), |
64 | | - // tools, // cache breakpoints go from tools > system > messages, and since tools dont change, we can just set the breakpoint at the end of system (this avoids having to set a breakpoint at the end of tools which by itself does not meet min requirements for haiku caching) |
65 | | - // tool_choice: { type: "auto" }, |
66 | | - // tools: tools, |
67 | | - stream: true, |
68 | | - }) |
| 140 | + yield { |
| 141 | + type: "usage", |
| 142 | + inputTokens: 0, |
| 143 | + outputTokens: chunk.usage.output_tokens || 0, |
| 144 | + } |
| 145 | + break |
| 146 | + case "message_stop": |
| 147 | + // no usage data, just an indicator that the message is done |
| 148 | + break |
| 149 | + case "content_block_start": |
| 150 | + switch (chunk.content_block.type) { |
| 151 | + case "text": |
| 152 | + // we may receive multiple text blocks, in which case just insert a line break between them |
| 153 | + if (chunk.index > 0) { |
| 154 | + yield { |
| 155 | + type: "text", |
| 156 | + text: "\n", |
| 157 | + } |
| 158 | + } |
| 159 | + yield { |
| 160 | + type: "text", |
| 161 | + text: chunk.content_block.text, |
| 162 | + } |
| 163 | + break |
| 164 | + } |
| 165 | + break |
| 166 | + case "content_block_delta": |
| 167 | + switch (chunk.delta.type) { |
| 168 | + case "text_delta": |
| 169 | + yield { |
| 170 | + type: "text", |
| 171 | + text: chunk.delta.text, |
| 172 | + } |
| 173 | + break |
| 174 | + } |
| 175 | + break |
| 176 | + case "content_block_stop": |
| 177 | + break |
| 178 | + } |
| 179 | + } |
69 | 180 | } |
70 | 181 |
|
71 | | - override getModel(): { id: AnthropicModelId; info: ModelInfo } { |
| 182 | + getModel(): { id: AnthropicModelId; info: ModelInfo } { |
72 | 183 | const modelId = this.options.apiModelId |
73 | 184 | if (modelId && modelId in anthropicModels) { |
74 | 185 | const id = modelId as AnthropicModelId |
|
0 commit comments