Skip to content

Commit 8192bbf

Browse files
authored
[Feature](agg) Support bool agg functions (apache#55643)
Support `BOOL_AND(BOOLAND_AGG)`, `BOOL_OR(BOOLOR_AGG)`, `BOOL_XOR(BOOLXOR_AGG)`
1 parent e33275e commit 8192bbf

File tree

12 files changed

+653
-0
lines changed

12 files changed

+653
-0
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include "vec/aggregate_functions/aggregate_function_bool_union.h"
19+
20+
#include <fmt/format.h>
21+
22+
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
23+
#include "vec/aggregate_functions/helpers.h"
24+
#include "vec/data_types/data_type.h"
25+
26+
namespace doris::vectorized {
27+
#include "common/compile_check_begin.h"
28+
29+
void register_aggregate_function_bool_union(AggregateFunctionSimpleFactory& factory) {
30+
factory.register_function_both(
31+
"bool_or", creator_with_type_list<TYPE_BOOLEAN>::template creator<
32+
AggregateFunctionBitwise, AggregateFunctionGroupBitOrData>);
33+
factory.register_function_both(
34+
"bool_and", creator_with_type_list<TYPE_BOOLEAN>::template creator<
35+
AggregateFunctionBitwise, AggregateFunctionGroupBitAndData>);
36+
factory.register_function_both(
37+
"bool_xor",
38+
creator_without_type::creator<AggregateFuntionBoolUnion<AggregateFunctionBoolXorData>>);
39+
factory.register_alias("bool_or", "boolor_agg");
40+
factory.register_alias("bool_and", "booland_agg");
41+
factory.register_alias("bool_xor", "boolxor_agg");
42+
}
43+
} // namespace doris::vectorized
44+
#include "common/compile_check_end.h"
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#pragma once
19+
20+
#include "runtime/define_primitive_type.h"
21+
#include "runtime/primitive_type.h"
22+
#include "vec/aggregate_functions/aggregate_function.h"
23+
#include "vec/aggregate_functions/aggregate_function_bit.h"
24+
#include "vec/columns/column_nullable.h"
25+
#include "vec/columns/column_vector.h"
26+
#include "vec/common/assert_cast.h"
27+
#include "vec/core/types.h"
28+
#include "vec/data_types/data_type.h"
29+
#include "vec/data_types/data_type_nullable.h"
30+
#include "vec/data_types/data_type_number.h"
31+
32+
namespace doris::vectorized {
33+
#include "common/compile_check_begin.h"
34+
35+
struct AggregateFunctionBoolXorData {
36+
static constexpr auto name = "bool_xor";
37+
38+
void add(bool x) {
39+
if (x && _st < 2) {
40+
++_st;
41+
}
42+
}
43+
44+
void merge(const AggregateFunctionBoolXorData& rhs) {
45+
if (_st == 0) {
46+
_st = rhs._st;
47+
} else if (_st == 1) {
48+
_st = (rhs._st > 0) ? 2 : 1;
49+
}
50+
}
51+
52+
void write(BufferWritable& buf) const { buf.write_binary(_st); }
53+
54+
void read(BufferReadable& buf) { buf.read_binary(_st); }
55+
56+
void reset() { _st = 0; }
57+
58+
bool get() const { return _st == 1; }
59+
60+
private:
61+
// represents the current XOR state
62+
// '0': there are no true values currently
63+
// '1': exactly one true value has appeared
64+
// '2': two true values have already appeared and will not change thereafter
65+
uint8_t _st = 0;
66+
};
67+
68+
template <typename BoolFunc>
69+
class AggregateFuntionBoolUnion final
70+
: public IAggregateFunctionDataHelper<BoolFunc, AggregateFuntionBoolUnion<BoolFunc>>,
71+
NullableAggregateFunction,
72+
UnaryExpression {
73+
public:
74+
explicit AggregateFuntionBoolUnion(const DataTypes& argument_types_)
75+
: IAggregateFunctionDataHelper<BoolFunc, AggregateFuntionBoolUnion<BoolFunc>>(
76+
argument_types_) {}
77+
78+
String get_name() const override { return BoolFunc::name; }
79+
80+
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeBool>(); }
81+
82+
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
83+
Arena&) const override {
84+
this->data(place).add(
85+
assert_cast<const ColumnUInt8&, TypeCheckOnRelease::DISABLE>(*columns[0])
86+
.get_element(row_num));
87+
}
88+
89+
void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }
90+
91+
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
92+
Arena&) const override {
93+
this->data(place).merge(this->data(rhs));
94+
}
95+
96+
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
97+
this->data(place).write(buf);
98+
}
99+
100+
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
101+
Arena&) const override {
102+
this->data(place).read(buf);
103+
}
104+
105+
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
106+
assert_cast<ColumnUInt8&>(to).insert_value(this->data(place).get());
107+
}
108+
};
109+
} // namespace doris::vectorized
110+
111+
#include "common/compile_check_end.h"

