Skip to content

Commit 1e81cd6

Browse files
authored
fix(kb): added tiktoken for embedding token estimation (#1616)
* fix(kb): added tiktoken for embedding token estimation * added missing mock
1 parent ec73e2e commit 1e81cd6

File tree

10 files changed

+250
-51
lines changed

10 files changed

+250
-51
lines changed

apps/sim/app/api/files/upload/route.ts

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,22 @@ import {
99
InvalidRequestError,
1010
} from '@/app/api/files/utils'
1111

12-
// Allowlist of permitted file extensions for security
1312
const ALLOWED_EXTENSIONS = new Set([
14-
// Documents
1513
'pdf',
1614
'doc',
1715
'docx',
1816
'txt',
1917
'md',
20-
// Images (safe formats)
2118
'png',
2219
'jpg',
2320
'jpeg',
2421
'gif',
25-
// Data files
2622
'csv',
2723
'xlsx',
2824
'xls',
25+
'json',
26+
'yaml',
27+
'yml',
2928
])
3029

3130
/**
@@ -50,19 +49,16 @@ export async function POST(request: NextRequest) {
5049

5150
const formData = await request.formData()
5251

53-
// Check if multiple files are being uploaded or a single file
5452
const files = formData.getAll('file') as File[]
5553

5654
if (!files || files.length === 0) {
5755
throw new InvalidRequestError('No files provided')
5856
}
5957

60-
// Get optional scoping parameters for execution-scoped storage
6158
const workflowId = formData.get('workflowId') as string | null
6259
const executionId = formData.get('executionId') as string | null
6360
const workspaceId = formData.get('workspaceId') as string | null
6461

65-
// Log storage mode
6662
const usingCloudStorage = isUsingCloudStorage()
6763
logger.info(`Using storage mode: ${usingCloudStorage ? 'Cloud' : 'Local'} for file upload`)
6864

@@ -74,7 +70,6 @@ export async function POST(request: NextRequest) {
7470

7571
const uploadResults = []
7672

77-
// Process each file
7873
for (const file of files) {
7974
const originalName = file.name
8075

@@ -88,9 +83,7 @@ export async function POST(request: NextRequest) {
8883
const bytes = await file.arrayBuffer()
8984
const buffer = Buffer.from(bytes)
9085

91-
// For execution-scoped files, use the dedicated execution file storage
9286
if (workflowId && executionId) {
93-
// Use the dedicated execution file storage system
9487
const { uploadExecutionFile } = await import('@/lib/workflows/execution-file-storage')
9588
const userFile = await uploadExecutionFile(
9689
{
@@ -107,13 +100,10 @@ export async function POST(request: NextRequest) {
107100
continue
108101
}
109102

110-
// Upload to cloud or local storage using the standard uploadFile function
111103
try {
112104
logger.info(`Uploading file: ${originalName}`)
113105
const result = await uploadFile(buffer, originalName, file.type, file.size)
114106

115-
// Generate a presigned URL for cloud storage with appropriate expiry
116-
// Regular files get 24 hours (execution files are handled above)
117107
let presignedUrl: string | undefined
118108
if (usingCloudStorage) {
119109
try {
@@ -144,7 +134,6 @@ export async function POST(request: NextRequest) {
144134
}
145135
}
146136

147-
// Return all file information
148137
if (uploadResults.length === 1) {
149138
return NextResponse.json(uploadResults[0])
150139
}
@@ -155,7 +144,6 @@ export async function POST(request: NextRequest) {
155144
}
156145
}
157146

158-
// Handle preflight requests
159147
export async function OPTIONS() {
160148
return createOptionsResponse()
161149
}

apps/sim/app/api/knowledge/search/utils.test.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ vi.stubGlobal(
3232

3333
vi.mock('@/lib/env', () => ({
3434
env: {},
35+
getEnv: (key: string) => process.env[key],
3536
isTruthy: (value: string | boolean | number | undefined) =>
3637
typeof value === 'string' ? value === 'true' || value === '1' : Boolean(value),
3738
}))

apps/sim/app/api/knowledge/utils.test.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ vi.mock('drizzle-orm', () => ({
1717

1818
vi.mock('@/lib/env', () => ({
1919
env: { OPENAI_API_KEY: 'test-key' },
20+
getEnv: (key: string) => process.env[key],
2021
isTruthy: (value: string | boolean | number | undefined) =>
2122
typeof value === 'string' ? value === 'true' || value === '1' : Boolean(value),
2223
}))

apps/sim/lib/chunkers/json-yaml-chunker.ts

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,29 @@
1+
import * as yaml from 'js-yaml'
2+
import { createLogger } from '@/lib/logs/console/logger'
3+
import { getAccurateTokenCount } from '@/lib/tokenization'
14
import { estimateTokenCount } from '@/lib/tokenization/estimators'
25
import type { Chunk, ChunkerOptions } from './types'
36

7+
const logger = createLogger('JsonYamlChunker')
8+
49
function getTokenCount(text: string): number {
5-
const estimate = estimateTokenCount(text)
6-
return estimate.count
10+
try {
11+
return getAccurateTokenCount(text, 'text-embedding-3-small')
12+
} catch (error) {
13+
logger.warn('Tiktoken failed, falling back to estimation')
14+
const estimate = estimateTokenCount(text)
15+
return estimate.count
16+
}
717
}
818

919
/**
1020
* Configuration for JSON/YAML chunking
21+
* Reduced limits to ensure we stay well under OpenAI's 8,191 token limit per embedding request
1122
*/
1223
const JSON_YAML_CHUNKING_CONFIG = {
13-
TARGET_CHUNK_SIZE: 2000, // Target tokens per chunk
24+
TARGET_CHUNK_SIZE: 1000, // Target tokens per chunk
1425
MIN_CHUNK_SIZE: 100, // Minimum tokens per chunk
15-
MAX_CHUNK_SIZE: 3000, // Maximum tokens per chunk
26+
MAX_CHUNK_SIZE: 1500, // Maximum tokens per chunk
1627
MAX_DEPTH_FOR_SPLITTING: 5, // Maximum depth to traverse for splitting
1728
}
1829

@@ -34,7 +45,6 @@ export class JsonYamlChunker {
3445
return true
3546
} catch {
3647
try {
37-
const yaml = require('js-yaml')
3848
yaml.load(content)
3949
return true
4050
} catch {
@@ -48,9 +58,26 @@ export class JsonYamlChunker {
4858
*/
4959
async chunk(content: string): Promise<Chunk[]> {
5060
try {
51-
const data = JSON.parse(content)
52-
return this.chunkStructuredData(data)
61+
let data: any
62+
try {
63+
data = JSON.parse(content)
64+
} catch {
65+
data = yaml.load(content)
66+
}
67+
const chunks = this.chunkStructuredData(data)
68+
69+
const tokenCounts = chunks.map((c) => c.tokenCount)
70+
const totalTokens = tokenCounts.reduce((a, b) => a + b, 0)
71+
const maxTokens = Math.max(...tokenCounts)
72+
const avgTokens = Math.round(totalTokens / chunks.length)
73+
74+
logger.info(
75+
`JSON chunking complete: ${chunks.length} chunks, ${totalTokens} total tokens (avg: ${avgTokens}, max: ${maxTokens})`
76+
)
77+
78+
return chunks
5379
} catch (error) {
80+
logger.info('JSON parsing failed, falling back to text chunking')
5481
return this.chunkAsText(content)
5582
}
5683
}
@@ -102,7 +129,6 @@ export class JsonYamlChunker {
102129
const itemTokens = getTokenCount(itemStr)
103130

104131
if (itemTokens > this.chunkSize) {
105-
// Save current batch if it has items
106132
if (currentBatch.length > 0) {
107133
const batchContent = contextHeader + JSON.stringify(currentBatch, null, 2)
108134
chunks.push({
@@ -134,7 +160,7 @@ export class JsonYamlChunker {
134160
const batchContent = contextHeader + JSON.stringify(currentBatch, null, 2)
135161
chunks.push({
136162
text: batchContent,
137-
tokenCount: currentTokens,
163+
tokenCount: getTokenCount(batchContent),
138164
metadata: {
139165
startIndex: i - currentBatch.length,
140166
endIndex: i - 1,
@@ -152,7 +178,7 @@ export class JsonYamlChunker {
152178
const batchContent = contextHeader + JSON.stringify(currentBatch, null, 2)
153179
chunks.push({
154180
text: batchContent,
155-
tokenCount: currentTokens,
181+
tokenCount: getTokenCount(batchContent),
156182
metadata: {
157183
startIndex: arr.length - currentBatch.length,
158184
endIndex: arr.length - 1,
@@ -194,12 +220,11 @@ export class JsonYamlChunker {
194220
const valueTokens = getTokenCount(valueStr)
195221

196222
if (valueTokens > this.chunkSize) {
197-
// Save current object if it has properties
198223
if (Object.keys(currentObj).length > 0) {
199224
const objContent = JSON.stringify(currentObj, null, 2)
200225
chunks.push({
201226
text: objContent,
202-
tokenCount: currentTokens,
227+
tokenCount: getTokenCount(objContent),
203228
metadata: {
204229
startIndex: 0,
205230
endIndex: objContent.length,
@@ -230,7 +255,7 @@ export class JsonYamlChunker {
230255
const objContent = JSON.stringify(currentObj, null, 2)
231256
chunks.push({
232257
text: objContent,
233-
tokenCount: currentTokens,
258+
tokenCount: getTokenCount(objContent),
234259
metadata: {
235260
startIndex: 0,
236261
endIndex: objContent.length,
@@ -250,7 +275,7 @@ export class JsonYamlChunker {
250275
const objContent = JSON.stringify(currentObj, null, 2)
251276
chunks.push({
252277
text: objContent,
253-
tokenCount: currentTokens,
278+
tokenCount: getTokenCount(objContent),
254279
metadata: {
255280
startIndex: 0,
256281
endIndex: objContent.length,
@@ -262,7 +287,7 @@ export class JsonYamlChunker {
262287
}
263288

264289
/**
265-
* Fall back to text chunking if JSON parsing fails.
290+
* Fall back to text chunking if JSON parsing fails
266291
*/
267292
private async chunkAsText(content: string): Promise<Chunk[]> {
268293
const chunks: Chunk[] = []
@@ -308,7 +333,7 @@ export class JsonYamlChunker {
308333
}
309334

310335
/**
311-
* Static method for chunking JSON/YAML data with default options.
336+
* Static method for chunking JSON/YAML data with default options
312337
*/
313338
static async chunkJsonYaml(content: string, options: ChunkerOptions = {}): Promise<Chunk[]> {
314339
const chunker = new JsonYamlChunker(options)

apps/sim/lib/embeddings/utils.ts

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import { env } from '@/lib/env'
22
import { isRetryableError, retryWithExponentialBackoff } from '@/lib/knowledge/documents/utils'
33
import { createLogger } from '@/lib/logs/console/logger'
4+
import { batchByTokenLimit, getTotalTokenCount } from '@/lib/tokenization'
45

56
const logger = createLogger('EmbeddingUtils')
67

8+
const MAX_TOKENS_PER_REQUEST = 8000
9+
710
export class EmbeddingAPIError extends Error {
811
public status: number
912

@@ -104,35 +107,54 @@ async function callEmbeddingAPI(inputs: string[], config: EmbeddingConfig): Prom
104107
}
105108

106109
/**
107-
* Generate embeddings for multiple texts with simple batching
110+
* Generate embeddings for multiple texts with token-aware batching
111+
* Uses tiktoken for token counting
108112
*/
109113
export async function generateEmbeddings(
110114
texts: string[],
111115
embeddingModel = 'text-embedding-3-small'
112116
): Promise<number[][]> {
113117
const config = getEmbeddingConfig(embeddingModel)
114118

115-
logger.info(`Using ${config.useAzure ? 'Azure OpenAI' : 'OpenAI'} for embeddings generation`)
119+
logger.info(
120+
`Using ${config.useAzure ? 'Azure OpenAI' : 'OpenAI'} for embeddings generation (${texts.length} texts)`
121+
)
122+
123+
const batches = batchByTokenLimit(texts, MAX_TOKENS_PER_REQUEST, embeddingModel)
124+
125+
logger.info(
126+
`Split ${texts.length} texts into ${batches.length} batches (max ${MAX_TOKENS_PER_REQUEST} tokens per batch)`
127+
)
116128

117-
// Reduced batch size to prevent API timeouts and improve reliability
118-
const batchSize = 50 // Reduced from 100 to prevent issues with large documents
119129
const allEmbeddings: number[][] = []
120130

121-
for (let i = 0; i < texts.length; i += batchSize) {
122-
const batch = texts.slice(i, i + batchSize)
123-
const batchEmbeddings = await callEmbeddingAPI(batch, config)
124-
allEmbeddings.push(...batchEmbeddings)
131+
for (let i = 0; i < batches.length; i++) {
132+
const batch = batches[i]
133+
const batchTokenCount = getTotalTokenCount(batch, embeddingModel)
125134

126135
logger.info(
127-
`Generated embeddings for batch ${Math.floor(i / batchSize) + 1}/${Math.ceil(texts.length / batchSize)}`
136+
`Processing batch ${i + 1}/${batches.length}: ${batch.length} texts, ${batchTokenCount} tokens`
128137
)
129138

130-
// Add small delay between batches to avoid rate limiting
131-
if (i + batchSize < texts.length) {
139+
try {
140+
const batchEmbeddings = await callEmbeddingAPI(batch, config)
141+
allEmbeddings.push(...batchEmbeddings)
142+
143+
logger.info(
144+
`Generated ${batchEmbeddings.length} embeddings for batch ${i + 1}/${batches.length}`
145+
)
146+
} catch (error) {
147+
logger.error(`Failed to generate embeddings for batch ${i + 1}:`, error)
148+
throw error
149+
}
150+
151+
if (i + 1 < batches.length) {
132152
await new Promise((resolve) => setTimeout(resolve, 100))
133153
}
134154
}
135155

156+
logger.info(`Successfully generated ${allEmbeddings.length} embeddings total`)
157+
136158
return allEmbeddings
137159
}
138160

0 commit comments

Comments
 (0)