Skip to content

Commit f66e517

Browse files
authored
[Feature](func) Support function PERIOD_ADD and PERIOD_DIFF (apache#56945)
```text mysql> SELECT PERIOD_ADD(2512, 1); +---------------------+ | PERIOD_ADD(2512, 1) | +---------------------+ | 202601 | +---------------------+ mysql> SELECT PERIOD_ADD(6901, 1); +---------------------+ | PERIOD_ADD(6901, 1) | +---------------------+ | 206902 | +---------------------+ mysql> SELECT PERIOD_ADD(7001, 1); +---------------------+ | PERIOD_ADD(7001, 1) | +---------------------+ | 197002 | +---------------------+ mysql> SELECT PERIOD_DIFF(2510, 2501); +-------------------------+ | PERIOD_DIFF(2510, 2501) | +-------------------------+ | 9 | +-------------------------+ mysql> SELECT PERIOD_DIFF(2501, 2510); +-------------------------+ | PERIOD_DIFF(2501, 2510) | +-------------------------+ | -9 | +-------------------------+ ```
1 parent 5b7acb0 commit f66e517

File tree

13 files changed

+590
-82
lines changed

13 files changed

+590
-82
lines changed

be/src/vec/functions/function.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "common/logging.h"
3333
#include "common/status.h"
3434
#include "olap/rowset/segment_v2/inverted_index_iterator.h" // IWYU pragma: keep
35+
#include "runtime/define_primitive_type.h"
3536
#include "udf/udf.h"
3637
#include "vec/core/block.h"
3738
#include "vec/core/column_numbers.h"

be/src/vec/functions/function_date_or_datetime_computation.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ using FunctionSecToTime = FunctionCurrentDateOrDateTime<SecToTimeImpl>;
6464
using FunctionMicroSecToDateTime = TimestampToDateTime<MicroSec>;
6565
using FunctionMilliSecToDateTime = TimestampToDateTime<MilliSec>;
6666
using FunctionSecToDateTime = TimestampToDateTime<Sec>;
67+
using FunctionPeriodAdd = FunctionNeedsToHandleNull<PeriodAddImpl, PrimitiveType::TYPE_BIGINT>;
68+
using FunctionPeriodDiff = FunctionNeedsToHandleNull<PeriodDiffImpl, PrimitiveType::TYPE_BIGINT>;
6769

6870
void register_function_date_time_computation(SimpleFunctionFactory& factory) {
6971
factory.register_function<FunctionDateDiff>();
@@ -94,6 +96,8 @@ void register_function_date_time_computation(SimpleFunctionFactory& factory) {
9496
factory.register_function<FunctionMonthsBetween>();
9597
factory.register_function<FunctionTime>();
9698
factory.register_function<FunctionGetFormat>();
99+
factory.register_function<FunctionPeriodAdd>();
100+
factory.register_function<FunctionPeriodDiff>();
97101

98102
// alias
99103
factory.register_alias("days_add", "date_add");

be/src/vec/functions/function_date_or_datetime_computation.h

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
#include "vec/functions/datetime_errors.h"
5959
#include "vec/functions/function.h"
6060
#include "vec/functions/function_helpers.h"
61+
#include "vec/functions/function_needs_to_handle_null.h"
6162
#include "vec/runtime/time_value.h"
6263
#include "vec/runtime/vdatetime_value.h"
6364
#include "vec/utils/util.hpp"
@@ -1378,5 +1379,94 @@ class FunctionGetFormat : public IFunction {
13781379
static constexpr auto TIME_NAME = "TIME";
13791380
};
13801381

1382+
class PeriodHelper {
1383+
public:
1384+
// For two digit year, 70-99 -> 1970-1999, 00-69 -> 2000-2069
1385+
// this rule is same as MySQL
1386+
static constexpr int YY_PART_YEAR = 70;
1387+
static Status valid_period(int64_t period) {
1388+
if (period <= 0 || (period % 100) == 0 || (period % 100) > 12) {
1389+
return Status::InvalidArgument("Period function got invalid period: {}", period);
1390+
}
1391+
return Status::OK();
1392+
}
1393+
1394+
static int64_t check_and_convert_period_to_month(uint64_t period) {
1395+
THROW_IF_ERROR(valid_period(period));
1396+
uint64_t year = period / 100;
1397+
if (year < 100) {
1398+
year += (year >= YY_PART_YEAR) ? 1900 : 2000;
1399+
}
1400+
return year * 12LL + (period % 100) - 1;
1401+
}
1402+
1403+
static int64_t convert_month_to_period(uint64_t month) {
1404+
uint64_t year = month / 12;
1405+
if (year < 100) {
1406+
year += (year >= YY_PART_YEAR) ? 1900 : 2000;
1407+
}
1408+
return year * 100 + month % 12 + 1;
1409+
}
1410+
};
1411+
1412+
class PeriodAddImpl {
1413+
public:
1414+
static constexpr auto name = "period_add";
1415+
static size_t get_number_of_arguments() { return 2; }
1416+
static DataTypePtr get_return_type_impl(const DataTypes& arguments) {
1417+
return std::make_shared<DataTypeInt64>();
1418+
}
1419+
1420+
static void execute(const std::vector<ColumnWithConstAndNullMap>& cols_info,
1421+
ColumnInt64::MutablePtr& res_col, PaddedPODArray<UInt8>& res_null_map_data,
1422+
size_t input_rows_count) {
1423+
const auto& left_data =
1424+
assert_cast<const ColumnInt64*>(cols_info[0].nested_col)->get_data();
1425+
const auto& right_data =
1426+
assert_cast<const ColumnInt64*>(cols_info[1].nested_col)->get_data();
1427+
for (size_t i = 0; i < input_rows_count; ++i) {
1428+
if (cols_info[0].is_null_at(i) || cols_info[1].is_null_at(i)) {
1429+
res_col->insert_default();
1430+
res_null_map_data[i] = 1;
1431+
continue;
1432+
}
1433+
1434+
int64_t period = left_data[index_check_const(i, cols_info[0].is_const)];
1435+
int64_t months = right_data[index_check_const(i, cols_info[1].is_const)];
1436+
res_col->insert_value(PeriodHelper::convert_month_to_period(
1437+
PeriodHelper::check_and_convert_period_to_month(period) + months));
1438+
}
1439+
}
1440+
};
1441+
class PeriodDiffImpl {
1442+
public:
1443+
static constexpr auto name = "period_diff";
1444+
static size_t get_number_of_arguments() { return 2; }
1445+
static DataTypePtr get_return_type_impl(const DataTypes& arguments) {
1446+
return std::make_shared<DataTypeInt64>();
1447+
}
1448+
1449+
static void execute(const std::vector<ColumnWithConstAndNullMap>& cols_info,
1450+
ColumnInt64::MutablePtr& res_col, PaddedPODArray<UInt8>& res_null_map_data,
1451+
size_t input_rows_count) {
1452+
const auto& left_data =
1453+
assert_cast<const ColumnInt64*>(cols_info[0].nested_col)->get_data();
1454+
const auto& right_data =
1455+
assert_cast<const ColumnInt64*>(cols_info[1].nested_col)->get_data();
1456+
for (size_t i = 0; i < input_rows_count; ++i) {
1457+
if (cols_info[0].is_null_at(i) || cols_info[1].is_null_at(i)) {
1458+
res_col->insert_default();
1459+
res_null_map_data[i] = 1;
1460+
continue;
1461+
}
1462+
1463+
int64_t period1 = left_data[index_check_const(i, cols_info[0].is_const)];
1464+
int64_t period2 = right_data[index_check_const(i, cols_info[1].is_const)];
1465+
res_col->insert_value(PeriodHelper::check_and_convert_period_to_month(period1) -
1466+
PeriodHelper::check_and_convert_period_to_month(period2));
1467+
}
1468+
}
1469+
};
1470+
13811471
#include "common/compile_check_avoid_end.h"
13821472
} // namespace doris::vectorized
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
#pragma once
18+
#include <boost/mpl/aux_/na_fwd.hpp>
19+
20+
#include "vec/functions/function.h"
21+
22+
namespace doris::vectorized {
23+
#include "common/compile_check_begin.h"
24+
25+
// Helper struct to store information about const+nullable columns
26+
struct ColumnWithConstAndNullMap {
27+
const IColumn* nested_col = nullptr;
28+
const NullMap* null_map = nullptr;
29+
bool is_const = false;
30+
31+
bool is_null_at(size_t row) const { return (null_map && (*null_map)[is_const ? 0 : row]); }
32+
};
33+
34+
// For functions that need to handle const+nullable column combinations
35+
// means that functioin `use_default_implementation_for_nulls()` returns false
36+
template <typename Impl, PrimitiveType ResultPrimitiveType>
37+
class FunctionNeedsToHandleNull : public IFunction {
38+
public:
39+
using ResultColumnType = PrimitiveTypeTraits<ResultPrimitiveType>::ColumnType;
40+
41+
static constexpr auto name = Impl::name;
42+
String get_name() const override { return name; }
43+
44+
static std::shared_ptr<IFunction> create() {
45+
return std::make_shared<FunctionNeedsToHandleNull>();
46+
}
47+
48+
size_t get_number_of_arguments() const override { return Impl::get_number_of_arguments(); }
49+
50+
bool is_variadic() const override {
51+
if constexpr (requires { Impl::is_variadic(); }) {
52+
return Impl::is_variadic();
53+
}
54+
return false;
55+
}
56+
57+
bool use_default_implementation_for_nulls() const override { return false; }
58+
59+
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
60+
return Impl::get_return_type_impl(arguments);
61+
}
62+
63+
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
64+
uint32_t result, size_t input_rows_count) const override {
65+
auto res_col = ResultColumnType::create();
66+
auto null_map = ColumnUInt8::create();
67+
auto& null_map_data = null_map->get_data();
68+
res_col->reserve(input_rows_count);
69+
null_map_data.resize_fill(input_rows_count, 0);
70+
71+
const size_t arg_size = arguments.size();
72+
73+
std::vector<ColumnWithConstAndNullMap> columns_info;
74+
columns_info.resize(arg_size);
75+
bool has_nullable = false;
76+
collect_columns_info(columns_info, block, arguments, has_nullable);
77+
78+
// Check if there is a const null
79+
for (size_t i = 0; i < arg_size; ++i) {
80+
if (columns_info[i].is_const && columns_info[i].null_map &&
81+
(*columns_info[i].null_map)[0] &&
82+
execute_const_null(res_col, null_map_data, input_rows_count, i)) {
83+
block.replace_by_position(
84+
result, ColumnNullable::create(std::move(res_col), std::move(null_map)));
85+
return Status::OK();
86+
}
87+
}
88+
89+
Impl::execute(columns_info, res_col, null_map_data, input_rows_count);
90+
91+
if (is_return_nullable(has_nullable, columns_info)) {
92+
block.replace_by_position(
93+
result, ColumnNullable::create(std::move(res_col), std::move(null_map)));
94+
} else {
95+
block.replace_by_position(result, std::move(res_col));
96+
}
97+
98+
return Status::OK();
99+
}
100+
101+
private:
102+
// Handle a NULL literal
103+
// Default behavior is fill result with all NULLs
104+
// return true when the res_col is ready to be written back to the block without further processing
105+
bool execute_const_null(typename ResultColumnType::MutablePtr& res_col,
106+
PaddedPODArray<UInt8>& res_null_map_data, size_t input_rows_count,
107+
size_t null_index) const {
108+
if constexpr (requires {
109+
Impl::execute_const_null(res_col, res_null_map_data, input_rows_count,
110+
null_index);
111+
}) {
112+
return Impl::execute_const_null(res_col, res_null_map_data, input_rows_count,
113+
null_index);
114+
}
115+
116+
res_col->insert_many_defaults(input_rows_count);
117+
res_null_map_data.assign(input_rows_count, (UInt8)1);
118+
119+
return true;
120+
}
121+
122+
// Collect the required information for each column into columns_info
123+
// Including whether it is a constant column, nested column and null map(if exists).
124+
void collect_columns_info(std::vector<ColumnWithConstAndNullMap>& columns_info,
125+
const Block& block, const ColumnNumbers& arguments,
126+
bool& has_nullable) const {
127+
for (size_t i = 0; i < arguments.size(); ++i) {
128+
ColumnPtr col_ptr;
129+
const auto& col_with_type = block.get_by_position(arguments[i]);
130+
std::tie(col_ptr, columns_info[i].is_const) = unpack_if_const(col_with_type.column);
131+
132+
if (is_column_nullable(*col_ptr)) {
133+
has_nullable = true;
134+
const auto* nullable = check_and_get_column<ColumnNullable>(col_ptr.get());
135+
columns_info[i].nested_col = &nullable->get_nested_column();
136+
columns_info[i].null_map = &nullable->get_null_map_data();
137+
} else {
138+
columns_info[i].nested_col = col_ptr.get();
139+
}
140+
}
141+
}
142+
143+
// Determine if the return type should be wrapped in nullable
144+
// Default behavior is return nullable if any argument is nullable
145+
bool is_return_nullable(bool has_nullable,
146+
const std::vector<ColumnWithConstAndNullMap>& cols_info) const {
147+
if constexpr (requires { Impl::is_return_nullable(has_nullable, cols_info); }) {
148+
return Impl::is_return_nullable(has_nullable, cols_info);
149+
}
150+
return has_nullable;
151+
}
152+
};
153+
#include "common/compile_check_end.h"
154+
} // namespace doris::vectorized

be/src/vec/functions/function_string.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,6 +1346,8 @@ using FunctionStringAppendTrailingCharIfAbsent =
13461346
using FunctionStringLPad = FunctionStringPad<StringLPad>;
13471347
using FunctionStringRPad = FunctionStringPad<StringRPad>;
13481348

1349+
using FunctionMakeSet = FunctionNeedsToHandleNull<MakeSetImpl, PrimitiveType::TYPE_STRING>;
1350+
13491351
void register_function_string(SimpleFunctionFactory& factory) {
13501352
factory.register_function<FunctionStringParseDataSize>();
13511353
factory.register_function<FunctionStringASCII>();

0 commit comments

Comments
 (0)