Skip to content

Commit 7769109

Browse files
apollo_propeller: implemented MessageProcessor
1 parent 5428236 commit 7769109

File tree

4 files changed

+262
-21
lines changed

4 files changed

+262
-21
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ pyo3-log = "0.8.1"
336336
quote = "1.0.26"
337337
rand = "0.8.5"
338338
rand_chacha = "0.3.1"
339+
rayon = "1.10"
339340
rand_distr = "0.4.3"
340341
reed-solomon-simd = "3.1.0"
341342
regex = "1.10.4"

crates/apollo_propeller/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ asynchronous-codec.workspace = true
1616
futures.workspace = true
1717
libp2p.workspace = true
1818
prost.workspace = true
19+
rand.workspace = true
20+
rayon.workspace = true
1921
reed-solomon-simd.workspace = true
2022
sha2.workspace = true
2123
thiserror.workspace = true

crates/apollo_propeller/src/message_processor.rs

Lines changed: 257 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,20 @@ use std::sync::Arc;
33
use std::time::Duration;
44

55
use libp2p::identity::{PeerId, PublicKey};
6-
use tokio::sync::mpsc;
7-
use tokio::time::sleep_until;
6+
use rand::seq::SliceRandom;
7+
use tokio::sync::{mpsc, oneshot};
88
use tracing::{debug, trace};
99

10+
use crate::sharding::reconstruct_message_from_shards;
1011
use crate::tree::PropellerScheduleManager;
11-
use crate::types::{Channel, Event, MessageRoot};
12+
use crate::types::{Channel, Event, MessageRoot, ReconstructionError, ShardValidationError};
1213
use crate::unit::PropellerUnit;
13-
use crate::ShardIndex;
14+
use crate::unit_validator::UnitValidator;
15+
use crate::{MerkleProof, ShardIndex};
1416

1517
pub type UnitToValidate = (PeerId, PropellerUnit);
18+
type ValidationResult = (Result<(), ShardValidationError>, UnitValidator, PropellerUnit);
19+
type ReconstructionResult = Result<ReconstructionSuccess, ReconstructionError>;
1620

