Skip to content

Commit 08dab67

Browse files
apollo_propeller: implemented MessageProcessor
1 parent 2dbf031 commit 08dab67

File tree

3 files changed

+253
-21
lines changed

3 files changed

+253
-21
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/apollo_propeller/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ asynchronous-codec.workspace = true
1616
futures.workspace = true
1717
libp2p.workspace = true
1818
prost.workspace = true
19+
rand.workspace = true
1920
reed-solomon-simd.workspace = true
2021
sha2.workspace = true
2122
thiserror.workspace = true

crates/apollo_propeller/src/message_processor.rs

Lines changed: 251 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,82 @@ use std::sync::Arc;
33
use std::time::Duration;
44

55
use libp2p::identity::{PeerId, PublicKey};
6+
use rand::seq::SliceRandom;
67
use tokio::sync::mpsc;
7-
use tokio::time::sleep_until;
88
use tracing::{debug, trace};
99

10+
use crate::sharding::reconstruct_message_from_shards;
1011
use crate::tree::PropellerScheduleManager;
11-
use crate::types::{Channel, Event, MessageRoot};
12+
use crate::types::{Channel, Event, MessageRoot, ReconstructionError, ShardValidationError};
1213
use crate::unit::PropellerUnit;
13-
use crate::ShardIndex;
14+
use crate::unit_validator::UnitValidator;
15+
use crate::{MerkleProof, ShardIndex};
1416

1517
pub type UnitToValidate = (PeerId, PropellerUnit);
18+
type ValidationResult = (Result<(), ShardValidationError>, UnitValidator, PropellerUnit);
19+
type ReconstructionResult = Result<ReconstructionSuccess, ReconstructionError>;
1620

1721
#[derive(Debug)]
1822
pub enum EventStateManagerToEngine {
1923
BehaviourEvent(Event),
2024
Finalized { channel: Channel, publisher: PeerId, message_root: MessageRoot },
21-
BroadcastUnit { unit: PropellerUnit, peers: Vec<PeerId> },
25+
SendUnitToPeers { unit: PropellerUnit, peers: Vec<PeerId> },
26+
}
27+
28+
#[derive(Debug)]
29+
struct ReconstructionSuccess {
30+
message: Vec<u8>,
31+
my_shard: Vec<u8>,
32+
my_shard_proof: MerkleProof,
33+
}
34+
35+
/// Tracks reconstruction progress for a single message.
36+
struct ReconstructionState {
37+
received_shards: Vec<PropellerUnit>,
38+
broadcast_my_shard: bool,
39+
signature: Option<Vec<u8>>,
40+
reconstructed_message: Option<Vec<u8>>,
41+
shard_count_at_reconstruction: usize,
42+
shards_received_after_reconstruction: usize,
43+
}
44+
45+
impl ReconstructionState {
46+
fn new() -> Self {
47+
Self {
48+
received_shards: Vec::new(),
49+
broadcast_my_shard: false,
50+
signature: None,
51+
reconstructed_message: None,
52+
shard_count_at_reconstruction: 0,
53+
shards_received_after_reconstruction: 0,
54+
}
55+
}
56+
57+
fn is_reconstructed(&self) -> bool {
58+
self.reconstructed_message.is_some()
59+
}
60+
61+
fn record_shard(&mut self, is_my_shard: bool) {
62+
if is_my_shard {
63+
self.broadcast_my_shard = true;
64+
}
65+
if self.is_reconstructed() {
66+
self.shards_received_after_reconstruction += 1;
67+
}
68+
}
69+
70+
fn capture_signature(&mut self, unit: &PropellerUnit) {
71+
if self.signature.is_none() {
72+
self.signature = Some(unit.signature().to_vec());
73+
}
74+
}
75+
76+
/// Total shard count used for the access-threshold check.
77+
fn effective_shard_count(&self) -> usize {
78+
self.shard_count_at_reconstruction
79+
+ self.shards_received_after_reconstruction
80+
+ usize::from(!self.broadcast_my_shard)
81+
}
2282
}
2383

2484
/// Message processor that handles validation and state management for a single message.
@@ -47,18 +107,10 @@ impl MessageProcessor {
47107
self.channel, self.publisher, self.message_root
48108
);
49109

