Skip to content

Commit 8caf034

Browse files
authored
repo sync (#215)
1 parent 6e67a19 commit 8caf034

File tree

12 files changed

+69
-41
lines changed

12 files changed

+69
-41
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ chmod +x traceconv
205205
```
206206
4. Open chrome://tracing in your chrome and load JSON file.
207207

208+
208209
## PSI V2 Benchamrk
209210

210211
Please refer to [PSI V2 Benchmark](docs/user_guide/psi_v2_benchmark.md)

benchmark/docker-compose/.env

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
# OPENSOURCE-CLEANUP GSUB psi:latest secretflow/psi:latest
21
# docker env
3-
IMAGE_WITH_TAG=secretflow/psi-anolis8:0.4.2b0
2+
IMAGE_WITH_TAG=secretflow/psi:latest
43

54
# network env
65
# LATENCY=10ms

benchmark/stats.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
import csv
1818
import sys
19+
import os
1920
import time
2021
from datetime import datetime
2122

@@ -40,7 +41,7 @@ def stream_container_stats(container_name, output_file):
4041
data = json.loads(stats)
4142
running_time_s = int(time.time()) - start_unix_time
4243
cpu_percent = ((data['cpu_stats']['cpu_usage']['total_usage'] - prev_cpu_total) /
43-
(data['cpu_stats']['system_cpu_usage'] - prev_cpu_system)) * 100
44+
(data['cpu_stats']['system_cpu_usage'] - prev_cpu_system)) * 100 * os.cpu_count()
4445
mem_usage = (data['memory_stats']['usage'] - data['memory_stats']['stats']['inactive_file']) / 1024 / 1024
4546
mem_limit = data['memory_stats']['limit'] / 1024 / 1024
4647
net_tx = 0

docker/entry.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ cd src_copied
77

88
conda install -y perl=5.20.3.1
99

10-
bazel build psi:main -c opt --config=linux-release --repository_cache=/tmp/bazel_repo_cache
10+
bazel build psi:main -c opt --config=linux-release --remote_timeout=300s --remote_retries=10
1111
chmod 777 bazel-bin/psi/main
1212
mkdir -p ../src/docker/linux/amd64
1313
cp bazel-bin/psi/main ../src/docker/linux/amd64

experiment/pir/pps/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ psi_cc_library(
3434
deps = [
3535
":ggm_pset",
3636
"@yacl//yacl/base:dynamic_bitset",
37+
"@yacl//yacl/base:exception",
3738
"@yacl//yacl/crypto/rand",
3839
"@yacl//yacl/crypto/tools:prg",
3940
],

experiment/pir/pps/client.cc

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
#include <spdlog/spdlog.h>
1818

19+
#include "yacl/base/exception.h"
20+
1921
namespace pir::pps {
2022

2123
bool PpsPirClient::Bernoulli() {
@@ -34,17 +36,22 @@ uint64_t PpsPirClient::GetRandomU64Less() {
3436
// Generate sk and m random numbers \in [n]
3537
void PpsPirClient::Setup(PIRKey& sk, std::set<uint64_t>& deltas) {
3638
sk = pps_.Gen(lambda_);
39+
40+
size_t max_try_count = 10 * M();
41+
size_t count = 0;
42+
3743
// The map.size() must be equal to SET_SIZE.
38-
std::vector<uint64_t> rand =
39-
yacl::crypto::PrgAesCtr<uint64_t>(yacl::crypto::RandU64(), M());
40-
for (uint64_t i = 0; i < M(); i++) {
41-
// The most expensive operation.
42-
uint64_t r = LemireTrick(rand[i], universe_size_);
44+
size_t i = 0;
45+
while (i < M() && count < max_try_count) {
46+
count += 1;
47+
uint64_t r = LemireTrick(yacl::crypto::RandU64(), universe_size_);
4348
if (!deltas.insert(r).second) {
44-
rand[i] = yacl::crypto::RandU64();
45-
i--;
49+
continue;
4650
}
51+
++i;
4752
}
53+
54+
YACL_ENFORCE(count < max_try_count);
4855
}
4956

5057
// Params:
@@ -91,18 +98,24 @@ void PpsPirClient::Setup(std::vector<PIRKeyUnion>& ck,
9198
std::vector<std::unordered_set<uint64_t>>& v) {
9299
ck.resize(MM());
93100
v.resize(MM());
94-
std::vector<uint128_t> rand =
95-
yacl::crypto::PrgAesCtr<uint128_t>(yacl::crypto::RandU128(), MM());
96-
for (uint64_t i = 0; i < MM(); ++i) {
97-
pps_.Eval(rand[i], v[i]);
101+
102+
size_t max_try_count = 10 * MM();
103+
size_t count = 0;
104+
105+
size_t i = 0;
106+
while (i < MM() && count < max_try_count) {
107+
count += 1;
108+
auto rand = yacl::crypto::RandU128();
109+
pps_.Eval(rand, v[i]);
98110
if (v[i].size() == set_size_) {
99-
ck[i] = PIRKeyUnion(rand[i]);
111+
ck[i] = PIRKeyUnion(rand);
100112
} else {
101113
v[i].clear();
102-
rand[i] = yacl::crypto::RandU128();
103-
--i;
114+
continue;
104115
}
116+
++i;
105117
}
118+
YACL_ENFORCE(count < max_try_count);
106119
}
107120

108121
void PpsPirClient::Query(uint64_t i, std::vector<PIRKeyUnion>& ck,

experiment/pir/pps/pps_pir_benchmark.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ static void BM_PpsSingleBitPir(benchmark::State& state) {
4646
pir::pps::PpsPirServer pirOfflineServer(n * n, n);
4747
pir::pps::PpsPirServer pirOnlineServer(n * n, n);
4848

49-
pir::pps::PIRKey pirKey, pirKeyOffline;
50-
pir::pps::PIRQueryParam pirQueryParam;
51-
pir::pps::PIRPuncKey pirPuncKey, pirPuncKeyOnline;
52-
std::set<uint64_t> deltas, deltasOffline;
49+
pir::pps::PIRKey pirKey{}, pirKeyOffline{};
50+
pir::pps::PIRQueryParam pirQueryParam{};
51+
pir::pps::PIRPuncKey pirPuncKey{}, pirPuncKeyOnline{};
52+
std::set<uint64_t> deltas{}, deltasOffline{};
5353
yacl::dynamic_bitset<> bits;
5454
GenerateRandomBitString(bits, n * n);
5555
yacl::dynamic_bitset<> h, hOffline;
5656
uint64_t query_index = pirClient.GetRandomU64Less();
57-
bool query_result;
57+
bool query_result{};
5858

5959
constexpr int kWorldSize = 2;
6060
const auto contextsOffline = yacl::link::test::SetupWorld(kWorldSize);
@@ -102,7 +102,7 @@ static void BM_PpsSingleBitPir(benchmark::State& state) {
102102
recver_future.get();
103103

104104
bool a = pirOnlineServer.Answer(pirPuncKeyOnline, bits);
105-
bool aClient;
105+
bool aClient{};
106106

107107
sender_future =
108108
std::async(std::launch::async, pir::pps::OnlineServerSendToClient,
@@ -129,13 +129,13 @@ static void BM_PpsMultiBitsPir(benchmark::State& state) {
129129
pir::pps::PpsPirServer pirOfflineServer(n * n, n);
130130
pir::pps::PpsPirServer pirOnlineServer(n * n, n);
131131

132-
std::vector<pir::pps::PIRKeyUnion> pirKey, pirKeyOffline;
132+
std::vector<pir::pps::PIRKeyUnion> pirKey{}, pirKeyOffline{};
133133
yacl::dynamic_bitset<> bits;
134134
GenerateRandomBitString(bits, n * n);
135135
yacl::dynamic_bitset<> h, hOffline;
136-
pir::pps::PIRQueryParam pirParam;
136+
pir::pps::PIRQueryParam pirParam{};
137137

138-
bool aLeft, aRight, aLeftOnline, aRightOnline, queryResult;
138+
bool aLeft{}, aRight{}, aLeftOnline{}, aRightOnline{}, queryResult{};
139139
std::vector<std::unordered_set<uint64_t>> v;
140140

141141
constexpr int kWorldSize = 2;
@@ -170,8 +170,8 @@ static void BM_PpsMultiBitsPir(benchmark::State& state) {
170170
recver_future.get();
171171

172172
for (uint i = 0; i < n * n; ++i) {
173-
pir::pps::PIRPuncKey pirPuncKeyL, pirPuncKeyR;
174-
pir::pps::PIRPuncKey pirPuncKeyLOnline, pirPuncKeyROnline;
173+
pir::pps::PIRPuncKey pirPuncKeyL{}, pirPuncKeyR{};
174+
pir::pps::PIRPuncKey pirPuncKeyLOnline{}, pirPuncKeyROnline{};
175175

176176
pirClient.Query(i, pirKey, v, pirParam, pirPuncKeyL, pirPuncKeyR);
177177

experiment/pir/pps/sender.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
namespace pir::pps {
2424
std::array<std::byte, 16> Uint128_to_bytes(PIRKey sk) {
25-
std::array<std::byte, 16> bytes;
25+
std::array<std::byte, 16> bytes{};
2626
uint64_t high = static_cast<uint64_t>(sk >> 64);
2727
uint64_t low = static_cast<uint64_t>(sk & 0xFFFFFFFFFFFFFFFF);
2828
std::memcpy(bytes.data(), &high, sizeof(high));

psi/apsi_wrapper/api/receiver_c_wrapper.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Receiver* BucketReceiverMake(size_t bucket_cnt, size_t thread_count) {
3636
}
3737

3838
void BucketReceiverFree(Receiver** receiver) {
39-
if (receiver != nullptr || *receiver == nullptr) {
39+
if (receiver == nullptr || *receiver == nullptr) {
4040
return;
4141
}
4242
(void)std::unique_ptr<ApiReceiver>(reinterpret_cast<ApiReceiver*>(*receiver));

psi/rr22/rr22_psi.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,13 +257,18 @@ class Rr22Runner {
257257
futures[i] = std::async(
258258
std::launch::async,
259259
[&](size_t thread_idx) {
260+
std::shared_ptr<yacl::link::Context> spawn_read_lctx =
261+
read_lctx_->Spawn(std::to_string(thread_idx));
262+
std::shared_ptr<yacl::link::Context> spawn_run_lctx =
263+
run_lctx_->Spawn(std::to_string(thread_idx));
264+
std::shared_ptr<yacl::link::Context> spawn_intersection_lctx =
265+
intersection_lctx_->Spawn(std::to_string(thread_idx));
260266
for (size_t j = 0; j < bucket_num_; j++) {
261267
if (j % parallel_num == thread_idx) {
262268
auto runner = CreateBucketRunner(j, is_sender);
263-
runner->Prepare(read_lctx_->Spawn(std::to_string(thread_idx)));
264-
runner->RunOprf(run_lctx_->Spawn(std::to_string(thread_idx)));
265-
runner->GetIntersection(
266-
intersection_lctx_->Spawn(std::to_string(thread_idx)));
269+
runner->Prepare(spawn_read_lctx);
270+
runner->RunOprf(spawn_run_lctx);
271+
runner->GetIntersection(spawn_intersection_lctx);
267272
}
268273
}
269274
},

0 commit comments

Comments
 (0)