Skip to content

Commit 31d9e2a

Browse files
icecrasher321Vikhyath Mondretigreptile-apps[bot]waleedlatif1
authored
feat(kb-tags-filtering): filter kb docs using pre-set tags (#648)
* feat(knowledge-base): tag filtering * fix lint * remove migrations * fix migrations * fix lint * Update apps/sim/app/api/knowledge/search/route.ts Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * fix lint * fix lint * UI --------- Co-authored-by: Vikhyath Mondreti <[email protected]> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Co-authored-by: Waleed Latif <[email protected]>
1 parent e5080fe commit 31d9e2a

File tree

19 files changed

+6221
-142
lines changed

19 files changed

+6221
-142
lines changed

apps/sim/app/api/__test-utils__/utils.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,13 @@ export function mockKnowledgeSchemas() {
619619
processingCompletedAt: 'processing_completed_at',
620620
processingError: 'processing_error',
621621
enabled: 'enabled',
622+
tag1: 'tag1',
623+
tag2: 'tag2',
624+
tag3: 'tag3',
625+
tag4: 'tag4',
626+
tag5: 'tag5',
627+
tag6: 'tag6',
628+
tag7: 'tag7',
622629
uploadedAt: 'uploaded_at',
623630
deletedAt: 'deleted_at',
624631
},
@@ -631,6 +638,13 @@ export function mockKnowledgeSchemas() {
631638
embedding: 'embedding',
632639
tokenCount: 'token_count',
633640
characterCount: 'character_count',
641+
tag1: 'tag1',
642+
tag2: 'tag2',
643+
tag3: 'tag3',
644+
tag4: 'tag4',
645+
tag5: 'tag5',
646+
tag6: 'tag6',
647+
tag7: 'tag7',
634648
createdAt: 'created_at',
635649
},
636650
}))

apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.ts

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,13 @@ export async function GET(
118118
enabled: embedding.enabled,
119119
startOffset: embedding.startOffset,
120120
endOffset: embedding.endOffset,
121-
metadata: embedding.metadata,
121+
tag1: embedding.tag1,
122+
tag2: embedding.tag2,
123+
tag3: embedding.tag3,
124+
tag4: embedding.tag4,
125+
tag5: embedding.tag5,
126+
tag6: embedding.tag6,
127+
tag7: embedding.tag7,
122128
createdAt: embedding.createdAt,
123129
updatedAt: embedding.updatedAt,
124130
})
@@ -239,7 +245,14 @@ export async function POST(
239245
embeddingModel: 'text-embedding-3-small',
240246
startOffset: 0, // Manual chunks don't have document offsets
241247
endOffset: validatedData.content.length,
242-
metadata: { manual: true }, // Mark as manually created
248+
// Inherit tags from parent document
249+
tag1: doc.tag1,
250+
tag2: doc.tag2,
251+
tag3: doc.tag3,
252+
tag4: doc.tag4,
253+
tag5: doc.tag5,
254+
tag6: doc.tag6,
255+
tag7: doc.tag7,
243256
enabled: validatedData.enabled,
244257
createdAt: now,
245258
updatedAt: now,

apps/sim/app/api/knowledge/[id]/documents/route.ts

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,14 @@ const CreateDocumentSchema = z.object({
153153
fileUrl: z.string().url('File URL must be valid'),
154154
fileSize: z.number().min(1, 'File size must be greater than 0'),
155155
mimeType: z.string().min(1, 'MIME type is required'),
156+
// Document tags for filtering
157+
tag1: z.string().optional(),
158+
tag2: z.string().optional(),
159+
tag3: z.string().optional(),
160+
tag4: z.string().optional(),
161+
tag5: z.string().optional(),
162+
tag6: z.string().optional(),
163+
tag7: z.string().optional(),
156164
})
157165

158166
const BulkCreateDocumentsSchema = z.object({
@@ -229,6 +237,14 @@ export async function GET(req: NextRequest, { params }: { params: Promise<{ id:
229237
processingError: document.processingError,
230238
enabled: document.enabled,
231239
uploadedAt: document.uploadedAt,
240+
// Include tags in response
241+
tag1: document.tag1,
242+
tag2: document.tag2,
243+
tag3: document.tag3,
244+
tag4: document.tag4,
245+
tag5: document.tag5,
246+
tag6: document.tag6,
247+
tag7: document.tag7,
232248
})
233249
.from(document)
234250
.where(and(...whereConditions))
@@ -298,6 +314,14 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
298314
processingStatus: 'pending' as const,
299315
enabled: true,
300316
uploadedAt: now,
317+
// Include tags from upload
318+
tag1: docData.tag1 || null,
319+
tag2: docData.tag2 || null,
320+
tag3: docData.tag3 || null,
321+
tag4: docData.tag4 || null,
322+
tag5: docData.tag5 || null,
323+
tag6: docData.tag6 || null,
324+
tag7: docData.tag7 || null,
301325
}
302326

303327
await tx.insert(document).values(newDocument)
@@ -372,6 +396,14 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
372396
characterCount: 0,
373397
enabled: true,
374398
uploadedAt: now,
399+
// Include tags from upload
400+
tag1: validatedData.tag1 || null,
401+
tag2: validatedData.tag2 || null,
402+
tag3: validatedData.tag3 || null,
403+
tag4: validatedData.tag4 || null,
404+
tag5: validatedData.tag5 || null,
405+
tag6: validatedData.tag6 || null,
406+
tag7: validatedData.tag7 || null,
375407
}
376408

377409
await db.insert(document).values(newDocument)

apps/sim/app/api/knowledge/search/route.ts

Lines changed: 92 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,30 @@ import { embedding, knowledgeBase } from '@/db/schema'
1010

1111
const logger = createLogger('VectorSearchAPI')
1212

13+
// Helper function to create tag filters
14+
function getTagFilters(filters: Record<string, string>, embedding: any) {
15+
return Object.entries(filters).map(([key, value]) => {
16+
switch (key) {
17+
case 'tag1':
18+
return sql`LOWER(${embedding.tag1}) = LOWER(${value})`
19+
case 'tag2':
20+
return sql`LOWER(${embedding.tag2}) = LOWER(${value})`
21+
case 'tag3':
22+
return sql`LOWER(${embedding.tag3}) = LOWER(${value})`
23+
case 'tag4':
24+
return sql`LOWER(${embedding.tag4}) = LOWER(${value})`
25+
case 'tag5':
26+
return sql`LOWER(${embedding.tag5}) = LOWER(${value})`
27+
case 'tag6':
28+
return sql`LOWER(${embedding.tag6}) = LOWER(${value})`
29+
case 'tag7':
30+
return sql`LOWER(${embedding.tag7}) = LOWER(${value})`
31+
default:
32+
return sql`1=1` // No-op for unknown keys
33+
}
34+
})
35+
}
36+
1337
class APIError extends Error {
1438
public status: number
1539

@@ -27,6 +51,18 @@ const VectorSearchSchema = z.object({
2751
]),
2852
query: z.string().min(1, 'Search query is required'),
2953
topK: z.number().min(1).max(100).default(10),
54+
// Tag filters for pre-filtering
55+
filters: z
56+
.object({
57+
tag1: z.string().optional(),
58+
tag2: z.string().optional(),
59+
tag3: z.string().optional(),
60+
tag4: z.string().optional(),
61+
tag5: z.string().optional(),
62+
tag6: z.string().optional(),
63+
tag7: z.string().optional(),
64+
})
65+
.optional(),
3066
})
3167

3268
async function generateSearchEmbedding(query: string): Promise<number[]> {
@@ -102,7 +138,8 @@ async function executeParallelQueries(
102138
knowledgeBaseIds: string[],
103139
queryVector: string,
104140
topK: number,
105-
distanceThreshold: number
141+
distanceThreshold: number,
142+
filters?: Record<string, string>
106143
) {
107144
const parallelLimit = Math.ceil(topK / knowledgeBaseIds.length) + 5
108145

@@ -113,7 +150,13 @@ async function executeParallelQueries(
113150
content: embedding.content,
114151
documentId: embedding.documentId,
115152
chunkIndex: embedding.chunkIndex,
116-
metadata: embedding.metadata,
153+
tag1: embedding.tag1,
154+
tag2: embedding.tag2,
155+
tag3: embedding.tag3,
156+
tag4: embedding.tag4,
157+
tag5: embedding.tag5,
158+
tag6: embedding.tag6,
159+
tag7: embedding.tag7,
117160
distance: sql<number>`${embedding.embedding} <=> ${queryVector}::vector`.as('distance'),
118161
knowledgeBaseId: embedding.knowledgeBaseId,
119162
})
@@ -122,7 +165,9 @@ async function executeParallelQueries(
122165
and(
123166
eq(embedding.knowledgeBaseId, kbId),
124167
eq(embedding.enabled, true),
125-
sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`
168+
sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`,
169+
// Apply tag filters if provided (case-insensitive)
170+
...(filters ? getTagFilters(filters, embedding) : [])
126171
)
127172
)
128173
.orderBy(sql`${embedding.embedding} <=> ${queryVector}::vector`)
@@ -139,23 +184,53 @@ async function executeSingleQuery(
139184
knowledgeBaseIds: string[],
140185
queryVector: string,
141186
topK: number,
142-
distanceThreshold: number
187+
distanceThreshold: number,
188+
filters?: Record<string, string>
143189
) {
144190
return await db
145191
.select({
146192
id: embedding.id,
147193
content: embedding.content,
148194
documentId: embedding.documentId,
149195
chunkIndex: embedding.chunkIndex,
150-
metadata: embedding.metadata,
196+
tag1: embedding.tag1,
197+
tag2: embedding.tag2,
198+
tag3: embedding.tag3,
199+
tag4: embedding.tag4,
200+
tag5: embedding.tag5,
201+
tag6: embedding.tag6,
202+
tag7: embedding.tag7,
151203
distance: sql<number>`${embedding.embedding} <=> ${queryVector}::vector`.as('distance'),
152204
})
153205
.from(embedding)
154206
.where(
155207
and(
156208
inArray(embedding.knowledgeBaseId, knowledgeBaseIds),
157209
eq(embedding.enabled, true),
158-
sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`
210+
sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}`,
211+
// Apply tag filters if provided (case-insensitive)
212+
...(filters
213+
? Object.entries(filters).map(([key, value]) => {
214+
switch (key) {
215+
case 'tag1':
216+
return sql`LOWER(${embedding.tag1}) = LOWER(${value})`
217+
case 'tag2':
218+
return sql`LOWER(${embedding.tag2}) = LOWER(${value})`
219+
case 'tag3':
220+
return sql`LOWER(${embedding.tag3}) = LOWER(${value})`
221+
case 'tag4':
222+
return sql`LOWER(${embedding.tag4}) = LOWER(${value})`
223+
case 'tag5':
224+
return sql`LOWER(${embedding.tag5}) = LOWER(${value})`
225+
case 'tag6':
226+
return sql`LOWER(${embedding.tag6}) = LOWER(${value})`
227+
case 'tag7':
228+
return sql`LOWER(${embedding.tag7}) = LOWER(${value})`
229+
default:
230+
return sql`1=1` // No-op for unknown keys
231+
}
232+
})
233+
: [])
159234
)
160235
)
161236
.orderBy(sql`${embedding.embedding} <=> ${queryVector}::vector`)
@@ -231,7 +306,8 @@ export async function POST(request: NextRequest) {
231306
foundKbIds,
232307
queryVector,
233308
validatedData.topK,
234-
strategy.distanceThreshold
309+
strategy.distanceThreshold,
310+
validatedData.filters
235311
)
236312
results = mergeAndRankResults(parallelResults, validatedData.topK)
237313
} else {
@@ -240,7 +316,8 @@ export async function POST(request: NextRequest) {
240316
foundKbIds,
241317
queryVector,
242318
validatedData.topK,
243-
strategy.distanceThreshold
319+
strategy.distanceThreshold,
320+
validatedData.filters
244321
)
245322
}
246323

@@ -252,7 +329,13 @@ export async function POST(request: NextRequest) {
252329
content: result.content,
253330
documentId: result.documentId,
254331
chunkIndex: result.chunkIndex,
255-
metadata: result.metadata,
332+
tag1: result.tag1,
333+
tag2: result.tag2,
334+
tag3: result.tag3,
335+
tag4: result.tag4,
336+
tag5: result.tag5,
337+
tag6: result.tag6,
338+
tag7: result.tag7,
256339
similarity: 1 - result.distance,
257340
})),
258341
query: validatedData.query,

