Skip to content

Commit 773e556

Browse files
authored
Merge pull request RooCodeInc#1312 from RooVetGit/count_tokens
Infrastructure to support calling token count APIs, starting with Anthropic
2 parents a2d441c + 7e62d34 commit 773e556

File tree

18 files changed

+543
-361
lines changed

18 files changed

+543
-361
lines changed

src/api/index.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,16 @@ export interface SingleCompletionHandler {
2727
export interface ApiHandler {
2828
createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream
2929
getModel(): { id: string; info: ModelInfo }
30+
31+
/**
32+
* Counts tokens for content blocks
33+
* All providers extend BaseProvider which provides a default tiktoken implementation,
34+
* but they can override this to use their native token counting endpoints
35+
*
36+
* @param content The content to count tokens for
37+
* @returns A promise resolving to the token count
38+
*/
39+
countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number>
3040
}
3141

3242
export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {

src/api/providers/anthropic.ts

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,17 @@ import {
99
ModelInfo,
1010
} from "../../shared/api"
1111
import { ApiStream } from "../transform/stream"
12+
import { BaseProvider } from "./base-provider"
1213
import { ANTHROPIC_DEFAULT_MAX_TOKENS } from "./constants"
13-
import { ApiHandler, SingleCompletionHandler, getModelParams } from "../index"
14+
import { SingleCompletionHandler, getModelParams } from "../index"
1415

15-
export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
16+
export class AnthropicHandler extends BaseProvider implements SingleCompletionHandler {
1617
private options: ApiHandlerOptions
1718
private client: Anthropic
1819

1920
constructor(options: ApiHandlerOptions) {
21+
super()
2022
this.options = options
21-
2223
this.client = new Anthropic({
2324
apiKey: this.options.apiKey,
2425
baseURL: this.options.anthropicBaseUrl || undefined,
@@ -212,4 +213,35 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
212213
const content = message.content.find(({ type }) => type === "text")
213214
return content?.type === "text" ? content.text : ""
214215
}
216+
217+
/**
218+
* Counts tokens for the given content using Anthropic's API
219+
*
220+
* @param content The content blocks to count tokens for
221+
* @returns A promise resolving to the token count
222+
*/
223+
override async countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
224+
try {
225+
// Use the current model
226+
const actualModelId = this.getModel().id
227+
228+
const response = await this.client.messages.countTokens({
229+
model: actualModelId,
230+
messages: [
231+
{
232+
role: "user",
233+
content: content,
234+
},
235+
],
236+
})
237+
238+
return response.input_tokens
239+
} catch (error) {
240+
// Log error but fallback to tiktoken estimation
241+
console.warn("Anthropic token counting failed, using fallback", error)
242+
243+
// Use the base provider's implementation as fallback
244+
return super.countTokens(content)
245+
}
246+
}
215247
}

src/api/providers/base-provider.ts

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import { Anthropic } from "@anthropic-ai/sdk"
2+
import { ApiHandler } from ".."
3+
import { ModelInfo } from "../../shared/api"
4+
import { ApiStream } from "../transform/stream"
5+
import { Tiktoken } from "js-tiktoken/lite"
6+
import o200kBase from "js-tiktoken/ranks/o200k_base"
7+
8+
// Reuse the fudge factor used in the original code
9+
const TOKEN_FUDGE_FACTOR = 1.5
10+
11+
/**
12+
* Base class for API providers that implements common functionality
13+
*/
14+
export abstract class BaseProvider implements ApiHandler {
15+
// Cache the Tiktoken encoder instance since it's stateless
16+
private encoder: Tiktoken | null = null
17+
abstract createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream
18+
abstract getModel(): { id: string; info: ModelInfo }
19+
20+
/**
21+
* Default token counting implementation using tiktoken
22+
* Providers can override this to use their native token counting endpoints
23+
*
24+
* Uses a cached Tiktoken encoder instance for performance since it's stateless.
25+
* The encoder is created lazily on first use and reused for subsequent calls.
26+
*
27+
* @param content The content to count tokens for
28+
* @returns A promise resolving to the token count
29+
*/
30+
async countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
31+
if (!content || content.length === 0) return 0
32+
33+
let totalTokens = 0
34+
35+
// Lazily create and cache the encoder if it doesn't exist
36+
if (!this.encoder) {
37+
this.encoder = new Tiktoken(o200kBase)
38+
}
39+
40+
// Process each content block using the cached encoder
41+
for (const block of content) {
42+
if (block.type === "text") {
43+
// Use tiktoken for text token counting
44+
const text = block.text || ""
45+
if (text.length > 0) {
46+
const tokens = this.encoder.encode(text)
47+
totalTokens += tokens.length
48+
}
49+
} else if (block.type === "image") {
50+
// For images, calculate based on data size
51+
const imageSource = block.source
52+
if (imageSource && typeof imageSource === "object" && "data" in imageSource) {
53+
const base64Data = imageSource.data as string
54+
totalTokens += Math.ceil(Math.sqrt(base64Data.length))
55+
} else {
56+
totalTokens += 300 // Conservative estimate for unknown images
57+
}
58+
}
59+
}
60+
61+
// Add a fudge factor to account for the fact that tiktoken is not always accurate
62+
return Math.ceil(totalTokens * TOKEN_FUDGE_FACTOR)
63+
}
64+
}

src/api/providers/bedrock.ts

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ import {
66
} from "@aws-sdk/client-bedrock-runtime"
77
import { fromIni } from "@aws-sdk/credential-providers"
88
import { Anthropic } from "@anthropic-ai/sdk"
9-
import { ApiHandler, SingleCompletionHandler } from "../"
9+
import { SingleCompletionHandler } from "../"
1010
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
1111
import { ApiStream } from "../transform/stream"
1212
import { convertToBedrockConverseMessages } from "../transform/bedrock-converse-format"
13+
import { BaseProvider } from "./base-provider"
1314

1415
const BEDROCK_DEFAULT_TEMPERATURE = 0.3
1516

@@ -46,11 +47,12 @@ export interface StreamEvent {
4647
}
4748
}
4849

49-
export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
50-
private options: ApiHandlerOptions
50+
export class AwsBedrockHandler extends BaseProvider implements SingleCompletionHandler {
51+
protected options: ApiHandlerOptions
5152
private client: BedrockRuntimeClient
5253

5354
constructor(options: ApiHandlerOptions) {
55+
super()
5456
this.options = options
5557

5658
const clientConfig: BedrockRuntimeClientConfig = {
@@ -74,7 +76,7 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
7476
this.client = new BedrockRuntimeClient(clientConfig)
7577
}
7678

77-
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
79+
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
7880
const modelConfig = this.getModel()
7981

8082
// Handle cross-region inference
@@ -205,7 +207,7 @@ export class AwsBedrockHandler implements ApiHandler, SingleCompletionHandler {
205207
}
206208
}
207209

208-
getModel(): { id: BedrockModelId | string; info: ModelInfo } {
210+
override getModel(): { id: BedrockModelId | string; info: ModelInfo } {
209211
const modelId = this.options.apiModelId
210212
if (modelId) {
211213
// For tests, allow any model ID

src/api/providers/gemini.ts

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
import { Anthropic } from "@anthropic-ai/sdk"
22
import { GoogleGenerativeAI } from "@google/generative-ai"
3-
import { ApiHandler, SingleCompletionHandler } from "../"
3+
import { SingleCompletionHandler } from "../"
44
import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../../shared/api"
55
import { convertAnthropicMessageToGemini } from "../transform/gemini-format"
66
import { ApiStream } from "../transform/stream"
7+
import { BaseProvider } from "./base-provider"
78

89
const GEMINI_DEFAULT_TEMPERATURE = 0
910

10-
export class GeminiHandler implements ApiHandler, SingleCompletionHandler {
11-
private options: ApiHandlerOptions
11+
export class GeminiHandler extends BaseProvider implements SingleCompletionHandler {
12+
protected options: ApiHandlerOptions
1213
private client: GoogleGenerativeAI
1314

1415
constructor(options: ApiHandlerOptions) {
16+
super()
1517
this.options = options
1618
this.client = new GoogleGenerativeAI(options.geminiApiKey ?? "not-provided")
1719
}
1820

19-
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
21+
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
2022
const model = this.client.getGenerativeModel({
2123
model: this.getModel().id,
2224
systemInstruction: systemPrompt,
@@ -44,7 +46,7 @@ export class GeminiHandler implements ApiHandler, SingleCompletionHandler {
4446
}
4547
}
4648

47-
getModel(): { id: GeminiModelId; info: ModelInfo } {
49+
override getModel(): { id: GeminiModelId; info: ModelInfo } {
4850
const modelId = this.options.apiModelId
4951
if (modelId && modelId in geminiModels) {
5052
const id = modelId as GeminiModelId

src/api/providers/glama.ts

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,39 @@ import { ApiHandlerOptions, ModelInfo, glamaDefaultModelId, glamaDefaultModelInf
66
import { parseApiPrice } from "../../utils/cost"
77
import { convertToOpenAiMessages } from "../transform/openai-format"
88
import { ApiStream } from "../transform/stream"
9-
import { ApiHandler, SingleCompletionHandler } from "../"
9+
import { SingleCompletionHandler } from "../"
10+
import { BaseProvider } from "./base-provider"
1011

1112
const GLAMA_DEFAULT_TEMPERATURE = 0
1213

13-
export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
14-
private options: ApiHandlerOptions
14+
export class GlamaHandler extends BaseProvider implements SingleCompletionHandler {
15+
protected options: ApiHandlerOptions
1516
private client: OpenAI
1617

1718
constructor(options: ApiHandlerOptions) {
19+
super()
1820
this.options = options
1921
const baseURL = "https://glama.ai/api/gateway/openai/v1"
2022
const apiKey = this.options.glamaApiKey ?? "not-provided"
2123
this.client = new OpenAI({ baseURL, apiKey })
2224
}
2325

24-
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
26+
private supportsTemperature(): boolean {
27+
return !this.getModel().id.startsWith("openai/o3-mini")
28+
}
29+
30+
override getModel(): { id: string; info: ModelInfo } {
31+
const modelId = this.options.glamaModelId
32+
const modelInfo = this.options.glamaModelInfo
33+
34+
if (modelId && modelInfo) {
35+
return { id: modelId, info: modelInfo }
36+
}
37+
38+
return { id: glamaDefaultModelId, info: glamaDefaultModelInfo }
39+
}
40+
41+
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
2542
// Convert Anthropic messages to OpenAI format
2643
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
2744
{ role: "system", content: systemPrompt },
@@ -152,21 +169,6 @@ export class GlamaHandler implements ApiHandler, SingleCompletionHandler {
152169
}
153170
}
154171

155-
private supportsTemperature(): boolean {
156-
return !this.getModel().id.startsWith("openai/o3-mini")
157-
}
158-
159-
getModel(): { id: string; info: ModelInfo } {
160-
const modelId = this.options.glamaModelId
161-
const modelInfo = this.options.glamaModelInfo
162-
163-
if (modelId && modelInfo) {
164-
return { id: modelId, info: modelInfo }
165-
}
166-
167-
return { id: glamaDefaultModelId, info: glamaDefaultModelInfo }
168-
}
169-
170172
async completePrompt(prompt: string): Promise<string> {
171173
try {
172174
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {

src/api/providers/lmstudio.ts

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,28 @@ import { Anthropic } from "@anthropic-ai/sdk"
22
import OpenAI from "openai"
33
import axios from "axios"
44

5-
import { ApiHandler, SingleCompletionHandler } from "../"
5+
import { SingleCompletionHandler } from "../"
66
import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../../shared/api"
77
import { convertToOpenAiMessages } from "../transform/openai-format"
88
import { ApiStream } from "../transform/stream"
9+
import { BaseProvider } from "./base-provider"
910

1011
const LMSTUDIO_DEFAULT_TEMPERATURE = 0
1112

12-
export class LmStudioHandler implements ApiHandler, SingleCompletionHandler {
13-
private options: ApiHandlerOptions
13+
export class LmStudioHandler extends BaseProvider implements SingleCompletionHandler {
14+
protected options: ApiHandlerOptions
1415
private client: OpenAI
1516

1617
constructor(options: ApiHandlerOptions) {
18+
super()
1719
this.options = options
1820
this.client = new OpenAI({
1921
baseURL: (this.options.lmStudioBaseUrl || "http://localhost:1234") + "/v1",
2022
apiKey: "noop",
2123
})
2224
}
2325

24-
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
26+
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
2527
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
2628
{ role: "system", content: systemPrompt },
2729
...convertToOpenAiMessages(messages),
@@ -51,7 +53,7 @@ export class LmStudioHandler implements ApiHandler, SingleCompletionHandler {
5153
}
5254
}
5355

54-
getModel(): { id: string; info: ModelInfo } {
56+
override getModel(): { id: string; info: ModelInfo } {
5557
return {
5658
id: this.options.lmStudioModelId || "",
5759
info: openAiModelInfoSaneDefaults,

src/api/providers/mistral.ts

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { Anthropic } from "@anthropic-ai/sdk"
22
import { Mistral } from "@mistralai/mistralai"
3-
import { ApiHandler } from "../"
3+
import { SingleCompletionHandler } from "../"
44
import {
55
ApiHandlerOptions,
66
mistralDefaultModelId,
@@ -13,14 +13,16 @@ import {
1313
} from "../../shared/api"
1414
import { convertToMistralMessages } from "../transform/mistral-format"
1515
import { ApiStream } from "../transform/stream"
16+
import { BaseProvider } from "./base-provider"
1617

1718
const MISTRAL_DEFAULT_TEMPERATURE = 0
1819

19-
export class MistralHandler implements ApiHandler {
20-
private options: ApiHandlerOptions
20+
export class MistralHandler extends BaseProvider implements SingleCompletionHandler {
21+
protected options: ApiHandlerOptions
2122
private client: Mistral
2223

2324
constructor(options: ApiHandlerOptions) {
25+
super()
2426
if (!options.mistralApiKey) {
2527
throw new Error("Mistral API key is required")
2628
}
@@ -48,7 +50,7 @@ export class MistralHandler implements ApiHandler {
4850
return "https://api.mistral.ai"
4951
}
5052

51-
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
53+
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
5254
const response = await this.client.chat.stream({
5355
model: this.options.apiModelId || mistralDefaultModelId,
5456
messages: [{ role: "system", content: systemPrompt }, ...convertToMistralMessages(messages)],
@@ -81,7 +83,7 @@ export class MistralHandler implements ApiHandler {
8183
}
8284
}
8385

84-
getModel(): { id: MistralModelId; info: ModelInfo } {
86+
override getModel(): { id: MistralModelId; info: ModelInfo } {
8587
const modelId = this.options.apiModelId
8688
if (modelId && modelId in mistralModels) {
8789
const id = modelId as MistralModelId

0 commit comments

Comments
 (0)