Skip to content

Commit 7b35632

Browse files
authored
repo-sync-2025-03-19T11:04:50+0800 (#1074)
1 parent 9e688b2 commit 7b35632

File tree

15 files changed

+366
-77
lines changed

15 files changed

+366
-77
lines changed

MODULE.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
module(
2323
name = "spu",
24-
version = "0.9.4.dev20250312",
24+
version = "0.9.4.dev20250319",
2525
compatibility_level = 1,
2626
)
2727

src/MODULE.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
module(
2323
name = "spulib",
24-
version = "0.9.4.dev20250312",
24+
version = "0.9.4.dev20250319",
2525
compatibility_level = 1,
2626
)
2727

src/libspu/kernel/hal/permute.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -611,8 +611,8 @@ std::vector<spu::Value> PrepareSort(SPUContext *ctx,
611611
// use a random permutation to break link of values, such that the following
612612
// comparison can be revealed without loss of information.
613613
for (const auto &input : inputs) {
614-
inp.emplace_back(
615-
std::move(_perm_ss(ctx, input, rand_perm).setDtype(input.dtype())));
614+
inp.emplace_back(std::move(
615+
_perm_ss(ctx, _2s(ctx, input), rand_perm).setDtype(input.dtype())));
616616
}
617617

618618
return inp;
@@ -1608,6 +1608,12 @@ std::vector<spu::Value> simple_sort1d(SPUContext *ctx,
16081608
// and the number of rounds increases (poly) logarithmically. In contrast,
16091609
// when the ring size doubles in radix sort, the communication (roughly)
16101610
// quadruples and the number of rounds doubles.
1611+
// 6. The above conclusions regarding performance apply only to
1612+
// the cases of SECRET input and SECRET permutation. In reality, only radix
1613+
// sort has implemented a complete mechanism for selecting the best
1614+
// implementation based on visibility. The other implementations will use
1615+
// local computation only when all keys are public; in other cases, they
1616+
// will revert to the scenarios of SECRET input and SECRET permutation.
16111617
//
16121618

16131619
// if all keys are public, fallback to plaintext sort.

src/libspu/kernel/hal/soprf.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ namespace spu::kernel::hal {
2323
Value soprf(SPUContext* ctx, const Value& x) {
2424
SPU_TRACE_HAL_LEAF(ctx, x);
2525

26+
if (x.numel() == 0) {
27+
return x;
28+
}
29+
2630
// currently, wo only support LowMC block cipher
2731
SPU_ENFORCE(ctx->hasKernel("lowmc_b"));
2832
auto inp = x;
@@ -64,6 +68,10 @@ Value soprf(SPUContext* ctx, absl::Span<const spu::Value> inputs) {
6468
}),
6569
"not all element has same dtype");
6670

71+
if (inputs.front().numel() == 0) {
72+
return inputs.front();
73+
}
74+
6775
std::vector<Value> inp;
6876
inp.reserve(inputs.size());
6977
for (const auto& v : inputs) {

src/libspu/kernel/hlo/BUILD.bazel

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,9 @@ spu_cc_test(
349349
deps = [
350350
":shuffle",
351351
"//libspu/kernel:test_util",
352+
"//libspu/kernel/hlo:casting",
353+
"//libspu/kernel/hlo:const",
354+
"//libspu/mpc/utils:simulate",
352355
],
353356
)
354357

src/libspu/kernel/hlo/permute_test.cc

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,6 @@ xt::xarray<T> evalSinglePermuteOp(SPUContext* ctx, VisType x_vis,
123123
PtBufferView perm,
124124
const PermuteFunc& perm_func,
125125
int64_t perm_dim = 0) {
126-
const auto prot = ctx->config().protocol;
127-
128126
auto x_v = makeTestValue(ctx, x, x_vis);
129127
auto perm_v = makeTestValue(ctx, perm, perm_vis);
130128

@@ -139,9 +137,7 @@ xt::xarray<T> evalSinglePermuteOp(SPUContext* ctx, VisType x_vis,
139137
EXPECT_EQ(send_round, 0);
140138
}
141139

142-
// costs of cheetah is highly dependant of OT kind, so we skip it.
143-
if (prot != CHEETAH && ctx->hasKernel("inv_perm_av") &&
144-
checkSpPass(x_vis, perm_vis)) {
140+
if (ctx->hasKernel("inv_perm_av") && checkSpPass(x_vis, perm_vis)) {
145141
auto n_repeat = x_v.shape().numel() / x_v.shape().dim(perm_dim);
146142
// For ss version, at least 3 rounds.
147143
EXPECT_LE(std::min(send_round, recv_round), 2 * n_repeat);
@@ -194,11 +190,10 @@ std::vector<PermuteParams> GetValidParamsCombinations() {
194190

195191
for (const auto& vis_x : kVisTypes) {
196192
for (const auto& vis_perm : kVisTypes) {
197-
for (const auto& protocol : {CHEETAH, SEMI2K, ABY3}) {
193+
for (const auto& protocol : {SEMI2K, ABY3}) {
198194
for (const auto& npc : {2, 3}) {
199-
// npc=2/3 is not valid in ABY3/CHEETAH
200-
if ((protocol == ABY3 && npc == 2) ||
201-
(protocol == CHEETAH && npc == 3)) {
195+
// npc=2 is not valid in ABY3
196+
if (protocol == ABY3 && npc == 2) {
202197
continue; // Skip invalid combinations
203198
}
204199
valid_combinations.emplace_back(vis_x, vis_perm, protocol, npc);
@@ -327,17 +322,14 @@ TEST_P(PermuteTest, MultiplePermuteWork) {
327322
class PermuteEmptyTest : public ::testing::TestWithParam<ProtocolKind> {};
328323

329324
INSTANTIATE_TEST_SUITE_P(
330-
PermuteEmpty, PermuteEmptyTest, testing::Values(CHEETAH, SEMI2K, ABY3),
325+
PermuteEmpty, PermuteEmptyTest, testing::Values(SEMI2K, ABY3),
331326
[](const testing::TestParamInfo<PermuteEmptyTest::ParamType>& p) {
332327
return fmt::format("{}", p.param);
333328
});
334329

335330
TEST_P(PermuteEmptyTest, Empty) {
336331
ProtocolKind prot = GetParam();
337332
size_t npc = 3;
338-
if (prot == CHEETAH) {
339-
npc = 2;
340-
}
341333

342334
mpc::utils::simulate(
343335
npc, [&](const std::shared_ptr<yacl::link::Context>& lctx) {

src/libspu/kernel/hlo/rank.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,14 @@ std::vector<spu::Value> TopK(SPUContext *ctx, const spu::Value &input,
7979
int64_t k_lo, int64_t k_hi, bool largest,
8080
bool value_only) {
8181
const Shape &shape = input.shape();
82+
83+
if (shape.numel() == 0) {
84+
if (value_only) {
85+
return std::vector<spu::Value>(1, input);
86+
}
87+
return std::vector<spu::Value>(2, input);
88+
}
89+
8290
SPU_ENFORCE(shape.numel() > 0, "input must non-empty.");
8391
SPU_ENFORCE(
8492
k_lo <= shape.back() && k_lo > 0,

src/libspu/kernel/hlo/rank_test.cc

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,12 +374,32 @@ TEST_P(TopkTest, ValueOnlyTest) {
374374
});
375375
}
376376

377+
TEST_P(TopkTest, EmptyTest) {
378+
size_t npc = std::get<0>(GetParam());
379+
FieldType field = std::get<1>(GetParam());
380+
ProtocolKind prot = std::get<2>(GetParam());
381+
382+
mpc::utils::simulate(
383+
npc, [&](const std::shared_ptr<yacl::link::Context> &lctx) {
384+
SPUContext sctx = test::makeSPUContext(prot, field, lctx);
385+
auto empty_x = test::makeValue(&sctx, 1, VIS_SECRET, DT_INVALID, {0});
386+
387+
auto out = TopK(&sctx, empty_x, 1, 1);
388+
EXPECT_EQ(out.size(), 2);
389+
EXPECT_EQ(out[0].numel(), 0);
390+
EXPECT_EQ(out[1].numel(), 0);
391+
EXPECT_EQ(out[0].shape().size(), 1);
392+
EXPECT_EQ(out[1].shape().size(), 1);
393+
EXPECT_EQ(out[0].shape()[0], 0);
394+
EXPECT_EQ(out[1].shape()[0], 0);
395+
});
396+
}
397+
377398
INSTANTIATE_TEST_SUITE_P(
378399
Topk2PCTestInstances, TopkTest,
379400
testing::Combine(testing::Values(2),
380401
testing::Values(FieldType::FM64, FieldType::FM128),
381-
testing::Values(ProtocolKind::SEMI2K,
382-
ProtocolKind::CHEETAH)),
402+
testing::Values(ProtocolKind::SEMI2K)),
383403
[](const testing::TestParamInfo<TopkTest::ParamType> &p) {
384404
return fmt::format("{}x{}x{}", std::get<0>(p.param), std::get<1>(p.param),
385405
std::get<2>(p.param));

src/libspu/kernel/hlo/shuffle.cc

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,26 @@ std::vector<spu::Value> Shuffle(SPUContext* ctx,
3939
absl::Span<const spu::Value> inputs,
4040
int64_t axis) {
4141
SPU_ENFORCE_GT(inputs.size(), 0U);
42-
if (inputs[0].numel() == 0) {
42+
auto input_shape = inputs[0].shape();
43+
SPU_ENFORCE(std::all_of(inputs.begin() + 1, inputs.end(),
44+
[&](const spu::Value& v) {
45+
return v.shape() == input_shape;
46+
}),
47+
"all inputs should have the same shape");
48+
49+
// edge case: empty or single element tensor
50+
if (inputs[0].numel() <= 1) {
4351
return std::vector<spu::Value>(inputs.begin(), inputs.end());
4452
}
45-
auto input_shape = inputs[0].shape();
4653

4754
// TODO: Rename permute-related kernels
4855
if (ctx->hasKernel("rand_perm_m") && ctx->hasKernel("perm_am")) {
4956
auto shuffle_fn = [&](absl::Span<const spu::Value> input) {
5057
std::vector<spu::Value> rets;
51-
auto rand_perm = hal::_rand_perm_s(ctx, input_shape);
52-
for (size_t i = 0; i < input.size(); ++i) {
53-
rets.emplace_back(hal::_perm_ss(ctx, _2s(ctx, input[i]), rand_perm)
54-
.setDtype(input[i].dtype()));
58+
auto rand_perm = hal::_rand_perm_s(ctx, {input_shape.dim(axis)});
59+
for (const auto& inp : input) {
60+
rets.emplace_back(
61+
hal::_perm_ss(ctx, _2s(ctx, inp), rand_perm).setDtype(inp.dtype()));
5562
}
5663
return rets;
5764
};

0 commit comments

Comments
 (0)