1721
#[derive(Debug)]
1822
pub enum EventStateManagerToEngine {
@@ -21,6 +25,61 @@ pub enum EventStateManagerToEngine {
2125
BroadcastUnit { unit: PropellerUnit, peers: Vec<PeerId> },
2226
}
2327

28+
#[derive(Debug)]
29+
struct ReconstructionSuccess {
30+
message: Vec<u8>,
31+
my_shard: Vec<u8>,
32+
my_shard_proof: MerkleProof,
33+
}
34+
35+
/// Tracks reconstruction progress for a single message.
36+
struct ReconstructionState {
37+
received_shards: Vec<PropellerUnit>,
38+
received_my_index: bool,
39+
signature: Option<Vec<u8>>,
40+
reconstructed_message: Option<Vec<u8>>,
41+
count_at_reconstruction: usize,
42+
additional_shards_after_reconstruction: usize,
43+
}
44+
45+
impl ReconstructionState {
46+
fn new() -> Self {
47+
Self {
48+
received_shards: Vec::new(),
49+
received_my_index: false,
50+
signature: None,
51+
reconstructed_message: None,
52+
count_at_reconstruction: 0,
53+
additional_shards_after_reconstruction: 0,
54+
}
55+
}
56+
57+
fn is_reconstructed(&self) -> bool {
58+
self.reconstructed_message.is_some()
59+
}
60+
61+
fn record_shard(&mut self, is_my_shard: bool) {
62+
if is_my_shard {
63+
self.received_my_index = true;
64+
} else if self.is_reconstructed() {
65+
self.additional_shards_after_reconstruction += 1;
66+
}
67+
}
68+
69+
fn capture_signature(&mut self, unit: &PropellerUnit) {
70+
if self.signature.is_none() {
71+
self.signature = Some(unit.signature().to_vec());
72+
}
73+
}
74+
75+
/// Total shard count used for the access-threshold check.
76+
fn access_count(&self) -> usize {
77+
self.count_at_reconstruction
78+
+ self.additional_shards_after_reconstruction
79+
+ usize::from(!self.received_my_index)
80+
}
81+
}
82+
2483
/// Message processor that handles validation and state management for a single message.
2584
pub struct MessageProcessor {
2685
pub channel: Channel,
@@ -47,18 +106,10 @@ impl MessageProcessor {
47106
self.channel, self.publisher, self.message_root
48107
);
49108

50-
// Local state variables
51-
let deadline = tokio::time::Instant::now() + self.timeout;
52-
53-
// TODO(AndrewL): remove this
54-
#[allow(clippy::never_loop)]
55-
loop {
56-
tokio::select! {
57-
_ = sleep_until(deadline) => {
58-
let _ = self.emit_timeout_and_finalize().await;
59-
break;
60-
}
61-
}
109+
let timed_out = tokio::time::timeout(self.timeout, self.process_units()).await.is_err();
110+
111+
if timed_out {
112+
self.emit_timeout_and_finalize();
62113
}
63114

64115
debug!(
@@ -67,32 +118,217 @@ impl MessageProcessor {
67118
);
68119
}
69120

70-
async fn emit_timeout_and_finalize(&mut self) -> ControlFlow<()> {
121+
async fn process_units(&mut self) {
122+
let mut validator = UnitValidator::new(
123+
self.channel,
124+
self.publisher,
125+
self.publisher_public_key.clone(),
126+
self.message_root,
127+
Arc::clone(&self.tree_manager),
128+
);
129+
let mut state = ReconstructionState::new();
130+
131+
while let Some((sender, unit)) = self.unit_rx.recv().await {
132+
// TODO(AndrewL): finalize immediately if first validation fails (DOS attack vector)
133+
trace!("[MSG_PROC] Validating unit from sender={:?} index={:?}", sender, unit.index());
134+
135+
let (result, returned_validator, unit) =
136+
Self::validate_on_rayon(validator, sender, unit).await;
137+
validator = returned_validator;
138+
139+
if let Err(err) = result {
140+
trace!("[MSG_PROC] Validation failed for index={:?}: {:?}", unit.index(), err);
141+
continue;
142+
}
143+
144+
self.maybe_broadcast_my_shard(&unit, &state);
145+
state.record_shard(unit.index() == self.my_shard_index);
146+
state.capture_signature(&unit);
147+
148+
if self.advance_reconstruction(unit, &mut state).await.is_break() {
149+
return;
150+
}
151+
}
152+
71153
trace!(
72-
"[MSG_PROC] Timeout reached for channel={:?} publisher={:?} root={:?}",
154+
"[MSG_PROC] All channels closed for channel={:?} publisher={:?} root={:?}",
73155
self.channel,
74156
self.publisher,
75157
self.message_root
76158
);
159+
self.finalize();
160+
}
77161

78-
self.emit_and_finalize(Event::MessageTimeout {
79-
channel: self.channel,
162+
// --- Validation --------------------------------------------------------
163+
164+
/// Offloads CPU-bound validation (signature verification, merkle proofs) to the rayon thread
165+
/// pool to avoid blocking the tokio runtime. Benchmarked to outperform `spawn_blocking`.
166+
async fn validate_on_rayon(
167+
mut validator: UnitValidator,
168+
sender: PeerId,
169+
unit: PropellerUnit,
170+
) -> ValidationResult {
171+
let (tx, rx) = oneshot::channel();
172+
rayon::spawn(move || {
173+
let result = validator.validate_shard(sender, &unit);
174+
let _ = tx.send((result, validator, unit));
175+
});
176+
rx.await.expect("Validation task panicked")
177+
}
178+
179+
// --- Broadcasting ------------------------------------------------------
180+
181+
fn maybe_broadcast_my_shard(&self, unit: &PropellerUnit, state: &ReconstructionState) {
182+
if unit.index() == self.my_shard_index && !state.received_my_index {
183+
self.broadcast_shard(unit);
184+
}
185+
}
186+
187+
fn broadcast_shard(&self, unit: &PropellerUnit) {
188+
let mut peers: Vec<PeerId> = self
189+
.tree_manager
190+
.get_nodes()
191+
.iter()
192+
.map(|(p, _)| *p)
193+
.filter(|p| *p != self.publisher && *p != self.local_peer_id)
194+
.collect();
195+
peers.shuffle(&mut rand::thread_rng());
196+
trace!("[MSG_PROC] Broadcasting unit index={:?} to {} peers", unit.index(), peers.len());
197+
self.engine_tx
198+
.send(EventStateManagerToEngine::BroadcastUnit { unit: unit.clone(), peers })
199+
.expect("Engine task has exited");
200+
}
201+
202+
// --- Reconstruction ----------------------------------------------------
203+
204+
async fn advance_reconstruction(
205+
&self,
206+
unit: PropellerUnit,
207+
state: &mut ReconstructionState,
208+
) -> ControlFlow<()> {
209+
if state.is_reconstructed() {
210+
return self.maybe_emit_message(state);
211+
}
212+
213+
state.received_shards.push(unit);
214+
215+
if !self.tree_manager.should_build(state.received_shards.len()) {
216+
return ControlFlow::Continue(());
217+
}
218+
219+
trace!("[MSG_PROC] Starting reconstruction with {} shards", state.received_shards.len());
220+
state.count_at_reconstruction = state.received_shards.len();
221+
222+
match self.reconstruct_on_rayon(state).await {
223+
Ok(success) => self.handle_reconstruction_success(success, state),
224+
Err(e) => {
225+
tracing::error!("[MSG_PROC] Reconstruction failed: {:?}", e);
226+
self.emit_and_finalize(Event::MessageReconstructionFailed {
227+
publisher: self.publisher,
228+
message_root: self.message_root,
229+
error: e,
230+
})
231+
}
232+
}
233+
}
234+
235+
/// Offloads erasure-coding reconstruction to rayon.
236+
async fn reconstruct_on_rayon(&self, state: &mut ReconstructionState) -> ReconstructionResult {
237+
let shards = std::mem::take(&mut state.received_shards);
238+
let message_root = self.message_root;
239+
let my_index: usize = self.my_shard_index.0.try_into().unwrap();
240+
let data_count = self.tree_manager.num_data_shards();
241+
let coding_count = self.tree_manager.num_coding_shards();
242+
243+
let (tx, rx) = oneshot::channel();
244+
rayon::spawn(move || {
245+
let result = reconstruct_message_from_shards(
246+
shards,
247+
message_root,
248+
my_index,
249+
data_count,
250+
coding_count,
251+
)
252+
.map(|(message, my_shard, my_shard_proof)| ReconstructionSuccess {
253+
message,
254+
my_shard,
255+
my_shard_proof,
256+
});
257+
let _ = tx.send(result);
258+
});
259+
rx.await.expect("Reconstruction task panicked")
260+
}
261+
262+
fn handle_reconstruction_success(
263+
&self,
264+
success: ReconstructionSuccess,
265+
state: &mut ReconstructionState,
266+
) -> ControlFlow<()> {
267+
let ReconstructionSuccess { message, my_shard, my_shard_proof } = success;
268+
269+
if !state.received_my_index {
270+
let signature = state.signature.clone().expect("Signature must exist");
271+
let reconstructed_unit = PropellerUnit::new(
272+
self.channel,
273+
self.publisher,
274+
self.message_root,
275+
signature,
276+
self.my_shard_index,
277+
my_shard,
278+
my_shard_proof,
279+
);
280+
self.broadcast_shard(&reconstructed_unit);
281+
}
282+
283+
state.reconstructed_message = Some(message);
284+
self.maybe_emit_message(state)
285+
}
286+
287+
// --- Emission / finalization -------------------------------------------
288+
289+
fn maybe_emit_message(&self, state: &mut ReconstructionState) -> ControlFlow<()> {
290+
if !self.tree_manager.should_receive(state.access_count()) {
291+
return ControlFlow::Continue(());
292+
}
293+
294+
trace!("[MSG_PROC] Access threshold reached, emitting message");
295+
let message = state.reconstructed_message.take().expect("Message must exist");
296+
self.emit_and_finalize(Event::MessageReceived {
80297
publisher: self.publisher,
81298
message_root: self.message_root,
299+
message,
82300
})
83301
}
84302

303+
fn emit_timeout_and_finalize(&self) {
304+
trace!(
305+
"[MSG_PROC] Timeout reached for channel={:?} publisher={:?} root={:?}",
306+
self.channel,
307+
self.publisher,
308+
self.message_root
309+
);
310+
let _ = self.emit_and_finalize(Event::MessageTimeout {
311+
channel: self.channel,
312+
publisher: self.publisher,
313+
message_root: self.message_root,
314+
});
315+
}
316+
85317
fn emit_and_finalize(&self, event: Event) -> ControlFlow<()> {
86318
self.engine_tx
87319
.send(EventStateManagerToEngine::BehaviourEvent(event))
88320
.expect("Engine task has exited");
321+
self.finalize();
322+
ControlFlow::Break(())
323+
}
324+
325+
fn finalize(&self) {
89326
self.engine_tx
90327
.send(EventStateManagerToEngine::Finalized {
91328
channel: self.channel,
92329
publisher: self.publisher,
93330
message_root: self.message_root,
94331
})
95332
.expect("Engine task has exited");
96-
ControlFlow::Break(())
97333
}
98334
}

0 commit comments

Comments
 (0)