Skip to content

Commit 3887733

Browse files
authored
feat(kb): added cost for kb blocks (#654)
* added cost to kb upload + search * small fix * ack PR comments
1 parent 614d826 commit 3887733

File tree

12 files changed

+957
-22
lines changed

12 files changed

+957
-22
lines changed

apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.test.ts

Lines changed: 413 additions & 0 deletions
Large diffs are not rendered by default.

apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.ts

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ import { type NextRequest, NextResponse } from 'next/server'
44
import { z } from 'zod'
55
import { getSession } from '@/lib/auth'
66
import { createLogger } from '@/lib/logs/console-logger'
7+
import { estimateTokenCount } from '@/lib/tokenization/estimators'
78
import { getUserId } from '@/app/api/auth/oauth/utils'
89
import { db } from '@/db'
910
import { document, embedding } from '@/db/schema'
11+
import { calculateCost } from '@/providers/utils'
1012
import { checkDocumentAccess, generateEmbeddings } from '../../../../utils'
1113

1214
const logger = createLogger('DocumentChunksAPI')
@@ -217,6 +219,9 @@ export async function POST(
217219
logger.info(`[${requestId}] Generating embedding for manual chunk`)
218220
const embeddings = await generateEmbeddings([validatedData.content])
219221

222+
// Calculate accurate token count for both database storage and cost calculation
223+
const tokenCount = estimateTokenCount(validatedData.content, 'openai')
224+
220225
const chunkId = crypto.randomUUID()
221226
const now = new Date()
222227

@@ -240,7 +245,7 @@ export async function POST(
240245
chunkHash: crypto.createHash('sha256').update(validatedData.content).digest('hex'),
241246
content: validatedData.content,
242247
contentLength: validatedData.content.length,
243-
tokenCount: Math.ceil(validatedData.content.length / 4), // Rough approximation
248+
tokenCount: tokenCount.count, // Use accurate token count
244249
embedding: embeddings[0],
245250
embeddingModel: 'text-embedding-3-small',
246251
startOffset: 0, // Manual chunks don't have document offsets
@@ -276,9 +281,38 @@ export async function POST(
276281

277282
logger.info(`[${requestId}] Manual chunk created: ${chunkId} in document ${documentId}`)
278283

284+
// Calculate cost for the embedding (with fallback if calculation fails)
285+
let cost = null
286+
try {
287+
cost = calculateCost('text-embedding-3-small', tokenCount.count, 0, false)
288+
} catch (error) {
289+
logger.warn(`[${requestId}] Failed to calculate cost for chunk upload`, {
290+
error: error instanceof Error ? error.message : 'Unknown error',
291+
})
292+
// Continue without cost information rather than failing the upload
293+
}
294+
279295
return NextResponse.json({
280296
success: true,
281-
data: newChunk,
297+
data: {
298+
...newChunk,
299+
...(cost
300+
? {
301+
cost: {
302+
input: cost.input,
303+
output: cost.output,
304+
total: cost.total,
305+
tokens: {
306+
prompt: tokenCount.count,
307+
completion: 0,
308+
total: tokenCount.count,
309+
},
310+
model: 'text-embedding-3-small',
311+
pricing: cost.pricing,
312+
},
313+
}
314+
: {}),
315+
},
282316
})
283317
} catch (validationError) {
284318
if (validationError instanceof z.ZodError) {

apps/sim/app/api/knowledge/route.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import { document, knowledgeBase } from '@/db/schema'
88

99
const logger = createLogger('KnowledgeBaseAPI')
1010

11-
// Schema for knowledge base creation
1211
const CreateKnowledgeBaseSchema = z.object({
1312
name: z.string().min(1, 'Name is required'),
1413
description: z.string().optional(),

apps/sim/app/api/knowledge/search/route.test.ts

Lines changed: 143 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,23 @@ vi.mock('@/lib/documents/utils', () => ({
3434
retryWithExponentialBackoff: vi.fn().mockImplementation((fn) => fn()),
3535
}))
3636

37+
vi.mock('@/lib/tokenization/estimators', () => ({
38+
estimateTokenCount: vi.fn().mockReturnValue({ count: 521 }),
39+
}))
40+
41+
vi.mock('@/providers/utils', () => ({
42+
calculateCost: vi.fn().mockReturnValue({
43+
input: 0.00001042,
44+
output: 0,
45+
total: 0.00001042,
46+
pricing: {
47+
input: 0.02,
48+
output: 0,
49+
updatedAt: '2025-07-10',
50+
},
51+
}),
52+
}))
53+
3754
mockConsoleLogger()
3855

3956
describe('Knowledge Search API Route', () => {
@@ -206,7 +223,7 @@ describe('Knowledge Search API Route', () => {
206223
expect(mockGetUserId).toHaveBeenCalledWith(expect.any(String), 'workflow-123')
207224
})
208225

209-
it('should return unauthorized for unauthenticated request', async () => {
226+
it.concurrent('should return unauthorized for unauthenticated request', async () => {
210227
mockGetUserId.mockResolvedValue(null)
211228

212229
const req = createMockRequest('POST', validSearchData)
@@ -218,7 +235,7 @@ describe('Knowledge Search API Route', () => {
218235
expect(data.error).toBe('Unauthorized')
219236
})
220237

221-
it('should return not found for workflow that does not exist', async () => {
238+
it.concurrent('should return not found for workflow that does not exist', async () => {
222239
const workflowData = {
223240
...validSearchData,
224241
workflowId: 'nonexistent-workflow',
@@ -268,7 +285,7 @@ describe('Knowledge Search API Route', () => {
268285
expect(data.error).toBe('Knowledge bases not found: kb-missing')
269286
})
270287

271-
it('should validate search parameters', async () => {
288+
it.concurrent('should validate search parameters', async () => {
272289
const invalidData = {
273290
knowledgeBaseIds: '', // Empty string
274291
query: '', // Empty query
@@ -314,7 +331,7 @@ describe('Knowledge Search API Route', () => {
314331
expect(data.data.topK).toBe(10) // Default value
315332
})
316333

317-
it('should handle OpenAI API errors', async () => {
334+
it.concurrent('should handle OpenAI API errors', async () => {
318335
mockGetUserId.mockResolvedValue('user-123')
319336
mockDbChain.limit.mockResolvedValueOnce(mockKnowledgeBases)
320337

@@ -334,7 +351,7 @@ describe('Knowledge Search API Route', () => {
334351
expect(data.error).toBe('Failed to perform vector search')
335352
})
336353

337-
it('should handle missing OpenAI API key', async () => {
354+
it.concurrent('should handle missing OpenAI API key', async () => {
338355
vi.doMock('@/lib/env', () => ({
339356
env: {
340357
OPENAI_API_KEY: undefined,
@@ -353,7 +370,7 @@ describe('Knowledge Search API Route', () => {
353370
expect(data.error).toBe('Failed to perform vector search')
354371
})
355372

356-
it('should handle database errors during search', async () => {
373+
it.concurrent('should handle database errors during search', async () => {
357374
mockGetUserId.mockResolvedValue('user-123')
358375
mockDbChain.limit.mockResolvedValueOnce(mockKnowledgeBases)
359376
mockDbChain.limit.mockRejectedValueOnce(new Error('Database error'))
@@ -375,7 +392,7 @@ describe('Knowledge Search API Route', () => {
375392
expect(data.error).toBe('Failed to perform vector search')
376393
})
377394

378-
it('should handle invalid OpenAI response format', async () => {
395+
it.concurrent('should handle invalid OpenAI response format', async () => {
379396
mockGetUserId.mockResolvedValue('user-123')
380397
mockDbChain.limit.mockResolvedValueOnce(mockKnowledgeBases)
381398

@@ -395,5 +412,124 @@ describe('Knowledge Search API Route', () => {
395412
expect(response.status).toBe(500)
396413
expect(data.error).toBe('Failed to perform vector search')
397414
})
415+
416+
describe('Cost tracking', () => {
417+
it.concurrent('should include cost information in successful search response', async () => {
418+
mockGetUserId.mockResolvedValue('user-123')
419+
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases)
420+
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
421+
422+
mockFetch.mockResolvedValue({
423+
ok: true,
424+
json: () =>
425+
Promise.resolve({
426+
data: [{ embedding: mockEmbedding }],
427+
}),
428+
})
429+
430+
const req = createMockRequest('POST', validSearchData)
431+
const { POST } = await import('./route')
432+
const response = await POST(req)
433+
const data = await response.json()
434+
435+
expect(response.status).toBe(200)
436+
expect(data.success).toBe(true)
437+
438+
// Verify cost information is included
439+
expect(data.data.cost).toBeDefined()
440+
expect(data.data.cost.input).toBe(0.00001042)
441+
expect(data.data.cost.output).toBe(0)
442+
expect(data.data.cost.total).toBe(0.00001042)
443+
expect(data.data.cost.tokens).toEqual({
444+
prompt: 521,
445+
completion: 0,
446+
total: 521,
447+
})
448+
expect(data.data.cost.model).toBe('text-embedding-3-small')
449+
expect(data.data.cost.pricing).toEqual({
450+
input: 0.02,
451+
output: 0,
452+
updatedAt: '2025-07-10',
453+
})
454+
})
455+
456+
it('should call cost calculation functions with correct parameters', async () => {
457+
const { estimateTokenCount } = await import('@/lib/tokenization/estimators')
458+
const { calculateCost } = await import('@/providers/utils')
459+
460+
mockGetUserId.mockResolvedValue('user-123')
461+
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases)
462+
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
463+
464+
mockFetch.mockResolvedValue({
465+
ok: true,
466+
json: () =>
467+
Promise.resolve({
468+
data: [{ embedding: mockEmbedding }],
469+
}),
470+
})
471+
472+
const req = createMockRequest('POST', validSearchData)
473+
const { POST } = await import('./route')
474+
await POST(req)
475+
476+
// Verify token estimation was called with correct parameters
477+
expect(estimateTokenCount).toHaveBeenCalledWith('test search query', 'openai')
478+
479+
// Verify cost calculation was called with correct parameters
480+
expect(calculateCost).toHaveBeenCalledWith('text-embedding-3-small', 521, 0, false)
481+
})
482+
483+
it('should handle cost calculation with different query lengths', async () => {
484+
const { estimateTokenCount } = await import('@/lib/tokenization/estimators')
485+
const { calculateCost } = await import('@/providers/utils')
486+
487+
// Mock different token count for longer query
488+
vi.mocked(estimateTokenCount).mockReturnValue({
489+
count: 1042,
490+
confidence: 'high',
491+
provider: 'openai',
492+
method: 'precise',
493+
})
494+
vi.mocked(calculateCost).mockReturnValue({
495+
input: 0.00002084,
496+
output: 0,
497+
total: 0.00002084,
498+
pricing: {
499+
input: 0.02,
500+
output: 0,
501+
updatedAt: '2025-07-10',
502+
},
503+
})
504+
505+
const longQueryData = {
506+
...validSearchData,
507+
query:
508+
'This is a much longer search query with many more tokens to test cost calculation accuracy',
509+
}
510+
511+
mockGetUserId.mockResolvedValue('user-123')
512+
mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases)
513+
mockDbChain.limit.mockResolvedValueOnce(mockSearchResults)
514+
515+
mockFetch.mockResolvedValue({
516+
ok: true,
517+
json: () =>
518+
Promise.resolve({
519+
data: [{ embedding: mockEmbedding }],
520+
}),
521+
})
522+
523+
const req = createMockRequest('POST', longQueryData)
524+
const { POST } = await import('./route')
525+
const response = await POST(req)
526+
const data = await response.json()
527+
528+
expect(response.status).toBe(200)
529+
expect(data.data.cost.input).toBe(0.00002084)
530+
expect(data.data.cost.tokens.prompt).toBe(1042)
531+
expect(calculateCost).toHaveBeenCalledWith('text-embedding-3-small', 1042, 0, false)
532+
})
533+
})
398534
})
399535
})

apps/sim/app/api/knowledge/search/route.ts

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ import { z } from 'zod'
44
import { retryWithExponentialBackoff } from '@/lib/documents/utils'
55
import { env } from '@/lib/env'
66
import { createLogger } from '@/lib/logs/console-logger'
7+
import { estimateTokenCount } from '@/lib/tokenization/estimators'
78
import { getUserId } from '@/app/api/auth/oauth/utils'
89
import { db } from '@/db'
910
import { embedding, knowledgeBase } from '@/db/schema'
11+
import { calculateCost } from '@/providers/utils'
1012

1113
const logger = createLogger('VectorSearchAPI')
1214

13-
// Helper function to create tag filters
1415
function getTagFilters(filters: Record<string, string>, embedding: any) {
1516
return Object.entries(filters).map(([key, value]) => {
1617
switch (key) {
@@ -51,7 +52,6 @@ const VectorSearchSchema = z.object({
5152
]),
5253
query: z.string().min(1, 'Search query is required'),
5354
topK: z.number().min(1).max(100).default(10),
54-
// Tag filters for pre-filtering
5555
filters: z
5656
.object({
5757
tag1: z.string().optional(),
@@ -166,7 +166,6 @@ async function executeParallelQueries(
166166
eq(embedding.knowledgeBaseId, kbId),
167167
eq(embedding.enabled, true),
168168
sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`,
169-
// Apply tag filters if provided (case-insensitive)
170169
...(filters ? getTagFilters(filters, embedding) : [])
171170
)
172171
)
@@ -208,7 +207,6 @@ async function executeSingleQuery(
208207
inArray(embedding.knowledgeBaseId, knowledgeBaseIds),
209208
eq(embedding.enabled, true),
210209
sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`,
211-
// Apply tag filters if provided (case-insensitive)
212210
...(filters
213211
? Object.entries(filters).map(([key, value]) => {
214212
switch (key) {
@@ -321,6 +319,19 @@ export async function POST(request: NextRequest) {
321319
)
322320
}
323321

322+
// Calculate cost for the embedding (with fallback if calculation fails)
323+
let cost = null
324+
let tokenCount = null
325+
try {
326+
tokenCount = estimateTokenCount(validatedData.query, 'openai')
327+
cost = calculateCost('text-embedding-3-small', tokenCount.count, 0, false)
328+
} catch (error) {
329+
logger.warn(`[${requestId}] Failed to calculate cost for search query`, {
330+
error: error instanceof Error ? error.message : 'Unknown error',
331+
})
332+
// Continue without cost information rather than failing the search
333+
}
334+
324335
return NextResponse.json({
325336
success: true,
326337
data: {
@@ -343,6 +354,22 @@ export async function POST(request: NextRequest) {
343354
knowledgeBaseId: foundKbIds[0],
344355
topK: validatedData.topK,
345356
totalResults: results.length,
357+
...(cost && tokenCount
358+
? {
359+
cost: {
360+
input: cost.input,
361+
output: cost.output,
362+
total: cost.total,
363+
tokens: {
364+
prompt: tokenCount.count,
365+
completion: 0,
366+
total: tokenCount.count,
367+
},
368+
model: 'text-embedding-3-small',
369+
pricing: cost.pricing,
370+
},
371+
}
372+
: {}),
346373
},
347374
})
348375
} catch (validationError) {

0 commit comments

Comments
 (0)