1
1
use std:: any:: Any ;
2
2
use std:: fmt:: Debug ;
3
- use std:: sync:: { Arc , OnceLock } ;
3
+ use std:: sync:: Arc ;
4
4
5
5
use arrow:: array:: { ArrayRef , BooleanArray , ListArray , as_list_array} ;
6
6
use arrow:: datatypes:: DataType ;
@@ -17,7 +17,7 @@ use datafusion::logical_expr::{
17
17
pub trait FilterUdf : Any + Clone + Debug + Send + Sync {
18
18
/// The scalar datafusion type signature for this UDF.
19
19
///
20
- /// The list version will automatically be accepted as well, see [`Self ::signature`] .
20
+ /// The list version will automatically be accepted as well, see `FilterUdfWrapper ::signature`.
21
21
const PRIMITIVE_SIGNATURE : TypeSignature ;
22
22
23
23
/// Name for this UDF.
@@ -38,27 +38,7 @@ pub trait FilterUdf: Any + Clone + Debug + Send + Sync {
38
38
39
39
/// Turn this type into a [`ScalarUDF`].
40
40
fn as_scalar_udf ( & self ) -> ScalarUDF {
41
- ScalarUDF :: new_from_impl ( FilterUdfWrapper ( self . clone ( ) ) )
42
- }
43
-
44
- /// Signature for this UDF.
45
- ///
46
- /// See [`ScalarUDFImpl::signature`].
47
- fn signature ( & self ) -> & Signature {
48
- static SIGNATURE : OnceLock < Signature > = OnceLock :: new ( ) ;
49
-
50
- SIGNATURE . get_or_init ( || {
51
- Signature :: one_of (
52
- vec ! [
53
- Self :: PRIMITIVE_SIGNATURE ,
54
- TypeSignature :: ArraySignature ( ArrayFunctionSignature :: Array {
55
- arguments: vec![ ArrayFunctionArgument :: Array ] ,
56
- array_coercion: None ,
57
- } ) ,
58
- ] ,
59
- Volatility :: Immutable ,
60
- )
61
- } )
41
+ ScalarUDF :: new_from_impl ( FilterUdfWrapper :: new ( self . clone ( ) ) )
62
42
}
63
43
64
44
/// Is this datatype valid?
@@ -105,21 +85,48 @@ pub trait FilterUdf: Any + Clone + Debug + Send + Sync {
105
85
}
106
86
}
107
87
108
- // shield against orphan rule
109
- #[ derive( Debug , Clone ) ]
110
- struct FilterUdfWrapper < T : FilterUdf + Debug + Send + Sync > ( T ) ;
88
+ /// Wrapper for implementor of [`FilterUdf`].
89
+ ///
90
+ /// This serves two purposes:
91
+ /// 1) Allow blanket implementation of [`ScalarUDFImpl`] (orphan rule)
92
+ /// 2) Cache the [`Signature`] (needed for [`ScalarUDFImpl::signature`])
93
+ #[ derive( Debug ) ]
94
+ struct FilterUdfWrapper < T : FilterUdf > {
95
+ inner : T ,
96
+ signature : Signature ,
97
+ }
98
+
99
+ impl < T : FilterUdf > FilterUdfWrapper < T > {
100
+ fn new ( filter : T ) -> Self {
101
+ let signature = Signature :: one_of (
102
+ vec ! [
103
+ T :: PRIMITIVE_SIGNATURE ,
104
+ TypeSignature :: ArraySignature ( ArrayFunctionSignature :: Array {
105
+ arguments: vec![ ArrayFunctionArgument :: Array ] ,
106
+ array_coercion: None ,
107
+ } ) ,
108
+ ] ,
109
+ Volatility :: Immutable ,
110
+ ) ;
111
+
112
+ Self {
113
+ inner : filter,
114
+ signature,
115
+ }
116
+ }
117
+ }
111
118
112
119
impl < T : FilterUdf + Debug + Send + Sync > ScalarUDFImpl for FilterUdfWrapper < T > {
113
120
fn as_any ( & self ) -> & dyn Any {
114
- & self . 0
121
+ & self . inner
115
122
}
116
123
117
124
fn name ( & self ) -> & ' static str {
118
- self . 0 . name ( )
125
+ self . inner . name ( )
119
126
}
120
127
121
128
fn signature ( & self ) -> & Signature {
122
- self . 0 . signature ( )
129
+ & self . signature
123
130
}
124
131
125
132
fn return_type ( & self , arg_types : & [ DataType ] ) -> DataFusionResult < DataType > {
@@ -136,7 +143,7 @@ impl<T: FilterUdf + Debug + Send + Sync> ScalarUDFImpl for FilterUdfWrapper<T> {
136
143
exec_err ! (
137
144
"input data type {} not supported for {} filter UDF" ,
138
145
arg_types[ 0 ] ,
139
- self . 0 . name( )
146
+ self . inner . name( )
140
147
)
141
148
}
142
149
}
@@ -149,18 +156,18 @@ impl<T: FilterUdf + Debug + Send + Sync> ScalarUDFImpl for FilterUdfWrapper<T> {
149
156
let results = match input_array. data_type ( ) {
150
157
DataType :: List ( _field) => {
151
158
let array = as_list_array ( input_array) ;
152
- self . 0 . invoke_list_array ( array) ?
159
+ self . inner . invoke_list_array ( array) ?
153
160
}
154
161
155
162
//TODO(ab): support other containers
156
163
data_type if T :: is_valid_primitive_input_type ( data_type) => {
157
- self . 0 . invoke_primitive_array ( input_array) ?
164
+ self . inner . invoke_primitive_array ( input_array) ?
158
165
}
159
166
160
167
_ => {
161
168
return exec_err ! (
162
169
"DataType not implemented for {} filter UDF: {}" ,
163
- self . 0 . name( ) ,
170
+ self . inner . name( ) ,
164
171
input_array. data_type( )
165
172
) ;
166
173
}
0 commit comments