Skip to content

Commit 06bc521

Browse files
authored
Introdce FilterUdf trait to reduce duplication in filter-related UDFs implementations (#11396)
1 parent 1db2047 commit 06bc521

File tree

6 files changed

+308
-600
lines changed

6 files changed

+308
-600
lines changed

crates/viewer/re_dataframe_ui/src/filters/column_filter_ui.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,10 +343,9 @@ mod tests {
343343

344344
use super::super::{
345345
ComparisonOperator, FloatFilter, IntFilter, NonNullableBooleanFilter,
346-
NullableBooleanFilter, StringFilter, StringOperator, TimestampFilter,
346+
NullableBooleanFilter, StringFilter, StringOperator, TimestampFilter, TypedFilter,
347347
};
348348
use super::*;
349-
use crate::filters::TypedFilter;
350349

351350
fn test_cases() -> Vec<(TypedFilter, &'static str)> {
352351
// Let's remember to update this test when adding new filter types.
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
use std::any::Any;
2+
use std::fmt::Debug;
3+
use std::sync::{Arc, OnceLock};
4+
5+
use arrow::array::{ArrayRef, BooleanArray, ListArray, as_list_array};
6+
use arrow::datatypes::DataType;
7+
use datafusion::common::{Result as DataFusionResult, exec_err};
8+
use datafusion::logical_expr::{
9+
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ScalarFunctionArgs, ScalarUDF,
10+
ScalarUDFImpl, Signature, TypeSignature, Volatility,
11+
};
12+
13+
/// Helper trait to make it straightforward to implement a filter UDF.
14+
///
15+
/// Note that a filter UDF is only a _building block_ towards creating a final expression for
16+
/// datafusion. See [`super::Filter::as_filter_expression`] in its implementation for more details.
17+
pub trait FilterUdf: Any + Clone + Debug + Send + Sync {
18+
/// The scalar datafusion type signature for this UDF.
19+
///
20+
/// The list version will automatically be accepted as well, see [`Self::signature`].
21+
const PRIMITIVE_SIGNATURE: TypeSignature;
22+
23+
/// Name for this UDF.
24+
///
25+
/// Keep it simple, it's also used in error. Example: "string" (for a string filter).
26+
fn name(&self) -> &'static str;
27+
28+
/// Which _primitive_ datatypes are supported?
29+
///
30+
/// Emphasis on "primitive". One layer of nested types (aka `List`) is automatically supported
31+
/// as well, see [`Self::is_valid_input_type`].
32+
fn is_valid_primitive_input_type(data_type: &DataType) -> bool;
33+
34+
/// Invoke this UDF on a primitive array.
35+
///
36+
/// Again, nested types (aka `List`) are automatically supported, see [`Self::invoke_list_array`].
37+
fn invoke_primitive_array(&self, array: &ArrayRef) -> DataFusionResult<BooleanArray>;
38+
39+
/// Turn this type into a [`ScalarUDF`].
40+
fn as_scalar_udf(&self) -> ScalarUDF {
41+
ScalarUDF::new_from_impl(FilterUdfWrapper(self.clone()))
42+
}
43+
44+
/// Signature for this UDF.
45+
///
46+
/// See [`ScalarUDFImpl::signature`].
47+
fn signature(&self) -> &Signature {
48+
static SIGNATURE: OnceLock<Signature> = OnceLock::new();
49+
50+
SIGNATURE.get_or_init(|| {
51+
Signature::one_of(
52+
vec![
53+
Self::PRIMITIVE_SIGNATURE,
54+
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
55+
arguments: vec![ArrayFunctionArgument::Array],
56+
array_coercion: None,
57+
}),
58+
],
59+
Volatility::Immutable,
60+
)
61+
})
62+
}
63+
64+
/// Is this datatype valid?
65+
///
66+
/// Delegates to [`Self::is_valid_primitive_input_type`] for non-nested types.
67+
fn is_valid_input_type(data_type: &DataType) -> bool {
68+
match data_type {
69+
DataType::List(field) | DataType::ListView(field) => {
70+
// Note: we do not support double nested types
71+
Self::is_valid_primitive_input_type(field.data_type())
72+
}
73+
74+
//TODO(ab): support other containers
75+
_ => Self::is_valid_primitive_input_type(data_type),
76+
}
77+
}
78+
79+
/// Invoke this UDF for a list array.
80+
///
81+
/// Delegates actual implementation to [`Self::invoke_primitive_array`].
82+
fn invoke_list_array(&self, list_array: &ListArray) -> DataFusionResult<BooleanArray> {
83+
// TODO(ab): we probably should do this in two steps:
84+
// 1) Convert the list array to a bool array (with same offsets and nulls)
85+
// 2) Apply the ANY (or, in the future, another) semantics to "merge" each row's instances
86+
// into the final bool.
87+
list_array
88+
.iter()
89+
.map(|maybe_row| {
90+
maybe_row.map(|row| {
91+
// Note: we know this is a primitive array because we explicitly disallow nested
92+
// lists or other containers.
93+
let element_results = self.invoke_primitive_array(&row)?;
94+
95+
// `ANY` semantics happening here
96+
Ok(element_results
97+
.iter()
98+
.map(|x| x.unwrap_or(false))
99+
.find(|x| *x)
100+
.unwrap_or(false))
101+
})
102+
})
103+
.map(|x| x.transpose())
104+
.collect::<DataFusionResult<BooleanArray>>()
105+
}
106+
}
107+
108+
// shield against orphan rule
109+
#[derive(Debug, Clone)]
110+
struct FilterUdfWrapper<T: FilterUdf + Debug + Send + Sync>(T);
111+
112+
impl<T: FilterUdf + Debug + Send + Sync> ScalarUDFImpl for FilterUdfWrapper<T> {
113+
fn as_any(&self) -> &dyn Any {
114+
&self.0
115+
}
116+
117+
fn name(&self) -> &'static str {
118+
self.0.name()
119+
}
120+
121+
fn signature(&self) -> &Signature {
122+
self.0.signature()
123+
}
124+
125+
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
126+
if arg_types.len() != 1 {
127+
return exec_err!(
128+
"expected a single column of input, received {}",
129+
arg_types.len()
130+
);
131+
}
132+
133+
if T::is_valid_input_type(&arg_types[0]) {
134+
Ok(DataType::Boolean)
135+
} else {
136+
exec_err!(
137+
"input data type {} not supported for {} filter UDF",
138+
arg_types[0],
139+
self.0.name()
140+
)
141+
}
142+
}
143+
144+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
145+
let ColumnarValue::Array(input_array) = &args.args[0] else {
146+
return exec_err!("expected array inputs, not scalar values");
147+
};
148+
149+
let results = match input_array.data_type() {
150+
DataType::List(_field) => {
151+
let array = as_list_array(input_array);
152+
self.0.invoke_list_array(array)?
153+
}
154+
155+
//TODO(ab): support other containers
156+
data_type if T::is_valid_primitive_input_type(data_type) => {
157+
self.0.invoke_primitive_array(input_array)?
158+
}
159+
160+
_ => {
161+
return exec_err!(
162+
"DataType not implemented for {} filter UDF: {}",
163+
self.0.name(),
164+
input_array.data_type()
165+
);
166+
}
167+
};
168+
169+
Ok(ColumnarValue::Array(Arc::new(results)))
170+
}
171+
}

crates/viewer/re_dataframe_ui/src/filters/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@ mod boolean;
22
mod column_filter;
33
mod column_filter_ui;
44
mod filter;
5+
mod filter_udf;
56
mod numerical;
67
mod parse_timestamp;
78
mod string;
89
mod timestamp;
910
mod timestamp_formatted;
1011

1112
pub use self::{
12-
boolean::*, column_filter::*, column_filter_ui::*, filter::*, numerical::*, parse_timestamp::*,
13-
string::*, timestamp::*, timestamp_formatted::*,
13+
boolean::*, column_filter::*, column_filter_ui::*, filter::*, filter_udf::*, numerical::*,
14+
parse_timestamp::*, string::*, timestamp::*, timestamp_formatted::*,
1415
};

0 commit comments

Comments
 (0)