Skip to content

Commit 1af42ad

Browse files
apollo_propeller: implemented MessageProcessor
1 parent fafe1a0 commit 1af42ad

File tree

3 files changed

+310
-21
lines changed

3 files changed

+310
-21
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/apollo_propeller/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ asynchronous-codec.workspace = true
1616
futures.workspace = true
1717
libp2p.workspace = true
1818
prost.workspace = true
19+
rand.workspace = true
1920
reed-solomon-simd.workspace = true
2021
sha2.workspace = true
2122
thiserror.workspace = true

crates/apollo_propeller/src/message_processor.rs

Lines changed: 308 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,116 @@ use std::sync::Arc;
33
use std::time::Duration;
44

55
use libp2p::identity::{PeerId, PublicKey};
6+
use rand::seq::SliceRandom;
67
use tokio::sync::mpsc;
7-
use tokio::time::sleep_until;
8-
use tracing::{debug, trace};
8+
use tracing::{debug, error, trace};
99

10+
use crate::sharding::reconstruct_message_from_shards;
1011
use crate::tree::PropellerScheduleManager;
11-
use crate::types::{Channel, Event, MessageRoot};
12+
use crate::types::{Channel, Event, MessageRoot, ReconstructionError, ShardValidationError};
1213
use crate::unit::PropellerUnit;
13-
use crate::ShardIndex;
14+
use crate::unit_validator::UnitValidator;
15+
use crate::{MerkleProof, ShardIndex};
1416

1517
pub type UnitToValidate = (PeerId, PropellerUnit);
18+
type ValidationResult = (Result<(), ShardValidationError>, UnitValidator, PropellerUnit);
19+
type ReconstructionResult = Result<ReconstructionOutput, ReconstructionError>;
1620

1721
#[derive(Debug)]
1822
pub enum EventStateManagerToEngine {
1923
BehaviourEvent(Event),
2024
Finalized { channel: Channel, publisher: PeerId, message_root: MessageRoot },
21-
BroadcastUnit { unit: PropellerUnit, peers: Vec<PeerId> },
25+
SendUnitToPeers { unit: PropellerUnit, peers: Vec<PeerId> },
26+
}
27+
28+
#[derive(Debug)]
29+
struct ReconstructionOutput {
30+
message: Vec<u8>,
31+
my_shard: Vec<u8>,
32+
my_shard_proof: MerkleProof,
33+
}
34+
35+
/// Tracks reconstruction progress for a single message.
36+
enum ReconstructionState {
37+
PreConstruction {
38+
received_shards: Vec<PropellerUnit>,
39+
did_broadcast_my_shard: bool,
40+
signature: Option<Vec<u8>>,
41+
},
42+
/// Message was reconstructed but not yet delivered to the application. We keep collecting
43+
/// shards until the access threshold is reached, then emit the message.
44+
// No need to track the unit indices after reconstruction (unit duplication already validated)
45+
PostConstruction { reconstructed_message: Option<Vec<u8>>, num_received_shards: usize },
46+
}
47+
48+
impl ReconstructionState {
49+
fn new() -> Self {
50+
Self::PreConstruction {
51+
received_shards: Vec::new(),
52+
did_broadcast_my_shard: false,
53+
signature: None,
54+
}
55+
}
56+
57+
fn is_reconstructed(&self) -> bool {
58+
matches!(self, Self::PostConstruction { .. })
59+
}
60+
61+
fn did_broadcast_my_shard(&self) -> bool {
62+
match self {
63+
Self::PreConstruction { did_broadcast_my_shard, .. } => *did_broadcast_my_shard,
64+
Self::PostConstruction { .. } => true,
65+
}
66+
}
67+
68+
fn record_shard(&mut self, is_my_shard: bool) {
69+
match self {
70+
Self::PreConstruction { did_broadcast_my_shard, .. } => {
71+
if is_my_shard {
72+
*did_broadcast_my_shard = true;
73+
}
74+
}
75+
Self::PostConstruction { num_received_shards, .. } => {
76+
if !is_my_shard {
77+
*num_received_shards += 1;
78+
}
79+
}
80+
}
81+
}
82+
83+
fn capture_signature(&mut self, unit: &PropellerUnit) {
84+
if let Self::PreConstruction { signature, .. } = self {
85+
if signature.is_none() {
86+
*signature = Some(unit.signature().to_vec());
87+
}
88+
// The signature was already validated to be the same for all units.
89+
}
90+
}
91+
92+
fn take_shards(&mut self) -> Vec<PropellerUnit> {
93+
match self {
94+
Self::PreConstruction { received_shards, .. } => std::mem::take(received_shards),
95+
Self::PostConstruction { .. } => Vec::new(),
96+
}
97+
}
98+
99+
fn push_shard(&mut self, unit: PropellerUnit) {
100+
if let Self::PreConstruction { received_shards, .. } = self {
101+
received_shards.push(unit);
102+
}
103+
}
104+
105+
fn received_shard_count(&self) -> usize {
106+
match self {
107+
Self::PreConstruction { received_shards, .. } => received_shards.len(),
108+
Self::PostConstruction { .. } => 0,
109+
}
110+
}
111+
112+
fn transition_to_post(&mut self, message: Vec<u8>, num_received_shards: usize) {
113+
*self =
114+
Self::PostConstruction { reconstructed_message: Some(message), num_received_shards };
115+
}
22116
}
23117

