Skip to content

Commit cc6b533

Browse files
authored
Faster matching with threads (#1)
* server: adding threads for faster matching * threads working
1 parent 8411025 commit cc6b533

File tree

5 files changed

+120
-91
lines changed

5 files changed

+120
-91
lines changed

mpc-server/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@ dotenv = "0.15.0"
3737
toml = "0.8.22"
3838
hex = "0.4.3"
3939
axum-server = { version = "0.7.2", features = ["tls-rustls"] }
40+
rayon = "1.10.0"

mpc-server/src/db.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ use crate::matching::DATA_DIR;
66
#[derive(Debug, Clone)]
77
pub struct User {
88
pub id: String,
9+
#[allow(dead_code)]
910
pub twitter_handle: String,
1011
pub checked: Vec<String>,
1112
}
1213

1314
#[derive(Debug, Clone)]
1415
pub struct Match {
16+
#[allow(dead_code)]
1517
pub id: u32,
1618
pub user_id1: String,
1719
pub user_id2: String,
@@ -117,6 +119,22 @@ pub fn update_checked(
117119
Ok(())
118120
}
119121

122+
pub fn update_checked_many(
123+
conn: &Connection,
124+
user_ids: Vec<String>,
125+
new_checked: Vec<String>,
126+
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
127+
let mut stmt = conn.prepare("UPDATE users SET checked = ?1 WHERE id = ?2")?;
128+
129+
for user_id in user_ids {
130+
let user = get_user(conn, &user_id)?;
131+
let mut checked: HashSet<String> = user.checked.into_iter().collect();
132+
checked.extend(new_checked.clone());
133+
stmt.execute((serde_json::to_string(&checked)?, user_id))?;
134+
}
135+
Ok(())
136+
}
137+
120138
pub fn insert_matches(
121139
conn: &Connection,
122140
matches: Vec<(String, String)>,

mpc-server/src/main.rs

Lines changed: 27 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use axum::{
55
routing::{get, post},
66
};
77
use axum_server::tls_rustls::RustlsConfig;
8-
use co_noir::{Address, Bn254, CrsParser, NetworkParty, PartyID, Utils};
8+
use co_noir::{Bn254, CrsParser, Utils};
99
use co_ultrahonk::prelude::{ProverCrs, ZeroKnowledge};
1010
use rustls::pki_types::CertificateDer;
1111
use serde::Deserialize;
@@ -61,22 +61,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
6161

6262
setup_db()?;
6363

64-
let parties = vec![
65-
NetworkParty::new(
66-
PartyID::ID0.into(),
67-
Address::new("localhost".to_string(), 10000),
68-
CertificateDer::from(std::fs::read(CONFIG_DIR.join("cert0.der"))?).into_owned(),
69-
),
70-
NetworkParty::new(
71-
PartyID::ID1.into(),
72-
Address::new("localhost".to_string(), 10001),
73-
CertificateDer::from(std::fs::read(CONFIG_DIR.join("cert1.der"))?).into_owned(),
74-
),
75-
NetworkParty::new(
76-
PartyID::ID2.into(),
77-
Address::new("localhost".to_string(), 10002),
78-
CertificateDer::from(std::fs::read(CONFIG_DIR.join("cert2.der"))?).into_owned(),
79-
),
64+
let parties_certs = [
65+
CertificateDer::from(std::fs::read(CONFIG_DIR.join("cert0.der"))?).into_owned(),
66+
CertificateDer::from(std::fs::read(CONFIG_DIR.join("cert1.der"))?).into_owned(),
67+
CertificateDer::from(std::fs::read(CONFIG_DIR.join("cert2.der"))?).into_owned(),
8068
];
8169

8270
let program_artifact = Utils::get_program_artifact_from_file(DATA_DIR.join("circuit.json"))?;
@@ -128,9 +116,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
128116
post(move |token: Token| async move {
129117
match run_matches(
130118
token.user_id,
131-
parties,
119+
parties_certs,
132120
&program_artifact,
133-
constraint_system,
121+
constraint_system.clone(),
134122
recursive,
135123
has_zk,
136124
prover_crs,
@@ -201,8 +189,8 @@ mod tests {
201189
use super::*;
202190
use crate::{matching::run_match, shares::split_input};
203191

204-
#[tokio::test]
205-
async fn test_match() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
192+
#[test]
193+
fn test_match() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
206194
let prover_toml = r#"[user1]
207195
age = 30
208196
region = 1
@@ -233,22 +221,10 @@ gender = 0"#;
233221
.install_default()
234222
.unwrap();
235223

236-
let parties = vec![
237-
NetworkParty::new(
238-
PartyID::ID0.into(),
239-
Address::new("localhost".to_string(), 10000),
240-
CertificateDer::from(std::fs::read(CONFIG_DIR.join("cert0.der"))?).into_owned(),
241-
),
242-
NetworkParty::new(
243-
PartyID::ID1.into(),
244-
Address::new("localhost".to_string(), 10001),
245-
CertificateDer::from(std::fs::read(CONFIG_DIR.join("cert1.der"))?).into_owned(),
246-
),
247-
NetworkParty::new(
248-
PartyID::ID2.into(),
249-
Address::new("localhost".to_string(), 10002),
250-
CertificateDer::from(std::fs::read(CONFIG_DIR.join("cert2.der"))?).into_owned(),
251-
),
224+
let parties_certs = [
225+
CertificateDer::from(std::fs::read(CONFIG_DIR.join("cert0.der"))?).into_owned(),
226+
CertificateDer::from(std::fs::read(CONFIG_DIR.join("cert1.der"))?).into_owned(),
227+
CertificateDer::from(std::fs::read(CONFIG_DIR.join("cert2.der"))?).into_owned(),
252228
];
253229

254230
let program_artifact =
@@ -274,19 +250,20 @@ gender = 0"#;
274250

275251
let shares = split_input(PathBuf::from("Prover.toml"), &program_artifact)?;
276252

277-
let result = run_match(
278-
shares,
279-
parties.clone(),
280-
&program_artifact,
281-
constraint_system.clone(),
282-
recursive,
283-
has_zk,
284-
prover_crs.clone(),
285-
verifier_crs.clone(),
286-
)
287-
.await;
288-
289-
println!("result: {:?}", result);
253+
for i in 0..5 {
254+
let result = run_match(
255+
i,
256+
shares.clone(),
257+
parties_certs.clone(),
258+
&program_artifact,
259+
constraint_system.clone(),
260+
recursive,
261+
has_zk,
262+
prover_crs.clone(),
263+
verifier_crs.clone(),
264+
);
265+
println!("result: {:?}", result);
266+
}
290267

291268
std::fs::remove_file("Prover.toml").unwrap();
292269

mpc-server/src/matching.rs

Lines changed: 73 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
use co_noir::{
2-
AcirFormat, Bn254, NetworkConfig, NetworkParty, PartyID, Poseidon2Sponge, Rep3CoUltraHonk,
3-
Rep3MpcNet, UltraHonk, merge_input_shares,
2+
AcirFormat, Address, Bn254, NetworkConfig, NetworkParty, PartyID, Poseidon2Sponge,
3+
Rep3CoUltraHonk, Rep3MpcNet, UltraHonk, merge_input_shares,
44
};
55
use co_ultrahonk::prelude::{ProverCrs, ZeroKnowledge};
66
use noirc_artifacts::program::ProgramArtifact;
77
use once_cell::sync::Lazy;
8+
use rayon::prelude::*;
9+
use rustls::pki_types::CertificateDer;
810
use rustls::pki_types::{PrivateKeyDer, PrivatePkcs8KeyDer};
9-
use std::sync::Arc;
10-
use std::thread;
1111
use std::{
1212
path::PathBuf,
13+
sync::Arc,
14+
thread,
1315
time::{Duration, Instant},
1416
};
1517

16-
use crate::db::{connect_db, get_all_users, get_user, insert_matches, update_checked};
18+
use crate::db::{
19+
connect_db, get_all_users, get_user, insert_matches, update_checked, update_checked_many,
20+
};
1721
use crate::shares::{Share, get_shares};
1822

1923
pub const DATA_DIR: Lazy<PathBuf> =
@@ -25,7 +29,7 @@ pub const SHARES_DIR_2: Lazy<PathBuf> = Lazy::new(|| DATA_DIR.join("user2"));
2529

2630
pub async fn run_matches(
2731
user_id: String,
28-
parties: Vec<NetworkParty>,
32+
parties_certs: [CertificateDer<'static>; 3],
2933
program_artifact: &ProgramArtifact,
3034
constraint_system: Arc<AcirFormat<ark_bn254::Fr>>,
3135
recursive: bool,
@@ -52,38 +56,45 @@ pub async fn run_matches(
5256
.collect::<Vec<String>>(),
5357
)?;
5458

55-
let mut verified_matches = Vec::new();
56-
57-
for user2 in all_users {
58-
update_checked(&conn, &user2.id, vec![user_id.clone()])?;
59+
let users2 = all_users
60+
.iter()
61+
.map(|u| u.id.clone())
62+
.collect::<Vec<String>>();
5963

60-
let shares_user1 = get_shares(&user1.id, true)?;
61-
let shares_user2 = get_shares(&user2.id, false)?;
64+
let verified_matches = all_users
65+
.into_par_iter()
66+
.enumerate()
67+
.map(
68+
|(thread_id, user2)| -> Result<String, Box<dyn std::error::Error + Send + Sync + 'static>> {
69+
let shares_user1 = get_shares(&user1.id, true)?;
70+
let shares_user2 = get_shares(&user2.id, false)?;
6271

63-
let share0 = merge_shares(shares_user1[0].clone(), shares_user2[0].clone())?;
64-
let share1 = merge_shares(shares_user1[1].clone(), shares_user2[1].clone())?;
65-
let share2 = merge_shares(shares_user1[2].clone(), shares_user2[2].clone())?;
72+
let share0 = merge_shares(shares_user1[0].clone(), shares_user2[0].clone())?;
73+
let share1 = merge_shares(shares_user1[1].clone(), shares_user2[1].clone())?;
74+
let share2 = merge_shares(shares_user1[2].clone(), shares_user2[2].clone())?;
6675

67-
match run_match(
68-
[share0, share1, share2],
69-
parties.clone(),
70-
program_artifact,
71-
constraint_system.clone(),
72-
recursive,
73-
has_zk,
74-
prover_crs.clone(),
75-
verifier_crs.clone(),
76+
match run_match(
77+
thread_id,
78+
[share0, share1, share2],
79+
parties_certs.clone(),
80+
program_artifact,
81+
constraint_system.clone(),
82+
recursive,
83+
has_zk,
84+
prover_crs.clone(),
85+
verifier_crs.clone(),
86+
) {
87+
Ok(_) => Ok(user2.id),
88+
Err(e) => Err(e),
89+
}
90+
},
7691
)
77-
.await
78-
{
79-
Ok(_) => {
80-
verified_matches.push(user2.id);
81-
}
82-
Err(e) => {
83-
println!("Error: {:?}", e);
84-
}
85-
}
86-
}
92+
.filter(|m| m.is_ok())
93+
.collect::<Result<Vec<String>, _>>()?;
94+
95+
println!("verified matches: {verified_matches:?}");
96+
97+
update_checked_many(&conn, users2, vec![user_id.clone()])?;
8798

8899
insert_matches(
89100
&conn,
@@ -95,9 +106,10 @@ pub async fn run_matches(
95106
Ok(())
96107
}
97108

98-
pub async fn run_match(
109+
pub fn run_match(
110+
thread_id: usize,
99111
[share0, share1, share2]: [Share; 3],
100-
parties: Vec<NetworkParty>,
112+
parties_certs: [CertificateDer<'static>; 3],
101113
program_artifact: &ProgramArtifact,
102114
constraint_system: Arc<AcirFormat<ark_bn254::Fr>>,
103115
recursive: bool,
@@ -107,9 +119,31 @@ pub async fn run_match(
107119
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
108120
let match_time = Instant::now();
109121

122+
let party0_port = 10000 + thread_id as u16;
123+
let party1_port = 11000 + thread_id as u16;
124+
let party2_port = 12000 + thread_id as u16;
125+
126+
let parties = vec![
127+
NetworkParty::new(
128+
PartyID::ID0.into(),
129+
Address::new("localhost".to_string(), party0_port),
130+
parties_certs[0].clone(),
131+
),
132+
NetworkParty::new(
133+
PartyID::ID1.into(),
134+
Address::new("localhost".to_string(), party1_port),
135+
parties_certs[1].clone(),
136+
),
137+
NetworkParty::new(
138+
PartyID::ID2.into(),
139+
Address::new("localhost".to_string(), party2_port),
140+
parties_certs[2].clone(),
141+
),
142+
];
143+
110144
let data0 = DataForThread {
111145
id: PartyID::ID0,
112-
port: 10000,
146+
port: party0_port,
113147
key: PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(std::fs::read(
114148
CONFIG_DIR.join("key0.der"),
115149
)?))
@@ -125,7 +159,7 @@ pub async fn run_match(
125159
};
126160
let data1 = DataForThread {
127161
id: PartyID::ID1,
128-
port: 10001,
162+
port: party1_port,
129163
key: PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(std::fs::read(
130164
CONFIG_DIR.join("key1.der"),
131165
)?))
@@ -141,7 +175,7 @@ pub async fn run_match(
141175
};
142176
let data2 = DataForThread {
143177
id: PartyID::ID2,
144-
port: 10002,
178+
port: party2_port,
145179
key: PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(std::fs::read(
146180
CONFIG_DIR.join("key2.der"),
147181
)?))
@@ -237,7 +271,6 @@ fn spawn_party(
237271
println!("pk time: {:?}", pk_time.elapsed());
238272

239273
let proof_time = Instant::now();
240-
// let (proof, _) = Rep3CoUltraHonk::<_, _, Poseidon2Sponge>::prove(net, pk, &prover_crs, has_zk)?;
241274
let (proof, _) = Rep3CoUltraHonk::<_, _, Poseidon2Sponge>::prove(net, pk, &prover_crs, has_zk)?;
242275
println!("proof time: {:?}", proof_time.elapsed());
243276

web-app/src/pages/Matches.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ export default function Matches() {
7373
</h2>
7474
<div className="p-4 bg-gradient-to-r from-pink-50 to-purple-50 rounded-xl border border-pink-100 mb-6">
7575
<p className="text-gray-600 mb-2">
76-
This might take up to 45 seconds - we&apos;re doing some heavy cryptographic lifting behind the scenes! 🔐
76+
This might take up to 15 seconds - we&apos;re doing some heavy cryptographic lifting behind the scenes! 🔐
7777
</p>
7878
<p className="text-sm text-gray-500 mb-2">While you wait, here&apos;s what&apos;s happening:</p>
7979
<ul className="list-disc list-inside text-sm text-gray-500 mt-2 space-y-1">

0 commit comments

Comments
 (0)