@@ -23,6 +23,7 @@ use arrow::datatypes::{
2323 DataType , Int16Type , Int32Type , Int64Type , Int8Type , UInt16Type , UInt32Type ,
2424 UInt64Type , UInt8Type ,
2525} ;
26+ use datafusion_common:: cast:: as_boolean_array;
2627use datafusion_common:: { plan_err, Result } ;
2728use datafusion_expr:: {
2829 ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature , TypeSignature ,
@@ -46,6 +47,7 @@ impl SparkBitCount {
4647 Self {
4748 signature : Signature :: one_of (
4849 vec ! [
50+ TypeSignature :: Exact ( vec![ DataType :: Boolean ] ) ,
4951 TypeSignature :: Exact ( vec![ DataType :: Int8 ] ) ,
5052 TypeSignature :: Exact ( vec![ DataType :: Int16 ] ) ,
5153 TypeSignature :: Exact ( vec![ DataType :: Int32 ] ) ,
@@ -90,28 +92,34 @@ impl ScalarUDFImpl for SparkBitCount {
9092fn spark_bit_count ( value_array : & [ ArrayRef ] ) -> Result < ArrayRef > {
9193 let value_array = value_array[ 0 ] . as_ref ( ) ;
9294 match value_array. data_type ( ) {
95+ DataType :: Boolean => {
96+ let result: Int32Array = as_boolean_array ( value_array) ?
97+ . iter ( )
98+ . map ( |x| x. map ( |y| y as i32 ) )
99+ . collect ( ) ;
100+ Ok ( Arc :: new ( result) )
101+ }
93102 DataType :: Int8 => {
94103 let result: Int32Array = value_array
95104 . as_primitive :: < Int8Type > ( )
96- . unary ( |v| v . count_ones ( ) as i32 ) ;
105+ . unary ( |v| bit_count ( v . into ( ) ) ) ;
97106 Ok ( Arc :: new ( result) )
98107 }
99108 DataType :: Int16 => {
100109 let result: Int32Array = value_array
101110 . as_primitive :: < Int16Type > ( )
102- . unary ( |v| v . count_ones ( ) as i32 ) ;
111+ . unary ( |v| bit_count ( v . into ( ) ) ) ;
103112 Ok ( Arc :: new ( result) )
104113 }
105114 DataType :: Int32 => {
106115 let result: Int32Array = value_array
107116 . as_primitive :: < Int32Type > ( )
108- . unary ( |v| v . count_ones ( ) as i32 ) ;
117+ . unary ( |v| bit_count ( v . into ( ) ) ) ;
109118 Ok ( Arc :: new ( result) )
110119 }
111120 DataType :: Int64 => {
112- let result: Int32Array = value_array
113- . as_primitive :: < Int64Type > ( )
114- . unary ( |v| v. count_ones ( ) as i32 ) ;
121+ let result: Int32Array =
122+ value_array. as_primitive :: < Int64Type > ( ) . unary ( bit_count) ;
115123 Ok ( Arc :: new ( result) )
116124 }
117125 DataType :: UInt8 => {
@@ -147,12 +155,26 @@ fn spark_bit_count(value_array: &[ArrayRef]) -> Result<ArrayRef> {
147155 }
148156}
149157
158+ // Here’s the equivalent Rust implementation of the bitCount function (similar to Apache Spark's bitCount for LongType)
159+ // Spark: https://github.com/apache/spark/blob/ac717dd7aec665de578d7c6b0070e8fcdde3cea9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala#L243
160+ // Java impl: https://github.com/openjdk/jdk/blob/d226023643f90027a8980d161ec6d423887ae3ce/src/java.base/share/classes/java/lang/Long.java#L1584
161+ fn bit_count ( i : i64 ) -> i32 {
162+ let mut u = i as u64 ;
163+ u = u - ( ( u >> 1 ) & 0x5555555555555555 ) ;
164+ u = ( u & 0x3333333333333333 ) + ( ( u >> 2 ) & 0x3333333333333333 ) ;
165+ u = ( u + ( u >> 4 ) ) & 0x0f0f0f0f0f0f0f0f ;
166+ u = u + ( u >> 8 ) ;
167+ u = u + ( u >> 16 ) ;
168+ u = u + ( u >> 32 ) ;
169+ ( u as i32 ) & 0x7f
170+ }
171+
150172#[ cfg( test) ]
151173mod tests {
152174 use super :: * ;
153175 use arrow:: array:: {
154- Array , Int16Array , Int32Array , Int64Array , Int8Array , UInt16Array , UInt32Array ,
155- UInt64Array , UInt8Array ,
176+ Array , BooleanArray , Int16Array , Int32Array , Int64Array , Int8Array , UInt16Array ,
177+ UInt32Array , UInt64Array , UInt8Array ,
156178 } ;
157179 use arrow:: datatypes:: Int32Type ;
158180
@@ -192,7 +214,18 @@ mod tests {
192214 assert_eq ! ( arr. value( 2 ) , 2 ) ;
193215 assert_eq ! ( arr. value( 3 ) , 3 ) ;
194216 assert_eq ! ( arr. value( 4 ) , 4 ) ;
195- assert_eq ! ( arr. value( 5 ) , 8 ) ;
217+ assert_eq ! ( arr. value( 5 ) , 64 ) ;
218+ }
219+
220+ #[ test]
221+ fn test_bit_count_boolean ( ) {
222+ // Test bit_count on BooleanArray
223+ let result =
224+ spark_bit_count ( & [ Arc :: new ( BooleanArray :: from ( vec ! [ true , false ] ) ) ] ) . unwrap ( ) ;
225+
226+ let arr = result. as_primitive :: < Int32Type > ( ) ;
227+ assert_eq ! ( arr. value( 0 ) , 1 ) ;
228+ assert_eq ! ( arr. value( 1 ) , 0 ) ;
196229 }
197230
198231 #[ test]
@@ -207,7 +240,7 @@ mod tests {
207240 assert_eq ! ( arr. value( 1 ) , 1 ) ;
208241 assert_eq ! ( arr. value( 2 ) , 8 ) ;
209242 assert_eq ! ( arr. value( 3 ) , 10 ) ;
210- assert_eq ! ( arr. value( 4 ) , 16 ) ;
243+ assert_eq ! ( arr. value( 4 ) , 64 ) ;
211244 }
212245
213246 #[ test]
@@ -222,7 +255,7 @@ mod tests {
222255 assert_eq ! ( arr. value( 1 ) , 1 ) ; // 0b00000000000000000000000000000001 = 1
223256 assert_eq ! ( arr. value( 2 ) , 8 ) ; // 0b00000000000000000000000011111111 = 8
224257 assert_eq ! ( arr. value( 3 ) , 10 ) ; // 0b00000000000000000000001111111111 = 10
225- assert_eq ! ( arr. value( 4 ) , 32 ) ; // -1 in two's complement = all 32 bits set
258+ assert_eq ! ( arr. value( 4 ) , 64 ) ; // -1 in two's complement = all 32 bits set
226259 }
227260
228261 #[ test]
0 commit comments