50-
// Local state variables
51-
let deadline = tokio::time::Instant::now() + self.timeout;
52-
53-
// TODO(AndrewL): remove this
54-
#[allow(clippy::never_loop)]
55-
loop {
56-
tokio::select! {
57-
_ = sleep_until(deadline) => {
58-
let _ = self.emit_timeout_and_finalize().await;
59-
break;
60-
}
61-
}
110+
let timed_out = tokio::time::timeout(self.timeout, self.process_units()).await.is_err();
111+
112+
if timed_out {
113+
self.emit_timeout_and_finalize();
62114
}
63115

64116
debug!(
@@ -67,32 +119,210 @@ impl MessageProcessor {
67119
);
68120
}
69121

70-
async fn emit_timeout_and_finalize(&mut self) -> ControlFlow<()> {
122+
async fn process_units(&mut self) {
123+
let mut validator = UnitValidator::new(
124+
self.channel,
125+
self.publisher,
126+
self.publisher_public_key.clone(),
127+
self.message_root,
128+
Arc::clone(&self.tree_manager),
129+
);
130+
let mut state = ReconstructionState::new();
131+
132+
while let Some((sender, unit)) = self.unit_rx.recv().await {
133+
// TODO(AndrewL): finalize immediately if first validation fails (DOS attack vector)
134+
trace!("[MSG_PROC] Validating unit from sender={:?} index={:?}", sender, unit.index());
135+
136+
let (result, returned_validator, unit) =
137+
Self::validate_blocking(validator, sender, unit).await;
138+
validator = returned_validator;
139+
140+
if let Err(err) = result {
141+
// TODO(AndrewL): penalize sender of bad shard.
142+
trace!("[MSG_PROC] Validation failed for index={:?}: {:?}", unit.index(), err);
143+
continue;
144+
}
145+
146+
self.maybe_broadcast_my_shard(&unit, &state);
147+
state.record_shard(unit.index() == self.my_shard_index);
148+
state.capture_signature(&unit);
149+
150+
if self.advance_reconstruction(unit, &mut state).await.is_break() {
151+
return;
152+
}
153+
}
154+
71155
trace!(
72-
"[MSG_PROC] Timeout reached for channel={:?} publisher={:?} root={:?}",
156+
"[MSG_PROC] All channels closed for channel={:?} publisher={:?} root={:?}",
73157
self.channel,
74158
self.publisher,
75159
self.message_root
76160
);
161+
self.finalize();
162+
}
77163

78-
self.emit_and_finalize(Event::MessageTimeout {
79-
channel: self.channel,
164+
/// Offloads CPU-bound validation (signature verification, merkle proofs) to a blocking thread
165+
/// to avoid blocking the tokio runtime.
166+
async fn validate_blocking(
167+
mut validator: UnitValidator,
168+
sender: PeerId,
169+
unit: PropellerUnit,
170+
) -> ValidationResult {
171+
tokio::task::spawn_blocking(move || {
172+
let result = validator.validate_shard(sender, &unit);
173+
(result, validator, unit)
174+
})
175+
.await
176+
.expect("Validation task panicked")
177+
}
178+
179+
fn maybe_broadcast_my_shard(&self, unit: &PropellerUnit, state: &ReconstructionState) {
180+
if unit.index() == self.my_shard_index && !state.broadcast_my_shard {
181+
self.broadcast_shard(unit);
182+
}
183+
}
184+
185+
fn broadcast_shard(&self, unit: &PropellerUnit) {
186+
let mut peers: Vec<PeerId> = self
187+
.tree_manager
188+
.get_nodes()
189+
.iter()
190+
.map(|(p, _)| *p)
191+
.filter(|p| *p != self.publisher && *p != self.local_peer_id)
192+
.collect();
193+
peers.shuffle(&mut rand::thread_rng());
194+
trace!("[MSG_PROC] Broadcasting unit index={:?} to {} peers", unit.index(), peers.len());
195+
self.engine_tx
196+
.send(EventStateManagerToEngine::SendUnitToPeers { unit: unit.clone(), peers })
197+
.expect("Engine task has exited");
198+
}
199+
200+
async fn advance_reconstruction(
201+
&self,
202+
unit: PropellerUnit,
203+
state: &mut ReconstructionState,
204+
) -> ControlFlow<()> {
205+
if state.is_reconstructed() {
206+
return self.maybe_emit_message(state);
207+
}
208+
209+
state.received_shards.push(unit);
210+
211+
if !self.tree_manager.should_build(state.received_shards.len()) {
212+
return ControlFlow::Continue(());
213+
}
214+
215+
trace!("[MSG_PROC] Starting reconstruction with {} shards", state.received_shards.len());
216+
state.shard_count_at_reconstruction = state.received_shards.len();
217+
218+
match self.reconstruct_blocking(state).await {
219+
Ok(success) => self.handle_reconstruction_success(success, state),
220+
Err(e) => {
221+
tracing::error!("[MSG_PROC] Reconstruction failed: {:?}", e);
222+
self.emit_and_finalize(Event::MessageReconstructionFailed {
223+
publisher: self.publisher,
224+
message_root: self.message_root,
225+
error: e,
226+
})
227+
}
228+
}
229+
}
230+
231+
/// Offloads erasure-coding reconstruction to a blocking thread.
232+
async fn reconstruct_blocking(&self, state: &mut ReconstructionState) -> ReconstructionResult {
233+
let shards = std::mem::take(&mut state.received_shards);
234+
let message_root = self.message_root;
235+
let my_index: usize = self.my_shard_index.0.try_into().unwrap();
236+
let data_count = self.tree_manager.num_data_shards();
237+
let coding_count = self.tree_manager.num_coding_shards();
238+
239+
tokio::task::spawn_blocking(move || {
240+
reconstruct_message_from_shards(
241+
shards,
242+
message_root,
243+
my_index,
244+
data_count,
245+
coding_count,
246+
)
247+
.map(|(message, my_shard, my_shard_proof)| ReconstructionSuccess {
248+
message,
249+
my_shard,
250+
my_shard_proof,
251+
})
252+
})
253+
.await
254+
.expect("Reconstruction task panicked")
255+
}
256+
257+
fn handle_reconstruction_success(
258+
&self,
259+
success: ReconstructionSuccess,
260+
state: &mut ReconstructionState,
261+
) -> ControlFlow<()> {
262+
let ReconstructionSuccess { message, my_shard, my_shard_proof } = success;
263+
264+
if !state.broadcast_my_shard {
265+
let signature = state.signature.clone().expect("Signature must exist");
266+
let reconstructed_unit = PropellerUnit::new(
267+
self.channel,
268+
self.publisher,
269+
self.message_root,
270+
signature,
271+
self.my_shard_index,
272+
my_shard,
273+
my_shard_proof,
274+
);
275+
self.broadcast_shard(&reconstructed_unit);
276+
state.broadcast_my_shard = true;
277+
}
278+
279+
state.reconstructed_message = Some(message);
280+
self.maybe_emit_message(state)
281+
}
282+
283+
fn maybe_emit_message(&self, state: &mut ReconstructionState) -> ControlFlow<()> {
284+
if !self.tree_manager.should_receive(state.effective_shard_count()) {
285+
return ControlFlow::Continue(());
286+
}
287+
288+
trace!("[MSG_PROC] Access threshold reached, emitting message");
289+
let message = state.reconstructed_message.take().expect("Message must exist");
290+
self.emit_and_finalize(Event::MessageReceived {
80291
publisher: self.publisher,
81292
message_root: self.message_root,
293+
message,
82294
})
83295
}
84296

297+
fn emit_timeout_and_finalize(&self) {
298+
trace!(
299+
"[MSG_PROC] Timeout reached for channel={:?} publisher={:?} root={:?}",
300+
self.channel,
301+
self.publisher,
302+
self.message_root
303+
);
304+
let _ = self.emit_and_finalize(Event::MessageTimeout {
305+
channel: self.channel,
306+
publisher: self.publisher,
307+
message_root: self.message_root,
308+
});
309+
}
310+
85311
fn emit_and_finalize(&self, event: Event) -> ControlFlow<()> {
86312
self.engine_tx
87313
.send(EventStateManagerToEngine::BehaviourEvent(event))
88314
.expect("Engine task has exited");
315+
self.finalize();
316+
ControlFlow::Break(())
317+
}
318+
319+
fn finalize(&self) {
89320
self.engine_tx
90321
.send(EventStateManagerToEngine::Finalized {
91322
channel: self.channel,
92323
publisher: self.publisher,
93324
message_root: self.message_root,
94325
})
95326
.expect("Engine task has exited");
96-
ControlFlow::Break(())
97327
}
98328
}

0 commit comments

Comments
 (0)