|
1 |
| -use std::any::Any; |
2 | 1 | use std::collections::HashMap;
|
3 | 2 | use std::fmt::Formatter;
|
4 |
| -use std::sync::Arc; |
5 |
| - |
6 |
| -use arrow::array::{Array as _, ArrayRef, AsArray as _, BooleanArray, ListArray, as_list_array}; |
7 |
| -use arrow::compute::cast; |
8 |
| -use arrow::datatypes::{DataType, Field, TimeUnit}; |
9 |
| -use datafusion::common::ExprSchema as _; |
10 |
| -use datafusion::common::{DFSchema, Result as DataFusionResult, exec_err}; |
11 |
| -use datafusion::logical_expr::{ |
12 |
| - ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ScalarFunctionArgs, ScalarUDF, |
13 |
| - ScalarUDFImpl, Signature, TypeSignature, Volatility, |
14 |
| -}; |
15 |
| -use datafusion::prelude::{Column, Expr, col}; |
16 | 3 |
|
17 |
| -use re_types_core::datatypes::TimeInt; |
18 |
| -use re_types_core::{Component as _, FIELD_METADATA_KEY_COMPONENT_TYPE, Loggable as _}; |
| 4 | +use arrow::datatypes::{DataType, Field}; |
| 5 | +use datafusion::common::{DFSchema, ExprSchema as _}; |
| 6 | +use datafusion::prelude::{Column, Expr}; |
| 7 | + |
| 8 | +use re_types_core::{Component as _, FIELD_METADATA_KEY_COMPONENT_TYPE}; |
19 | 9 |
|
20 | 10 | use super::{
|
21 | 11 | FloatFilter, IntFilter, NonNullableBooleanFilter, NullableBooleanFilter, StringFilter,
|
@@ -217,283 +207,10 @@ impl FilterKind {
|
217 | 207 | Self::NonNullableBoolean(boolean_filter) => {
|
218 | 208 | boolean_filter.as_filter_expression(column, field)
|
219 | 209 | }
|
220 |
| - |
221 |
| - Self::Int(_) | Self::Float(_) | Self::Timestamp(_) => { |
222 |
| - let udf = FilterKindUdf::new(self.clone()); |
223 |
| - let udf = ScalarUDF::new_from_impl(udf); |
224 |
| - |
225 |
| - Ok(udf.call(vec![col(column.clone())])) |
226 |
| - } |
227 |
| - |
| 210 | + Self::Int(int_filter) => Ok(int_filter.as_filter_expression(column)), |
| 211 | + Self::Float(float_filter) => Ok(float_filter.as_filter_expression(column)), |
228 | 212 | Self::String(string_filter) => Ok(string_filter.as_filter_expression(column)),
|
| 213 | + Self::Timestamp(timestamp_filter) => Ok(timestamp_filter.as_filter_expression(column)), |
229 | 214 | }
|
230 | 215 | }
|
231 | 216 | }
|
232 |
| - |
233 |
| -/// Custom UDF for evaluating some filters kinds. |
234 |
| -//TODO(ab): consider splitting the vectorized filtering part from the `any`/`all` aggregation. |
235 |
| -#[derive(Debug, Clone)] |
236 |
| -struct FilterKindUdf { |
237 |
| - op: FilterKind, |
238 |
| - signature: Signature, |
239 |
| -} |
240 |
| - |
241 |
| -impl FilterKindUdf { |
242 |
| - fn new(op: FilterKind) -> Self { |
243 |
| - let type_signature = match op { |
244 |
| - FilterKind::Int(_) | FilterKind::Float(_) => TypeSignature::Numeric(1), |
245 |
| - |
246 |
| - FilterKind::Timestamp(_) => TypeSignature::Any(1), |
247 |
| - |
248 |
| - // TODO(ab): add support for other filter types? |
249 |
| - // FilterKind::StringContains(_) => TypeSignature::String(1), |
250 |
| - // FilterKind::BooleanEquals(_) => TypeSignature::Exact(vec![DataType::Boolean]), |
251 |
| - _ => { |
252 |
| - debug_assert!(false, "Invalid filter kind"); |
253 |
| - TypeSignature::Any(1) |
254 |
| - } |
255 |
| - }; |
256 |
| - |
257 |
| - let signature = Signature::one_of( |
258 |
| - vec![ |
259 |
| - type_signature, |
260 |
| - TypeSignature::ArraySignature(ArrayFunctionSignature::Array { |
261 |
| - arguments: vec![ArrayFunctionArgument::Array], |
262 |
| - array_coercion: None, |
263 |
| - }), |
264 |
| - ], |
265 |
| - Volatility::Immutable, |
266 |
| - ); |
267 |
| - |
268 |
| - Self { op, signature } |
269 |
| - } |
270 |
| - |
271 |
| - /// Check if the provided _primitive_ type is valid. |
272 |
| - fn is_valid_primitive_input_type(&self, data_type: &DataType) -> bool { |
273 |
| - match data_type { |
274 |
| - _data_type if _data_type == &TimeInt::arrow_datatype() => { |
275 |
| - // TimeInt special case: we allow filtering by timestamp on Int64 columns |
276 |
| - matches!(&self.op, FilterKind::Int(_) | FilterKind::Timestamp(_)) |
277 |
| - } |
278 |
| - |
279 |
| - _data_type if data_type.is_integer() => { |
280 |
| - matches!(&self.op, FilterKind::Int(_)) |
281 |
| - } |
282 |
| - |
283 |
| - //TODO(ab): float16 support (use `is_floating()`) |
284 |
| - DataType::Float32 | DataType::Float64 => { |
285 |
| - matches!(&self.op, FilterKind::Float(_)) |
286 |
| - } |
287 |
| - |
288 |
| - DataType::Timestamp(_, _) => { |
289 |
| - matches!(&self.op, FilterKind::Timestamp(_)) |
290 |
| - } |
291 |
| - |
292 |
| - _ => false, |
293 |
| - } |
294 |
| - } |
295 |
| - |
296 |
| - fn is_valid_input_type(&self, data_type: &DataType) -> bool { |
297 |
| - match data_type { |
298 |
| - DataType::List(field) | DataType::ListView(field) => { |
299 |
| - // Note: we do not support double nested types |
300 |
| - self.is_valid_primitive_input_type(field.data_type()) |
301 |
| - } |
302 |
| - |
303 |
| - //TODO(ab): support other containers |
304 |
| - _ => self.is_valid_primitive_input_type(data_type), |
305 |
| - } |
306 |
| - } |
307 |
| - |
308 |
| - fn invoke_primitive_array(&self, array: &ArrayRef) -> DataFusionResult<BooleanArray> { |
309 |
| - macro_rules! int_float_case { |
310 |
| - ($op_arm:ident, $conv_fun:ident, $op:expr) => {{ |
311 |
| - let FilterKind::$op_arm(filter) = &$op else { |
312 |
| - return exec_err!( |
313 |
| - "Incompatible filter kind and data types {:?} - {}", |
314 |
| - $op, |
315 |
| - array.data_type() |
316 |
| - ); |
317 |
| - }; |
318 |
| - let array = datafusion::common::cast::$conv_fun(array)?; |
319 |
| - |
320 |
| - #[allow(trivial_numeric_casts)] |
321 |
| - let result: BooleanArray = array |
322 |
| - .iter() |
323 |
| - .map(|x| { |
324 |
| - let Some(rhs_value) = filter.rhs_value() else { |
325 |
| - return Some(true); |
326 |
| - }; |
327 |
| - |
328 |
| - x.map(|x| filter.comparison_operator().apply(x, rhs_value as _)) |
329 |
| - }) |
330 |
| - .collect(); |
331 |
| - |
332 |
| - Ok(result) |
333 |
| - }}; |
334 |
| - } |
335 |
| - |
336 |
| - macro_rules! timestamp_case { |
337 |
| - ($apply_fun:ident, $conv_fun:ident, $op:expr) => {{ |
338 |
| - let FilterKind::Timestamp(timestamp_filter) = &$op else { |
339 |
| - return exec_err!( |
340 |
| - "Incompatible filter and data types {:?} - {}", |
341 |
| - $op, |
342 |
| - array.data_type() |
343 |
| - ); |
344 |
| - }; |
345 |
| - let array = datafusion::common::cast::$conv_fun(array)?; |
346 |
| - let result: BooleanArray = array |
347 |
| - .iter() |
348 |
| - .map(|x| x.map(|v| timestamp_filter.$apply_fun(v))) |
349 |
| - .collect(); |
350 |
| - |
351 |
| - Ok(result) |
352 |
| - }}; |
353 |
| - } |
354 |
| - |
355 |
| - match array.data_type() { |
356 |
| - DataType::Int8 => int_float_case!(Int, as_int8_array, self.op), |
357 |
| - DataType::Int16 => int_float_case!(Int, as_int16_array, self.op), |
358 |
| - DataType::Int32 => int_float_case!(Int, as_int32_array, self.op), |
359 |
| - |
360 |
| - // Note: although `TimeInt` is Int64, by now we casted it to `Timestamp`, see |
361 |
| - // `invoke_list_array` impl. |
362 |
| - DataType::Int64 => int_float_case!(Int, as_int64_array, self.op), |
363 |
| - DataType::UInt8 => int_float_case!(Int, as_uint8_array, self.op), |
364 |
| - DataType::UInt16 => int_float_case!(Int, as_uint16_array, self.op), |
365 |
| - DataType::UInt32 => int_float_case!(Int, as_uint32_array, self.op), |
366 |
| - DataType::UInt64 => int_float_case!(Int, as_uint64_array, self.op), |
367 |
| - |
368 |
| - //TODO(ab): float16 support |
369 |
| - DataType::Float32 => int_float_case!(Float, as_float32_array, self.op), |
370 |
| - DataType::Float64 => int_float_case!(Float, as_float64_array, self.op), |
371 |
| - |
372 |
| - DataType::Timestamp(TimeUnit::Second, _) => { |
373 |
| - timestamp_case!(apply_seconds, as_timestamp_second_array, self.op) |
374 |
| - } |
375 |
| - DataType::Timestamp(TimeUnit::Millisecond, _) => { |
376 |
| - timestamp_case!(apply_milliseconds, as_timestamp_millisecond_array, self.op) |
377 |
| - } |
378 |
| - DataType::Timestamp(TimeUnit::Microsecond, _) => { |
379 |
| - timestamp_case!(apply_microseconds, as_timestamp_microsecond_array, self.op) |
380 |
| - } |
381 |
| - DataType::Timestamp(TimeUnit::Nanosecond, _) => { |
382 |
| - timestamp_case!(apply_nanoseconds, as_timestamp_nanosecond_array, self.op) |
383 |
| - } |
384 |
| - |
385 |
| - _ => { |
386 |
| - exec_err!("Unsupported data type {}", array.data_type()) |
387 |
| - } |
388 |
| - } |
389 |
| - } |
390 |
| - |
391 |
| - fn invoke_list_array(&self, list_array: &ListArray) -> DataFusionResult<BooleanArray> { |
392 |
| - // TimeInt special case: we cast the Int64 array TimestampNano |
393 |
| - let cast_list_array = if list_array.values().data_type() == &TimeInt::arrow_datatype() |
394 |
| - && matches!(self.op, FilterKind::Timestamp(_)) |
395 |
| - { |
396 |
| - let DataType::List(field) = list_array.data_type() else { |
397 |
| - unreachable!("ListArray must have a List data type"); |
398 |
| - }; |
399 |
| - let new_field = Arc::new(Arc::unwrap_or_clone(field.clone()).with_data_type( |
400 |
| - DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), |
401 |
| - )); |
402 |
| - |
403 |
| - Some(cast(list_array, &DataType::List(new_field))?) |
404 |
| - } else { |
405 |
| - None |
406 |
| - }; |
407 |
| - |
408 |
| - let cast_list_array = cast_list_array |
409 |
| - .as_ref() |
410 |
| - .map(|array| array.as_list()) |
411 |
| - .unwrap_or(list_array); |
412 |
| - |
413 |
| - // TODO(ab): we probably should do this in two steps: |
414 |
| - // 1) Convert the list array to a bool array (with same offsets and nulls) |
415 |
| - // 2) Apply the ANY (or, in the future, another) semantics to "merge" each row's instances |
416 |
| - // into the final bool. |
417 |
| - // TODO(ab): duplicated code with the other UDF, pliz unify. |
418 |
| - cast_list_array |
419 |
| - .iter() |
420 |
| - .map(|maybe_row| { |
421 |
| - maybe_row.map(|row| { |
422 |
| - // Note: we know this is a primitive array because we explicitly disallow nested |
423 |
| - // lists or other containers. |
424 |
| - let element_results = self.invoke_primitive_array(&row)?; |
425 |
| - |
426 |
| - // `ANY` semantics happening here |
427 |
| - Ok(element_results |
428 |
| - .iter() |
429 |
| - .map(|x| x.unwrap_or(false)) |
430 |
| - .find(|x| *x) |
431 |
| - .unwrap_or(false)) |
432 |
| - }) |
433 |
| - }) |
434 |
| - .map(|x| x.transpose()) |
435 |
| - .collect::<DataFusionResult<BooleanArray>>() |
436 |
| - } |
437 |
| -} |
438 |
| - |
439 |
| -impl ScalarUDFImpl for FilterKindUdf { |
440 |
| - fn as_any(&self) -> &dyn Any { |
441 |
| - self |
442 |
| - } |
443 |
| - |
444 |
| - fn name(&self) -> &'static str { |
445 |
| - "filter_kind" |
446 |
| - } |
447 |
| - |
448 |
| - fn signature(&self) -> &Signature { |
449 |
| - &self.signature |
450 |
| - } |
451 |
| - |
452 |
| - fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> { |
453 |
| - if arg_types.len() != 1 { |
454 |
| - return exec_err!( |
455 |
| - "expected a single column of input, received {}", |
456 |
| - arg_types.len() |
457 |
| - ); |
458 |
| - } |
459 |
| - |
460 |
| - if self.is_valid_input_type(&arg_types[0]) { |
461 |
| - Ok(DataType::Boolean) |
462 |
| - } else { |
463 |
| - exec_err!( |
464 |
| - "input data type {} not supported for filter {:?}", |
465 |
| - arg_types[0], |
466 |
| - self.op |
467 |
| - ) |
468 |
| - } |
469 |
| - } |
470 |
| - |
471 |
| - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> { |
472 |
| - let ColumnarValue::Array(input_array) = &args.args[0] else { |
473 |
| - return exec_err!("expected array inputs, not scalar values"); |
474 |
| - }; |
475 |
| - |
476 |
| - let results = match input_array.data_type() { |
477 |
| - DataType::List(_field) => { |
478 |
| - let array = as_list_array(input_array); |
479 |
| - self.invoke_list_array(array)? |
480 |
| - } |
481 |
| - |
482 |
| - //TODO(ab): float16 support (use `is_floating()`) |
483 |
| - DataType::Float32 | DataType::Float64 | DataType::Timestamp(_, _) => { |
484 |
| - self.invoke_primitive_array(input_array)? |
485 |
| - } |
486 |
| - |
487 |
| - _data_type if _data_type.is_integer() => self.invoke_primitive_array(input_array)?, |
488 |
| - |
489 |
| - _ => { |
490 |
| - return exec_err!( |
491 |
| - "DataType not implemented for FilterKindUdf: {}", |
492 |
| - input_array.data_type() |
493 |
| - ); |
494 |
| - } |
495 |
| - }; |
496 |
| - |
497 |
| - Ok(ColumnarValue::Array(Arc::new(results))) |
498 |
| - } |
499 |
| -} |
0 commit comments