24118
/// Message processor that handles validation and state management for a single message.
@@ -47,18 +141,10 @@ impl MessageProcessor {
47141
self.channel, self.publisher, self.message_root
48142
);
49143

50-
// Local state variables
51-
let deadline = tokio::time::Instant::now() + self.timeout;
144+
let timed_out = tokio::time::timeout(self.timeout, self.process_units()).await.is_err();
52145

53-
// TODO(AndrewL): remove this
54-
#[allow(clippy::never_loop)]
55-
loop {
56-
tokio::select! {
57-
_ = sleep_until(deadline) => {
58-
let _ = self.emit_timeout_and_finalize().await;
59-
break;
60-
}
61-
}
146+
if timed_out {
147+
self.emit_timeout_and_finalize();
62148
}
63149

64150
debug!(
@@ -67,32 +153,233 @@ impl MessageProcessor {
67153
);
68154
}
69155

70-
async fn emit_timeout_and_finalize(&mut self) -> ControlFlow<()> {
156+
async fn process_units(&mut self) {
157+
let mut validator = UnitValidator::new(
158+
self.channel,
159+
self.publisher,
160+
self.publisher_public_key.clone(),
161+
self.message_root,
162+
Arc::clone(&self.tree_manager),
163+
);
164+
let mut state = ReconstructionState::new();
165+
166+
while let Some((sender, unit)) = self.unit_rx.recv().await {
167+
// TODO(AndrewL): finalize immediately if first validation fails (DOS attack vector)
168+
trace!("[MSG_PROC] Validating unit from sender={:?} index={:?}", sender, unit.index());
169+
170+
// TODO(AndrewL): process multiple shards simultaneously instead of sequentially.
171+
let (result, returned_validator, unit) =
172+
Self::validate_blocking(validator, sender, unit).await;
173+
validator = returned_validator;
174+
175+
if let Err(err) = result {
176+
// TODO(AndrewL): penalize sender of bad shard.
177+
trace!("[MSG_PROC] Validation failed for index={:?}: {:?}", unit.index(), err);
178+
continue;
179+
}
180+
181+
self.maybe_broadcast_my_shard(&unit, &state);
182+
state.record_shard(unit.index() == self.my_shard_index);
183+
state.capture_signature(&unit);
184+
185+
if self.update_state(unit, &mut state).await.is_break() {
186+
return;
187+
}
188+
}
189+
71190
trace!(
72-
"[MSG_PROC] Timeout reached for channel={:?} publisher={:?} root={:?}",
191+
"[MSG_PROC] All channels closed for channel={:?} publisher={:?} root={:?}",
73192
self.channel,
74193
self.publisher,
75194
self.message_root
76195
);
196+
self.finalize();
197+
}
77198

78-
self.emit_and_finalize(Event::MessageTimeout {
79-
channel: self.channel,
199+
/// Offloads CPU-bound validation (signature verification, merkle proofs) to a blocking thread
200+
/// to avoid blocking the tokio runtime.
201+
async fn validate_blocking(
202+
mut validator: UnitValidator,
203+
sender: PeerId,
204+
unit: PropellerUnit,
205+
) -> ValidationResult {
206+
tokio::task::spawn_blocking(move || {
207+
let result = validator.validate_shard(sender, &unit);
208+
(result, validator, unit)
209+
})
210+
.await
211+
.expect("Validation task panicked")
212+
}
213+
214+
fn maybe_broadcast_my_shard(&self, unit: &PropellerUnit, state: &ReconstructionState) {
215+
if unit.index() == self.my_shard_index && !state.did_broadcast_my_shard() {
216+
self.broadcast_unit(unit);
217+
}
218+
}
219+
220+
fn broadcast_unit(&self, unit: &PropellerUnit) {
221+
let mut peers: Vec<PeerId> = self
222+
.tree_manager
223+
.get_nodes()
224+
.iter()
225+
.map(|(p, _)| *p)
226+
.filter(|p| *p != self.publisher && *p != self.local_peer_id)
227+
.collect();
228+
// TODO(AndrewL): get seeded RNG for tests.
229+
peers.shuffle(&mut rand::thread_rng());
230+
trace!("[MSG_PROC] Broadcasting unit index={:?} to {} peers", unit.index(), peers.len());
231+
self.engine_tx
232+
.send(EventStateManagerToEngine::SendUnitToPeers { unit: unit.clone(), peers })
233+
.expect("Engine task has exited");
234+
}
235+
236+
async fn update_state(
237+
&self,
238+
unit: PropellerUnit,
239+
state: &mut ReconstructionState,
240+
) -> ControlFlow<()> {
241+
if state.is_reconstructed() {
242+
return self.maybe_emit_message(state);
243+
}
244+
245+
state.push_shard(unit);
246+
247+
let shard_count = state.received_shard_count();
248+
if !self.tree_manager.should_build(shard_count) {
249+
return ControlFlow::Continue(());
250+
}
251+
252+
trace!("[MSG_PROC] Starting reconstruction with {} shards", shard_count);
253+
254+
match self.reconstruct_blocking(state).await {
255+
Ok(output) => self.handle_reconstruction_output(output, shard_count, state),
256+
Err(e) => {
257+
error!("[MSG_PROC] Reconstruction failed: {:?}", e);
258+
self.emit_and_finalize(Event::MessageReconstructionFailed {
259+
publisher: self.publisher,
260+
message_root: self.message_root,
261+
error: e,
262+
})
263+
}
264+
}
265+
}
266+
267+
/// Offloads erasure-coding reconstruction to a blocking thread.
268+
async fn reconstruct_blocking(&self, state: &mut ReconstructionState) -> ReconstructionResult {
269+
let shards = state.take_shards();
270+
let message_root = self.message_root;
271+
let my_index: usize = self.my_shard_index.0.try_into().unwrap();
272+
let data_count = self.tree_manager.num_data_shards();
273+
let coding_count = self.tree_manager.num_coding_shards();
274+
275+
tokio::task::spawn_blocking(move || {
276+
reconstruct_message_from_shards(
277+
shards,
278+
message_root,
279+
my_index,
280+
data_count,
281+
coding_count,
282+
)
283+
.map(|(message, my_shard, my_shard_proof)| ReconstructionOutput {
284+
message,
285+
my_shard,
286+
my_shard_proof,
287+
})
288+
})
289+
.await
290+
.expect("Reconstruction task panicked")
291+
}
292+
293+
fn handle_reconstruction_output(
294+
&self,
295+
output: ReconstructionOutput,
296+
shard_count: usize,
297+
state: &mut ReconstructionState,
298+
) -> ControlFlow<()> {
299+
let ReconstructionOutput { message, my_shard, my_shard_proof } = output;
300+
301+
let should_broadcast = !state.did_broadcast_my_shard();
302+
if should_broadcast {
303+
let signature = match state {
304+
ReconstructionState::PreConstruction { signature, .. } => {
305+
signature.clone().expect("Signature must exist")
306+
}
307+
ReconstructionState::PostConstruction { .. } => {
308+
unreachable!("Cannot be PostConstruction before transition")
309+
}
310+
};
311+
let reconstructed_unit = PropellerUnit::new(
312+
self.channel,
313+
self.publisher,
314+
self.message_root,
315+
signature,
316+
self.my_shard_index,
317+
my_shard,
318+
my_shard_proof,
319+
);
320+
self.broadcast_unit(&reconstructed_unit);
321+
}
322+
323+
let total_shards = shard_count + usize::from(should_broadcast);
324+
state.transition_to_post(message, total_shards);
325+
self.maybe_emit_message(state)
326+
}
327+
328+
fn maybe_emit_message(&self, state: &mut ReconstructionState) -> ControlFlow<()> {
329+
let num = match state {
330+
ReconstructionState::PostConstruction { num_received_shards, .. } => {
331+
*num_received_shards
332+
}
333+
_ => return ControlFlow::Continue(()),
334+
};
335+
336+
if !self.tree_manager.should_receive(num) {
337+
return ControlFlow::Continue(());
338+
}
339+
340+
trace!("[MSG_PROC] Access threshold reached, emitting message");
341+
let message = match state {
342+
ReconstructionState::PostConstruction { reconstructed_message, .. } => {
343+
reconstructed_message.take().expect("Message already emitted")
344+
}
345+
_ => unreachable!(),
346+
};
347+
self.emit_and_finalize(Event::MessageReceived {
80348
publisher: self.publisher,
81349
message_root: self.message_root,
350+
message,
82351
})
83352
}
84353

354+
fn emit_timeout_and_finalize(&self) {
355+
trace!(
356+
"[MSG_PROC] Timeout reached for channel={:?} publisher={:?} root={:?}",
357+
self.channel,
358+
self.publisher,
359+
self.message_root
360+
);
361+
let _ = self.emit_and_finalize(Event::MessageTimeout {
362+
channel: self.channel,
363+
publisher: self.publisher,
364+
message_root: self.message_root,
365+
});
366+
}
367+
85368
fn emit_and_finalize(&self, event: Event) -> ControlFlow<()> {
86369
self.engine_tx
87370
.send(EventStateManagerToEngine::BehaviourEvent(event))
88371
.expect("Engine task has exited");
372+
self.finalize();
373+
ControlFlow::Break(())
374+
}
375+
376+
fn finalize(&self) {
89377
self.engine_tx
90378
.send(EventStateManagerToEngine::Finalized {
91379
channel: self.channel,
92380
publisher: self.publisher,
93381
message_root: self.message_root,
94382
})
95383
.expect("Engine task has exited");
96-
ControlFlow::Break(())
97384
}
98385
}

0 commit comments

Comments
 (0)