Skip to content

Commit 861ed00

Browse files
kazantsev-maksimKazantsev MaksimJefffrey
authored andcommitted
Fix: spark bit_count function (apache#18322)
## Which issue does this PR close? Closes apache#18225 ## Rationale for this change After adding the bit_count function in Comet, we got different results from Spark. (apache/datafusion-comet#2553) ## Are these changes tested? Tested with existing unit tests --------- Co-authored-by: Kazantsev Maksim <[email protected]> Co-authored-by: Jeffrey Vo <[email protected]>
1 parent f6b3fa3 commit 861ed00

File tree

2 files changed

+50
-17
lines changed

2 files changed

+50
-17
lines changed

datafusion/spark/src/function/bitwise/bit_count.rs

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use arrow::datatypes::{
2323
DataType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type,
2424
UInt64Type, UInt8Type,
2525
};
26+
use datafusion_common::cast::as_boolean_array;
2627
use datafusion_common::{plan_err, Result};
2728
use datafusion_expr::{
2829
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
@@ -46,6 +47,7 @@ impl SparkBitCount {
4647
Self {
4748
signature: Signature::one_of(
4849
vec![
50+
TypeSignature::Exact(vec![DataType::Boolean]),
4951
TypeSignature::Exact(vec![DataType::Int8]),
5052
TypeSignature::Exact(vec![DataType::Int16]),
5153
TypeSignature::Exact(vec![DataType::Int32]),
@@ -90,28 +92,34 @@ impl ScalarUDFImpl for SparkBitCount {
9092
fn spark_bit_count(value_array: &[ArrayRef]) -> Result<ArrayRef> {
9193
let value_array = value_array[0].as_ref();
9294
match value_array.data_type() {
95+
DataType::Boolean => {
96+
let result: Int32Array = as_boolean_array(value_array)?
97+
.iter()
98+
.map(|x| x.map(|y| y as i32))
99+
.collect();
100+
Ok(Arc::new(result))
101+
}
93102
DataType::Int8 => {
94103
let result: Int32Array = value_array
95104
.as_primitive::<Int8Type>()
96-
.unary(|v| v.count_ones() as i32);
105+
.unary(|v| bit_count(v.into()));
97106
Ok(Arc::new(result))
98107
}
99108
DataType::Int16 => {
100109
let result: Int32Array = value_array
101110
.as_primitive::<Int16Type>()
102-
.unary(|v| v.count_ones() as i32);
111+
.unary(|v| bit_count(v.into()));
103112
Ok(Arc::new(result))
104113
}
105114
DataType::Int32 => {
106115
let result: Int32Array = value_array
107116
.as_primitive::<Int32Type>()
108-
.unary(|v| v.count_ones() as i32);
117+
.unary(|v| bit_count(v.into()));
109118
Ok(Arc::new(result))
110119
}
111120
DataType::Int64 => {
112-
let result: Int32Array = value_array
113-
.as_primitive::<Int64Type>()
114-
.unary(|v| v.count_ones() as i32);
121+
let result: Int32Array =
122+
value_array.as_primitive::<Int64Type>().unary(bit_count);
115123
Ok(Arc::new(result))
116124
}
117125
DataType::UInt8 => {
@@ -147,12 +155,26 @@ fn spark_bit_count(value_array: &[ArrayRef]) -> Result<ArrayRef> {
147155
}
148156
}
149157

158+
// Here’s the equivalent Rust implementation of the bitCount function (similar to Apache Spark's bitCount for LongType)
159+
// Spark: https://github.com/apache/spark/blob/ac717dd7aec665de578d7c6b0070e8fcdde3cea9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala#L243
160+
// Java impl: https://github.com/openjdk/jdk/blob/d226023643f90027a8980d161ec6d423887ae3ce/src/java.base/share/classes/java/lang/Long.java#L1584
161+
fn bit_count(i: i64) -> i32 {
162+
let mut u = i as u64;
163+
u = u - ((u >> 1) & 0x5555555555555555);
164+
u = (u & 0x3333333333333333) + ((u >> 2) & 0x3333333333333333);
165+
u = (u + (u >> 4)) & 0x0f0f0f0f0f0f0f0f;
166+
u = u + (u >> 8);
167+
u = u + (u >> 16);
168+
u = u + (u >> 32);
169+
(u as i32) & 0x7f
170+
}
171+
150172
#[cfg(test)]
151173
mod tests {
152174
use super::*;
153175
use arrow::array::{
154-
Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array,
155-
UInt64Array, UInt8Array,
176+
Array, BooleanArray, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array,
177+
UInt32Array, UInt64Array, UInt8Array,
156178
};
157179
use arrow::datatypes::Int32Type;
158180

@@ -192,7 +214,18 @@ mod tests {
192214
assert_eq!(arr.value(2), 2);
193215
assert_eq!(arr.value(3), 3);
194216
assert_eq!(arr.value(4), 4);
195-
assert_eq!(arr.value(5), 8);
217+
assert_eq!(arr.value(5), 64);
218+
}
219+
220+
#[test]
221+
fn test_bit_count_boolean() {
222+
// Test bit_count on BooleanArray
223+
let result =
224+
spark_bit_count(&[Arc::new(BooleanArray::from(vec![true, false]))]).unwrap();
225+
226+
let arr = result.as_primitive::<Int32Type>();
227+
assert_eq!(arr.value(0), 1);
228+
assert_eq!(arr.value(1), 0);
196229
}
197230

198231
#[test]
@@ -207,7 +240,7 @@ mod tests {
207240
assert_eq!(arr.value(1), 1);
208241
assert_eq!(arr.value(2), 8);
209242
assert_eq!(arr.value(3), 10);
210-
assert_eq!(arr.value(4), 16);
243+
assert_eq!(arr.value(4), 64);
211244
}
212245

213246
#[test]
@@ -222,7 +255,7 @@ mod tests {
222255
assert_eq!(arr.value(1), 1); // 0b00000000000000000000000000000001 = 1
223256
assert_eq!(arr.value(2), 8); // 0b00000000000000000000000011111111 = 8
224257
assert_eq!(arr.value(3), 10); // 0b00000000000000000000001111111111 = 10
225-
assert_eq!(arr.value(4), 32); // -1 in two's complement = all 32 bits set
258+
assert_eq!(arr.value(4), 64); // -1 in two's complement = all 32 bits set
226259
}
227260

228261
#[test]

datafusion/sqllogictest/test_files/spark/bitwise/bit_count.slt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,17 @@ SELECT bit_count(1023::int);
5959
query I
6060
SELECT bit_count(-1::int);
6161
----
62-
32
62+
64
6363

6464
query I
6565
SELECT bit_count(-2::int);
6666
----
67-
31
67+
63
6868

6969
query I
7070
SELECT bit_count(-3::int);
7171
----
72-
31
72+
63
7373

7474
# Tests with different integer types
7575
query I
@@ -85,7 +85,7 @@ SELECT bit_count(arrow_cast(15, 'Int8'));
8585
query I
8686
SELECT bit_count(arrow_cast(-1, 'Int8'));
8787
----
88-
8
88+
64
8989

9090
query I
9191
SELECT bit_count(arrow_cast(0, 'Int16'));
@@ -100,7 +100,7 @@ SELECT bit_count(arrow_cast(255, 'Int16'));
100100
query I
101101
SELECT bit_count(arrow_cast(-1, 'Int16'));
102102
----
103-
16
103+
64
104104

105105
query I
106106
SELECT bit_count(arrow_cast(0, 'Int64'));
@@ -214,7 +214,7 @@ SELECT bit_count(arrow_cast(2147483647, 'Int32'));
214214
query I
215215
SELECT bit_count(arrow_cast(-2147483648, 'Int32'));
216216
----
217-
1
217+
33
218218

219219
query I
220220
SELECT bit_count(arrow_cast(9223372036854775807, 'Int64'));

0 commit comments

Comments
 (0)