apps/sim/app/api/knowledge/utils.ts

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ export interface DocumentData {
7373
enabled: boolean
7474
deletedAt?: Date | null
7575
uploadedAt: Date
76+
// Document tags
77+
tag1?: string | null
78+
tag2?: string | null
79+
tag3?: string | null
80+
tag4?: string | null
81+
tag5?: string | null
82+
tag6?: string | null
83+
tag7?: string | null
7684
}
7785

7886
export interface EmbeddingData {
@@ -88,7 +96,14 @@ export interface EmbeddingData {
8896
embeddingModel: string
8997
startOffset: number
9098
endOffset: number
91-
metadata: unknown
99+
// Tag fields for filtering
100+
tag1?: string | null
101+
tag2?: string | null
102+
tag3?: string | null
103+
tag4?: string | null
104+
tag5?: string | null
105+
tag6?: string | null
106+
tag7?: string | null
92107
enabled: boolean
93108
createdAt: Date
94109
updatedAt: Date
@@ -445,7 +460,26 @@ export async function processDocumentAsync(
445460
const chunkTexts = processed.chunks.map((chunk) => chunk.text)
446461
const embeddings = chunkTexts.length > 0 ? await generateEmbeddings(chunkTexts) : []
447462

448-
logger.info(`[${documentId}] Embeddings generated, updating document record`)
463+
logger.info(`[${documentId}] Embeddings generated, fetching document tags`)
464+
465+
// Fetch document to get tags
466+
const documentRecord = await db
467+
.select({
468+
tag1: document.tag1,
469+
tag2: document.tag2,
470+
tag3: document.tag3,
471+
tag4: document.tag4,
472+
tag5: document.tag5,
473+
tag6: document.tag6,
474+
tag7: document.tag7,
475+
})
476+
.from(document)
477+
.where(eq(document.id, documentId))
478+
.limit(1)
479+
480+
const documentTags = documentRecord[0] || {}
481+
482+
logger.info(`[${documentId}] Creating embedding records with tags`)
449483

450484
const embeddingRecords = processed.chunks.map((chunk, chunkIndex) => ({
451485
id: crypto.randomUUID(),
@@ -460,7 +494,14 @@ export async function processDocumentAsync(
460494
embeddingModel: 'text-embedding-3-small',
461495
startOffset: chunk.metadata.startIndex,
462496
endOffset: chunk.metadata.endIndex,
463-
metadata: {},
497+
// Copy tags from document
498+
tag1: documentTags.tag1,
499+
tag2: documentTags.tag2,
500+
tag3: documentTags.tag3,
501+
tag4: documentTags.tag4,
502+
tag5: documentTags.tag5,
503+
tag6: documentTags.tag6,
504+
tag7: documentTags.tag7,
464505
createdAt: now,
465506
updatedAt: now,
466507
}))

0 commit comments

Comments
 (0)