@@ -99,6 +99,30 @@ func canHandleEquals(schema sql.Schema, tableName string, eq *expression.Equals)
99
99
return false
100
100
}
101
101
102
+ // canHandleIn returns whether the given in expression can be handled as a selector.
103
+ // For that to happen, the left side must be a GetField expression and the right
104
+ // side must be a Tuple expression with only Literal expressions as children.
105
+ // The GetField expr must exist in the schema and match the given table name.
106
+ func canHandleIn (schema sql.Schema , tableName string , in * expression.In ) bool {
107
+ left , ok := in .Left ().(* expression.GetField )
108
+ if ! ok || ! schema .Contains (left .Name ()) || left .Table () != tableName {
109
+ return false
110
+ }
111
+
112
+ right , ok := in .Right ().(expression.Tuple )
113
+ if ! ok {
114
+ return false
115
+ }
116
+
117
+ for _ , elem := range right {
118
+ if _ , ok := elem .(* expression.Literal ); ! ok {
119
+ return false
120
+ }
121
+ }
122
+
123
+ return true
124
+ }
125
+
102
126
// getEqualityValues returns the field and value of the literal in the
103
127
// given equality expression.
104
128
func getEqualityValues (eq * expression.Equals ) (string , interface {}, error ) {
@@ -119,6 +143,36 @@ func getEqualityValues(eq *expression.Equals) (string, interface{}, error) {
119
143
return "" , "" , nil
120
144
}
121
145
146
+ // getInValues returns the field and values of the literals in the
147
+ // given in expression.
148
+ func getInValues (in * expression.In ) (string , []interface {}, error ) {
149
+ left , ok := in .Left ().(* expression.GetField )
150
+ if ! ok {
151
+ return "" , nil , nil
152
+ }
153
+
154
+ right , ok := in .Right ().(expression.Tuple )
155
+ if ! ok {
156
+ return "" , nil , nil
157
+ }
158
+
159
+ var values = make ([]interface {}, len (right ))
160
+ for i , elem := range right {
161
+ lit , ok := elem .(* expression.Literal )
162
+ if ! ok {
163
+ return "" , nil , nil
164
+ }
165
+
166
+ var err error
167
+ values [i ], err = lit .Eval (nil , nil )
168
+ if err != nil {
169
+ return "" , nil , err
170
+ }
171
+ }
172
+
173
+ return left .Name (), values , nil
174
+ }
175
+
122
176
// handledFilters returns the set of filters that can be handled with the given
123
177
// schema. That is, all expressions that don't have GetField expressions that
124
178
// don't belong to the given schema.
@@ -174,15 +228,36 @@ func classifyFilters(
174
228
continue
175
229
}
176
230
}
177
- // TODO: handle IN when it's implemented
231
+ case * expression.In :
232
+ if canHandleIn (schema , table , f ) {
233
+ field , vals , err := getInValues (f )
234
+ if err != nil {
235
+ return nil , nil , err
236
+ }
237
+
238
+ if stringContains (handledCols , field ) {
239
+ selectors [field ] = append (selectors [field ], selector (vals ))
240
+ continue
241
+ }
242
+ }
178
243
case * expression.Or :
179
244
exprs := unfoldOrs (f )
180
245
// check all unfolded exprs can be handled, if not we have to
181
246
// resort to treating them as conditions
182
247
valid := true
183
248
for _ , e := range exprs {
184
- f , ok := e .(* expression.Equals )
185
- if ! ok || ! canHandleEquals (schema , table , f ) {
249
+ switch e := e .(type ) {
250
+ case * expression.Equals :
251
+ if ! canHandleEquals (schema , table , e ) {
252
+ valid = false
253
+ break
254
+ }
255
+ case * expression.In :
256
+ if ! canHandleIn (schema , table , e ) {
257
+ valid = false
258
+ break
259
+ }
260
+ default :
186
261
valid = false
187
262
break
188
263
}
@@ -200,11 +275,9 @@ func classifyFilters(
200
275
}
201
276
202
277
for k , v := range sels {
203
- var values = make (selector , len (v ))
204
- for i , val := range v {
205
- if len (val ) > 0 {
206
- values [i ] = val [0 ]
207
- }
278
+ var values selector
279
+ for _ , vals := range v {
280
+ values = append (values , vals ... )
208
281
}
209
282
selectors [k ] = append (selectors [k ], values )
210
283
}
0 commit comments