be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ void register_aggregate_function_approx_top_k(AggregateFunctionSimpleFactory& fa
8080
void register_aggregate_function_approx_top_sum(AggregateFunctionSimpleFactory& factory);
8181
void register_aggregate_function_percentile_reservoir(AggregateFunctionSimpleFactory& factory);
8282
void register_aggregate_function_ai_agg(AggregateFunctionSimpleFactory& factory);
83+
void register_aggregate_function_bool_union(AggregateFunctionSimpleFactory& factory);
8384

8485
AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
8586
static std::once_flag oc;
@@ -137,6 +138,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
137138
register_aggregate_function_approx_top_sum(instance);
138139
register_aggregate_function_percentile_reservoir(instance);
139140
register_aggregate_function_ai_agg(instance);
141+
register_aggregate_function_bool_union(instance);
140142
// Register foreach and foreachv2 functions
141143
register_aggregate_function_combinator_foreach(instance);
142144
register_aggregate_function_combinator_foreachv2(instance);
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include <gtest/gtest.h>
19+
20+
#include "agg_function_test.h"
21+
#include "vec/data_types/data_type_number.h"
22+
23+
namespace doris::vectorized {
24+
25+
struct AggregateFunctionBoolUnionTest : public AggregateFunctiontest {};
26+
27+
TEST_F(AggregateFunctionBoolUnionTest, test_bool_or) {
28+
{
29+
create_agg("bool_or", false, {std::make_shared<DataTypeBool>()});
30+
31+
execute(Block({ColumnHelper::create_column_with_name<DataTypeBool>({false, false, true})}),
32+
ColumnHelper::create_column_with_name<DataTypeBool>({true}));
33+
}
34+
35+
{
36+
create_agg("boolor_agg", false, {std::make_shared<DataTypeBool>()});
37+
38+
execute(Block({ColumnHelper::create_column_with_name<DataTypeBool>({true, false, true})}),
39+
ColumnHelper::create_column_with_name<DataTypeBool>({true}));
40+
}
41+
}
42+
43+
TEST_F(AggregateFunctionBoolUnionTest, test_bool_and) {
44+
{
45+
create_agg("bool_and", false, {std::make_shared<DataTypeBool>()});
46+
47+
execute(Block({ColumnHelper::create_column_with_name<DataTypeBool>({true, true})}),
48+
ColumnHelper::create_column_with_name<DataTypeBool>({true}));
49+
}
50+
51+
{
52+
create_agg("booland_agg", false, {std::make_shared<DataTypeBool>()});
53+
54+
execute(Block({ColumnHelper::create_column_with_name<DataTypeBool>({true, false, true})}),
55+
ColumnHelper::create_column_with_name<DataTypeBool>({false}));
56+
}
57+
}
58+
59+
TEST_F(AggregateFunctionBoolUnionTest, test_bool_xor) {
60+
{
61+
create_agg("bool_xor", false, {std::make_shared<DataTypeBool>()});
62+
63+
execute(Block({ColumnHelper::create_column_with_name<DataTypeBool>({true, true, true})}),
64+
ColumnHelper::create_column_with_name<DataTypeBool>({false}));
65+
}
66+
67+
{
68+
create_agg("boolxor_agg", false, {std::make_shared<DataTypeBool>()});
69+
70+
execute(Block({ColumnHelper::create_column_with_name<DataTypeBool>({true, false, false})}),
71+
ColumnHelper::create_column_with_name<DataTypeBool>({true}));
72+
}
73+
}
74+
} // namespace doris::vectorized

fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
3030
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
3131
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionInt;
32+
import org.apache.doris.nereids.trees.expressions.functions.agg.BoolAnd;
33+
import org.apache.doris.nereids.trees.expressions.functions.agg.BoolOr;
34+
import org.apache.doris.nereids.trees.expressions.functions.agg.BoolXor;
3235
import org.apache.doris.nereids.trees.expressions.functions.agg.CollectList;
3336
import org.apache.doris.nereids.trees.expressions.functions.agg.CollectSet;
3437
import org.apache.doris.nereids.trees.expressions.functions.agg.Corr;
@@ -117,6 +120,9 @@ public class BuiltinAggregateFunctions implements FunctionHelper {
117120
agg(BitmapUnion.class, "bitmap_union"),
118121
agg(BitmapUnionCount.class, "bitmap_union_count"),
119122
agg(BitmapUnionInt.class, "bitmap_union_int"),
123+
agg(BoolOr.class, "bool_or", "boolor_agg"),
124+
agg(BoolAnd.class, "bool_and", "booland_agg"),
125+
agg(BoolXor.class, "bool_xor", "boolxor_agg"),
120126
agg(CollectList.class, "collect_list", "group_array"),
121127
agg(CollectSet.class, "collect_set", "group_uniq_array"),
122128
agg(Corr.class, "corr"),
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
package org.apache.doris.nereids.trees.expressions.functions.agg;
19+
20+
import org.apache.doris.catalog.FunctionSignature;
21+
import org.apache.doris.nereids.exceptions.AnalysisException;
22+
import org.apache.doris.nereids.trees.expressions.Expression;
23+
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
24+
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
25+
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
26+
import org.apache.doris.nereids.types.BooleanType;
27+
import org.apache.doris.nereids.types.DataType;
28+
29+
import com.google.common.base.Preconditions;
30+
import com.google.common.collect.ImmutableList;
31+
32+
import java.util.List;
33+
34+
/**
35+
* AggregateFunction 'bool_and'.
36+
*/
37+
public class BoolAnd extends NullableAggregateFunction
38+
implements UnaryExpression, ExplicitlyCastableSignature {
39+
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
40+
FunctionSignature.ret(BooleanType.INSTANCE).args(BooleanType.INSTANCE)
41+
);
42+
43+
public BoolAnd(Expression child) {
44+
this(false, false, child);
45+
}
46+
47+
private BoolAnd(boolean distinct, Expression arg) {
48+
this(distinct, false, arg);
49+
}
50+
51+
private BoolAnd(boolean distinct, boolean alwaysNullable, Expression arg) {
52+
super("bool_and", distinct, alwaysNullable, arg);
53+
}
54+
55+
/**
56+
* constructor for withChildren and reuse signature
57+
*/
58+
private BoolAnd(NullableAggregateFunctionParams functionParams) {
59+
super(functionParams);
60+
}
61+
62+
@Override
63+
public BoolAnd withDistinctAndChildren(boolean distinct, List<Expression> children) {
64+
Preconditions.checkArgument(children.size() == 1);
65+
return new BoolAnd(getFunctionParams(distinct, children));
66+
}
67+
68+
@Override
69+
public void checkLegalityBeforeTypeCoercion() {
70+
DataType argType = child().getDataType();
71+
if (!(argType.isBooleanType() || argType.isNumericType())) {
72+
throw new AnalysisException("bool_and requires a boolean or numeric argument");
73+
}
74+
}
75+
76+
@Override
77+
public NullableAggregateFunction withAlwaysNullable(boolean alwaysNullable) {
78+
return new BoolAnd(getAlwaysNullableFunctionParams(alwaysNullable));
79+
}
80+
81+
@Override
82+
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
83+
return visitor.visitBoolAnd(this, context);
84+
}
85+
86+
@Override
87+
public List<FunctionSignature> getSignatures() {
88+
return SIGNATURES;
89+
}
90+
}

0 commit comments

Comments
 (0)