Skip to content

Commit 7d7e64f

Browse files
committed
feat(vectorstore): add kotlin filter expression DSL for vector store
- Implement FilterExpressionDsl class for building complex filter expressions - Add support for various filter operations: eq, ne, gt, gte, lt, lte, in, nin - Implement logical operators: and, or, not - Add unit tests for filter expression DSL - Update project dependencies to include Kotlin stdlib Signed-off-by: Ahoo Wang <[email protected]>
1 parent 91937d0 commit 7d7e64f

File tree

3 files changed

+554
-0
lines changed

3 files changed

+554
-0
lines changed

spring-ai-vector-store/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@
5353
<optional>true</optional>
5454
</dependency>
5555

56+
<dependency>
57+
<groupId>org.jetbrains.kotlin</groupId>
58+
<artifactId>kotlin-stdlib</artifactId>
59+
<optional>true</optional>
60+
</dependency>
5661
<!-- test dependencies -->
5762

5863
<dependency>
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
package org.springframework.ai.vectorstore.filter
2+
3+
/**
4+
* DSL (Domain Specific Language) class for building filter expressions.
5+
* This class allows for the creation of complex filter expressions using a fluent API.
6+
*
7+
* example:
8+
* ``` kotlin
9+
* filterExpression {
10+
* "field1".eq("value1")
11+
* and {
12+
* "field2".ne("value2")
13+
* "field3".gt(3)
14+
* "field4".gte(4)
15+
* or {
16+
* "field5".lt(5)
17+
* "field6".lte(6)
18+
* }
19+
* }
20+
* }
21+
* ```
22+
* @author Ahoo Wang
23+
*/
24+
@Suppress("TooManyFunctions")
25+
class FilterExpressionDsl {
26+
// List to store individual filter expressions
27+
private val expressions: MutableList<Filter.Operand> = mutableListOf()
28+
29+
/**
30+
* Adds a filter expression to the list of expressions.
31+
*
32+
* @param expression The filter expression to add.
33+
*/
34+
fun expression(expression: Filter.Operand) {
35+
expressions.add(expression)
36+
}
37+
38+
/**
39+
* Creates and adds a filter expression based on the provided type, key, and value.
40+
*
41+
* @param type The type of the filter expression (e.g., EQ, GT, etc.).
42+
* @param key The key to filter on.
43+
* @param value The value to compare against.
44+
*/
45+
private fun expression(type: Filter.ExpressionType, key: String, value: Any) {
46+
expression(Filter.Expression(type, Filter.Key(key), Filter.Value(value)))
47+
}
48+
49+
/**
50+
* Assembles a list of filter operands into a single filter expression of the specified type.
51+
* Optionally, the resulting expression can be grouped.
52+
*
53+
* @param type The type of the filter expression (e.g., AND, OR).
54+
* @param group If true, the resulting expression will be wrapped in a group.
55+
* @return The assembled filter expression.
56+
*/
57+
@Suppress("ReturnCount")
58+
private fun List<Filter.Operand>.assembly(
59+
type: Filter.ExpressionType,
60+
group: Boolean = false,
61+
): Filter.Operand {
62+
if (this.size == 1) {
63+
return this[0]
64+
}
65+
var exp = Filter.Expression(type, this[0], this[1])
66+
for (i in 2..this.size - 1) {
67+
exp = Filter.Expression(type, exp, this[i])
68+
}
69+
if (!group) {
70+
return exp
71+
}
72+
73+
return Filter.Group(exp)
74+
}
75+
76+
/**
77+
* Combines a list of filter operands using the AND operator.
78+
*
79+
* @param group If true, the resulting expression will be wrapped in a group.
80+
* @return The combined filter expression.
81+
*/
82+
private fun List<Filter.Operand>.and(group: Boolean = false): Filter.Operand {
83+
return this.assembly(Filter.ExpressionType.AND, group)
84+
}
85+
86+
/**
87+
* Combines a list of filter operands using the OR operator.
88+
*
89+
* @param group If true, the resulting expression will be wrapped in a group.
90+
* @return The combined filter expression.
91+
*/
92+
private fun List<Filter.Operand>.or(group: Boolean = false): Filter.Operand {
93+
return this.assembly(Filter.ExpressionType.OR, group)
94+
}
95+
96+
/**
97+
* Creates a new FilterExpressionDsl instance, applies the provided block to it,
98+
* and combines the resulting expressions using the AND operator.
99+
*
100+
* @param group If true, the resulting expression will be wrapped in a group.
101+
* @param block A lambda that defines the filter expressions to be combined.
102+
*/
103+
fun and(group: Boolean = false, block: FilterExpressionDsl.() -> Unit) {
104+
val dsl = FilterExpressionDsl()
105+
dsl.block()
106+
if (dsl.expressions.isEmpty()) {
107+
return
108+
}
109+
expression(dsl.expressions.and(group))
110+
}
111+
112+
/**
113+
* Creates a new FilterExpressionDsl instance, applies the provided block to it,
114+
* and combines the resulting expressions using the OR operator.
115+
*
116+
* @param group If true, the resulting expression will be wrapped in a group.
117+
* @param block A lambda that defines the filter expressions to be combined.
118+
*/
119+
fun or(group: Boolean = false, block: FilterExpressionDsl.() -> Unit) {
120+
val dsl = FilterExpressionDsl()
121+
dsl.block()
122+
if (dsl.expressions.isEmpty()) {
123+
return
124+
}
125+
expression(dsl.expressions.or(group))
126+
}
127+
128+
/**
129+
* Creates a new FilterExpressionDsl instance, applies the provided block to it,
130+
* and negates the resulting expression using the NOT operator.
131+
*
132+
* @param block A lambda that defines the filter expression to be negated.
133+
*/
134+
fun not(block: FilterExpressionDsl.() -> Unit) {
135+
val dsl = FilterExpressionDsl()
136+
dsl.block()
137+
if (dsl.expressions.isEmpty()) {
138+
return
139+
}
140+
val nestedCondition = dsl.build()
141+
expression(Filter.Expression(Filter.ExpressionType.NOT, nestedCondition))
142+
}
143+
144+
/**
145+
* Creates an equality filter expression.
146+
*
147+
* @param value The value to compare against.
148+
*/
149+
infix fun String.eq(value: Any) {
150+
expression(Filter.ExpressionType.EQ, this, value)
151+
}
152+
153+
/**
154+
* Creates a non-equality filter expression.
155+
*
156+
* @param value The value to compare against.
157+
*/
158+
infix fun String.ne(value: Any) {
159+
expression(Filter.ExpressionType.NE, this, value)
160+
}
161+
162+
/**
163+
* Creates a greater-than filter expression.
164+
*
165+
* @param value The value to compare against.
166+
*/
167+
infix fun String.gt(value: Any) {
168+
expression(Filter.ExpressionType.GT, this, value)
169+
}
170+
171+
/**
172+
* Creates a greater-than-or-equal-to filter expression.
173+
*
174+
* @param value The value to compare against.
175+
*/
176+
infix fun String.gte(value: Any) {
177+
expression(Filter.ExpressionType.GTE, this, value)
178+
}
179+
180+
/**
181+
* Creates a less-than filter expression.
182+
*
183+
* @param value The value to compare against.
184+
*/
185+
infix fun String.lt(value: Any) {
186+
expression(Filter.ExpressionType.LT, this, value)
187+
}
188+
189+
/**
190+
* Creates a less-than-or-equal-to filter expression.
191+
*
192+
* @param value The value to compare against.
193+
*/
194+
infix fun String.lte(value: Any) {
195+
expression(Filter.ExpressionType.LTE, this, value)
196+
}
197+
198+
/**
199+
* Creates an "in" filter expression, checking if the key is in the provided list of values.
200+
*
201+
* @param key The key to filter on.
202+
* @param values The list of values to check against.
203+
*/
204+
fun isIn(key: String, vararg values: Any) {
205+
expression(Filter.ExpressionType.IN, key, values.toList())
206+
}
207+
208+
/**
209+
* Creates an "in" filter expression, checking if the key is in the provided list of values.
210+
*
211+
* @param values The list of values to check against.
212+
*/
213+
infix fun String.isIn(values: List<Any>) {
214+
expression(Filter.ExpressionType.IN, this, values)
215+
}
216+
217+
/**
218+
* Creates a "not in" filter expression, checking if the key is not in the provided list of values.
219+
*
220+
* @param key The key to filter on.
221+
* @param values The list of values to check against.
222+
*/
223+
fun nin(key: String, vararg values: Any) {
224+
expression(Filter.ExpressionType.NIN, key, values.toList())
225+
}
226+
227+
/**
228+
* Creates a "not in" filter expression, checking if the key is not in the provided list of values.
229+
*
230+
* @param values The list of values to check against.
231+
*/
232+
infix fun String.nin(values: List<Any>) {
233+
expression(Filter.ExpressionType.NIN, this, values)
234+
}
235+
236+
/**
237+
* Converts a filter operand into a filter expression.
238+
*
239+
* @return The filter expression.
240+
* @throws IllegalArgumentException if the operand type is unsupported.
241+
*/
242+
private fun Filter.Operand.asExpression(): Filter.Expression {
243+
return when (this) {
244+
is Filter.Expression -> this
245+
is Filter.Group -> this.content
246+
else -> throw IllegalArgumentException("Unsupported operand type: ${this::class.java.name}")
247+
}
248+
}
249+
250+
/**
251+
* Builds and returns the final filter expression.
252+
*
253+
* @return The final filter expression, or null if no expressions were added.
254+
*/
255+
@Suppress("ReturnCount")
256+
fun build(): Filter.Expression? {
257+
if (expressions.isEmpty()) {
258+
return null
259+
}
260+
if (expressions.size == 1) {
261+
return expressions[0].asExpression()
262+
}
263+
return expressions.and().asExpression()
264+
}
265+
266+
}
267+
268+
/**
269+
* Generates a filter expression based on the provided DSL block.
270+
*
271+
* This function takes a lambda expression as a parameter, which is executed in the context of `FilterExpressionDsl`,
272+
* allowing users to define filter conditions in a DSL manner.
273+
* Finally, the function returns a `Filter.Expression` object
274+
* representing the constructed filter expression.
275+
*
276+
* @param block A lambda expression executed in the context of `FilterExpressionDsl`, used to define filter conditions.
277+
* @return Returns a `Filter.Expression` object representing the filter expression constructed from the DSL.
278+
* If no conditions are defined in the DSL, it returns `null`.
279+
*/
280+
fun filterExpression(block: FilterExpressionDsl.() -> Unit): Filter.Expression? {
281+
val dsl = FilterExpressionDsl()
282+
dsl.block()
283+
return dsl.build()
284+
}

0 commit comments

Comments
 (0)