-
Notifications
You must be signed in to change notification settings - Fork 141
Expand file tree
/
Copy pathgroup_by_agg.cc
More file actions
142 lines (128 loc) · 5.19 KB
/
group_by_agg.cc
File metadata and controls
142 lines (128 loc) · 5.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
// Copyright 2025 Ant Group Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "libspu/kernel/hlo/group_by_agg.h"
#include "magic_enum.hpp"
namespace spu::kernel::hlo {
namespace {
inline int64_t _get_owner(const Value& x) {
return x.storage_type().as<Private>()->owner();
}
bool _all_pub_or_pri_with_same_owner(absl::Span<spu::Value const> values) {
bool flag = true;
int64_t pri_rank = -1;
for (const auto& v : values) {
if (v.isPrivate()) {
if (pri_rank == -1) {
pri_rank = _get_owner(v);
} else if (pri_rank != _get_owner(v)) {
flag = false;
break;
}
} else if (v.isSecret()) {
flag = false;
break;
}
}
return flag;
}
} // namespace
std::vector<Value> GroupByAgg(SPUContext* ctx,
absl::Span<spu::Value const> keys,
absl::Span<spu::Value const> payloads,
AggFunc agg_func, absl::Span<int64_t> valid_bits,
const GroupByAggOptions& options) {
// normal sanity checks
// TODO(zjj): support more valid_bits hint after radix sort supporting.
SPU_ENFORCE(valid_bits.size() <= 1);
{
SPU_ENFORCE(!keys.empty(), "keys should not be empty");
SPU_ENFORCE(keys[0].shape().ndim() == 1,
"Keys should be 1-d but actually have {} dimensions",
keys[0].shape().ndim());
SPU_ENFORCE(std::all_of(keys.begin(), keys.end(),
[&keys](const spu::Value& v) {
return v.shape() == keys[0].shape();
}),
"Keys shape mismatched");
SPU_ENFORCE(!payloads.empty(), "payloads should not be empty");
SPU_ENFORCE(payloads[0].shape().ndim() == 1,
"Payloads should be 1-d but actually have {} dimensions",
payloads[0].shape().ndim());
SPU_ENFORCE(std::all_of(payloads.begin(), payloads.end(),
[&payloads](const spu::Value& v) {
return v.shape() == payloads[0].shape();
}),
"Payloads shape mismatched");
SPU_ENFORCE(keys[0].numel() == payloads[0].numel(),
"Keys and payloads shape mismatched");
SPU_ENFORCE(std::all_of(keys.begin(), keys.end(),
[&keys](const spu::Value& v) {
return v.vtype() == keys[0].vtype();
}),
"Keys visibility mismatched");
SPU_ENFORCE(std::all_of(payloads.begin(), payloads.end(),
[&payloads](const spu::Value& v) {
return v.vtype() == payloads[0].vtype();
}),
"Payloads visibility mismatched");
}
// empty
if (keys[0].numel() == 0) {
std::vector<Value> rets;
rets.reserve(keys.size() + payloads.size());
rets.insert(rets.end(), keys.begin(), keys.end());
rets.insert(rets.end(), payloads.begin(), payloads.end());
return rets;
}
// TODO(zjj): only support the private groupby sum for now.
// maybe we should implement a switch function to dispatch different
// groupby-agg implementations.
switch (agg_func) {
case AggFunc::Sum:
if ((options.mode != GroupByAggMode::PrefixSumMode) &&
(options.output_format == OutputFormat::OutputOrder)) {
SPU_ENFORCE(_all_pub_or_pri_with_same_owner(keys),
"keys should be all public or private with the same owner");
return hal::private_groupby_sum_1d(ctx, keys, payloads);
} else {
SPU_THROW(
"groupby sum with mode {} and output format {} is not "
"supported now",
magic_enum::enum_name(options.mode),
magic_enum::enum_name(options.output_format));
}
break;
case AggFunc::Avg:
if ((options.mode != GroupByAggMode::PrefixSumMode) &&
(options.output_format == OutputFormat::OutputOrder)) {
SPU_ENFORCE(_all_pub_or_pri_with_same_owner(keys),
"keys should be all public or private with the same owner");
return hal::private_groupby_avg_1d(
ctx, keys, payloads,
options.unsafe_output_order_drop_rest /*unsafe_drop_rest*/);
} else {
SPU_THROW(
"groupby avg with mode {} and output format {} is not "
"supported now",
magic_enum::enum_name(options.mode),
magic_enum::enum_name(options.output_format));
}
break;
default:
SPU_THROW("groupby agg func {} is not supported now",
magic_enum::enum_name(agg_func));
}
SPU_THROW("should not reach here");
}
} // namespace spu::kernel::hlo